8.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from util import get_input
  2. from itertools import permutations
  3. input = get_input("8.input", lambda a: a.split(" | "))
  4. sum = 0
  5. for [digi, out] in input:
  6. sum += len([a for a in out.split() if len(a) in [2, 4, 3, 7]])
  7. print(sum)
  8. segments = ['abcefg', 'cf', 'acdeg', 'acdfg', 'bcdf', 'abdfg', 'abdefg', 'acf', 'abcdefg', 'abcdfg']
  9. segments = [set(a) for a in segments]
  10. def find_key(digits):
  11. perms = permutations("abcdefg", 7)
  12. # Brute force is life
  13. for perm in perms:
  14. segmap = {a: b for (a, b) in zip("abcdefg", perm)}
  15. segmentss = list(segments)
  16. for digit in digits:
  17. digit = [segmap[a] for a in digit]
  18. if set(digit) in segmentss:
  19. segmentss.remove(set(digit))
  20. else:
  21. break
  22. if len(segmentss) == 0:
  23. return segmap
  24. def decode_digit(digit, key):
  25. return segments.index(set([key[a] for a in digit]))
  26. def decode(digits, out):
  27. key = find_key(digits)
  28. out_digits = [str(decode_digit(d, key)) for d in out]
  29. return int("".join(out_digits))
  30. nums = [decode(d.split(), o.split()) for [d, o] in input]
  31. sum = 0
  32. for num in nums:
  33. sum += num
  34. print(sum)