from util import get_input
from itertools import permutations

input = get_input("8.input", lambda a: a.split(" | "))

sum = 0
for [digi, out] in input:
    sum += len([a for a in out.split() if len(a) in [2, 4, 3, 7]])

print(sum)

segments = ['abcefg', 'cf', 'acdeg', 'acdfg', 'bcdf', 'abdfg', 'abdefg', 'acf', 'abcdefg', 'abcdfg']
segments = [set(a) for a in segments]

def find_key(digits):
    perms = permutations("abcdefg", 7)
    # Brute force is life
    for perm in perms:
        segmap = {a: b for (a, b) in zip("abcdefg", perm)}
        segmentss = list(segments)
        for digit in digits:
            digit = [segmap[a] for a in digit]
            if set(digit) in segmentss:
                segmentss.remove(set(digit))
            else:
                break
        if len(segmentss) == 0:
            return segmap

def decode_digit(digit, key):
    return segments.index(set([key[a] for a in digit]))

def decode(digits, out):
    key = find_key(digits)

    out_digits = [str(decode_digit(d, key)) for d in out]

    return int("".join(out_digits))


nums = [decode(d.split(), o.split()) for [d, o] in input]
sum = 0
for num in nums:
    sum += num
print(sum)