from util import get_input
from queue import PriorityQueue


input = get_input("15.input")

def add(a, b):
    return tuple(map(lambda i, j: i + j, a, b))

def grid_size(grid):
    x = max([x for (x, y) in grid.keys()])
    y = max([y for (x, y) in grid.keys()])
    return (x + 1, y + 1)

def print_grid(grid):
    sz = grid_size(grid)
    for y in range(sz[1]):
        print("".join(str(grid[(x, y)]) for x in range(sz[0])))

def neighbors(pos, sz):
    ns = [add(pos, dpos) for dpos in [(1, 0), (-1, 0), (0, 1), (0, -1)]]
    return [n for n in ns if n[0] >= 0 and n[0] < sz[0] and n[1] >= 0 and n[1] < sz[1]]

def cheapest_path(grid):
    sz = grid_size(grid)
    maxpos = (sz[0] - 1, sz[1] - 1)
    cost = {(0, 0): 0}
    queue = PriorityQueue()
    queue.put((0, (0, 0)))
    while not queue.empty():
        prio, current = queue.get()
        if current == maxpos:
            return prio
        for n in neighbors(current, sz):
            if cost[current] + grid[n] < cost.get(n, 10000000000):
                cost[n] = cost[current] + grid[n]
                queue.put((cost[n], n))

grid = {}

for y, line in enumerate(input):
    for x, c in enumerate(line):
        grid[(x, y)] = int(c)

print("Part 1:", cheapest_path(grid))

sz = grid_size(grid)
for x in range(sz[0]):
    for y in range(sz[1]):
        for mx in range(1, 5):
            val = grid[(x, y)] + mx
            if val > 9:
                val -= 9
            grid[(x + mx * sz[0], y)] = val

sz = grid_size(grid)
for x in range(sz[0]):
    for y in range(sz[1]):
        for my in range(1, 5):
            val = grid[(x, y)] + my
            if val > 9:
                val -= 9
            grid[(x, y + my * sz[1])] = val

print("Part 2:", cheapest_path(grid))