from util import get_input

input = get_input("11.input")

grid = [[int(a) for a in line] for line in input]

def run_steps(grid, n=100):
    tot_flash = 0
    for i in range(n):
        grid = [[a + 1 for a in row] for row in grid]
        flashed = set()
        while True:
            did_flash = False
            for x, row in enumerate(grid):
                for y, a in enumerate(row):
                    if a > 9 and (x, y) not in flashed:
                        did_flash = True
                        flashed = flashed.union(set([(x, y)]))
                        for dx, dy in [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]:
                            if x + dx < 0 or y + dy < 0:
                                continue
                            try:
                                grid [x + dx] [y + dy] += 1
                            except IndexError:
                                continue
            if not did_flash:
                break
        for (x, y) in list(flashed):
            grid[x][y] = 0
        tot_flash += len(flashed)
        if len(flashed) == len(grid) * len(grid[0]):
            print("Part 2:", i + 1)
            return 0

    return tot_flash

print("Part 1:", run_steps(grid, 100))
run_steps(grid, 1000)