誰もプレイしたことがないのに、何故か見たことはあるというあのゲーム。
(最近は、そういう広告ゲーを集めたミニゲーム集としてリリースされてるのもあるけど)
まず、薬を飲む前に倒せる敵はできるだけ倒すのがよい。
強さ $T$ で敵 $i$ を倒す→薬 $j$ を飲むと、強さは $(T+g_i) \times g_j$
強さ $T$ で薬 $j$ を飲む→敵 $i$ を倒すと、強さは $T \times g_j + g_i$
制約上、$g_i \times g_j \ge g_i$ なので、先に敵を倒して必ず損しない
倒す順番は関係ないので、弱い敵から倒すとしてよい。
訪問済みの頂点に隣接する頂点にいる未討伐の敵を $s_i$ をキーに優先度付きキューで管理する。
で、倒せなくなったら、隣接頂点にある薬を飲むしかない。
この順番は、敵と違い、最適解は貪欲にはなかなか決めづらい。
なので薬を飲む順番は全探索する必要がある。
ただ、本当に全探索すると $10!$ のオーダーがかかるので無理。
メモ化再帰をおこなう。
「今の強さで倒せる敵は全部倒した」状態を「薬待ち状態」とする。
飲み終えた薬の集合を $S$ とする。$S$ が同じになる薬待ち状態には、
飲んだ順番によって高橋君の強さ(とそれに伴う残り敵)は様々な状態が考えられるが、
ということが言える。
なので、再帰関数で薬待ち状態になるたびに1つ選んで飲む(再帰する)よう処理する中で、
過去に探索された飲んだ薬の集合 $S$ ごとの最大強さをメモしておき、
現在の強さがメモ以下なら、探索する必要は無い。
(全ての敵を倒す道筋が1つ見つかればよいので、探索が続いているということは、メモられた強さの場合でも倒せなかった、ということになる。現在の強さがそれ以下なら、当然倒せない)
計算量がちゃんとわかっていないので、上手いことテストケースを作れば「ちょっとずつ強さが更新されていって、探索数がかさんでしまう」というケースも作れるかもしれない。
この解法では、「現在の強さが同じ $S$ の中での最大強さであることが確定していなくても、とりあえず最後まで探索してしまう」という実装なので危ないが、本当は、公式Editorialのように「$dp[S]$ ごとに最大強さを確定させていく。最大の時の隣接敵情報も残しておく」という実装だと、間に合うことが保証される。
Python3
$N$ が小さいので、再帰の前に現在の隣接敵の情報をまるっとコピーして残しておくなど、多少、楽な実装をしても間に合う。
from heapq import heappop, heappush
def add_next_room(u, links, types, next_enemy, next_portion):
for v in links[u]:
t, s, g = types[v]
if t == 1:
heappush(next_enemy, (s, g, v))
else:
next_portion |= 1 << portion_idx[v]
return next_portion
def bitset_iterate(bitset):
while bitset:
lsb = bitset & -bitset
yield lsb.bit_length() - 1
bitset ^= lsb
def solve(strength, links, types, next_enemy, next_portion, memo={}):
if memo.get(next_portion, -1) >= strength:
return False
orig_strength = strength
orig_next_portion = next_portion
while next_enemy and next_enemy[0][0] <= strength:
su, gu, u = heappop(next_enemy)
strength += gu
next_portion = add_next_room(u, links, types, next_enemy, next_portion)
if len(next_enemy) == 0 and next_portion == 0:
return True
for pi in bitset_iterate(next_portion):
v = portions[pi]
gv = types[v][2]
strength *= gv
if strength >= 1_000_000_000:
return True
cur_next_enemy = next_enemy.copy()
cur_next_portion = next_portion
cur_next_portion ^= 1 << pi
cur_next_portion = add_next_room(v, links, types, cur_next_enemy, cur_next_portion)
res = solve(strength, links, types, cur_next_enemy, cur_next_portion)
if res:
return True
strength //= gv
memo[orig_next_portion] = orig_strength
return False
n = int(input())
links = [[] for _ in range(n)]
types = [(0,)]
portions = []
for v in range(1, n):
p, t, s, g = map(int, input().split())
p -= 1
links[p].append(v)
types.append((t, s, g))
if t == 2:
portions.append(v)
portion_idx = {v: i for i, v in enumerate(portions)}
next_enemy = []
next_portion = 0
next_portion = add_next_room(0, links, types, next_enemy, next_portion)
ans = solve(1, links, types, next_enemy, next_portion)
print('Yes' if ans else 'No')
完全グラフの辺数は $O(N^2)$ のオーダーであるが、
削除できる辺数が限られているので、$N$ がでかい場合は最短距離はそこまで大きくできない。
具体的には、最短距離が $d$ の場合、最短パスを1つだけ取って考えても、
2つ以上離れた頂点間は辺を削除しないといけない。
少なくとも $\dfrac{d(d-1)}{2}$ の辺は削除されないといけない。
よって、だいたい $d \le \sqrt{2M}$ となる。
最短パスが複数あったり、$N$ が大きくて最短パスに含まれない頂点への辺を切断する必要があったりすると、さらに $d$ は小さくなる。
いま、頂点1からの最短距離が $k$ である頂点群 $S_k$ と、$S_k$ 内の頂点それぞれへの最短パス数 $P_v$ がわかっているとする。
ここから、最短距離が $k+1$ の頂点群 $S_{k+1}$ と最短パス数を求めたい。
これを最大 $d$ 回繰り返すと、答えが見つかるか、到達できないことが確定する。
普通の経路数数え上げのようにBFS的なことをすると、1頂点から出ている辺数が多すぎるのでTLE。
基本的に到達可能として、不可能なものは減算する、という逆のアプローチを取る。
今までに訪れたことのない頂点は、基本的に $S_k$ 内の全頂点から到達可能と仮定して、「最短パス数の総和 $\displaystyle \sum_{v \in S_k}P_v$」の最短パス数を持つ
そのうち、削除辺のうち、一方が $S_k$ に含まれ、一方が未訪問の辺は、(順に頂点 $u,v$ として)$v$ の最短パス数から、$P_u$ を減らす
最終的に正の最短パス数が残るものが、$S_{k+1}$ となる
$O((N+M)\sqrt{M})$ でとおる。
通す上では不要だが、$N$ の大小により場合分けして、小さい場合は普通にBFSによる経路数数え上げをすると高速になる。
Python3
def solve(n, m, del_edges):
checked = [False] * n
checked[0] = True
dp = {0: 1}
while dp:
total = sum(dp.values())
next_step = {u: total for u in range(n) if not checked[u]}
for u, v in del_edges:
if u in dp and not checked[v]:
next_step[v] -= dp[u]
if v in dp and not checked[u]:
next_step[u] -= dp[v]
dp = {u: pat for u, pat in next_step.items() if pat > 0}
if n - 1 in dp:
return dp[n - 1] % MOD
for u in dp.keys():
checked[u] = True
return -1
n, m = map(int, input().split())
MOD = 998244353
del_edges = []
for _ in range(m):
u, v = map(int, input().split())
u -= 1
v -= 1
del_edges.append(sorted([u, v]))
ans = solve(n, m, del_edges)
print(ans)