Skip to content

Commit

Permalink
Day 5 fast
Browse files Browse the repository at this point in the history
  • Loading branch information
Ted Cassirer committed Jan 2, 2024
1 parent 21229c7 commit ac94a1d
Show file tree
Hide file tree
Showing 2 changed files with 352 additions and 285 deletions.
156 changes: 103 additions & 53 deletions aoc_cas/aoc2023/day5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
from functools import reduce
from typing import Self

from aocd import get_data


@dataclasses.dataclass(frozen=True)
class Map:
class Range:
source: int
destination: int
length: int
Expand All @@ -16,77 +17,126 @@ def __lt__(self, other: Self) -> bool:
@classmethod
def from_line(cls, line: str) -> Self:
destination, source, length = line.split(" ")
return Map(destination=int(destination), source=int(source), length=int(length))
return Range(destination=int(destination), source=int(source), length=int(length))

def overlaps(self, other: Self) -> bool:
if other.source <= self.destination < other.source + other.length:
return True
if other.source <= self.destination + self.length < other.source + other.length:
return True
return False

def cut(self, length: int) -> tuple[Self, Self]:
assert length < self.length
left = Range(source=self.source, destination=self.destination, length=length)
right = Range(
source=self.source + length,
destination=self.destination + length,
length=self.length - length,
)
return left, right


@dataclasses.dataclass(frozen=True)
class Transformation:
name: str
ranges: list[Map]
class RangeSet:
ranges: list[Range]

@classmethod
def from_transformation_data(cls, group: str) -> Self:
name, *ranges_str = group.splitlines()
ranges: list[Map] = sorted(map(Map.from_line, ranges_str))
return Transformation(name=name, ranges=ranges)

def map(self, input: int) -> int:
if input < self.ranges[0].source:
return input

for range in self.ranges:
if range.source > input:
break
if range.source + range.length > input:
d = input - range.source
if d <= range.length:
return range.destination + d
def from_data(cls, group: str) -> Self:
_, *ranges_str = group.splitlines()
ranges: list[Range] = sorted(map(Range.from_line, ranges_str))
return RangeSet(ranges=ranges)

def add(self, other: Self) -> Self:
ranges = sorted(self.ranges, key=lambda r: r.destination)
out: set[Range] = set()
i1 = 0
r1 = ranges[i1]
for r2 in other.ranges:
while r1.destination + r1.length <= r2.source:
out.add(r1)
i1 += 1
if i1 == len(ranges):
return RangeSet(ranges=sorted(out))
r1 = ranges[i1]

while r1.overlaps(r2):
if r1.destination < r2.source:
# .|----
# ..|---
r1_offset = r2.source - r1.destination
head = Range(
source=r1.source,
destination=r1.destination,
length=r1_offset,
)
out.add(head)
r1 = Range(
source=r1.source + r1_offset,
destination=r1.destination + r1_offset,
length=r1.length - r1_offset,
)
else:
break
return input


def parse(data: str) -> tuple[list[int], list[Transformation]]:
# ..|---
# .|----
r2_offset = r1.destination - r2.source
r2 = Range(
source=r2.source + r2_offset,
destination=r2.destination + r2_offset,
length=r2.length - r2_offset,
)

if r1.destination + r1.length <= r2.source + r2.length:
# .|---|...
# .|----|.
overlap = Range(source=r1.source, destination=r2.destination, length=r1.length)
out.add(overlap)
i1 += 1
if i1 == len(ranges):
return RangeSet(ranges=sorted(out))
r1 = ranges[i1]
else:
# .|----|.
# .|---|...
overlap = Range(source=r1.source, destination=r2.destination, length=r2.length)
r1 = Range(
source=r1.source + overlap.length,
destination=r1.destination + overlap.length,
length=r1.length - overlap.length,
)
out.add(overlap)
out.add(r1)
out.update(ranges[i1 + 1 :])
return RangeSet(ranges=sorted(out))


def parse(data: str) -> tuple[list[int], list[RangeSet]]:
seeds_str, *groups = data.split("\n\n")
seeds = [int(s) for s in seeds_str.split(" ")[1:]]
transformations = list(map(Transformation.from_transformation_data, groups))
return seeds, transformations
range_sets = list(map(RangeSet.from_data, groups))
return seeds, range_sets


def part_a(data: str) -> int:
seeds, transformations = parse(data)

lowest = 1 << 63
for x in seeds:
for transformation in transformations:
x = transformation.map(x)
lowest = min(x, lowest)
return lowest
seeds, range_sets = parse(data)
ranges = [Range(source=s, destination=s, length=1) for s in seeds]
initial_rs = RangeSet(ranges=ranges)
full_transformation = reduce(RangeSet.add, range_sets, initial_rs)
return min(r.destination for r in full_transformation.ranges)


def part_b(data: str) -> int:
seeds, transformations = parse(data)

lowest = 1 << 63
seeds_to_check = sum(seeds[1::2])
print("Seeds to check:", seeds_to_check)

checked = 0
seeds, range_sets = parse(data)
ranges = []
for a, b in zip(seeds[::2], seeds[1::2]):
for x in range(a, a+b):
checked += 1
if checked % 1000000 == 0:
print(checked, round(checked / seeds_to_check, 5))
for transformation in transformations:
x = transformation.map(x)
lowest = min(x, lowest)
return lowest
ranges.append(Range(source=a, destination=a, length=b))
initial_rs = RangeSet(ranges=ranges)
full_transformation = reduce(RangeSet.add, range_sets, initial_rs)
return min(r.destination for r in full_transformation.ranges)


if __name__ == "__main__":
from aoc_cas.util import solve_with_examples

solve_with_examples(year=2023, day=5)
data = get_data(year=2023, day=5)
print(part_a(data))
print(part_b(data))
Loading

0 comments on commit ac94a1d

Please sign in to comment.