Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Feb 11, 2025
1 parent 7a62630 commit 1f6a69a
Showing 1 changed file with 23 additions and 80 deletions.
103 changes: 23 additions & 80 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,7 +3281,7 @@ def _map_single(
class NumExamplesMismatchError(Exception):
pass

def validate_function_output(processed_inputs, indices):
def validate_function_output(processed_inputs):
"""Validate output of the map function."""
allowed_processed_inputs_types = (Mapping, pa.Table, pd.DataFrame)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
Expand All @@ -3292,7 +3292,7 @@ def validate_function_output(processed_inputs, indices):
raise TypeError(
f"Provided `function` which is applied to all elements of table returns a variable of type {type(processed_inputs)}. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects."
)
elif isinstance(indices, list) and isinstance(processed_inputs, Mapping):
if batched and isinstance(processed_inputs, Mapping):
allowed_batch_return_types = (list, np.ndarray, pd.Series)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl
Expand All @@ -3318,9 +3318,8 @@ def validate_function_output(processed_inputs, indices):
f"Provided `function` which is applied to all elements of table returns a `dict` of types {[type(x) for x in processed_inputs.values()]}. When using `batched=True`, make sure provided `function` returns a `dict` of types like `{allowed_batch_return_types}`."
)

def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0):
def prepare_inputs(pa_inputs, indices, offset=0):
"""Utility to apply the function on a selection of columns."""
nonlocal update_data
inputs = format_table(
pa_inputs,
0 if not batched else range(pa_inputs.num_rows),
Expand All @@ -3337,25 +3336,20 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
additional_args += (effective_indices,)
if with_rank:
additional_args += (rank,)
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
return inputs, fn_args, additional_args, fn_kwargs

def prepare_outputs(pa_inputs, inputs, processed_inputs, check_same_num_examples=False):
nonlocal update_data
if not (update_data := (processed_inputs is not None)):
return None
if isinstance(processed_inputs, LazyDict):
processed_inputs = {
k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
}
returned_lazy_dict = True
else:
returned_lazy_dict = False
if update_data is None:
# Check if the function returns updated examples
updatable_types = (Mapping, pa.Table, pd.DataFrame)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

updatable_types += (pl.DataFrame,)
update_data = isinstance(processed_inputs, updatable_types)
validate_function_output(processed_inputs, indices)
if not update_data:
return None # Nothing to update, let's move on
validate_function_output(processed_inputs)
if shard._format_type or input_columns:
# TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release)
inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns()))
Expand Down Expand Up @@ -3385,72 +3379,21 @@ def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_example
else:
return processed_inputs

def apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0):
"""Utility to apply the function on a selection of columns."""
inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset)
processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
return prepare_outputs(
pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples
)

async def async_apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples=False, offset=0):
"""Utility to apply the function on a selection of columns. Same code but async"""
nonlocal update_data
inputs = format_table(
pa_inputs,
0 if not batched else range(pa_inputs.num_rows),
format_columns=input_columns,
formatter=input_formatter,
)
fn_args = [inputs] if input_columns is None else [inputs[col] for col in input_columns]
if offset == 0:
effective_indices = indices
else:
effective_indices = [i + offset for i in indices] if isinstance(indices, list) else indices + offset
additional_args = ()
if with_indices:
additional_args += (effective_indices,)
if with_rank:
additional_args += (rank,)
inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(pa_inputs, indices, offset=offset)
processed_inputs = await function(*fn_args, *additional_args, **fn_kwargs)
if isinstance(processed_inputs, LazyDict):
processed_inputs = {
k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
}
returned_lazy_dict = True
else:
returned_lazy_dict = False
if update_data is None:
# Check if the function returns updated examples
updatable_types = (Mapping, pa.Table, pd.DataFrame)
if config.POLARS_AVAILABLE and "polars" in sys.modules:
import polars as pl

updatable_types += (pl.DataFrame,)
update_data = isinstance(processed_inputs, updatable_types)
validate_function_output(processed_inputs, indices)
if not update_data:
return None # Nothing to update, let's move on
if shard._format_type or input_columns:
# TODO(QL, MS): ideally the behavior should be the same even if the dataset is formatted (may require major release)
inputs_to_merge = dict(zip(pa_inputs.column_names, pa_inputs.itercolumns()))
elif isinstance(inputs, LazyDict):
inputs_to_merge = {
k: (v if k not in inputs.keys_to_format else pa_inputs[k]) for k, v in inputs.data.items()
}
else:
inputs_to_merge = inputs
if remove_columns is not None:
for column in remove_columns:
# `function` can modify input in-place causing column to be already removed.
if column in inputs_to_merge:
inputs_to_merge.pop(column)
if returned_lazy_dict and column in processed_inputs:
processed_inputs.pop(column)
if check_same_num_examples:
input_num_examples = len(pa_inputs)
processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))])
if input_num_examples != processed_inputs_num_examples:
raise NumExamplesMismatchError()
if isinstance(inputs, Mapping) and isinstance(processed_inputs, Mapping):
# The .map() transform *updates* the dataset:
# the output dictionary contains both the the input data and the output data.
# The output dictionary may contain Arrow values from `inputs_to_merge` so that we can re-write them efficiently.
return {**inputs_to_merge, **processed_inputs}
else:
return processed_inputs
return prepare_outputs(
pa_inputs, inputs, processed_inputs, check_same_num_examples=check_same_num_examples
)

def init_buffer_and_writer():
# Prepare output buffer and batched writer in memory or on file if we update the table
Expand Down Expand Up @@ -3495,7 +3438,7 @@ def iter_output_examples(shard_iterable):
for i, example in shard_iterable:
indices.append(i)
tasks.append(loop.create_task(async_apply_function_on_filtered_inputs(example, i, offset=offset)))
# keep the total active tasks under 30
# keep the total active tasks under a certain number
if len(tasks) >= config.MAX_NUM_RUNNING_ASYNC_MAP_FUNCTIONS_IN_PARALLEL:
done, pending = loop.run_until_complete(
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
Expand Down

0 comments on commit 1f6a69a

Please sign in to comment.