1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from util import get_input
- from more_itertools import flatten
- from functools import reduce
- import operator
- input = get_input("16.input")
- def hextobin(a):
- i = int(a, 16)
- return "{:04b}".format(i)
- def take_input(input, n):
- x = ""
- for i in range(n):
- x += input.pop(0)
- return int("".join(x), 2)
- def parse_packet(input):
- v = take_input(input, 3)
- t = take_input(input, 3)
- if t == 4:
- value = ""
- while True:
- lead = take_input(input, 1)
- num = take_input(input, 4)
- value += "{:04b}".format(num)
- if lead == 0:
- break
- return (v, "val", int(value, 2))
- else:
- lt = take_input(input, 1)
- packets = []
- if lt == 0:
- ln = take_input(input, 15)
- start_len = len(input)
- while start_len - len(input) < ln:
- packets.append(parse_packet(input))
- else:
- ln = take_input(input, 11)
- packets = [parse_packet(input) for _ in range(ln)]
- return (v, "op", t, packets)
- def v_sum(pkt):
- if pkt[1] == "val":
- return pkt[0]
- else:
- return pkt[0] + sum(v_sum(p) for p in pkt[3])
- def eval(pkt):
- if pkt[1] == "val":
- return pkt[2]
- else:
- (_, _, t, pkts) = pkt
- if t == 0:
- return sum(eval(p) for p in pkts)
- elif t == 1:
- return reduce(operator.mul, [eval(p) for p in pkts], 1)
- elif t == 2:
- return min(eval(p) for p in pkts)
- elif t == 3:
- return max(eval(p) for p in pkts)
- elif t == 5:
- return 1 if eval(pkts[0]) > eval(pkts[1]) else 0
- elif t == 6:
- return 1 if eval(pkts[0]) < eval(pkts[1]) else 0
- elif t == 7:
- return 1 if eval(pkts[0]) == eval(pkts[1]) else 0
- else:
- raise ValueError("Invalid packet type {}".format(t))
- for line in input:
- line = list(flatten([hextobin(a) for a in line]))
- pkt = parse_packet(line)
- print("Part 1:", v_sum(pkt))
- print("Part 2:", eval(pkt))
- print("-" * 20)
|