Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AoC 2024: days 17-20 #538

Merged
merged 21 commits into from
Dec 25, 2024
Merged
151 changes: 151 additions & 0 deletions examples/aoc2024/day17/part1.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import "stdlib/io.jou"
import "stdlib/str.jou"


def combo_op_to_string(n: int) -> byte*:
assert 0 <= n and n <= 6
results = ["0", "1", "2", "3", "A", "B", "C"]
return results[n]


# Prints machine code and registers to stdout in human-readable form.
# In the code, 0-7 are instructions and operands and -1 = halt.
def print_code(code: int*) -> None:
for i = 0; code[i] != -1; i += 2:
printf("%4d ", i)

opcode = code[i]
operand = code[i+1]

if opcode == 0:
printf("A /= 2**%s\n", combo_op_to_string(operand))
elif opcode == 1:
printf("B ^= %d\n", operand)
elif opcode == 2:
printf("B = %s %% 8\n", combo_op_to_string(operand))
elif opcode == 3:
printf("if A != 0: jump to %d\n", operand)
elif opcode == 4:
printf("B ^= C\n")
elif opcode == 5:
printf("output %s %% 8\n", combo_op_to_string(operand))
elif opcode == 6:
printf("B = A / 2**%s\n", combo_op_to_string(operand))
elif opcode == 7:
printf("C = A / 2**%s\n", combo_op_to_string(operand))
else:
assert False


# TODO: xor operator
def xor(a: long, b: long) -> long:
assert a >= 0
assert b >= 0

result = 0L
power_of_two = 1L

while a != 0 or b != 0:
if a % 2 != b % 2:
result += power_of_two
a /= 2
b /= 2
power_of_two *= 2

return result


def do_combo_op(n: int, regs: long[3]) -> long:
assert 0 <= n and n <= 6
if n >= 4:
return regs[n - 4]
else:
return n


def run_code(code: int*, regs: long[3]) -> None:
output_started = False
ip = &code[0]

while *ip != -1:
opcode = *ip++
operand = *ip++

if opcode == 0: # adv = A DiVision
i = do_combo_op(operand, regs)
while i --> 0:
regs[0] /= 2 # TODO: add left shift operator to Jou
elif opcode == 1: # bxl = B bitwise Xor with Literal
regs[1] = xor(regs[1], operand)
elif opcode == 2: # bst = B Set value and Truncate to 3 bits
regs[1] = do_combo_op(operand, regs) % 8
elif opcode == 3: # jnz = Jump if NonZero
if regs[0] != 0:
ip = &code[operand]
elif opcode == 4: # bxc = B Xor C
regs[1] = xor(regs[1], regs[2])
elif opcode == 5: # out = OUTput value
if output_started:
putchar(',')
output_started = True
printf("%d", do_combo_op(operand, regs) % 8)
elif opcode == 6: # bdv = B DiVision
regs[1] = regs[0]
i = do_combo_op(operand, regs)
while i --> 0:
regs[1] /= 2
elif opcode == 7: # cdv = C DiVision
regs[2] = regs[0]
i = do_combo_op(operand, regs)
while i --> 0:
regs[2] /= 2

putchar('\n')


def main() -> int:
f = fopen("sampleinput1.txt", "r")
assert f != NULL

line: byte[1000]
code: int[100]
reg_a = 0L
reg_b = 0L
reg_c = 0L

while fgets(line, sizeof(line) as int, f) != NULL:
if starts_with(line, "Register A: "):
reg_a = atoll(&line[12])
elif starts_with(line, "Register B: "):
reg_b = atoll(&line[12])
elif starts_with(line, "Register C: "):
reg_c = atoll(&line[12])
elif starts_with(line, "Program: "):
p = &line[9]
code_len = 0
while True:
number = *p++
assert '0' <= number and number <= '7'
assert code_len < sizeof(code)/sizeof(code[0])
code[code_len++] = number - '0'
if *p++ != ',':
break
# Terminate with many halt instructions to reduce risk of overflow :)
for i = 0; i < 10; i++:
assert code_len < sizeof(code)/sizeof(code[0])
code[code_len++] = -1

fclose(f)

# Output: Registers: A=729 B=0 C=0
printf("Registers: A=%lld B=%lld C=%lld\n", reg_a, reg_b, reg_c)

# Output: 0 A /= 2**1
# Output: 2 output A % 8
# Output: 4 if A != 0: jump to 0
print_code(code)

# Output: 4,6,3,5,6,3,5,2,1,0
run_code(code, [reg_a, reg_b, reg_c])

return 0
177 changes: 177 additions & 0 deletions examples/aoc2024/day17/part2.jou
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import "stdlib/io.jou"
import "stdlib/mem.jou"
import "stdlib/str.jou"


# TODO: xor operator
def xor(a: long, b: long) -> long:
assert a >= 0
assert b >= 0

result = 0L
power_of_two = 1L

while a != 0 or b != 0:
if a % 2 != b % 2:
result += power_of_two
a /= 2
b /= 2
power_of_two *= 2

return result


def do_combo_op(n: int, regs: long[3]) -> long:
assert 0 <= n and n <= 6
if n >= 4:
return regs[n - 4]
else:
return n


# Return value is between 0 and 7 (3 bits), or -1 if the program doesn't output anything.
def get_first_output(code: int*, reg_a: long) -> int:
regs = [reg_a, 0L, 0L]
ip = &code[0]

while *ip != -1:
opcode = *ip++
operand = *ip++

if opcode == 0: # adv = A DiVision
i = do_combo_op(operand, regs)
while i --> 0:
regs[0] /= 2 # TODO: add left shift operator to Jou
elif opcode == 1: # bxl = B bitwise Xor with Literal
regs[1] = xor(regs[1], operand)
elif opcode == 2: # bst = B Set value and Truncate to 3 bits
regs[1] = do_combo_op(operand, regs) % 8
elif opcode == 3: # jnz = Jump if NonZero
if regs[0] != 0:
ip = &code[operand]
elif opcode == 4: # bxc = B Xor C
regs[1] = xor(regs[1], regs[2])
elif opcode == 5: # out = OUTput value
return (do_combo_op(operand, regs) % 8) as int
elif opcode == 6: # bdv = B DiVision
regs[1] = regs[0]
i = do_combo_op(operand, regs)
while i --> 0:
regs[1] /= 2
elif opcode == 7: # cdv = C DiVision
regs[2] = regs[0]
i = do_combo_op(operand, regs)
while i --> 0:
regs[2] /= 2

return -1 # no output


class List:
ptr: long*
len: int
alloc: int

def append(self, item: long) -> None:
if self->len == self->alloc:
if self->alloc == 0:
self->alloc = 4
else:
self->alloc *= 2
self->ptr = realloc(self->ptr, self->alloc * sizeof(self->ptr[0]))
assert self->ptr != NULL
self->ptr[self->len++] = item


def last_bits(n: long, how_many: int) -> long:
power_of_two = 1L
while how_many --> 0:
power_of_two *= 2
return xor(n, (n / power_of_two) * power_of_two)


# My input looked like this, as printed by code in part 1:
#
# Registers: A=... B=0 C=0
# 0 B = A % 8
# 2 B ^= 1
# 4 C = A / 2**B
# 6 B ^= 5
# 8 B ^= C
# 10 A /= 2**3
# 12 output B % 8
# 14 if A != 0: jump to 0
#
# The output depends only on the last 10 bits of A, because line 4 shifts away
# no more than 7 bits (7 = 111 binary), and the last 3 bits of C are used in
# the output. This means that we can do a simple depth-first search.
#
# This function finds all such x that:
# - when machine runs on x, it outputs the desired output
# - last 7 bits of x are as given (-1 means anything will do).
def find_matching_inputs(code: int*, desired_output: int*, desired_output_len: int, last7: int) -> List:
assert desired_output_len >= 0
results = List{}

if desired_output_len == 0:
# Ensure the code does nothing, imagining that the loop condition is at
# start of loop rather than end.
if last7 == -1 or last7 == 0:
# The input can actually be zero. It's not banned by the required
# last 7 bits.
results.append(0)
else:
for last10 = 0; last10 < 1024; last10++: # all 10 bit numbers
if (
(last7 == -1 or last_bits(last10, 7) == last7)
and get_first_output(code, last10) == desired_output[0]
):
# Recursively find more bits.
last7_before = last10 / 8 # shift right by 3
new_results = find_matching_inputs(code, &desired_output[1], desired_output_len - 1, last7_before)
for i = 0; i < new_results.len; i++:
x = new_results.ptr[i]
x *= 8 # shift left to make room
x = xor(x, last_bits(last10, 3))
assert last_bits(x, 10) == last10
results.append(x)
free(new_results.ptr)

return results


def main() -> int:
f = fopen("sampleinput2.txt", "r")
assert f != NULL

line: byte[1000]
code: int[100]
code_len = 0

while fgets(line, sizeof(line) as int, f) != NULL:
if starts_with(line, "Program: "):
p = &line[9]
while True:
number = *p++
assert '0' <= number and number <= '7'
assert code_len < sizeof(code)/sizeof(code[0])
code[code_len++] = number - '0'
if *p++ != ',':
break
# Terminate with many halt instructions to reduce risk of overflow :)
assert code_len + 10 <= sizeof(code)/sizeof(code[0])
for i = 0; i < 10; i++:
code[code_len + i] = -1

fclose(f)

results = find_matching_inputs(code, code, code_len, -1)

best = results.ptr[0]
for i = 0; i < results.len; i++:
if results.ptr[i] < best:
best = results.ptr[i]
printf("%lld\n", best) # Output: 117440

free(results.ptr)
return 0
5 changes: 5 additions & 0 deletions examples/aoc2024/day17/sampleinput1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Register A: 729
Register B: 0
Register C: 0

Program: 0,1,5,4,3,0
5 changes: 5 additions & 0 deletions examples/aoc2024/day17/sampleinput2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Register A: 2024
Register B: 0
Register C: 0

Program: 0,3,5,4,3,0
Loading
Loading