from copy import deepcopy from math import floor from itertools import product from util import get_input input = get_input("24.input") # Yes, you have to adjust these manually. # Yes, it can be done automatically. # No, I don't want to. # 1111 # 01234567890123 part1 = "93959993429899" part2 = "11815671117121" lookup = {k: int(part2[k]) for k in range(14)} def paren(exp): return "(" + exp + ")" def pprint(exp): if exp[0] == 'c': return str(exp[1]) elif exp[0] == 'in': return 'in' + str(exp[1]) elif exp[0] == 'add': return paren(pprint(exp[1]) + ' + ' + pprint(exp[2])) elif exp[0] == 'mul': return paren(pprint(exp[1]) + ' * ' + pprint(exp[2])) elif exp[0] == 'div': return paren(pprint(exp[1]) + ' / ' + pprint(exp[2])) elif exp[0] == 'mod': return paren(pprint(exp[1]) + ' % ' + pprint(exp[2])) elif exp[0] == 'eql': return paren(pprint(exp[1]) + ' == ' + pprint(exp[2])) elif exp[0] == 'sum': return paren(" + ".join(pprint(t) for t in exp[1])) raise Exception("Not allowed: {}".format(exp[0])) def evl(exp, lookup): if exp[0] == 'c': return exp[1] elif exp[0] == 'in': return lookup[exp[1]] elif exp[0] == 'add': return evl(exp[1], lookup) + evl(exp[2], lookup) elif exp[0] == 'mul': return evl(exp[1], lookup) * evl(exp[2], lookup) elif exp[0] == 'div': return floor(evl(exp[1], lookup) / evl(exp[2], lookup)) elif exp[0] == 'mod': return evl(exp[1], lookup) % evl(exp[2], lookup) elif exp[0] == 'eql': return 1 if evl(exp[1], lookup) == evl(exp[2], lookup) else 0 elif exp[0] == 'sum': return sum(evl(t, lookup) for t in exp[1]) raise Exception("Not allowed: {}".format(exp[0])) def max_val(exp, minv=False): if exp[0] == 'c': return exp[1] elif exp[0] == 'in': return 1 if minv else 9 elif exp[0] == 'add': return max_val(exp[1], minv) + max_val(exp[2], minv) elif exp[0] == 'mul': return max_val(exp[1], minv) * max_val(exp[2], minv) elif exp[0] == 'div': return floor(max_val(exp[1], minv) / max_val(exp[2], not minv)) elif exp[0] == 'mod': return (max_val(exp[2]) - 1) if not minv else 0 elif exp[0] == 'eql': return 0 if minv else 1 elif exp[0] == 'sum': return sum(max_val(t, minv) for t in exp[1]) raise Exception("Not allowed: {}".format(exp[0])) def reduce_sum(terms): indices = [i for (i, t) in enumerate(terms) if t[0] == 'c'] if len(indices) == 2: terms[indices[0]] = ('c', terms[indices[0]][1] + terms[indices[1]][1]) terms.pop(indices[1]) return [t for t in terms if t != ('c', 0)] def simplify(exp): global lookup if exp[0] == 'sum': s = reduce_sum([simplify(t) for t in exp[1]]) if len(s) == 1: return s[0] else: return ('sum', s) elif exp[0] == 'add': if exp[1][0] == 'sum' and exp[2][0] == 'sum': return simplify(('sum', exp[1][1] + exp[2][1])) elif exp[1][0] == 'sum': return simplify(('sum', exp[1][1] + [exp[2]])) elif exp[2][0] == 'sum': return simplify(('sum', [exp[1]] + exp[2][1])) else: return simplify(('sum', [exp[1], exp[2]])) elif exp[0] == 'mul': if exp[2] == ('c', 0) or exp[1] == ('c', 0): return ('c', 0) if exp[2] == ('c', 1): return exp[1] if exp[1] == ('c', 1): return exp[2] if exp[1][0] == 'div' and exp[1][2] == exp[2]: return exp[1][1] if exp[1][0] == 'c' and exp[2][0] == 'c': return ('c', exp[1][1] * exp[2][1]) elif exp[0] == 'div': if exp[2] == ('c', 1): return exp[1] if exp[1][0] == 'c' and exp[2][0] == 'c': return ('c', floor(exp[1][1] / exp[2][1])) if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]: return ('c', 0) if exp[1][0] == 'mul' and exp[1][2] == exp[2]: return exp[1][1] if exp[1][0] == 'sum': indices = [i for i, t in enumerate(exp[1][1]) if t[0] == 'mul' and t[2] == exp[2]] if len(indices) > 0: terms = exp[1][1] term = terms[indices[0]][1] terms.pop(indices[0]) return simplify(('sum', [term] + [('div', ('sum', terms), exp[2])])) elif exp[0] == 'mod': if exp[1][0] == 'c' and exp[2][0] == 'c': return ('c', exp[1][1] % exp[2][1]) if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]: return exp[1] if exp[1][0] == 'mul' and exp[1][1] == exp[2]: return ('c', 0) if exp[1][0] == 'mul' and exp[1][2] == exp[2]: return ('c', 0) if exp[1][0] == 'sum': return simplify(('sum', [('mod', t, exp[2]) for t in exp[1][1] if not (t[0] == 'mul' and t[2] == exp[2])])) elif exp[0] == 'eql': if exp[1] == exp[2]: return ('c', 1) try: if evl(exp[1], lookup) == evl(exp[2], lookup): print(pprint(exp[1]), 'EQUALS', pprint(exp[2])) return ('c', 1) except KeyError: pass if max_val(exp[1]) < max_val(exp[2], True): #print("1Reduced", pprint(exp[1]), "<>", pprint(exp[2])) return ('c', 0) if max_val(exp[2]) < max_val(exp[1], True): #print("2Reduced", pprint(exp[1]), "<>", pprint(exp[2])) return ('c', 0) #print(pprint(exp)) return exp def run(input): var = { 'x': ("c", 0), 'y': ("c", 0), 'z': ("c", 0), 'w': ("c", 0), } input_idx = 0 for i, line in enumerate(input[:270]): #print(i + 1, "/", len(input)) l = line.split() if l[0] == 'inp': var[l[1]] = ("in", input_idx) input_idx += 1 else: try: val = ("c", int(l[2])) except ValueError: val = var[l[2]] var[l[1]] = simplify((l[0], deepcopy(var[l[1]]), deepcopy(val))) print(pprint(var['z']).replace("*", "*\n")) print(max_val(var['z'], True)) print(evl(var['z'], lookup)) #print(max_val(var['z'])) run(input) print("Part 1:", part1) print("Part 2:", part2)