from util import get_input from itertools import permutations input = get_input("8.input", lambda a: a.split(" | ")) sum = 0 for [digi, out] in input: sum += len([a for a in out.split() if len(a) in [2, 4, 3, 7]]) print(sum) segments = ['abcefg', 'cf', 'acdeg', 'acdfg', 'bcdf', 'abdfg', 'abdefg', 'acf', 'abcdefg', 'abcdfg'] segments = [set(a) for a in segments] def find_key(digits): perms = permutations("abcdefg", 7) # Brute force is life for perm in perms: segmap = {a: b for (a, b) in zip("abcdefg", perm)} segmentss = list(segments) for digit in digits: digit = [segmap[a] for a in digit] if set(digit) in segmentss: segmentss.remove(set(digit)) else: break if len(segmentss) == 0: return segmap def decode_digit(digit, key): return segments.index(set([key[a] for a in digit])) def decode(digits, out): key = find_key(digits) out_digits = [str(decode_digit(d, key)) for d in out] return int("".join(out_digits)) nums = [decode(d.split(), o.split()) for [d, o] in input] sum = 0 for num in nums: sum += num print(sum)