from more_itertools import flatten from queue import PriorityQueue ############# #01.4.7.A.DE# ###2#5#8#B### #3#6#9#C# ######### blank_map = """############# #...........# ###.#.#.#.### #.#.#.#.# #########""" final_map = """############# #...........# ###A#B#C#D### #A#B#C#D# #########""" from util import get_input input = get_input("23.input", lambda a: a) top_row = [(1, 1), (2, 1), (4, 1), (6, 1), (8, 1), (10, 1), (11, 1)] columns = { 'A': [(3, 2), (3, 3)], 'B': [(5, 2), (5, 3)], 'C': [(7, 2), (7, 3)], 'D': [(9, 2), (9, 3)], } state_map = list(flatten([col for col in columns.values()])) + top_row cost = {'A': 1, 'B': 10, 'C': 100, 'D': 1000} def map_to_state(m, state_map): state = list() for (x, y) in state_map: state.append(m[y][x]) return tuple(state) def state_to_map(state, state_map, blank_map): m = [[c for c in line] for line in blank_map.splitlines()] state = list(state) for i, (x, y) in enumerate(state_map): m[y][x] = state[i] return m def neighbors(pos, m): (x, y) = pos return [(x + dx, y + dy) for (dx, dy) in [(-1, 0), (1, 0), (0, -1), (0, 1)] if m[y + dy][x + dx] == '.'] def try_walk(pos, m, dests, visited, dist): res = list() if pos in dests and dist > 0: res.append((pos, dist)) for n in neighbors(pos, m): if n in visited: continue res += try_walk(n, m, dests, visited + [n], dist + 1) return res def all_moves(pos, m, columns): (x, y) = pos a = m[y][x] if a == '.': return [] col = columns[a] if pos in col and all([m[y][x] == a for (x, y) in col[col.index(pos):]]): return [] dests = [] for c in columns.values(): if pos in c: dests += top_row break if all([m[y][x] in [a, '.'] for (x, y) in col]): dests.append([(x, y) for (x, y) in col if m[y][x] == '.'][-1]) return try_walk(pos, m, dests, [pos], 0) def next_states(state, state_map, blank_map, columns): state = list(state) res = list() m = state_to_map(state, state_map, blank_map) for i, a in enumerate(state): for dest, dist in all_moves(state_map[i], m, columns): new_state = list(state) new_state[i] = '.' new_state[state_map.index(dest)] = a res.append((tuple(new_state), dist * cost[a])) return res def state_search(state, blank_map, final_map, state_map, columns): final_state = map_to_state(final_map.splitlines(), state_map) dist_map = {state: 0} queue = PriorityQueue() queue.put((0, state)) while not queue.empty(): (d, state) = queue.get() dist = dist_map[state] for (next_state, ndist) in next_states(state, state_map, blank_map, columns): if dist + ndist < dist_map.get(next_state, 99999999): dist_map[next_state] = dist + ndist queue.put((dist + ndist, next_state)) return dist_map[final_state] state = map_to_state(input, state_map) print("Part 1:", state_search(state, blank_map, final_map, state_map, columns)) columns = { 'A': [(3, 2), (3, 3),(3, 4), (3, 5)], 'B': [(5, 2), (5, 3),(5, 4), (5, 5)], 'C': [(7, 2), (7, 3),(7, 4), (7, 5)], 'D': [(9, 2), (9, 3),(9, 4), (9, 5)], } state_map = list(flatten([col for col in columns.values()])) + top_row blank_map = """############# #...........# ###.#.#.#.### #.#.#.#.# #.#.#.#.# #.#.#.#.# #########""" final_map = """############# #...........# ###A#B#C#D### #A#B#C#D# #A#B#C#D# #A#B#C#D# #########""" extra_in = """ #D#C#B#A# #D#B#A#C# """.splitlines() state = map_to_state(input[0:3] + extra_in + input[3:], state_map) print("Part 2:", state_search(state, blank_map, final_map, state_map, columns))