16.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from util import get_input
  2. from more_itertools import flatten
  3. from functools import reduce
  4. import operator
  5. input = get_input("16.input")
  6. def hextobin(a):
  7. i = int(a, 16)
  8. return "{:04b}".format(i)
  9. def take_input(input, n):
  10. x = ""
  11. for i in range(n):
  12. x += input.pop(0)
  13. return int("".join(x), 2)
  14. def parse_packet(input):
  15. v = take_input(input, 3)
  16. t = take_input(input, 3)
  17. if t == 4:
  18. value = ""
  19. while True:
  20. lead = take_input(input, 1)
  21. num = take_input(input, 4)
  22. value += "{:04b}".format(num)
  23. if lead == 0:
  24. break
  25. return (v, "val", int(value, 2))
  26. else:
  27. lt = take_input(input, 1)
  28. packets = []
  29. if lt == 0:
  30. ln = take_input(input, 15)
  31. start_len = len(input)
  32. while start_len - len(input) < ln:
  33. packets.append(parse_packet(input))
  34. else:
  35. ln = take_input(input, 11)
  36. packets = [parse_packet(input) for _ in range(ln)]
  37. return (v, "op", t, packets)
  38. def v_sum(pkt):
  39. if pkt[1] == "val":
  40. return pkt[0]
  41. else:
  42. return pkt[0] + sum(v_sum(p) for p in pkt[3])
  43. def eval(pkt):
  44. if pkt[1] == "val":
  45. return pkt[2]
  46. else:
  47. (_, _, t, pkts) = pkt
  48. if t == 0:
  49. return sum(eval(p) for p in pkts)
  50. elif t == 1:
  51. return reduce(operator.mul, [eval(p) for p in pkts], 1)
  52. elif t == 2:
  53. return min(eval(p) for p in pkts)
  54. elif t == 3:
  55. return max(eval(p) for p in pkts)
  56. elif t == 5:
  57. return 1 if eval(pkts[0]) > eval(pkts[1]) else 0
  58. elif t == 6:
  59. return 1 if eval(pkts[0]) < eval(pkts[1]) else 0
  60. elif t == 7:
  61. return 1 if eval(pkts[0]) == eval(pkts[1]) else 0
  62. else:
  63. raise ValueError("Invalid packet type {}".format(t))
  64. for line in input:
  65. line = list(flatten([hextobin(a) for a in line]))
  66. pkt = parse_packet(line)
  67. print("Part 1:", v_sum(pkt))
  68. print("Part 2:", eval(pkt))
  69. print("-" * 20)