diff --git a/py/aoc2022/day16.py b/py/aoc2022/day16.py index fe5f6af..268b612 100644 --- a/py/aoc2022/day16.py +++ b/py/aoc2022/day16.py @@ -3,7 +3,6 @@ """ import heapq -import itertools import math import re from collections import defaultdict @@ -117,28 +116,35 @@ def _solve(lines, num_agents, total_time): continue for delta, moves_by_agent in moves_by_time.items(): - for size in range(1, num_agents + 1): - for combo in itertools.combinations(moves_by_agent.items(), size): - indices = [i for i, _ in combo] - for valves in itertools.product(*(valves for _, valves in combo)): - if len(set(valves)) != size: - continue - new_rooms = [ - (room, age + delta + 1) for room, age in state.rooms - ] - for i, valve in zip(indices, valves): - new_rooms[i] = valve, 0 - rate = sum(graph[valve][0] for valve in valves) - new_state = _State( - rooms=tuple(sorted(new_rooms)), - valves=state.valves - set(valve for valve in valves), - flow=state.flow + rate, - total=state.total + state.flow * (delta + 1), - time=state.time - delta - 1, - ) - heapq.heappush( - heap, (-estimate - rate * new_state.time, new_state) - ) + indices = [None] * num_agents + while True: + for i, index in enumerate(indices): + index = 0 if index is None else index + 1 + if index < len(moves_by_agent[i]): + indices[i] = index + break + indices[i] = None + else: + break + valves = [ + (i, moves_by_agent[i][index]) + for i, index in enumerate(indices) + if index is not None + ] + if len(valves) != len(set(valve for _, valve in valves)): + continue + new_rooms = [(room, age + delta + 1) for room, age in state.rooms] + for i, valve in valves: + new_rooms[i] = valve, 0 + rate = sum(graph[valve][0] for _, valve in valves) + new_state = _State( + rooms=tuple(sorted(new_rooms)), + valves=state.valves - set(valve for _, valve in valves), + flow=state.flow + rate, + total=state.total + state.flow * (delta + 1), + time=state.time - delta - 1, + ) + heapq.heappush(heap, (-estimate - rate * new_state.time, new_state)) return max_seen