Skip to content

Commit

Permalink
[Data] Support class constructor args for filter() (#50245)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

<!-- Please give a short summary of the change and the problem this
solves. -->

Get user request to support class constructor args for
`dataset.filter()`, similar to flat_map, map, and map_batches.

## Related issue number

<!-- For example: "Closes #1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

Signed-off-by: liuxsh9 <[email protected]>
  • Loading branch information
liuxsh9 authored Feb 11, 2025
1 parent 5632a4b commit 499838a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def __init__(
self,
input_op: LogicalOperator,
fn: Optional[UserDefinedFunction] = None,
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
filter_expr: Optional["pa.dataset.Expression"] = None,
compute: Optional[ComputeStrategy] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
Expand All @@ -236,6 +240,10 @@ def __init__(
"Filter",
input_op,
fn=fn,
fn_args=fn_args,
fn_kwargs=fn_kwargs,
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
ray_remote_args=ray_remote_args,
Expand Down
28 changes: 28 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,10 @@ def filter(
expr: Optional[str] = None,
*,
compute: Union[str, ComputeStrategy] = None,
fn_args: Optional[Iterable[Any]] = None,
fn_kwargs: Optional[Dict[str, Any]] = None,
fn_constructor_args: Optional[Iterable[Any]] = None,
fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
concurrency: Optional[Union[int, Tuple[int, int]]] = None,
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
**ray_remote_args,
Expand Down Expand Up @@ -1272,6 +1276,16 @@ def filter(
that can be instantiated to create such a callable.
expr: An expression string needs to be a valid Python expression that
will be converted to ``pyarrow.dataset.Expression`` type.
fn_args: Positional arguments to pass to ``fn`` after the first argument.
These arguments are top-level arguments to the underlying Ray task.
fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are
top-level arguments to the underlying Ray task.
fn_constructor_args: Positional arguments to pass to ``fn``'s constructor.
You can only provide this if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
fn_constructor_kwargs: Keyword arguments to pass to ``fn``'s constructor.
This can only be provided if ``fn`` is a callable class. These arguments
are top-level arguments in the underlying Ray actor construction task.
compute: This argument is deprecated. Use ``concurrency`` argument.
concurrency: The number of Ray workers to use concurrently. For a
fixed-sized worker pool of size ``n``, specify ``concurrency=n``.
Expand All @@ -1292,6 +1306,15 @@ def filter(
if not ((fn is None) ^ (expr is None)):
raise ValueError("Exactly one of 'fn' or 'expr' must be provided.")
elif expr is not None:
if (
fn_args is not None
or fn_kwargs is not None
or fn_constructor_args is not None
or fn_constructor_kwargs is not None
):
raise ValueError(
"when 'expr' is used, 'fn_args/fn_kwargs' or 'fn_constructor_args/fn_constructor_kwargs' can not be used."
)
from ray.data._internal.compute import TaskPoolStrategy
from ray.data._internal.planner.plan_expression.expression_evaluator import ( # noqa: E501
ExpressionEvaluator,
Expand All @@ -1311,6 +1334,7 @@ def filter(
if callable(fn):
compute = get_compute_strategy(
fn=fn,
fn_constructor_args=fn_constructor_args,
compute=compute,
concurrency=concurrency,
)
Expand All @@ -1324,6 +1348,10 @@ def filter(
op = Filter(
input_op=self._logical_plan.dag,
fn=fn,
fn_args=fn_args,
fn_kwargs=fn_kwargs,
fn_constructor_args=fn_constructor_args,
fn_constructor_kwargs=fn_constructor_kwargs,
filter_expr=resolved_expr,
compute=compute,
ray_remote_args_fn=ray_remote_args_fn,
Expand Down
21 changes: 21 additions & 0 deletions python/ray/data/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,27 @@ def __call__(self, x, arg, kwarg):
).take()
assert sorted(extract_values("id", result)) == list(range(10)), result

class StatefulFilterFnWithArgs:
def __init__(self, arg, kwarg):
assert arg == 1
assert kwarg == 2

def __call__(self, x, arg, kwarg):
assert arg == 1
assert kwarg == 2
return True

# fiter with args & kwargs
result = ds.filter(
StatefulFilterFnWithArgs,
concurrency=1,
fn_args=(1,),
fn_kwargs={"kwarg": 2},
fn_constructor_args=(1,),
fn_constructor_kwargs={"kwarg": 2},
).take()
assert sorted(extract_values("id", result)) == list(range(10)), result


def test_concurrent_callable_classes(shutdown_only):
"""Test that concurrenct actor pool runs user UDF in a separate thread."""
Expand Down

0 comments on commit 499838a

Please sign in to comment.