diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py index a783f91fc73..8d08fbd0166 100644 --- a/python/cudf_polars/cudf_polars/experimental/parallel.py +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -60,8 +60,7 @@ def _lower_ir_single( partition_info = reduce(operator.or_, _partition_info) # Check that child partitioning is supported - count = max(partition_info[c].count for c in children) - if count > 1: + if any(partition_info[c].count > 1 for c in children): raise NotImplementedError( f"Class {type(ir)} does not support multiple partitions." ) # pragma: no cover @@ -276,7 +275,7 @@ def _( return rec.state["default_mapper"](ir) # pragma: no cover # Lower children - children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False) + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True) partition_info = reduce(operator.or_, _partition_info) # Partition count is the sum of all child partitions @@ -294,11 +293,12 @@ def _( ) -> MutableMapping[Any, Any]: part_out = 0 key_name = get_key_name(ir) - graph: MutableMapping[Any, Any] = {} - for child in ir.children: - for i in range(partition_info[child].count): - graph[(key_name, part_out)] = (get_key_name(child), i) - part_out += 1 + partition = itertools.count() + graph = { + (key_name, next(partition)): (get_key_name(child), i) + for i in range(partition_info[child].count) + for child in ir.children + } return graph