from util import get_input
from more_itertools import sliding_window as sw
from more_itertools import flatten
from itertools import product

input = get_input("14.input")

template = input[0]

rules = {a: b for [a, b] in [line.split(" -> ") for line in input[2:]]}

all_pairs = list(product(list(set(flatten([c for c in rules.values()]))), repeat=2))
all_pairs = [a + b for (a, b) in all_pairs]

def step(pairs, rules, counts):
    next = {k: 0 for k in all_pairs}
    for (k, v) in pairs.items():
        if k in rules:
            [a, c] = k
            b = rules[k]
            next[a + b] += v
            next[b + c] += v
            counts[b] += v
        else:
            next[k] += v
    return next

def run(template, rules, steps):
    pairs = {k: 0 for k in all_pairs}
    for [a, b] in sw(template, 2):
        pairs[a + b] += 1

    counts = {k: 0 for k in set(rules.values())}
    for c in template:
        counts[c] += 1

    for _ in range(steps):
        pairs = step(pairs, rules, counts)

    return max(counts.values()) - min(counts.values())

print("Part 1:", run(template, rules, 10))
print("Part 2:", run(template, rules, 40))