123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- from copy import deepcopy
- from math import floor
- from itertools import product
- from util import get_input
- input = get_input("24.input")
- # Yes, you have to adjust these manually.
- # Yes, it can be done automatically.
- # No, I don't want to.
- # 1111
- # 01234567890123
- part1 = "93959993429899"
- part2 = "11815671117121"
- lookup = {k: int(part2[k]) for k in range(14)}
- def paren(exp):
- return "(" + exp + ")"
- def pprint(exp):
- if exp[0] == 'c':
- return str(exp[1])
- elif exp[0] == 'in':
- return 'in' + str(exp[1])
- elif exp[0] == 'add':
- return paren(pprint(exp[1]) + ' + ' + pprint(exp[2]))
- elif exp[0] == 'mul':
- return paren(pprint(exp[1]) + ' * ' + pprint(exp[2]))
- elif exp[0] == 'div':
- return paren(pprint(exp[1]) + ' / ' + pprint(exp[2]))
- elif exp[0] == 'mod':
- return paren(pprint(exp[1]) + ' % ' + pprint(exp[2]))
- elif exp[0] == 'eql':
- return paren(pprint(exp[1]) + ' == ' + pprint(exp[2]))
- elif exp[0] == 'sum':
- return paren(" + ".join(pprint(t) for t in exp[1]))
- raise Exception("Not allowed: {}".format(exp[0]))
- def evl(exp, lookup):
- if exp[0] == 'c':
- return exp[1]
- elif exp[0] == 'in':
- return lookup[exp[1]]
- elif exp[0] == 'add':
- return evl(exp[1], lookup) + evl(exp[2], lookup)
- elif exp[0] == 'mul':
- return evl(exp[1], lookup) * evl(exp[2], lookup)
- elif exp[0] == 'div':
- return floor(evl(exp[1], lookup) / evl(exp[2], lookup))
- elif exp[0] == 'mod':
- return evl(exp[1], lookup) % evl(exp[2], lookup)
- elif exp[0] == 'eql':
- return 1 if evl(exp[1], lookup) == evl(exp[2], lookup) else 0
- elif exp[0] == 'sum':
- return sum(evl(t, lookup) for t in exp[1])
- raise Exception("Not allowed: {}".format(exp[0]))
- def max_val(exp, minv=False):
- if exp[0] == 'c':
- return exp[1]
- elif exp[0] == 'in':
- return 1 if minv else 9
- elif exp[0] == 'add':
- return max_val(exp[1], minv) + max_val(exp[2], minv)
- elif exp[0] == 'mul':
- return max_val(exp[1], minv) * max_val(exp[2], minv)
- elif exp[0] == 'div':
- return floor(max_val(exp[1], minv) / max_val(exp[2], not minv))
- elif exp[0] == 'mod':
- return (max_val(exp[2]) - 1) if not minv else 0
- elif exp[0] == 'eql':
- return 0 if minv else 1
- elif exp[0] == 'sum':
- return sum(max_val(t, minv) for t in exp[1])
- raise Exception("Not allowed: {}".format(exp[0]))
- def reduce_sum(terms):
- indices = [i for (i, t) in enumerate(terms) if t[0] == 'c']
- if len(indices) == 2:
- terms[indices[0]] = ('c', terms[indices[0]][1] + terms[indices[1]][1])
- terms.pop(indices[1])
- return [t for t in terms if t != ('c', 0)]
- def simplify(exp):
- global lookup
- if exp[0] == 'sum':
- s = reduce_sum([simplify(t) for t in exp[1]])
- if len(s) == 1:
- return s[0]
- else:
- return ('sum', s)
- elif exp[0] == 'add':
- if exp[1][0] == 'sum' and exp[2][0] == 'sum':
- return simplify(('sum', exp[1][1] + exp[2][1]))
- elif exp[1][0] == 'sum':
- return simplify(('sum', exp[1][1] + [exp[2]]))
- elif exp[2][0] == 'sum':
- return simplify(('sum', [exp[1]] + exp[2][1]))
- else:
- return simplify(('sum', [exp[1], exp[2]]))
- elif exp[0] == 'mul':
- if exp[2] == ('c', 0) or exp[1] == ('c', 0):
- return ('c', 0)
- if exp[2] == ('c', 1):
- return exp[1]
- if exp[1] == ('c', 1):
- return exp[2]
- if exp[1][0] == 'div' and exp[1][2] == exp[2]:
- return exp[1][1]
- if exp[1][0] == 'c' and exp[2][0] == 'c':
- return ('c', exp[1][1] * exp[2][1])
- elif exp[0] == 'div':
- if exp[2] == ('c', 1):
- return exp[1]
- if exp[1][0] == 'c' and exp[2][0] == 'c':
- return ('c', floor(exp[1][1] / exp[2][1]))
- if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]:
- return ('c', 0)
- if exp[1][0] == 'mul' and exp[1][2] == exp[2]:
- return exp[1][1]
- if exp[1][0] == 'sum':
- indices = [i for i, t in enumerate(exp[1][1]) if t[0] == 'mul' and t[2] == exp[2]]
- if len(indices) > 0:
- terms = exp[1][1]
- term = terms[indices[0]][1]
- terms.pop(indices[0])
- return simplify(('sum', [term] + [('div', ('sum', terms), exp[2])]))
- elif exp[0] == 'mod':
- if exp[1][0] == 'c' and exp[2][0] == 'c':
- return ('c', exp[1][1] % exp[2][1])
- if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]:
- return exp[1]
- if exp[1][0] == 'mul' and exp[1][1] == exp[2]:
- return ('c', 0)
- if exp[1][0] == 'mul' and exp[1][2] == exp[2]:
- return ('c', 0)
- if exp[1][0] == 'sum':
- return simplify(('sum', [('mod', t, exp[2]) for t in exp[1][1] if not (t[0] == 'mul' and t[2] == exp[2])]))
- elif exp[0] == 'eql':
- if exp[1] == exp[2]:
- return ('c', 1)
- try:
- if evl(exp[1], lookup) == evl(exp[2], lookup):
- print(pprint(exp[1]), 'EQUALS', pprint(exp[2]))
- return ('c', 1)
- except KeyError:
- pass
- if max_val(exp[1]) < max_val(exp[2], True):
- #print("1Reduced", pprint(exp[1]), "<>", pprint(exp[2]))
- return ('c', 0)
- if max_val(exp[2]) < max_val(exp[1], True):
- #print("2Reduced", pprint(exp[1]), "<>", pprint(exp[2]))
- return ('c', 0)
- #print(pprint(exp))
- return exp
- def run(input):
- var = {
- 'x': ("c", 0),
- 'y': ("c", 0),
- 'z': ("c", 0),
- 'w': ("c", 0),
- }
- input_idx = 0
- for i, line in enumerate(input[:270]):
- #print(i + 1, "/", len(input))
- l = line.split()
- if l[0] == 'inp':
- var[l[1]] = ("in", input_idx)
- input_idx += 1
- else:
- try:
- val = ("c", int(l[2]))
- except ValueError:
- val = var[l[2]]
- var[l[1]] = simplify((l[0], deepcopy(var[l[1]]), deepcopy(val)))
- print(pprint(var['z']).replace("*", "*\n"))
- print(max_val(var['z'], True))
- print(evl(var['z'], lookup))
- #print(max_val(var['z']))
- run(input)
- print("Part 1:", part1)
- print("Part 2:", part2)
|