from util import get_input from more_itertools import flatten from functools import reduce import operator input = get_input("16.input") def hextobin(a): i = int(a, 16) return "{:04b}".format(i) def take_input(input, n): x = "" for i in range(n): x += input.pop(0) return int("".join(x), 2) def parse_packet(input): v = take_input(input, 3) t = take_input(input, 3) if t == 4: value = "" while True: lead = take_input(input, 1) num = take_input(input, 4) value += "{:04b}".format(num) if lead == 0: break return (v, "val", int(value, 2)) else: lt = take_input(input, 1) packets = [] if lt == 0: ln = take_input(input, 15) start_len = len(input) while start_len - len(input) < ln: packets.append(parse_packet(input)) else: ln = take_input(input, 11) packets = [parse_packet(input) for _ in range(ln)] return (v, "op", t, packets) def v_sum(pkt): if pkt[1] == "val": return pkt[0] else: return pkt[0] + sum(v_sum(p) for p in pkt[3]) def eval(pkt): if pkt[1] == "val": return pkt[2] else: (_, _, t, pkts) = pkt if t == 0: return sum(eval(p) for p in pkts) elif t == 1: return reduce(operator.mul, [eval(p) for p in pkts], 1) elif t == 2: return min(eval(p) for p in pkts) elif t == 3: return max(eval(p) for p in pkts) elif t == 5: return 1 if eval(pkts[0]) > eval(pkts[1]) else 0 elif t == 6: return 1 if eval(pkts[0]) < eval(pkts[1]) else 0 elif t == 7: return 1 if eval(pkts[0]) == eval(pkts[1]) else 0 else: raise ValueError("Invalid packet type {}".format(t)) for line in input: line = list(flatten([hextobin(a) for a in line])) pkt = parse_packet(line) print("Part 1:", v_sum(pkt)) print("Part 2:", eval(pkt)) print("-" * 20)