24.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from copy import deepcopy
  2. from math import floor
  3. from itertools import product
  4. from util import get_input
  5. input = get_input("24.input")
  6. # Yes, you have to adjust these manually.
  7. # Yes, it can be done automatically.
  8. # No, I don't want to.
  9. # 1111
  10. # 01234567890123
  11. part1 = "93959993429899"
  12. part2 = "11815671117121"
  13. lookup = {k: int(part2[k]) for k in range(14)}
  14. def paren(exp):
  15. return "(" + exp + ")"
  16. def pprint(exp):
  17. if exp[0] == 'c':
  18. return str(exp[1])
  19. elif exp[0] == 'in':
  20. return 'in' + str(exp[1])
  21. elif exp[0] == 'add':
  22. return paren(pprint(exp[1]) + ' + ' + pprint(exp[2]))
  23. elif exp[0] == 'mul':
  24. return paren(pprint(exp[1]) + ' * ' + pprint(exp[2]))
  25. elif exp[0] == 'div':
  26. return paren(pprint(exp[1]) + ' / ' + pprint(exp[2]))
  27. elif exp[0] == 'mod':
  28. return paren(pprint(exp[1]) + ' % ' + pprint(exp[2]))
  29. elif exp[0] == 'eql':
  30. return paren(pprint(exp[1]) + ' == ' + pprint(exp[2]))
  31. elif exp[0] == 'sum':
  32. return paren(" + ".join(pprint(t) for t in exp[1]))
  33. raise Exception("Not allowed: {}".format(exp[0]))
  34. def evl(exp, lookup):
  35. if exp[0] == 'c':
  36. return exp[1]
  37. elif exp[0] == 'in':
  38. return lookup[exp[1]]
  39. elif exp[0] == 'add':
  40. return evl(exp[1], lookup) + evl(exp[2], lookup)
  41. elif exp[0] == 'mul':
  42. return evl(exp[1], lookup) * evl(exp[2], lookup)
  43. elif exp[0] == 'div':
  44. return floor(evl(exp[1], lookup) / evl(exp[2], lookup))
  45. elif exp[0] == 'mod':
  46. return evl(exp[1], lookup) % evl(exp[2], lookup)
  47. elif exp[0] == 'eql':
  48. return 1 if evl(exp[1], lookup) == evl(exp[2], lookup) else 0
  49. elif exp[0] == 'sum':
  50. return sum(evl(t, lookup) for t in exp[1])
  51. raise Exception("Not allowed: {}".format(exp[0]))
  52. def max_val(exp, minv=False):
  53. if exp[0] == 'c':
  54. return exp[1]
  55. elif exp[0] == 'in':
  56. return 1 if minv else 9
  57. elif exp[0] == 'add':
  58. return max_val(exp[1], minv) + max_val(exp[2], minv)
  59. elif exp[0] == 'mul':
  60. return max_val(exp[1], minv) * max_val(exp[2], minv)
  61. elif exp[0] == 'div':
  62. return floor(max_val(exp[1], minv) / max_val(exp[2], not minv))
  63. elif exp[0] == 'mod':
  64. return (max_val(exp[2]) - 1) if not minv else 0
  65. elif exp[0] == 'eql':
  66. return 0 if minv else 1
  67. elif exp[0] == 'sum':
  68. return sum(max_val(t, minv) for t in exp[1])
  69. raise Exception("Not allowed: {}".format(exp[0]))
  70. def reduce_sum(terms):
  71. indices = [i for (i, t) in enumerate(terms) if t[0] == 'c']
  72. if len(indices) == 2:
  73. terms[indices[0]] = ('c', terms[indices[0]][1] + terms[indices[1]][1])
  74. terms.pop(indices[1])
  75. return [t for t in terms if t != ('c', 0)]
  76. def simplify(exp):
  77. global lookup
  78. if exp[0] == 'sum':
  79. s = reduce_sum([simplify(t) for t in exp[1]])
  80. if len(s) == 1:
  81. return s[0]
  82. else:
  83. return ('sum', s)
  84. elif exp[0] == 'add':
  85. if exp[1][0] == 'sum' and exp[2][0] == 'sum':
  86. return simplify(('sum', exp[1][1] + exp[2][1]))
  87. elif exp[1][0] == 'sum':
  88. return simplify(('sum', exp[1][1] + [exp[2]]))
  89. elif exp[2][0] == 'sum':
  90. return simplify(('sum', [exp[1]] + exp[2][1]))
  91. else:
  92. return simplify(('sum', [exp[1], exp[2]]))
  93. elif exp[0] == 'mul':
  94. if exp[2] == ('c', 0) or exp[1] == ('c', 0):
  95. return ('c', 0)
  96. if exp[2] == ('c', 1):
  97. return exp[1]
  98. if exp[1] == ('c', 1):
  99. return exp[2]
  100. if exp[1][0] == 'div' and exp[1][2] == exp[2]:
  101. return exp[1][1]
  102. if exp[1][0] == 'c' and exp[2][0] == 'c':
  103. return ('c', exp[1][1] * exp[2][1])
  104. elif exp[0] == 'div':
  105. if exp[2] == ('c', 1):
  106. return exp[1]
  107. if exp[1][0] == 'c' and exp[2][0] == 'c':
  108. return ('c', floor(exp[1][1] / exp[2][1]))
  109. if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]:
  110. return ('c', 0)
  111. if exp[1][0] == 'mul' and exp[1][2] == exp[2]:
  112. return exp[1][1]
  113. if exp[1][0] == 'sum':
  114. indices = [i for i, t in enumerate(exp[1][1]) if t[0] == 'mul' and t[2] == exp[2]]
  115. if len(indices) > 0:
  116. terms = exp[1][1]
  117. term = terms[indices[0]][1]
  118. terms.pop(indices[0])
  119. return simplify(('sum', [term] + [('div', ('sum', terms), exp[2])]))
  120. elif exp[0] == 'mod':
  121. if exp[1][0] == 'c' and exp[2][0] == 'c':
  122. return ('c', exp[1][1] % exp[2][1])
  123. if exp[2][0] == 'c' and max_val(exp[1]) < exp[2][1]:
  124. return exp[1]
  125. if exp[1][0] == 'mul' and exp[1][1] == exp[2]:
  126. return ('c', 0)
  127. if exp[1][0] == 'mul' and exp[1][2] == exp[2]:
  128. return ('c', 0)
  129. if exp[1][0] == 'sum':
  130. return simplify(('sum', [('mod', t, exp[2]) for t in exp[1][1] if not (t[0] == 'mul' and t[2] == exp[2])]))
  131. elif exp[0] == 'eql':
  132. if exp[1] == exp[2]:
  133. return ('c', 1)
  134. try:
  135. if evl(exp[1], lookup) == evl(exp[2], lookup):
  136. print(pprint(exp[1]), 'EQUALS', pprint(exp[2]))
  137. return ('c', 1)
  138. except KeyError:
  139. pass
  140. if max_val(exp[1]) < max_val(exp[2], True):
  141. #print("1Reduced", pprint(exp[1]), "<>", pprint(exp[2]))
  142. return ('c', 0)
  143. if max_val(exp[2]) < max_val(exp[1], True):
  144. #print("2Reduced", pprint(exp[1]), "<>", pprint(exp[2]))
  145. return ('c', 0)
  146. #print(pprint(exp))
  147. return exp
  148. def run(input):
  149. var = {
  150. 'x': ("c", 0),
  151. 'y': ("c", 0),
  152. 'z': ("c", 0),
  153. 'w': ("c", 0),
  154. }
  155. input_idx = 0
  156. for i, line in enumerate(input[:270]):
  157. #print(i + 1, "/", len(input))
  158. l = line.split()
  159. if l[0] == 'inp':
  160. var[l[1]] = ("in", input_idx)
  161. input_idx += 1
  162. else:
  163. try:
  164. val = ("c", int(l[2]))
  165. except ValueError:
  166. val = var[l[2]]
  167. var[l[1]] = simplify((l[0], deepcopy(var[l[1]]), deepcopy(val)))
  168. print(pprint(var['z']).replace("*", "*\n"))
  169. print(max_val(var['z'], True))
  170. print(evl(var['z'], lookup))
  171. #print(max_val(var['z']))
  172. run(input)
  173. print("Part 1:", part1)
  174. print("Part 2:", part2)