diff --git a/.github/workflows/awkward-main.yml b/.github/workflows/awkward-main.yml
index b2ce95db..f1c97d05 100644
--- a/.github/workflows/awkward-main.yml
+++ b/.github/workflows/awkward-main.yml
@@ -29,7 +29,7 @@ jobs:
run: |
python3 -m pip install pip wheel -U
python3 -m pip install -q --no-cache-dir -e .[complete,test]
- python3 -m pip uninstall -y awkward && pip install git+https://github.com/scikit-hep/awkward.git@main
+ python3 -m pip uninstall -y awkward && pip install git+https://github.com/scikit-hep/awkward.git@main --no-deps
- name: Run tests
run: |
python3 -m pytest
diff --git a/.gitignore b/.gitignore
index 81601c75..77df5369 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,3 +130,6 @@ venv.bak/
# mypy
.mypy_cache/
+
+# pyright lsp
+pyrightconfig.json
diff --git a/docs/api/mapfilter.rst b/docs/api/mapfilter.rst
new file mode 100644
index 00000000..1e572d27
--- /dev/null
+++ b/docs/api/mapfilter.rst
@@ -0,0 +1,15 @@
+mapfilter
+---------
+
+.. currentmodule:: dask_awkward
+
+.. autosummary::
+ :toctree: generated/
+
+ mapfilter
+ prerun
+
+.. raw:: html
+
+
diff --git a/docs/how-to/mapfilter.rst b/docs/how-to/mapfilter.rst
new file mode 100644
index 00000000..94c9057f
--- /dev/null
+++ b/docs/how-to/mapfilter.rst
@@ -0,0 +1,210 @@
+mapfilter
+---------
+
+:func:`dask_awkward.mapfilter` is a function that applies a function to each partition of
+dask-awkward collections (:class:`dask_awkward.Array`). It maps the given function
+over each partition in the provided collections in an embarrassingly parallel way. The input collections
+must have the same number of partitions.
+
+An example is shown below:
+
+.. code-block:: python
+
+ import dask_awkward as dak
+ import awkward as ak
+
+ # Create a dask-awkward array
+ x = ak.Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
+ dak_array = dak.from_awkward(x, npartitions=2)
+
+
+ # Define a function to apply to each partition
+ @dak.mapfilter
+ def add_one(array):
+ return array + 1
+
+
+ # Apply the function to each partition
+ result = add_one(dak_array)
+
+ # Compute the result
+ result.compute()
+ #
+
+
+Here, the ``dak_array`` has two partitions, and :func:`dask_awkward.mapfilter` will
+apply the ``add_one`` function to each partition in parallel - resulting in two tasks in total (for the low-level graph).
+
+.. warning::
+
+ Since the mapped function is applied to each partition, the function must use eager awkward-array operations.
+ It is not possible to use (lazy) dask-awkward operations inside.
+
+
+Collapsing Lazy Operations
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The main purpose of :func:`dask_awkward.mapfilter` is to merge nodes into a single node
+in the highlevel dask graph. This is useful to keep the graph small and avoid unnecessary scheduling overhead.
+
+*Any* function that is given to :func:`dask_awkward.mapfilter` will become a *single* node in the highlevel dask graph,
+no matter how many operations are performed inside.
+
+An example is given in the following:
+
+.. code-block:: python
+
+ import dask_awkward as dak
+ import awkward as ak
+
+ # Create a dask-awkward array
+ x = ak.Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
+ dak_array = dak.from_awkward(x, npartitions=2)
+
+
+ # Define a function to apply to each partition
+ @dak.mapfilter
+ def fun(array):
+ return np.sin(array**2 + 1)
+
+
+ # Apply the function to each partition
+ result = fun(dak_array)
+
+ # Inspect the graph
+ result.dask
+ # HighLevelGraph with 2 layers.
+ #
+ # 0. from-awkward-25967e11ca4677388b80cfb6f556d752
+ # 1.
+ b.compute()
+ #
+ c.compute()
+ # (<__main__.some at 0x10b5819c0>, <__main__.some at 0x10b580dc0>)
+
+
+Untraceable Functions
+^^^^^^^^^^^^^^^^^^^^^
+
+Sometimes one needs to leave the awkward-array world and use some operations that are not traceable
+by awkward's typetracer. In this case :func:`dask_awkward.mapfilter` can be used to apply the function
+to each partition nevertheless. One needs to provide the ``meta`` and ``needs`` arguments to :func:`dask_awkward.mapfilter`
+to enable this:
+
+* ``meta``: The meta information of the output values
+* ``needs``: A mapping that specifies an iterable of columns mapped to :class:`dask_awkward.Array` input arguments
+
+An example is given in the following:
+
+.. code-block:: python
+
+ ak_array = ak.zip(
+ {
+ "x": ak.zip({"foo": [10, 20, 30, 40], "bar": [10, 20, 30, 40]}),
+ }
+ )
+ dak_array = dak.from_awkward(ak_array, 2)
+
+
+ def untraceable_fun(array):
+ foo = ak.to_numpy(array.x.foo)
+ return ak.Array([np.sum(foo)])
+
+
+ dak.mapfilter(untraceable_fun)(dak_array)
+ # ...
+ # TypeError: Converting from an nplike without known data to an nplike with known data is not supported
+ #
+ # This error occurred while calling
+ #
+ # ak.to_numpy(
+ #
+ # )
+ #
+ # The above exception was the direct cause of the following exception:
+ # ...
+
+ # Now let's add `meta` and `needs` arguments
+ from functools import partial
+
+ mapf = partial(dak.mapfilter, needs={"array": [("x", "foo")]}, meta=ak.Array([0, 0]))
+
+ # It works now!
+ mapf(untraceable_fun)(dak_array).compute()
+ #
+
+In fact, providing ``meta`` and ``needs`` is entirely skipping the tracing step as both arguments provide all necessary information already.
+In cases where the function is much more complex and not traceable it can be helpful to run the tracing step manually:
+
+.. code-block:: python
+
+ meta, needs = dak.prerun(untraceable_fun, array=dak_array)
+ # ...
+ # UntraceableFunctionError: '' is not traceable, an error occurred at line 9. 'dak.mapfilter' can circumvent this by providing 'needs' and 'meta' arguments to it.
+ #
+ # - 'needs': mapping where the keys point to input argument dask_awkward arrays and the values to columns that should be touched explicitly. The typetracing step could determine the following necessary columns until the exception occurred:
+ #
+ # needs={'array': [('x', 'foo')]}
+ #
+ # - 'meta': value(s) of what the wrapped function would return. For arrays, only the shape and type matter.
+
+Here, :func:`dask_awkward.prerun` will try to trace the function once and return the necessary information (``meta`` and ``needs``) to provide to :func:`dask_awkward.mapfilter`.
+In this case the function is untraceable, so :func:`dask_awkward.prerun` will report at least ``needs`` to the point where the function is not traceable anymore.
+
+.. tip::
+
+ For traceable but long-running functions (e.g. if the contain the evaluation of a neural network), it is recommended to use :func:`dask_awkward.prerun` to infer ``meta`` and ``needs`` once,
+ and provide it to all consecutive :func:`dask_awkward.mapfilter` calls. This way, the tracing step is only performed once.
diff --git a/docs/index.rst b/docs/index.rst
index 815d85fd..f91fa23d 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -32,6 +32,7 @@ Table of Contents
how-to/configuration.rst
how-to/io.rst
how-to/behaviors.rst
+ how-to/mapfilter.rst
.. toctree::
:maxdepth: 1
@@ -47,6 +48,7 @@ Table of Contents
api/collections.rst
api/inspect.rst
api/io.rst
+ api/mapfilter.rst
api/reducers.rst
api/structure.rst
api/behavior.rst
diff --git a/src/dask_awkward/__init__.py b/src/dask_awkward/__init__.py
index 34b5c4f5..c4c5396e 100644
--- a/src/dask_awkward/__init__.py
+++ b/src/dask_awkward/__init__.py
@@ -24,6 +24,7 @@
report_necessary_columns,
sample,
)
+from dask_awkward.lib.mapfilter import mapfilter, prerun
necessary_columns = report_necessary_columns # Export for backwards compatibility.
diff --git a/src/dask_awkward/lib/__init__.py b/src/dask_awkward/lib/__init__.py
index 74d16d6c..7be61ce9 100644
--- a/src/dask_awkward/lib/__init__.py
+++ b/src/dask_awkward/lib/__init__.py
@@ -27,6 +27,7 @@
from dask_awkward.lib.io.json import from_json, to_json
from dask_awkward.lib.io.parquet import from_parquet, to_parquet
from dask_awkward.lib.io.text import from_text
+from dask_awkward.lib.mapfilter import mapfilter, prerun
from dask_awkward.lib.operations import concatenate
from dask_awkward.lib.reducers import (
all,
diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py
index d7c1a4e0..650749aa 100644
--- a/src/dask_awkward/lib/core.py
+++ b/src/dask_awkward/lib/core.py
@@ -1982,7 +1982,21 @@ def __call__(self, *args_deps_expanded):
return self.fn(*args, **kwargs)
-def _map_partitions(
+def _new_dak_array_divisions(
+ dak_array: Array, output_divisions: int | None = None
+) -> tuple:
+ in_divisions = dak_array.divisions
+ if output_divisions is not None:
+ if output_divisions == 1:
+ new_divisions = dak_array.divisions
+ else:
+ new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions)) # type: ignore[operator]
+ else:
+ new_divisions = in_divisions
+ return new_divisions
+
+
+def _map_partitions_prepare(
fn: Callable,
*args: Any,
label: str | None = None,
@@ -1990,23 +2004,13 @@ def _map_partitions(
meta: Any | None = None,
output_divisions: int | None = None,
**kwargs: Any,
-) -> Array:
- """Map a callable across all partitions of any number of collections.
- No wrapper is used to flatten the function arguments. This is meant for
- dask-awkward internal use or in situations where input data are sanitized.
-
- The parameters of this function are otherwise the same as map_partitions,
- but the limitation that args, kwargs must be non-nested and flat. They
- will not be traversed to extract all dask collections, except those in
- the first dimension of args or kwargs.
- """
+) -> tuple:
token = token or tokenize(fn, *args, output_divisions, **kwargs)
label = hyphenize(label or funcname(fn))
name = f"{label}-{token}"
deps = [a for a in args if is_dask_collection(a)] + [
v for v in kwargs.values() if is_dask_collection(v)
]
- dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))
if name in dak_cache:
hlg, meta = dak_cache[name]
@@ -2027,22 +2031,46 @@ def _map_partitions(
dependencies=deps,
)
- if len(dak_arrays) == 0:
- raise TypeError(
- "at least one argument passed to map_partitions "
- "should be a dask_awkward.Array collection."
- )
dak_cache[name] = hlg, meta
- in_npartitions = dak_arrays[0].npartitions
- in_divisions = dak_arrays[0].divisions
- if output_divisions is not None:
- if output_divisions == 1:
- new_divisions = dak_arrays[0].divisions
- else:
- new_divisions = tuple(map(lambda x: x * output_divisions, in_divisions))
- else:
- new_divisions = in_divisions
+ return hlg, meta, deps, name
+
+def _map_partitions(
+ fn: Callable,
+ *args: Any,
+ label: str | None = None,
+ token: str | None = None,
+ meta: Any | None = None,
+ output_divisions: int | None = None,
+ **kwargs: Any,
+) -> Array:
+ """Map a callable across all partitions of any number of collections.
+ No wrapper is used to flatten the function arguments. This is meant for
+ dask-awkward internal use or in situations where input data are sanitized.
+
+ The parameters of this function are otherwise the same as map_partitions,
+ but the limitation that args, kwargs must be non-nested and flat. They
+ will not be traversed to extract all dask collections, except those in
+ the first dimension of args or kwargs.
+ """
+ hlg, meta, deps, name = _map_partitions_prepare(
+ fn,
+ *args,
+ label=label,
+ token=token,
+ meta=meta,
+ output_divisions=output_divisions,
+ **kwargs,
+ )
+ dak_arrays = tuple(filter(lambda x: isinstance(x, Array), deps))
+ if len(dak_arrays) == 0:
+ raise TypeError(
+ "at least one argument passed to map_partitions "
+ "should be a dask_awkward.Array collection."
+ )
+ first = dak_arrays[0]
+ new_divisions = _new_dak_array_divisions(first, output_divisions)
+ # from IPython import embed;embed()
if output_divisions is not None:
return new_array_object(
hlg,
@@ -2055,10 +2083,62 @@ def _map_partitions(
hlg,
name=name,
meta=meta,
- npartitions=in_npartitions,
+ npartitions=first.npartitions,
)
+def _to_packed_fn_args(
+ base_fn: Callable,
+ *args: Any,
+ traverse: bool = True,
+ **kwargs: Any,
+) -> tuple:
+ opt_touch_all = kwargs.pop("opt_touch_all", None)
+ if opt_touch_all is not None:
+ warnings.warn(
+ "The opt_touch_all argument does nothing.\n"
+ "This warning will be removed in a future version of dask-awkward "
+ "and the function call will likely fail."
+ )
+
+ kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse)
+ flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse)
+
+ if len(flat_deps) == 0:
+ message = (
+ "map_partitions expects at least one Dask collection instance, "
+ "you are passing non-Dask collections to dask-awkward code.\n"
+ "observed argument types:\n"
+ )
+ for arg in args:
+ message += f"- {type(arg)}"
+ raise TypeError(message)
+
+ arg_flat_deps_expanded = []
+ arg_repackers = []
+ arg_lens_for_repackers = []
+ for arg in args:
+ this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse)
+ if (
+ len(this_arg_flat_deps) > 0
+ ): # if the deps list is empty this arg does not contain any dask collection, no need to repack!
+ arg_flat_deps_expanded.extend(this_arg_flat_deps)
+ arg_repackers.append(repacker)
+ arg_lens_for_repackers.append(len(this_arg_flat_deps))
+ else:
+ arg_flat_deps_expanded.append(arg)
+ arg_repackers.append(None)
+ arg_lens_for_repackers.append(1)
+
+ packed_fn = ArgsKwargsPackedFunction(
+ base_fn,
+ arg_repackers,
+ kwarg_repacker,
+ arg_lens_for_repackers,
+ )
+ return packed_fn, arg_flat_deps_expanded, kwarg_flat_deps
+
+
def map_partitions(
base_fn: Callable,
*args: Any,
@@ -2139,49 +2219,8 @@ def map_partitions(
This is effectively the same as `d = c * a`
"""
-
- opt_touch_all = kwargs.pop("opt_touch_all", None)
- if opt_touch_all is not None:
- warnings.warn(
- "The opt_touch_all argument does nothing.\n"
- "This warning will be removed in a future version of dask-awkward "
- "and the function call will likely fail."
- )
-
- kwarg_flat_deps, kwarg_repacker = unpack_collections(kwargs, traverse=traverse)
- flat_deps, _ = unpack_collections(*args, *kwargs.values(), traverse=traverse)
-
- if len(flat_deps) == 0:
- message = (
- "map_partitions expects at least one Dask collection instance, "
- "you are passing non-Dask collections to dask-awkward code.\n"
- "observed argument types:\n"
- )
- for arg in args:
- message += f"- {type(arg)}"
- raise TypeError(message)
-
- arg_flat_deps_expanded = []
- arg_repackers = []
- arg_lens_for_repackers = []
- for arg in args:
- this_arg_flat_deps, repacker = unpack_collections(arg, traverse=traverse)
- if (
- len(this_arg_flat_deps) > 0
- ): # if the deps list is empty this arg does not contain any dask collection, no need to repack!
- arg_flat_deps_expanded.extend(this_arg_flat_deps)
- arg_repackers.append(repacker)
- arg_lens_for_repackers.append(len(this_arg_flat_deps))
- else:
- arg_flat_deps_expanded.append(arg)
- arg_repackers.append(None)
- arg_lens_for_repackers.append(1)
-
- fn = ArgsKwargsPackedFunction(
- base_fn,
- arg_repackers,
- kwarg_repacker,
- arg_lens_for_repackers,
+ fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args(
+ base_fn, *args, traverse=traverse, **kwargs
)
return _map_partitions(
fn,
diff --git a/src/dask_awkward/lib/io/parquet.py b/src/dask_awkward/lib/io/parquet.py
index 2fc47c25..9ff2a58d 100644
--- a/src/dask_awkward/lib/io/parquet.py
+++ b/src/dask_awkward/lib/io/parquet.py
@@ -483,6 +483,7 @@ def __init__(
npartitions: int,
prefix: str | None = None,
storage_options: dict | None = None,
+ write_metadata: bool = False,
**kwargs: Any,
):
self.fs = fs
@@ -496,6 +497,7 @@ def __init__(
if isinstance(self.fs.protocol, str)
else self.fs.protocol[0]
)
+ self.write_metadata = write_metadata
self.kwargs = kwargs
def __call__(self, data, block_index):
@@ -503,9 +505,11 @@ def __call__(self, data, block_index):
if self.prefix is not None:
filename = f"{self.prefix}-{filename}"
filename = self.fs.unstrip_protocol(f"{self.path}{self.fs.sep}{filename}")
- return ak.to_parquet(
+ out = ak.to_parquet(
data, filename, **self.kwargs, storage_options=self.storage_options
)
+ if self.write_metadata:
+ return out
def to_parquet(
@@ -597,7 +601,10 @@ def to_parquet(
storage_options
Storage options passed to ``fsspec``.
write_metadata
- Write Parquet metadata.
+ Write Parquet metadata. Note, that when this is True, all the
+ metadata pieces will be pulled into a single finalizer task. When
+ False, the whole write graph can be evaluated as a more efficient
+ tree reduction.
compute
If ``True``, immediately compute the result (write data to
disk). If ``False`` a Scalar collection will be returned such
@@ -667,6 +674,7 @@ def to_parquet(
parquet_old_int96_timestamps=parquet_old_int96_timestamps,
parquet_compliant_nested=parquet_compliant_nested,
parquet_extra_options=parquet_extra_options,
+ write_metadata=write_metadata,
),
array,
BlockIndex((array.npartitions,)),
@@ -681,17 +689,38 @@ def to_parquet(
dsk[(final_name, 0)] = (_metadata_file_from_metas, fs, path) + tuple(
map_res.__dask_keys__()
)
+ graph = HighLevelGraph.from_collections(
+ final_name,
+ AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
+ dependencies=[map_res],
+ )
+ out = new_scalar_object(graph, final_name, dtype="f8")
else:
final_name = name + "-finalize"
- dsk[(final_name, 0)] = (lambda *_: None, map_res.__dask_keys__())
- graph = HighLevelGraph.from_collections(
- final_name,
- AwkwardMaterializedLayer(dsk, previous_layer_names=[map_res.name]),
- dependencies=[map_res],
- )
- out = new_scalar_object(graph, final_name, dtype="f8")
+ from dask_awkward.layers import AwkwardTreeReductionLayer
+
+ layer = AwkwardTreeReductionLayer(
+ name=final_name,
+ concat_func=none_to_none,
+ tree_node_func=none_to_none,
+ name_input=map_res.name,
+ npartitions_input=map_res.npartitions,
+ finalize_func=none_to_none,
+ )
+ graph = HighLevelGraph.from_collections(
+ final_name,
+ layer,
+ dependencies=[map_res],
+ )
+ out = new_scalar_object(graph, final_name, dtype="f8")
+
if compute:
out.compute()
return None
else:
return out
+
+
+def none_to_none(*_):
+ """Dummy reduction function where write tasks produce no metadata"""
+ return None
diff --git a/src/dask_awkward/lib/mapfilter.py b/src/dask_awkward/lib/mapfilter.py
new file mode 100644
index 00000000..f6603474
--- /dev/null
+++ b/src/dask_awkward/lib/mapfilter.py
@@ -0,0 +1,430 @@
+from __future__ import annotations
+
+import typing as tp
+from dataclasses import dataclass
+
+import awkward as ak
+from dask.highlevelgraph import HighLevelGraph
+from dask.typing import DaskCollection
+
+from dask_awkward.lib.core import Array as DakArray
+from dask_awkward.lib.core import (
+ _map_partitions_prepare,
+ _to_packed_fn_args,
+ dak_cache,
+ empty_typetracer,
+ new_array_object,
+ partitionwise_layer,
+ to_meta,
+ typetracer_array,
+)
+from dask_awkward.utils import DaskAwkwardNotImplemented
+
+
+def _single_return_map_partitions(
+ hlg: HighLevelGraph,
+ name: str,
+ meta: tp.Any,
+ npartitions: int,
+) -> tp.Any:
+ from dask.utils import (
+ is_arraylike,
+ is_dataframe_like,
+ is_index_like,
+ is_series_like,
+ )
+
+ # ak.Array (this is dak.map_partitions case)
+ if isinstance(meta, ak.Array):
+ # convert to typetracer if not already
+ # this happens when the user provides a concrete array (e.g. np.array)
+ # and then wraps it with ak.Array as a return type
+ if not ak.backend(meta) == "typetracer":
+ meta = ak.to_backend(meta, "typetracer")
+ return new_array_object(
+ hlg,
+ name=name,
+ meta=meta,
+ npartitions=npartitions,
+ )
+ # TODO: array, dataframe, series, index
+ elif (
+ is_arraylike(meta)
+ or is_dataframe_like(meta)
+ or is_series_like(meta)
+ or is_index_like(meta)
+ ):
+ msg = (
+ f"{meta=} is not (yet) supported as return type. If possible, "
+ "you can convert it to ak.Array, or wrap it with a python container."
+ )
+ raise NotImplementedError(msg)
+ # don't know? -> put it in a bag
+ else:
+ from dask.bag.core import Bag
+
+ return Bag(dsk=hlg, name=name, npartitions=npartitions)
+
+
+def _multi_return_map_partitions(
+ hlg: HighLevelGraph,
+ name: str,
+ meta: tp.Any,
+ npartitions: int,
+) -> tp.Any:
+ # single-return case, this is equal to `dak.map_partitions`
+ # but supports other DaskCollections in addition
+ if not isinstance(meta, tuple):
+ return _single_return_map_partitions(
+ hlg=hlg,
+ name=name,
+ meta=meta,
+ npartitions=npartitions,
+ )
+ # multi-return case
+ else:
+ from operator import itemgetter
+ from typing import cast
+
+ # create tmp dask collection for HLG creation
+ tmp = new_array_object(
+ hlg, name=name, meta=empty_typetracer(), npartitions=npartitions
+ )
+
+ ret = []
+ for i, m_pick in enumerate(meta):
+ # add a "select/pick" layer
+ # to get the ith element of the output
+ ith_name = f"{name}-pick-{i}th"
+
+ if ith_name in dak_cache:
+ hlg_pick, m_pick = dak_cache[ith_name]
+ else:
+ lay_pick = partitionwise_layer(itemgetter(i), ith_name, tmp)
+ hlg_pick = HighLevelGraph.from_collections(
+ name=ith_name,
+ layer=lay_pick,
+ dependencies=[cast(DaskCollection, tmp)],
+ )
+ dak_cache[ith_name] = hlg_pick, m_pick
+ ret.append(
+ _single_return_map_partitions(
+ hlg=hlg_pick,
+ name=ith_name,
+ meta=m_pick,
+ npartitions=npartitions,
+ )
+ )
+ return tuple(ret)
+
+
+class UntraceableFunctionError(Exception): ...
+
+
+def _func_args(fn: tp.Callable, *args: tp.Any, **kwargs: tp.Any) -> tp.Mapping:
+ import inspect
+
+ ba = inspect.signature(fn).bind(*args, **kwargs)
+ return ba.arguments
+
+
+def _reports2needs(reports: tp.Mapping) -> dict:
+ import ast
+ from collections import defaultdict
+
+ needs = defaultdict(list)
+ for arg, report in reports.items():
+ # this should maybe be differently treated?
+ keys = set(report.shape_touched) | set(report.data_touched)
+ for key in keys:
+ slce = ast.literal_eval(key)
+ # only strings are actual slice paths to columns,
+ # `None` or `ints` are path values to non-record array types,
+ # see: https://github.com/scikit-hep/awkward/pull/3311
+ slce = tuple(it for it in slce if isinstance(it, str))
+ needs[arg].append(slce)
+ return needs
+
+
+def _replace_arrays_with_typetracers(meta: tp.Any) -> tp.Any:
+ def _to_tracer(meta: tp.Any) -> tp.Any:
+ if isinstance(meta, ak.Array):
+ if not ak.backend(meta) == "typetracer":
+ meta = typetracer_array(meta)
+ elif isinstance(meta, DakArray):
+ meta = to_meta([meta])[0]
+ return meta
+
+ if isinstance(meta, tuple):
+ meta = tuple(map(_to_tracer, meta))
+ else:
+ meta = _to_tracer(meta)
+ return meta
+
+
+def prerun(
+ fn: tp.Callable, *args: tp.Any, **kwargs: tp.Any
+) -> tuple[tp.Any, tp.Mapping]:
+ """
+ Pre-runs the provided function with typetracer arrays to determine the necessary columns
+ that should be touched explicitly and to infer the metadata of the function's output.
+
+ Parameters
+ ----------
+ fn : Callable
+ The function to be pre-run.
+ *args : Any
+ Positional arguments to be passed to the function.
+ **kwargs : Any
+ Keyword arguments to be passed to the function.
+
+ Returns
+ -------
+ tuple[Any, Mapping]
+ A tuple containing the output of the function when run with typetracer arrays and
+ a mapping of the touched columns (prepared to use with ``mapfilter(needs=...)``) generated during the typetracing step.
+
+ Examples
+ --------
+ >>> import awkward as ak
+ >>> import dask_awkward as dak
+ >>>
+ >>> ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+ >>> dak_array = dak.from_awkward(ak_array, 2)
+ >>>
+ >>> def process(array: ak.Array) -> ak.Array:
+ >>> return array.foo + array.bar
+ >>>
+ >>> meta, needs = dak.prerun(process, array)
+ >>> print(meta)
+
+ >>> print(needs)
+ {'array': [('bar',), ('foo',)]}
+ """
+ # unwrap `mapfilter`
+ if isinstance(fn, mapfilter):
+ fn = fn.fn
+
+ in_arguments = _func_args(fn, *args, **kwargs)
+
+ # replace ak.Arrays with typetracers and store the reports
+ reports = {}
+ fun_kwargs = {}
+ args_metas = {
+ arg: _replace_arrays_with_typetracers(val) for arg, val in in_arguments.items()
+ }
+
+ # can't typetrace if no ak.Arrays are present
+ ak_arrays = tuple(filter(lambda x: isinstance(x, ak.Array), args_metas.values()))
+ if not ak_arrays:
+ return None, {}
+
+ def _render_buffer_key(
+ form: ak.forms.Form,
+ form_key: str,
+ attribute: str,
+ ) -> str:
+ return form_key
+
+ # prepare function arguments
+ for arg, val in args_metas.items():
+ if isinstance(val, ak.Array):
+ if not ak.backend(val) == "typetracer":
+ val = typetracer_array(val)
+ tracer, report = ak.typetracer.typetracer_with_report(
+ val.layout.form_with_key_path(root=()),
+ highlevel=True,
+ behavior=val.behavior,
+ attrs=val.attrs,
+ buffer_key=_render_buffer_key,
+ )
+ reports[arg] = report
+ fun_kwargs[arg] = tracer
+ else:
+ fun_kwargs[arg] = val
+
+ # try to run the function once with typetracers
+ try:
+ out = fn(**fun_kwargs)
+ except Exception as err:
+ import traceback
+
+ # get line number of where the error occurred in the provided function
+ # traceback 0: this function, 1: the provided function, >1: the rest of the stack
+ tb = traceback.extract_tb(err.__traceback__)
+ line_number = tb[1].lineno
+
+ # add also the reports of the typetracer to the error message,
+ # and format them as 'needs' wants it to be
+ needs = dict(_reports2needs(reports=reports))
+
+ msg = (
+ f"'{fn}' is not traceable, an error occurred at line {line_number}. "
+ "'dak.mapfilter' can circumvent this by providing 'needs' and "
+ "'meta' arguments to it.\n\n"
+ "- 'needs': mapping where the keys point to input argument "
+ "dask_awkward arrays and the values to columns that should "
+ "be touched explicitly. The typetracing step could determine "
+ "the following necessary columns until the exception occurred:\n\n"
+ f"{needs=}\n\n"
+ "- 'meta': value(s) of what the wrapped function would "
+ "return. For arrays, only the shape and type matter."
+ )
+ raise UntraceableFunctionError(msg) from err
+ return out, dict(_reports2needs(reports))
+
+
+@dataclass
+class mapfilter:
+ """
+ A decorator to map a callable across all partitions of any number of collections.
+ The function will be treated as a single node in the Dask graph.
+
+ Parameters
+ ----------
+ fn : Callable
+ The function to apply on all partitions. This will get wrapped to handle kwargs, including Dask collections.
+ label : str, optional
+ Label for the Dask graph layer; if left as ``None`` (default), the name of the function will be used.
+ token : str, optional
+ Provide an already defined token. If ``None``, a new token will be generated.
+ meta : Any, optional
+ Metadata for the result (if known). If unknown, `fn` will be applied to the metadata of the `args`.
+ If provided, the tracing step will be skipped and the provided metadata is used as return value(s) of `fn`.
+ traverse : bool
+ Unpack basic Python containers to find Dask collections.
+ needs : dict, optional
+ If ``None`` (the default), nothing is touched in addition to the standard typetracer report.
+ In certain cases, it is necessary to touch additional objects **explicitly** to get the correct typetracer report.
+ For this, provide a dictionary that maps input arguments that are arrays to the columns of that array that should be touched.
+ If ``needs`` is used together with ``meta``, **only** the columns provided by the ``needs`` argument will be touched explicitly.
+
+ Examples
+ --------
+
+ .. code-block:: ipython
+
+ import awkward as ak
+ import dask_awkward as dak
+
+ ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+ dak_array = dak.from_awkward(ak_array, 2)
+
+ @dak.mapfilter
+ def process(array: ak.Array) -> ak.Array:
+ return array.foo * 2
+
+ out = process(dak_array)
+ print(out.compute())
+ #
+ """
+
+ fn: tp.Callable
+ label: str | None = None
+ token: str | None = None
+ meta: tp.Any | None = None
+ traverse: bool = True
+ needs: tp.Mapping | None = None
+
+ def __post_init__(self) -> None:
+ if self.needs is not None and not isinstance(self.needs, tp.Mapping):
+ msg = ( # type: ignore[unreachable]
+ "'needs' argument must be a mapping where the keys "
+ "point to input argument dask_awkward arrays and the values "
+ "to columns that should be touched explicitly, "
+ f"got {self.needs!r} instead.\n\n"
+ "Exemplary usage:\n"
+ "\n@partial(mapfilter, needs={'array': ['col1', 'col2']})"
+ "\ndef process(array: ak.Array) -> ak.Array:"
+ "\n return array.col1 + array.col2"
+ )
+ raise ValueError(msg)
+
+ def wrapped_fn(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
+ in_arguments = _func_args(self.fn, *args, **kwargs)
+ if self.needs is not None:
+ tobe_touched = set()
+ for arg in self.needs.keys():
+ if arg in in_arguments:
+ tobe_touched.add(arg)
+ else:
+ msg = f"Argument '{arg}' is not present in the function signature."
+ raise ValueError(msg)
+ for arg in tobe_touched:
+ array = in_arguments[arg]
+ if not isinstance(array, ak.Array):
+ raise ValueError(
+ f"Can only touch columns of an awkward array, got {array}."
+ )
+ if ak.backend(array) == "typetracer":
+ for slce in self.needs[arg]:
+ ak.typetracer.touch_data(array[slce])
+
+ if self.meta is not None:
+ ak_arrays = [
+ arg for arg in in_arguments.values() if isinstance(arg, ak.Array)
+ ]
+ if all(ak.backend(arr) == "typetracer" for arr in ak_arrays):
+ return self.meta
+ return self.fn(*args, **kwargs)
+
+ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Any:
+ fn, arg_flat_deps_expanded, kwarg_flat_deps = _to_packed_fn_args(
+ self.wrapped_fn, *args, traverse=self.traverse, **kwargs
+ )
+
+ arg_flat_deps_expanded = _replace_arrays_with_typetracers(
+ arg_flat_deps_expanded
+ )
+ kwarg_flat_deps = _replace_arrays_with_typetracers(kwarg_flat_deps)
+ meta = _replace_arrays_with_typetracers(self.meta)
+ in_typetracing_mode = arg_flat_deps_expanded or kwarg_flat_deps or meta
+
+ try:
+ hlg, meta, deps, name = _map_partitions_prepare(
+ fn,
+ *arg_flat_deps_expanded,
+ *kwarg_flat_deps,
+ label=self.label,
+ token=self.token,
+ meta=meta,
+ output_divisions=None,
+ )
+ # handle the case where the function is not implemented for Dask arrays in dask-awkward
+ except DaskAwkwardNotImplemented as err:
+ raise err from None
+ # handle the case where the function is not traceable - for whatever reason
+ except Exception as err:
+ if in_typetracing_mode:
+ fn_args = _func_args(self.fn, *args, **kwargs)
+ sig_str = ", ".join(f"{k}={v}" for k, v in fn_args.items())
+ msg = (
+ f"Failed to trace the function '{self.fn}'. "
+ "You can use 'needs' and 'meta' to circumvent this step. "
+ "For this, it might be helpful to do a pre-run of the function:"
+ f"\n\n\tmeta, needs = dak.prerun({self.fn.__name__}, {sig_str})"
+ f"\n\nThis may help to infer the correct `needs` for `mapfilter`."
+ )
+ raise UntraceableFunctionError(msg) from err
+ else:
+ raise err from None
+
+ if len(deps) == 0:
+ raise ValueError("Need at least one input that is a dask collection.")
+ elif len(deps) == 1:
+ npart = deps[0].npartitions
+ else:
+ npart = deps[0].npartitions
+ if not all(dep.npartitions == npart for dep in deps):
+ msg = "All inputs must have the same number of partitions, got:"
+ for dep in deps:
+ npartitions = dep.npartitions
+ msg += f"\n{dep}: {npartitions=}"
+ raise ValueError(msg)
+
+ return _multi_return_map_partitions(
+ hlg=hlg,
+ name=name,
+ meta=meta,
+ npartitions=npart,
+ )
diff --git a/tests/test_mapfilter.py b/tests/test_mapfilter.py
new file mode 100644
index 00000000..83e6fa34
--- /dev/null
+++ b/tests/test_mapfilter.py
@@ -0,0 +1,100 @@
+from functools import partial
+
+import awkward as ak
+import numpy as np
+import pytest
+
+import dask_awkward as dak
+
+
+def test_mapfilter_single_return():
+ ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+ dak_array = dak.from_awkward(ak_array, 2)
+
+ @dak.mapfilter
+ def fun(x):
+ y = x.foo + 1
+ return y
+
+ assert ak.all(
+ fun(dak_array).compute()
+ == dak.map_partitions(fun.wrapped_fn, dak_array).compute()
+ )
+
+
+def test_mapfilter_multiple_return():
+ ak_array = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+ dak_array = dak.from_awkward(ak_array, 2)
+
+ class some: ...
+
+ @dak.mapfilter
+ def fun(x):
+ y = x.foo + 1
+ return y, (np.sum(y),), some(), ak.Array(np.ones(4))
+
+ y, y_sum, something, arr = fun(dak_array)
+
+ assert ak.all(y.compute() == ak_array.foo + 1)
+ assert np.all(y_sum.compute() == [np.array(5), np.array(9)])
+ something = something.compute()
+ assert len(something) == 2
+ assert all(isinstance(s, some) for s in something)
+ array = arr.compute()
+ assert len(array) == 8
+ assert array.ndim == 1
+ assert ak.all(array == ak.Array(np.ones(8)))
+
+
+def test_mapfilter_needs_meta():
+ ak_array = ak.zip(
+ {
+ "x": ak.zip({"foo": [10, 20, 30, 40], "bar": [10, 20, 30, 40]}),
+ "y": ak.zip({"foo": [1, 1, 1, 1], "bar": [1, 1, 1, 1]}),
+ "z": ak.zip({"a": [0, 0, 0, 0], "b": [2, 2, 2, 2]}),
+ }
+ )
+ dak_array = dak.from_awkward(ak_array, 2)
+
+ def untraceable_fun(muons):
+ # a non-traceable computation for ak.typetracer
+ # which needs "pt" column from muons and returns a 1-element array
+ muons.y.bar[...]
+ muons.z[...]
+ pt = ak.to_numpy(muons.x.foo)
+ return ak.Array([np.sum(pt)])
+
+ # first check that the function is not traceable
+ with pytest.raises(TypeError):
+ dak.map_partitions(untraceable_fun, dak_array)
+
+ # now check that the necessary columns are reported correctly
+ wrap = partial(
+ dak.mapfilter,
+ needs={"muons": [("x", "foo"), ("z",), ("y", "bar")]},
+ meta=ak.Array([0.0]),
+ )
+ out = wrap(untraceable_fun)(dak_array)
+ cols = next(iter(dak.report_necessary_columns(out).values()))
+ assert cols == frozenset({"x.foo", "y.bar", "z.a", "z.b"})
+
+
+def test_mapfilter_multi_dak_array_inputs():
+ ak_array1 = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+ ak_array2 = ak.zip({"foo": [1, 2, 3, 4], "bar": [1, 1, 1, 1]})
+
+ @dak.mapfilter
+ def fun(x, y):
+ return x.foo + y.bar
+
+ assert ak.all(
+ fun(
+ dak.from_awkward(ak_array1, 2),
+ dak.from_awkward(ak_array2, 2),
+ ).compute()
+ == ak.Array([2, 3, 4, 5])
+ )
+
+ # test incompatible partitioning
+ with pytest.raises(ValueError):
+ fun(dak.from_awkward(ak_array1, 1), dak.from_awkward(ak_array2, 2))