Skip to content

Commit

Permalink
Remove obsolete UDF code
Browse files Browse the repository at this point in the history
  • Loading branch information
rlamy committed Sep 17, 2024
1 parent 0abafcd commit 80f4fbe
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 139 deletions.
24 changes: 6 additions & 18 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

from .schema import C, UDFParamSpec, normalize_param
from .session import Session
from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType
from .udf import UDFBase

if TYPE_CHECKING:
from sqlalchemy.sql.elements import ClauseElement
Expand Down Expand Up @@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback:

@frozen
class UDFStep(Step, ABC):
udf: UDFType
udf: UDFBase
catalog: "Catalog"
partition_by: Optional[PartitionByType] = None
parallel: Optional[int] = None
Expand Down Expand Up @@ -470,12 +470,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

else:
# Otherwise process single-threaded (faster for smaller UDFs)
# Optionally instantiate the UDF instance if a class is provided.
if isinstance(self.udf, UDFFactory):
udf: UDFBase = self.udf()
else:
udf = self.udf

warehouse = self.catalog.warehouse

with contextlib.closing(
Expand All @@ -485,7 +479,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
processed_cb = get_processed_callback()
generated_cb = get_generated_callback(self.is_generator)
try:
udf_results = udf.run(
udf_results = self.udf.run(
udf_fields,
udf_inputs,
self.catalog,
Expand All @@ -498,7 +492,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
warehouse,
udf_table,
udf_results,
udf,
self.udf,
cb=generated_cb,
)
finally:
Expand Down Expand Up @@ -1471,7 +1465,7 @@ def chunk(self, index: int, total: int) -> "Self":
@detach
def add_signals(
self,
udf: UDFType,
udf: UDFBase,
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
Expand All @@ -1492,9 +1486,6 @@ def add_signals(
at least that minimum number of rows to each distributed worker, mostly useful
if there are a very large number of small tasks to process.
"""
if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
# This is a bare decorated class, "instantiate" it now.
udf = udf() # type: ignore[unreachable]
query = self.clone()
query.steps.append(
UDFSignal(
Expand All @@ -1518,16 +1509,13 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self":
@detach
def generate(
self,
udf: UDFType,
udf: UDFBase,
parallel: Optional[int] = None,
workers: Union[bool, int] = False,
min_task_size: Optional[int] = None,
partition_by: Optional[PartitionByType] = None,
cache: bool = False,
) -> "Self":
if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable]
# This is a bare decorated class, "instantiate" it now.
udf = udf() # type: ignore[unreachable]
query = self.clone()
steps = query.steps
steps.append(
Expand Down
17 changes: 2 additions & 15 deletions src/datachain/query/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
put_into_queue,
unmarshal,
)
from datachain.query.udf import UDFBase, UDFFactory, UDFResult
from datachain.query.udf import UDFBase, UDFResult
from datachain.utils import batched_it

DEFAULT_BATCH_SIZE = 10000
Expand Down Expand Up @@ -156,8 +156,6 @@ def __init__(

@property
def batch_size(self):
if not self.udf:
self.udf = self.udf_factory()
if self._batch_size is None:
if hasattr(self.udf, "properties") and hasattr(
self.udf.properties, "batch"
Expand All @@ -181,18 +179,7 @@ def _create_worker(self) -> "UDFWorker":
self.catalog = Catalog(
id_generator, metastore, warehouse, **self.catalog_init_params
)
udf = loads(self.udf_data)
# isinstance cannot be used here, as cloudpickle packages the entire class
# definition, and so these two types are not considered exactly equal,
# even if they have the same import path.
if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory):
self.udf = udf
else:
self.udf = None
self.udf_factory = udf
if not self.udf:
self.udf = self.udf_factory()

self.udf = loads(self.udf_data)
return UDFWorker(
self.catalog,
self.udf,
Expand Down
106 changes: 0 additions & 106 deletions src/datachain/query/udf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import typing
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from functools import WRAPPER_ASSIGNMENTS
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Union,
)

from fsspec.callbacks import DEFAULT_CALLBACK, Callback
Expand Down Expand Up @@ -128,105 +124,3 @@ def _process_results(
for row_id, signals in zip(row_ids, results)
if signals is not None # skip rows with no output
]


class UDFClassWrapper:
"""
A wrapper for class-based (stateful) UDFs.
"""

def __init__(
self,
udf_class: type,
properties: UDFProperties,
method: Optional[str] = None,
):
self.udf_class = udf_class
self.udf_method = method
self.properties = properties
self.output = properties.output

def __call__(self, *args, **kwargs) -> "UDFFactory":
return UDFFactory(
self.udf_class,
args,
kwargs,
self.properties,
self.udf_method,
)


class UDFWrapper(UDFBase):
"""A wrapper class for function UDFs to be used in custom signal generation."""

def __init__(
self,
func: Callable,
properties: UDFProperties,
):
self.func = func
super().__init__(properties)
# This emulates the behavior of functools.wraps for a class decorator
for attr in WRAPPER_ASSIGNMENTS:
if hasattr(func, attr):
setattr(self, attr, getattr(func, attr))

def run_once(
self,
catalog: "Catalog",
arg: "UDFInput",
is_generator: bool = False,
cache: bool = False,
cb: Callback = DEFAULT_CALLBACK,
) -> Iterable[UDFResult]:
if isinstance(arg, UDFInputBatch):
udf_inputs = [
self.bind_parameters(catalog, row, cache=cache, cb=cb)
for row in arg.rows
]
udf_outputs = self.func(udf_inputs)
return self._process_results(arg.rows, udf_outputs, is_generator)
if isinstance(arg, RowDict):
udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb)
udf_outputs = self.func(*udf_inputs)
if not is_generator:
# udf_outputs is generator already if is_generator=True
udf_outputs = [udf_outputs]
return self._process_results([arg], udf_outputs, is_generator)
raise ValueError(f"Unexpected UDF argument: {arg}")

# This emulates the behavior of functools.wraps for a class decorator
def __repr__(self):
return repr(self.func)


class UDFFactory:
"""
A wrapper for late instantiation of UDF classes, primarily for use in parallelized
execution.
"""

def __init__(
self,
udf_class: type,
args,
kwargs,
properties: UDFProperties,
method: Optional[str] = None,
):
self.udf_class = udf_class
self.udf_method = method
self.args = args
self.kwargs = kwargs
self.properties = properties
self.output = properties.output

def __call__(self) -> UDFWrapper:
udf_func = self.udf_class(*self.args, **self.kwargs)
if self.udf_method:
udf_func = getattr(udf_func, self.udf_method)

return UDFWrapper(udf_func, self.properties)


UDFType = Union[UDFBase, UDFFactory]

0 comments on commit 80f4fbe

Please sign in to comment.