Skip to content

Commit

Permalink
Extract GPUEngine config options at translation time (#17339)
Browse files Browse the repository at this point in the history
Follow up to #16944
That PR added `config: GPUEngine` to the arguments of every `IR.do_evaluate` function. In order to simplify future multi-GPU development, this PR extracts the necessary configuration argument at `IR` translation time instead.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - https://github.com/brandon-b-miller
  - Lawrence Mitchell (https://github.com/wence-)

URL: #17339
  • Loading branch information
rjzamora authored Nov 20, 2024
1 parent 2e88835 commit f550ccc
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 61 deletions.
6 changes: 2 additions & 4 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def set_device(device: int | None) -> Generator[int, None, None]:

def _callback(
ir: IR,
config: GPUEngine,
with_columns: list[str] | None,
pyarrow_predicate: str | None,
n_rows: int | None,
Expand All @@ -146,7 +145,7 @@ def _callback(
set_device(device),
set_memory_resource(memory_resource),
):
return ir.evaluate(cache={}, config=config).to_polars()
return ir.evaluate(cache={}).to_polars()


def validate_config_options(config: dict) -> None:
Expand Down Expand Up @@ -201,7 +200,7 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None:
validate_config_options(config.config)

with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
translator = Translator(nt)
translator = Translator(nt, config)
ir = translator.translate_ir()
ir_translation_errors = translator.errors
if len(ir_translation_errors):
Expand All @@ -225,7 +224,6 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None:
partial(
_callback,
ir,
config,
device=device,
memory_resource=memory_resource,
)
Expand Down
62 changes: 21 additions & 41 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from collections.abc import Callable, Hashable, MutableMapping, Sequence
from typing import Literal

from polars import GPUEngine

from cudf_polars.typing import Schema


Expand Down Expand Up @@ -182,9 +180,7 @@ def get_hashable(self) -> Hashable:
translation phase should fail earlier.
"""

def evaluate(
self, *, cache: MutableMapping[int, DataFrame], config: GPUEngine
) -> DataFrame:
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""
Evaluate the node (recursively) and return a dataframe.
Expand All @@ -193,8 +189,6 @@ def evaluate(
cache
Mapping from cached node ids to constructed DataFrames.
Used to implement evaluation of the `Cache` node.
config
GPU engine configuration.
Notes
-----
Expand All @@ -214,9 +208,8 @@ def evaluate(
translation phase should fail earlier.
"""
return self.do_evaluate(
config,
*self._non_child_args,
*(child.evaluate(cache=cache, config=config) for child in self.children),
*(child.evaluate(cache=cache) for child in self.children),
)


Expand Down Expand Up @@ -263,6 +256,7 @@ class Scan(IR):
"typ",
"reader_options",
"cloud_options",
"config_options",
"paths",
"with_columns",
"skip_rows",
Expand All @@ -275,6 +269,7 @@ class Scan(IR):
"typ",
"reader_options",
"cloud_options",
"config_options",
"paths",
"with_columns",
"skip_rows",
Expand All @@ -288,6 +283,8 @@ class Scan(IR):
"""Reader-specific options, as dictionary."""
cloud_options: dict[str, Any] | None
"""Cloud-related authentication options, currently ignored."""
config_options: dict[str, Any]
"""GPU-specific configuration options"""
paths: list[str]
"""List of paths to read from."""
with_columns: list[str] | None
Expand All @@ -310,6 +307,7 @@ def __init__(
typ: str,
reader_options: dict[str, Any],
cloud_options: dict[str, Any] | None,
config_options: dict[str, Any],
paths: list[str],
with_columns: list[str] | None,
skip_rows: int,
Expand All @@ -321,6 +319,7 @@ def __init__(
self.typ = typ
self.reader_options = reader_options
self.cloud_options = cloud_options
self.config_options = config_options
self.paths = paths
self.with_columns = with_columns
self.skip_rows = skip_rows
Expand All @@ -331,6 +330,7 @@ def __init__(
schema,
typ,
reader_options,
config_options,
paths,
with_columns,
skip_rows,
Expand Down Expand Up @@ -412,6 +412,7 @@ def get_hashable(self) -> Hashable:
self.typ,
json.dumps(self.reader_options),
json.dumps(self.cloud_options),
json.dumps(self.config_options),
tuple(self.paths),
tuple(self.with_columns) if self.with_columns is not None else None,
self.skip_rows,
Expand All @@ -423,10 +424,10 @@ def get_hashable(self) -> Hashable:
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
schema: Schema,
typ: str,
reader_options: dict[str, Any],
config_options: dict[str, Any],
paths: list[str],
with_columns: list[str] | None,
skip_rows: int,
Expand Down Expand Up @@ -509,7 +510,7 @@ def do_evaluate(
colnames[0],
)
elif typ == "parquet":
parquet_options = config.config.get("parquet_options", {})
parquet_options = config_options.get("parquet_options", {})
if parquet_options.get("chunked", True):
reader = plc.io.parquet.ChunkedParquetReader(
plc.io.SourceInfo(paths),
Expand Down Expand Up @@ -657,26 +658,22 @@ def __init__(self, schema: Schema, key: int, value: IR):

@classmethod
def do_evaluate(
cls, config: GPUEngine, key: int, df: DataFrame
cls, key: int, df: DataFrame
) -> DataFrame: # pragma: no cover; basic evaluation never calls this
"""Evaluate and return a dataframe."""
# Our value has already been computed for us, so let's just
# return it.
return df

def evaluate(
self, *, cache: MutableMapping[int, DataFrame], config: GPUEngine
) -> DataFrame:
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
# We must override the recursion scheme because we don't want
# to recurse if we're in the cache.
try:
return cache[self.key]
except KeyError:
(value,) = self.children
return cache.setdefault(
self.key, value.evaluate(cache=cache, config=config)
)
return cache.setdefault(self.key, value.evaluate(cache=cache))


class DataFrameScan(IR):
Expand Down Expand Up @@ -722,7 +719,6 @@ def get_hashable(self) -> Hashable:
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
schema: Schema,
df: Any,
projection: tuple[str, ...] | None,
Expand Down Expand Up @@ -770,7 +766,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
exprs: tuple[expr.NamedExpr, ...],
should_broadcast: bool, # noqa: FBT001
df: DataFrame,
Expand Down Expand Up @@ -806,7 +801,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
exprs: tuple[expr.NamedExpr, ...],
df: DataFrame,
) -> DataFrame: # pragma: no cover; not exposed by polars yet
Expand Down Expand Up @@ -899,7 +893,6 @@ def check_agg(agg: expr.Expr) -> int:
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
keys_in: Sequence[expr.NamedExpr],
agg_requests: Sequence[expr.NamedExpr],
maintain_order: bool, # noqa: FBT001
Expand Down Expand Up @@ -1021,7 +1014,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
predicate: plc.expressions.Expression,
zlice: tuple[int, int] | None,
suffix: str,
Expand Down Expand Up @@ -1194,7 +1186,6 @@ def _reorder_maps(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
left_on_exprs: Sequence[expr.NamedExpr],
right_on_exprs: Sequence[expr.NamedExpr],
options: tuple[
Expand Down Expand Up @@ -1318,7 +1309,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
exprs: Sequence[expr.NamedExpr],
should_broadcast: bool, # noqa: FBT001
df: DataFrame,
Expand Down Expand Up @@ -1381,7 +1371,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
keep: plc.stream_compaction.DuplicateKeepOption,
subset: frozenset[str] | None,
zlice: tuple[int, int] | None,
Expand Down Expand Up @@ -1471,7 +1460,6 @@ def __init__(
@classmethod
def do_evaluate(
cls,
config: GPUEngine,
by: Sequence[expr.NamedExpr],
order: Sequence[plc.types.Order],
null_order: Sequence[plc.types.NullOrder],
Expand Down Expand Up @@ -1527,9 +1515,7 @@ def __init__(self, schema: Schema, offset: int, length: int, df: IR):
self.children = (df,)

@classmethod
def do_evaluate(
cls, config: GPUEngine, offset: int, length: int, df: DataFrame
) -> DataFrame:
def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
return df.slice((offset, length))

Expand All @@ -1549,9 +1535,7 @@ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
self.children = (df,)

@classmethod
def do_evaluate(
cls, config: GPUEngine, mask_expr: expr.NamedExpr, df: DataFrame
) -> DataFrame:
def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
(mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
return df.filter(mask)
Expand All @@ -1569,7 +1553,7 @@ def __init__(self, schema: Schema, df: IR):
self.children = (df,)

@classmethod
def do_evaluate(cls, config: GPUEngine, schema: Schema, df: DataFrame) -> DataFrame:
def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
# This can reorder things.
columns = broadcast(
Expand Down Expand Up @@ -1645,9 +1629,7 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR):
self._non_child_args = (name, self.options)

@classmethod
def do_evaluate(
cls, config: GPUEngine, name: str, options: Any, df: DataFrame
) -> DataFrame:
def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
if name == "rechunk":
# No-op in our data model
Expand Down Expand Up @@ -1726,9 +1708,7 @@ def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR)
raise NotImplementedError("Schema mismatch")

@classmethod
def do_evaluate(
cls, config: GPUEngine, zlice: tuple[int, int] | None, *dfs: DataFrame
) -> DataFrame:
def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
# TODO: only evaluate what we need if we have a slice?
return DataFrame.from_table(
Expand Down Expand Up @@ -1777,7 +1757,7 @@ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
)

@classmethod
def do_evaluate(cls, config: GPUEngine, *dfs: DataFrame) -> DataFrame:
def do_evaluate(cls, *dfs: DataFrame) -> DataFrame:
"""Evaluate and return a dataframe."""
max_rows = max(df.num_rows for df in dfs)
# Horizontal concatenation extends shorter tables with nulls
Expand Down
8 changes: 7 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from polars import GPUEngine

from cudf_polars.typing import NodeTraverser

__all__ = ["Translator", "translate_named_expr"]
Expand All @@ -39,10 +41,13 @@ class Translator:
----------
visitor
Polars NodeTraverser object
config
GPU engine configuration.
"""

def __init__(self, visitor: NodeTraverser):
def __init__(self, visitor: NodeTraverser, config: GPUEngine):
self.visitor = visitor
self.config = config
self.errors: list[Exception] = []

def translate_ir(self, *, n: int | None = None) -> ir.IR:
Expand Down Expand Up @@ -228,6 +233,7 @@ def _(
typ,
reader_options,
cloud_options,
translator.config.config.copy(),
node.paths,
with_columns,
skip_rows,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception])
AssertionError
If the specified exceptions were not raised.
"""
translator = Translator(q._ldf.visit())
translator = Translator(q._ldf.visit(), GPUEngine())
translator.translate_ir()
if errors := translator.errors:
for err in errors:
Expand Down
3 changes: 2 additions & 1 deletion python/cudf_polars/docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,12 @@ and convert back to polars:

```python
from cudf_polars.dsl.translate import Translator
import polars as pl

q = ...

# Convert to our IR
ir = Translator(q._ldf.visit()).translate_ir()
ir = Translator(q._ldf.visit(), pl.GPUEngine()).translate_ir()

# DataFrame living on the device
result = ir.evaluate(cache={})
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/tests/dsl/test_to_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def df():
)
def test_compute_column(expr, df):
q = df.select(expr)
ir = Translator(q._ldf.visit()).translate_ir()
ir = Translator(q._ldf.visit(), pl.GPUEngine()).translate_ir()

assert isinstance(ir, ir_nodes.Select)
table = ir.children[0].evaluate(cache={}, config=pl.GPUEngine())
table = ir.children[0].evaluate(cache={})
name_to_index = {c.name: i for i, c in enumerate(table.columns)}

def compute_column(e):
Expand Down
Loading

0 comments on commit f550ccc

Please sign in to comment.