Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Lawrence Mitchell <[email protected]>
  • Loading branch information
rjzamora and wence- authored Nov 27, 2024
1 parent 0b63126 commit c6eb1b8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def _lower_ir_single(
partition_info = reduce(operator.or_, _partition_info)

# Check that child partitioning is supported
count = max(partition_info[c].count for c in children)
if count > 1:
if any(partition_info[c].count > 1 for c in children):
raise NotImplementedError(
f"Class {type(ir)} does not support multiple partitions."
) # pragma: no cover
Expand Down Expand Up @@ -276,7 +275,7 @@ def _(
return rec.state["default_mapper"](ir) # pragma: no cover

# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False)
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

# Partition count is the sum of all child partitions
Expand All @@ -294,11 +293,12 @@ def _(
) -> MutableMapping[Any, Any]:
part_out = 0
key_name = get_key_name(ir)
graph: MutableMapping[Any, Any] = {}
for child in ir.children:
for i in range(partition_info[child].count):
graph[(key_name, part_out)] = (get_key_name(child), i)
part_out += 1
partition = itertools.count()
graph = {
(key_name, next(partition)): (get_key_name(child), i)
for i in range(partition_info[child].count)
for child in ir.children
}
return graph


Expand Down

0 comments on commit c6eb1b8

Please sign in to comment.