diff --git a/.github/workflows/awkward-main.yml b/.github/workflows/awkward-main.yml index b2ce95db..f1c97d05 100644 --- a/.github/workflows/awkward-main.yml +++ b/.github/workflows/awkward-main.yml @@ -29,7 +29,7 @@ jobs: run: | python3 -m pip install pip wheel -U python3 -m pip install -q --no-cache-dir -e .[complete,test] - python3 -m pip uninstall -y awkward && pip install git+https://github.com/scikit-hep/awkward.git@main + python3 -m pip uninstall -y awkward && pip install git+https://github.com/scikit-hep/awkward.git@main --no-deps - name: Run tests run: | python3 -m pytest diff --git a/.gitignore b/.gitignore index 81601c75..77df5369 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ venv.bak/ # mypy .mypy_cache/ + +# pyright lsp +pyrightconfig.json diff --git a/docs/api/mapfilter.rst b/docs/api/mapfilter.rst new file mode 100644 index 00000000..1e572d27 --- /dev/null +++ b/docs/api/mapfilter.rst @@ -0,0 +1,15 @@ +mapfilter +--------- + +.. currentmodule:: dask_awkward + +.. autosummary:: + :toctree: generated/ + + mapfilter + prerun + +.. raw:: html + + diff --git a/docs/how-to/mapfilter.rst b/docs/how-to/mapfilter.rst new file mode 100644 index 00000000..94c9057f --- /dev/null +++ b/docs/how-to/mapfilter.rst @@ -0,0 +1,210 @@ +mapfilter +--------- + +:func:`dask_awkward.mapfilter` is a function that applies a function to each partition of +dask-awkward collections (:class:`dask_awkward.Array`). It maps the given function +over each partition in the provided collections in an embarrassingly parallel way. The input collections +must have the same number of partitions. + +An example is shown below: + +.. code-block:: python + + import dask_awkward as dak + import awkward as ak + + # Create a dask-awkward array + x = ak.Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + dak_array = dak.from_awkward(x, npartitions=2) + + + # Define a function to apply to each partition + @dak.mapfilter + def add_one(array): + return array + 1 + + + # Apply the function to each partition + result = add_one(dak_array) + + # Compute the result + result.compute() + # + + +Here, the ``dak_array`` has two partitions, and :func:`dask_awkward.mapfilter` will +apply the ``add_one`` function to each partition in parallel - resulting in two tasks in total (for the low-level graph). + +.. warning:: + + Since the mapped function is applied to each partition, the function must use eager awkward-array operations. + It is not possible to use (lazy) dask-awkward operations inside. + + +Collapsing Lazy Operations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The main purpose of :func:`dask_awkward.mapfilter` is to merge nodes into a single node +in the highlevel dask graph. This is useful to keep the graph small and avoid unnecessary scheduling overhead. + +*Any* function that is given to :func:`dask_awkward.mapfilter` will become a *single* node in the highlevel dask graph, +no matter how many operations are performed inside. + +An example is given in the following: + +.. code-block:: python + + import dask_awkward as dak + import awkward as ak + + # Create a dask-awkward array + x = ak.Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + dak_array = dak.from_awkward(x, npartitions=2) + + + # Define a function to apply to each partition + @dak.mapfilter + def fun(array): + return np.sin(array**2 + 1) + + + # Apply the function to each partition + result = fun(dak_array) + + # Inspect the graph + result.dask + # HighLevelGraph with 2 layers. + # + # 0. from-awkward-25967e11ca4677388b80cfb6f556d752 + # 1. + b.compute() + # + c.compute() + # (<__main__.some at 0x10b5819c0>, <__main__.some at 0x10b580dc0>) + + +Untraceable Functions +^^^^^^^^^^^^^^^^^^^^^ + +Sometimes one needs to leave the awkward-array world and use some operations that are not traceable +by awkward's typetracer. In this case :func:`dask_awkward.mapfilter` can be used to apply the function +to each partition nevertheless. One needs to provide the ``meta`` and ``needs`` arguments to :func:`dask_awkward.mapfilter` +to enable this: + +* ``meta``: The meta information of the output values +* ``needs``: A mapping that specifies an iterable of columns mapped to :class:`dask_awkward.Array` input arguments + +An example is given in the following: + +.. code-block:: python + + ak_array = ak.zip( + { + "x": ak.zip({"foo": [10, 20, 30, 40], "bar": [10, 20, 30, 40]}), + } + ) + dak_array = dak.from_awkward(ak_array, 2) + + + def untraceable_fun(array): + foo = ak.to_numpy(array.x.foo) + return ak.Array([np.sum(foo)]) + + + dak.mapfilter(untraceable_fun)(dak_array) + # ... + # TypeError: Converting from an nplike without known data to an nplike with known data is not supported + # + # This error occurred while calling + # + # ak.to_numpy( + # + # ) + # + # The above exception was the direct cause of the following exception: + # ... + + # Now let's add `meta` and `needs` arguments + from functools import partial + + mapf = partial(dak.mapfilter, needs={"array": [("x", "foo")]}, meta=ak.Array([0, 0])) + + # It works now! + mapf(untraceable_fun)(dak_array).compute() + # + +In fact, providing ``meta`` and ``needs`` is entirely skipping the tracing step as both arguments provide all necessary information already. +In cases where the function is much more complex and not traceable it can be helpful to run the tracing step manually: + +.. code-block:: python + + meta, needs = dak.prerun(untraceable_fun, array=dak_array) + # ... + # UntraceableFunctionError: '' is not traceable, an error occurred at line 9. 'dak.mapfilter' can circumvent this by providing 'needs' and 'meta' arguments to it. + # + # - 'needs': mapping where the keys point to input argument dask_awkward arrays and the values to columns that should be touched explicitly. The typetracing step could determine the following necessary columns until the exception occurred: + # + # needs={'array': [('x', 'foo')]} + # + # - 'meta': value(s) of what the wrapped function would return. For arrays, only the shape and type matter. + +Here, :func:`dask_awkward.prerun` will try to trace the function once and return the necessary information (``meta`` and ``needs``) to provide to :func:`dask_awkward.mapfilter`. +In this case the function is untraceable, so :func:`dask_awkward.prerun` will report at least ``needs`` to the point where the function is not traceable anymore. + +.. tip:: + + For traceable but long-running functions (e.g. if the contain the evaluation of a neural network), it is recommended to use :func:`dask_awkward.prerun` to infer ``meta`` and ``needs`` once, + and provide it to all consecutive :func:`dask_awkward.mapfilter` calls. This way, the tracing step is only performed once. diff --git a/docs/index.rst b/docs/index.rst index 815d85fd..f91fa23d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ Table of Contents how-to/configuration.rst how-to/io.rst how-to/behaviors.rst + how-to/mapfilter.rst .. toctree:: :maxdepth: 1 @@ -47,6 +48,7 @@ Table of Contents api/collections.rst api/inspect.rst api/io.rst + api/mapfilter.rst api/reducers.rst api/structure.rst api/behavior.rst diff --git a/src/dask_awkward/__init__.py b/src/dask_awkward/__init__.py index 34b5c4f5..c4c5396e 100644 --- a/src/dask_awkward/__init__.py +++ b/src/dask_awkward/__init__.py @@ -24,6 +24,7 @@ report_necessary_columns, sample, ) +from dask_awkward.lib.mapfilter import mapfilter, prerun necessary_columns = report_necessary_columns # Export for backwards compatibility. diff --git a/src/dask_awkward/lib/__init__.py b/src/dask_awkward/lib/__init__.py index 74d16d6c..7be61ce9 100644 --- a/src/dask_awkward/lib/__init__.py +++ b/src/dask_awkward/lib/__init__.py @@ -27,6 +27,7 @@ from dask_awkward.lib.io.json import from_json, to_json from dask_awkward.lib.io.parquet import from_parquet, to_parquet from dask_awkward.lib.io.text import from_text +from dask_awkward.lib.mapfilter import mapfilter, prerun from dask_awkward.lib.operations import concatenate from dask_awkward.lib.reducers import ( all, diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index d7c1a4e0..650749aa 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -1982,7 +1982,21 @@ def __call__(self, *args_deps_expanded): return self.fn(*args, **kwargs) -def _map_partitions( +def _new_dak_array_divisions( + dak_array: Array, output_divisions: int | None = None +) -> tuple: + in_divisions = dak_array.divisions + if output_divisions is not None: + if output_divisions == 1: + new_divisions = dak_array.divisions + else: + new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) # type: ignore[operator] + else: + new_divisions = in_divisions + return new_divisions + + +def _map_partitions_prepare( fn: Callable, *args: Any, label: str | None = None, @@ -1990,23 +2004,13 @@ def _map_partitions( meta: Any | None = None, output_divisions: int | None = None, **kwargs: Any, -) -> Array: - """Map a callable across all partitions of any number of collections. - No wrapper is used to flatten the function arguments. This is meant for - dask-awkward internal use or in situations where input data are sanitized. - - The parameters of this function are otherwise the same as map_partitions, - but the limitation that args, kwargs must be non-nested and flat. They - will not be traversed to extract all dask collections, except those in - the first dimension of args or kwargs. - """ +) -> tuple: token = token or tokenize(fn, *args, output_divisions, **kwargs) label = hyphenize(label or funcname(fn)) name = f"{label}-{token}" deps = [a for a in args if is_dask_collection(a)] + [ v for v in kwargs.values() if is_dask_collection(v) ] - dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps)) if name in dak_cache: hlg, meta = dak_cache[name] @@ -2027,22 +2031,46 @@ def _map_partitions( dependencies=deps, ) - if len(dak_arrays) == 0: - raise TypeError( - "at least one argument passed to map_partitions " - "should be a dask_awkward.Array collection." - ) dak_cache[name] = hlg, meta - in_npartitions = dak_arrays[0].npartitions - in_divisions = dak_arrays[0].divisions - if output_divisions is not None: - if output_divisions == 1: - new_divisions = dak_arrays[0].divisions - else: - new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) - else: - new_divisions = in_divisions + return hlg, meta, deps, name + +def _map_partitions( + fn: Callable, + *args: Any, + label: str | None = None, + token: str | None = None, + meta: Any | None = None, + output_divisions: int | None = None, + **kwargs: Any, +) -> Array: + """Map a callable across all partitions of any number of collections. + No wrapper is used to flatten the function arguments. This is meant for + dask-awkward internal use or in situations where input data are sanitized. + + The parameters of this function are otherwise the same as map_partitions, + but the limitation that args, kwargs must be non-nested and flat. They + will not be traversed to extract all dask collections, except those in + the first dimension of args or kwargs. + """ + hlg, meta, deps, name = _map_partitions_prepare( + fn, + *args, + label=label, + token=token, + meta=meta, + output_divisions=output_divisions, + **kwargs, + ) + dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps)) + if len(dak_arrays) == 0: + raise TypeError( + "at least one argument passed to map_partitions " + "should be a dask_awkward.Array collection." + ) + first = dak_arrays[0] + new_divisions = _new_dak_array_divisions(first, output_divisions) + # from IPython import embed;embed() if output_divisions is not None: return new_array_object( hlg, @@ -2055,10 +2083,62 @@ def _map_partitions( hlg, name=name, meta=meta, - npartitions=in_npartitions, + npartitions=first.npartitions, ) +def _to_packed_fn_args( + base_fn: Callable, + *args: Any, + traverse: bool = True, + **kwargs: Any, +) -> tuple: + opt_touch_all = kwargs.pop("opt_touch_all", None) + if opt_touch_all is not None: + warnings.warn( + "The opt_touch_all argument does nothing.\n" + "This warning will be removed in a future version of dask-awkward " + "and the function call will likely fail." + ) + + kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse) + flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse) + + if len(flat_deps) == 0: + message = ( + "map_partitions expects at least one Dask collection instance, " + "you are passing non-Dask collections to dask-awkward code.\n" + "observed argument types:\n" + ) + for arg in args: + message += f"- {type(arg)}" + raise TypeError(message) + + arg_flat_deps_expanded = [] + arg_repackers = [] + arg_lens_for_repackers = [] + for arg in args: + this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse) + if ( + len(this_arg_flat_deps) > 0 + ): # if the deps list is empty this arg does not contain any dask collection, no need to repack! + arg_flat_deps_expanded.extend(this_arg_flat_deps) + arg_repackers.append(repacker) + arg_lens_for_repackers.append(len(this_arg_flat_deps)) + else: + arg_flat_deps_expanded.append(arg) + arg_repackers.append(None) + arg_lens_for_repackers.append(1) + + packed_fn = ArgsKwargsPackedFunction( + base_fn, + arg_repackers, + kwarg_repacker, + arg_lens_for_repackers, + ) + return packed_fn, arg_flat_deps_expanded, kwarg_flat_deps + + def map_partitions( base_fn: Callable, *args: Any, @@ -2139,49 +2219,8 @@ def map_partitions( This is effectively the same as `d = c * a` """ - - opt_touch_all = kwargs.pop("opt_touch_all", None) - if opt_touch_all is not None: - warnings.warn( - "The opt_touch_all argument does nothing.\n" - "This warning will be removed in a future version of dask-awkward " - "and the function call will likely fail." - ) - - kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse) - flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse) - - if len(flat_deps) == 0: - message = ( - "map_partitions expects at least one Dask collection instance, " - "you are passing non-Dask collections to dask-awkward code.\n" - "observed argument types:\n" - ) - for arg in args: - message += f"- {type(arg)}" - raise TypeError(message) - - arg_flat_deps_expanded = [] - arg_repackers = [] - arg_lens_for_repackers = [] - for arg in args: - this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse) - if ( - len(this_arg_flat_deps) > 0 - ): # if the deps list is empty this arg does not contain any dask collection, no need to repack! - arg_flat_deps_expanded.extend(this_arg_flat_deps) - arg_repackers.append(repacker) - arg_lens_for_repackers.append(len(this_arg_flat_deps)) - else: - arg_flat_deps_expanded.append(arg) - arg_repackers.append(None) - arg_lens_for_repackers.append(1) - - fn = ArgsKwargsPackedFunction( - base_fn, - arg_repackers, - kwarg_repacker, - arg_lens_for_repackers, + fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args( + base_fn, *args, traverse=traverse, **kwargs ) return _map_partitions( fn, diff --git a/src/dask_awkward/lib/io/parquet.py b/src/dask_awkward/lib/io/parquet.py index 2fc47c25..9ff2a58d 100644 --- a/src/dask_awkward/lib/io/parquet.py +++ b/src/dask_awkward/lib/io/parquet.py @@ -483,6 +483,7 @@ def __init__( npartitions: int, prefix: str | None = None, storage_options: dict | None = None, + write_metadata: bool = False, **kwargs: Any, ): self.fs = fs @@ -496,6 +497,7 @@ def __init__( if isinstance(self.fs.protocol, str) else self.fs.protocol[0] ) + self.write_metadata = write_metadata self.kwargs = kwargs def __call__(self, data, block_index): @@ -503,9 +505,11 @@ def __call__(self, data, block_index): if self.prefix is not None: filename = f"{self.prefix}-{filename}" filename = self.fs.unstrip_protocol(f"{self.path}{self.fs.sep}{filename}") - return ak.to_parquet( + out = ak.to_parquet( data, filename, **self.kwargs, storage_options=self.storage_options ) + if self.write_metadata: + return out def to_parquet( @@ -597,7 +601,10 @@ def to_parquet( storage_options Storage options passed to ``fsspec``. write_metadata - Write Parquet metadata. + Write Parquet metadata. Note, that when this is True, all the + metadata pieces will be pulled into a single finalizer task. When + False, the whole write graph can be evaluated as a more efficient + tree reduction. compute If ``True``, immediately compute the result (write data to disk). If ``False`` a Scalar collection will be returned such @@ -667,6 +674,7 @@ def to_parquet( parquet_old_int96_timestamps=parquet_old_int96_timestamps, parquet_compliant_nested=parquet_compliant_nested, parquet_extra_options=parquet_extra_options, + write_metadata=write_metadata, ), array, BlockIndex((array.npartitions,)), @@ -681,17 +689,38 @@ def to_parquet( dsk[(final_name, 0)] = (_metadata_file_from_metas, fs, path) + tuple( map_res.__dask_keys__() ) + graph = HighLevelGraph.from_collections( + final_name, + AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]), + dependencies=[map_res], + ) + out = new_scalar_object(graph, final_name, dtype="f8") else: final_name = name + "-finalize" - dsk[(final_name, 0)] = (lambda *_: None, map_res.__dask_keys__()) - graph = HighLevelGraph.from_collections( - final_name, - AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]), - dependencies=[map_res], - ) - out = new_scalar_object(graph, final_name, dtype="f8") + from dask_awkward.layers import AwkwardTreeReductionLayer + + layer = AwkwardTreeReductionLayer( + name=final_name, + concat_func=none_to_none, + tree_node_func=none_to_none, + name_input=map_res.name, + npartitions_input=map_res.npartitions, + finalize_func=none_to_none, + ) + graph = HighLevelGraph.from_collections( + final_name, + layer, + dependencies=[map_res], + ) + out = new_scalar_object(graph, final_name, dtype="f8") + if compute: out.compute() return None else: return out + + +def none_to_none(*_): + """Dummy reduction function where write tasks produce no metadata""" + return None diff --git a/src/dask_awkward/lib/mapfilter.py b/src/dask_awkward/lib/mapfilter.py new file mode 100644 index 00000000..f6603474 --- /dev/null +++ b/src/dask_awkward/lib/mapfilter.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import typing as tp +from dataclasses import dataclass + +import awkward as ak +from dask.highlevelgraph import HighLevelGraph +from dask.typing import DaskCollection + +from dask_awkward.lib.core import Array as DakArray +from dask_awkward.lib.core import ( + _map_partitions_prepare, + _to_packed_fn_args, + dak_cache, + empty_typetracer, + new_array_object, + partitionwise_layer, + to_meta, + typetracer_array, +) +from dask_awkward.utils import DaskAwkwardNotImplemented + + +def _single_return_map_partitions( + hlg: HighLevelGraph, + name: str, + meta: tp.Any, + npartitions: int, +) -> tp.Any: + from dask.utils import ( + is_arraylike, + is_dataframe_like, + is_index_like, + is_series_like, + ) + + # ak.Array (this is dak.map_partitions case) + if isinstance(meta, ak.Array): + # convert to typetracer if not already + # this happens when the user provides a concrete array (e.g. np.array) + # and then wraps it with ak.Array as a return type + if not ak.backend(meta) == "typetracer": + meta = ak.to_backend(meta, "typetracer") + return new_array_object( + hlg, + name=name, + meta=meta, + npartitions=npartitions, + ) + # TODO: array, dataframe, series, index + elif ( + is_arraylike(meta) + or is_dataframe_like(meta) + or is_series_like(meta) + or is_index_like(meta) + ): + msg = ( + f"{meta=} is not (yet) supported as return type. If possible, " + "you can convert it to ak.Array, or wrap it with a python container." + ) + raise NotImplementedError(msg) + # don't know? -> put it in a bag + else: + from dask.bag.core import Bag + + return Bag(dsk=hlg, name=name, npartitions=npartitions) + + +def _multi_return_map_partitions( + hlg: HighLevelGraph, + name: str, + meta: tp.Any, + npartitions: int, +) -> tp.Any: + # single-return case, this is equal to `dak.map_partitions` + # but supports other DaskCollections in addition + if not isinstance(meta, tuple): + return _single_return_map_partitions( + hlg=hlg, + name=name, + meta=meta, + npartitions=npartitions, + ) + # multi-return case + else: + from operator import itemgetter + from typing import cast + + # create tmp dask collection for HLG creation + tmp = new_array_object( + hlg, name=name, meta=empty_typetracer(), npartitions=npartitions + ) + + ret = [] + for i, m_pick in enumerate(meta): + # add a "select/pick" layer + # to get the ith element of the output + ith_name = f"{name}-pick-{i}th" + + if ith_name in dak_cache: + hlg_pick, m_pick = dak_cache[ith_name] + else: + lay_pick = partitionwise_layer(itemgetter(i), ith_name, tmp) + hlg_pick = HighLevelGraph.from_collections( + name=ith_name, + layer=lay_pick, + dependencies=[cast(DaskCollection, tmp)], + ) + dak_cache[ith_name] = hlg_pick, m_pick + ret.append( + _single_return_map_partitions( + hlg=hlg_pick, + name=ith_name, + meta=m_pick, + npartitions=npartitions, + ) + ) + return tuple(ret) + + +class UntraceableFunctionError(Exception): ... + + +def _func_args(fn: tp.Callable, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping: + import inspect + + ba = inspect.signature(fn).bind(*args, **kwargs) + return ba.arguments + + +def _reports2needs(reports: tp.Mapping) -> dict: + import ast + from collections import defaultdict + + needs = defaultdict(list) + for arg, report in reports.items(): + # this should maybe be differently treated? + keys = set(report.shape_touched) | set(report.data_touched) + for key in keys: + slce = ast.literal_eval(key) + # only strings are actual slice paths to columns, + # `None` or `ints` are path values to non-record array types, + # see: https://github.com/scikit-hep/awkward/pull/3311 + slce = tuple(it for it in slce if isinstance(it, str)) + needs[arg].append(slce) + return needs + + +def _replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any: + def _to_tracer(meta: tp.Any) -> tp.Any: + if isinstance(meta, ak.Array): + if not ak.backend(meta) == "typetracer": + meta = typetracer_array(meta) + elif isinstance(meta, DakArray): + meta = to_meta([meta])[0] + return meta + + if isinstance(meta, tuple): + meta = tuple(map(_to_tracer, meta)) + else: + meta = _to_tracer(meta) + return meta + + +def prerun( + fn: tp.Callable, *args: tp.Any, **kwargs: tp.Any +) -> tuple[tp.Any, tp.Mapping]: + """ + Pre-runs the provided function with typetracer arrays to determine the necessary columns + that should be touched explicitly and to infer the metadata of the function's output. + + Parameters + ---------- + fn : Callable + The function to be pre-run. + *args : Any + Positional arguments to be passed to the function. + **kwargs : Any + Keyword arguments to be passed to the function. + + Returns + ------- + tuple[Any, Mapping] + A tuple containing the output of the function when run with typetracer arrays and + a mapping of the touched columns (prepared to use with ``mapfilter(needs=...)``) generated during the typetracing step. + + Examples + -------- + >>> import awkward as ak + >>> import dask_awkward as dak + >>> + >>> ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + >>> dak_array = dak.from_awkward(ak_array, 2) + >>> + >>> def process(array: ak.Array) -> ak.Array: + >>> return array.foo + array.bar + >>> + >>> meta, needs = dak.prerun(process, array) + >>> print(meta) + + >>> print(needs) + {'array': [('bar',), ('foo',)]} + """ + # unwrap `mapfilter` + if isinstance(fn, mapfilter): + fn = fn.fn + + in_arguments = _func_args(fn, *args, **kwargs) + + # replace ak.Arrays with typetracers and store the reports + reports = {} + fun_kwargs = {} + args_metas = { + arg: _replace_arrays_with_typetracers(val) for arg, val in in_arguments.items() + } + + # can't typetrace if no ak.Arrays are present + ak_arrays = tuple(filter(lambda x: isinstance(x, ak.Array), args_metas.values())) + if not ak_arrays: + return None, {} + + def _render_buffer_key( + form: ak.forms.Form, + form_key: str, + attribute: str, + ) -> str: + return form_key + + # prepare function arguments + for arg, val in args_metas.items(): + if isinstance(val, ak.Array): + if not ak.backend(val) == "typetracer": + val = typetracer_array(val) + tracer, report = ak.typetracer.typetracer_with_report( + val.layout.form_with_key_path(root=()), + highlevel=True, + behavior=val.behavior, + attrs=val.attrs, + buffer_key=_render_buffer_key, + ) + reports[arg] = report + fun_kwargs[arg] = tracer + else: + fun_kwargs[arg] = val + + # try to run the function once with typetracers + try: + out = fn(**fun_kwargs) + except Exception as err: + import traceback + + # get line number of where the error occurred in the provided function + # traceback 0: this function, 1: the provided function, >1: the rest of the stack + tb = traceback.extract_tb(err.__traceback__) + line_number = tb[1].lineno + + # add also the reports of the typetracer to the error message, + # and format them as 'needs' wants it to be + needs = dict(_reports2needs(reports=reports)) + + msg = ( + f"'{fn}' is not traceable, an error occurred at line {line_number}. " + "'dak.mapfilter' can circumvent this by providing 'needs' and " + "'meta' arguments to it.\n\n" + "- 'needs': mapping where the keys point to input argument " + "dask_awkward arrays and the values to columns that should " + "be touched explicitly. The typetracing step could determine " + "the following necessary columns until the exception occurred:\n\n" + f"{needs=}\n\n" + "- 'meta': value(s) of what the wrapped function would " + "return. For arrays, only the shape and type matter." + ) + raise UntraceableFunctionError(msg) from err + return out, dict(_reports2needs(reports)) + + +@dataclass +class mapfilter: + """ + A decorator to map a callable across all partitions of any number of collections. + The function will be treated as a single node in the Dask graph. + + Parameters + ---------- + fn : Callable + The function to apply on all partitions. This will get wrapped to handle kwargs, including Dask collections. + label : str, optional + Label for the Dask graph layer; if left as ``None`` (default), the name of the function will be used. + token : str, optional + Provide an already defined token. If ``None``, a new token will be generated. + meta : Any, optional + Metadata for the result (if known). If unknown, `fn` will be applied to the metadata of the `args`. + If provided, the tracing step will be skipped and the provided metadata is used as return value(s) of `fn`. + traverse : bool + Unpack basic Python containers to find Dask collections. + needs : dict, optional + If ``None`` (the default), nothing is touched in addition to the standard typetracer report. + In certain cases, it is necessary to touch additional objects **explicitly** to get the correct typetracer report. + For this, provide a dictionary that maps input arguments that are arrays to the columns of that array that should be touched. + If ``needs`` is used together with ``meta``, **only** the columns provided by the ``needs`` argument will be touched explicitly. + + Examples + -------- + + .. code-block:: ipython + + import awkward as ak + import dask_awkward as dak + + ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + dak_array = dak.from_awkward(ak_array, 2) + + @dak.mapfilter + def process(array: ak.Array) -> ak.Array: + return array.foo * 2 + + out = process(dak_array) + print(out.compute()) + # + """ + + fn: tp.Callable + label: str | None = None + token: str | None = None + meta: tp.Any | None = None + traverse: bool = True + needs: tp.Mapping | None = None + + def __post_init__(self) -> None: + if self.needs is not None and not isinstance(self.needs, tp.Mapping): + msg = ( # type: ignore[unreachable] + "'needs' argument must be a mapping where the keys " + "point to input argument dask_awkward arrays and the values " + "to columns that should be touched explicitly, " + f"got {self.needs!r} instead.\n\n" + "Exemplary usage:\n" + "\n@partial(mapfilter, needs={'array': ['col1', 'col2']})" + "\ndef process(array: ak.Array) -> ak.Array:" + "\n return array.col1 + array.col2" + ) + raise ValueError(msg) + + def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: + in_arguments = _func_args(self.fn, *args, **kwargs) + if self.needs is not None: + tobe_touched = set() + for arg in self.needs.keys(): + if arg in in_arguments: + tobe_touched.add(arg) + else: + msg = f"Argument '{arg}' is not present in the function signature." + raise ValueError(msg) + for arg in tobe_touched: + array = in_arguments[arg] + if not isinstance(array, ak.Array): + raise ValueError( + f"Can only touch columns of an awkward array, got {array}." + ) + if ak.backend(array) == "typetracer": + for slce in self.needs[arg]: + ak.typetracer.touch_data(array[slce]) + + if self.meta is not None: + ak_arrays = [ + arg for arg in in_arguments.values() if isinstance(arg, ak.Array) + ] + if all(ak.backend(arr) == "typetracer" for arr in ak_arrays): + return self.meta + return self.fn(*args, **kwargs) + + def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any: + fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args( + self.wrapped_fn, *args, traverse=self.traverse, **kwargs + ) + + arg_flat_deps_expanded = _replace_arrays_with_typetracers( + arg_flat_deps_expanded + ) + kwarg_flat_deps = _replace_arrays_with_typetracers(kwarg_flat_deps) + meta = _replace_arrays_with_typetracers(self.meta) + in_typetracing_mode = arg_flat_deps_expanded or kwarg_flat_deps or meta + + try: + hlg, meta, deps, name = _map_partitions_prepare( + fn, + *arg_flat_deps_expanded, + *kwarg_flat_deps, + label=self.label, + token=self.token, + meta=meta, + output_divisions=None, + ) + # handle the case where the function is not implemented for Dask arrays in dask-awkward + except DaskAwkwardNotImplemented as err: + raise err from None + # handle the case where the function is not traceable - for whatever reason + except Exception as err: + if in_typetracing_mode: + fn_args = _func_args(self.fn, *args, **kwargs) + sig_str = ", ".join(f"{k}={v}" for k, v in fn_args.items()) + msg = ( + f"Failed to trace the function '{self.fn}'. " + "You can use 'needs' and 'meta' to circumvent this step. " + "For this, it might be helpful to do a pre-run of the function:" + f"\n\n\tmeta, needs = dak.prerun({self.fn.__name__}, {sig_str})" + f"\n\nThis may help to infer the correct `needs` for `mapfilter`." + ) + raise UntraceableFunctionError(msg) from err + else: + raise err from None + + if len(deps) == 0: + raise ValueError("Need at least one input that is a dask collection.") + elif len(deps) == 1: + npart = deps[0].npartitions + else: + npart = deps[0].npartitions + if not all(dep.npartitions == npart for dep in deps): + msg = "All inputs must have the same number of partitions, got:" + for dep in deps: + npartitions = dep.npartitions + msg += f"\n{dep}: {npartitions=}" + raise ValueError(msg) + + return _multi_return_map_partitions( + hlg=hlg, + name=name, + meta=meta, + npartitions=npart, + ) diff --git a/tests/test_mapfilter.py b/tests/test_mapfilter.py new file mode 100644 index 00000000..83e6fa34 --- /dev/null +++ b/tests/test_mapfilter.py @@ -0,0 +1,100 @@ +from functools import partial + +import awkward as ak +import numpy as np +import pytest + +import dask_awkward as dak + + +def test_mapfilter_single_return(): + ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + dak_array = dak.from_awkward(ak_array, 2) + + @dak.mapfilter + def fun(x): + y = x.foo + 1 + return y + + assert ak.all( + fun(dak_array).compute() + == dak.map_partitions(fun.wrapped_fn, dak_array).compute() + ) + + +def test_mapfilter_multiple_return(): + ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + dak_array = dak.from_awkward(ak_array, 2) + + class some: ... + + @dak.mapfilter + def fun(x): + y = x.foo + 1 + return y, (np.sum(y),), some(), ak.Array(np.ones(4)) + + y, y_sum, something, arr = fun(dak_array) + + assert ak.all(y.compute() == ak_array.foo + 1) + assert np.all(y_sum.compute() == [np.array(5), np.array(9)]) + something = something.compute() + assert len(something) == 2 + assert all(isinstance(s, some) for s in something) + array = arr.compute() + assert len(array) == 8 + assert array.ndim == 1 + assert ak.all(array == ak.Array(np.ones(8))) + + +def test_mapfilter_needs_meta(): + ak_array = ak.zip( + { + "x": ak.zip({"foo": [10, 20, 30, 40], "bar": [10, 20, 30, 40]}), + "y": ak.zip({"foo": [1, 1, 1, 1], "bar": [1, 1, 1, 1]}), + "z": ak.zip({"a": [0, 0, 0, 0], "b": [2, 2, 2, 2]}), + } + ) + dak_array = dak.from_awkward(ak_array, 2) + + def untraceable_fun(muons): + # a non-traceable computation for ak.typetracer + # which needs "pt" column from muons and returns a 1-element array + muons.y.bar[...] + muons.z[...] + pt = ak.to_numpy(muons.x.foo) + return ak.Array([np.sum(pt)]) + + # first check that the function is not traceable + with pytest.raises(TypeError): + dak.map_partitions(untraceable_fun, dak_array) + + # now check that the necessary columns are reported correctly + wrap = partial( + dak.mapfilter, + needs={"muons": [("x", "foo"), ("z",), ("y", "bar")]}, + meta=ak.Array([0.0]), + ) + out = wrap(untraceable_fun)(dak_array) + cols = next(iter(dak.report_necessary_columns(out).values())) + assert cols == frozenset({"x.foo", "y.bar", "z.a", "z.b"}) + + +def test_mapfilter_multi_dak_array_inputs(): + ak_array1 = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + ak_array2 = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]}) + + @dak.mapfilter + def fun(x, y): + return x.foo + y.bar + + assert ak.all( + fun( + dak.from_awkward(ak_array1, 2), + dak.from_awkward(ak_array2, 2), + ).compute() + == ak.Array([2, 3, 4, 5]) + ) + + # test incompatible partitioning + with pytest.raises(ValueError): + fun(dak.from_awkward(ak_array1, 1), dak.from_awkward(ak_array2, 2))