diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index b4406fa95..4fa15cf2d 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -53,7 +53,7 @@ from .schema import C, UDFParamSpec, normalize_param from .session import Session -from .udf import UDFBase, UDFClassWrapper, UDFFactory, UDFType +from .udf import UDFBase if TYPE_CHECKING: from sqlalchemy.sql.elements import ClauseElement @@ -364,7 +364,7 @@ def get_generated_callback(is_generator: bool = False) -> Callback: @frozen class UDFStep(Step, ABC): - udf: UDFType + udf: UDFBase catalog: "Catalog" partition_by: Optional[PartitionByType] = None parallel: Optional[int] = None @@ -470,12 +470,6 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: else: # Otherwise process single-threaded (faster for smaller UDFs) - # Optionally instantiate the UDF instance if a class is provided. - if isinstance(self.udf, UDFFactory): - udf: UDFBase = self.udf() - else: - udf = self.udf - warehouse = self.catalog.warehouse with contextlib.closing( @@ -485,7 +479,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: processed_cb = get_processed_callback() generated_cb = get_generated_callback(self.is_generator) try: - udf_results = udf.run( + udf_results = self.udf.run( udf_fields, udf_inputs, self.catalog, @@ -498,7 +492,7 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None: warehouse, udf_table, udf_results, - udf, + self.udf, cb=generated_cb, ) finally: @@ -1471,7 +1465,7 @@ def chunk(self, index: int, total: int) -> "Self": @detach def add_signals( self, - udf: UDFType, + udf: UDFBase, parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, @@ -1492,9 +1486,6 @@ def add_signals( at least that minimum number of rows to each distributed worker, mostly useful if there are a very large number of small tasks to process. """ - if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable] - # This is a bare decorated class, "instantiate" it now. - udf = udf() # type: ignore[unreachable] query = self.clone() query.steps.append( UDFSignal( @@ -1518,16 +1509,13 @@ def subtract(self, dq: "DatasetQuery", on: Sequence[tuple[str, str]]) -> "Self": @detach def generate( self, - udf: UDFType, + udf: UDFBase, parallel: Optional[int] = None, workers: Union[bool, int] = False, min_task_size: Optional[int] = None, partition_by: Optional[PartitionByType] = None, cache: bool = False, ) -> "Self": - if isinstance(udf, UDFClassWrapper): # type: ignore[unreachable] - # This is a bare decorated class, "instantiate" it now. - udf = udf() # type: ignore[unreachable] query = self.clone() steps = query.steps steps.append( diff --git a/src/datachain/query/dispatch.py b/src/datachain/query/dispatch.py index c4a053b62..458aa4e41 100644 --- a/src/datachain/query/dispatch.py +++ b/src/datachain/query/dispatch.py @@ -27,7 +27,7 @@ put_into_queue, unmarshal, ) -from datachain.query.udf import UDFBase, UDFFactory, UDFResult +from datachain.query.udf import UDFBase, UDFResult from datachain.utils import batched_it DEFAULT_BATCH_SIZE = 10000 @@ -156,8 +156,6 @@ def __init__( @property def batch_size(self): - if not self.udf: - self.udf = self.udf_factory() if self._batch_size is None: if hasattr(self.udf, "properties") and hasattr( self.udf.properties, "batch" @@ -181,18 +179,7 @@ def _create_worker(self) -> "UDFWorker": self.catalog = Catalog( id_generator, metastore, warehouse, **self.catalog_init_params ) - udf = loads(self.udf_data) - # isinstance cannot be used here, as cloudpickle packages the entire class - # definition, and so these two types are not considered exactly equal, - # even if they have the same import path. - if full_module_type_path(type(udf)) != full_module_type_path(UDFFactory): - self.udf = udf - else: - self.udf = None - self.udf_factory = udf - if not self.udf: - self.udf = self.udf_factory() - + self.udf = loads(self.udf_data) return UDFWorker( self.catalog, self.udf, diff --git a/src/datachain/query/udf.py b/src/datachain/query/udf.py index d92d8858f..cf9bf3d51 100644 --- a/src/datachain/query/udf.py +++ b/src/datachain/query/udf.py @@ -1,13 +1,9 @@ import typing from collections.abc import Iterable, Iterator, Sequence from dataclasses import dataclass -from functools import WRAPPER_ASSIGNMENTS from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, - Union, ) from fsspec.callbacks import DEFAULT_CALLBACK, Callback @@ -128,105 +124,3 @@ def _process_results( for row_id, signals in zip(row_ids, results) if signals is not None # skip rows with no output ] - - -class UDFClassWrapper: - """ - A wrapper for class-based (stateful) UDFs. - """ - - def __init__( - self, - udf_class: type, - properties: UDFProperties, - method: Optional[str] = None, - ): - self.udf_class = udf_class - self.udf_method = method - self.properties = properties - self.output = properties.output - - def __call__(self, *args, **kwargs) -> "UDFFactory": - return UDFFactory( - self.udf_class, - args, - kwargs, - self.properties, - self.udf_method, - ) - - -class UDFWrapper(UDFBase): - """A wrapper class for function UDFs to be used in custom signal generation.""" - - def __init__( - self, - func: Callable, - properties: UDFProperties, - ): - self.func = func - super().__init__(properties) - # This emulates the behavior of functools.wraps for a class decorator - for attr in WRAPPER_ASSIGNMENTS: - if hasattr(func, attr): - setattr(self, attr, getattr(func, attr)) - - def run_once( - self, - catalog: "Catalog", - arg: "UDFInput", - is_generator: bool = False, - cache: bool = False, - cb: Callback = DEFAULT_CALLBACK, - ) -> Iterable[UDFResult]: - if isinstance(arg, UDFInputBatch): - udf_inputs = [ - self.bind_parameters(catalog, row, cache=cache, cb=cb) - for row in arg.rows - ] - udf_outputs = self.func(udf_inputs) - return self._process_results(arg.rows, udf_outputs, is_generator) - if isinstance(arg, RowDict): - udf_inputs = self.bind_parameters(catalog, arg, cache=cache, cb=cb) - udf_outputs = self.func(*udf_inputs) - if not is_generator: - # udf_outputs is generator already if is_generator=True - udf_outputs = [udf_outputs] - return self._process_results([arg], udf_outputs, is_generator) - raise ValueError(f"Unexpected UDF argument: {arg}") - - # This emulates the behavior of functools.wraps for a class decorator - def __repr__(self): - return repr(self.func) - - -class UDFFactory: - """ - A wrapper for late instantiation of UDF classes, primarily for use in parallelized - execution. - """ - - def __init__( - self, - udf_class: type, - args, - kwargs, - properties: UDFProperties, - method: Optional[str] = None, - ): - self.udf_class = udf_class - self.udf_method = method - self.args = args - self.kwargs = kwargs - self.properties = properties - self.output = properties.output - - def __call__(self) -> UDFWrapper: - udf_func = self.udf_class(*self.args, **self.kwargs) - if self.udf_method: - udf_func = getattr(udf_func, self.udf_method) - - return UDFWrapper(udf_func, self.properties) - - -UDFType = Union[UDFBase, UDFFactory]