Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature proposal: Stacking, potentially heterogeneous, datasets #7279

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .arrow_dataset import Dataset
from .arrow_reader import ReadInstruction
from .builder import ArrowBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .combine import concatenate_datasets, interleave_datasets
from .combine import concatenate_datasets, interleave_datasets, stack_datasets
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download import *
from .features import *
Expand Down
90 changes: 90 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6325,3 +6325,93 @@ def get_indices_from_mask_function(
indices_array = indices_mapping.column(0).take(indices_array)
indices_array = indices_array.to_pylist()
return {"indices": indices_array}


def _stack_map_style_datasets(
datasets: Dict[str, "Dataset"],
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> "Dataset":
"""
Stack several map-style datasets (sources) into a single map-style dataset.
The new dataset is constructed by combining examples from each source dataset into a single example.

Args:
datasets (`Dict[str, Dataset]`): Dictionary of datasets to stack.
probabilities (`List[float]`, optional, default None): If specified, the new dataset is constructed by sampling
examples from one source at a time according to these probabilities.
info ([`DatasetInfo`], *optional*, defaults to `None`):
Dataset information, like description, citation, etc.
split ([`NamedSplit`], *optional*, defaults to `None`):
Name of the dataset split.
stopping_strategy (`Literal["first_exhausted", "all_exhausted"]`, *optional*, defaults to `first_exhausted`):
If undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted.
If oversampling ("all_exhausted"), we stop as soon as every dataset is exhausted,
i.e as soon as every samples of every dataset has been visited at least once.
"all_exhausted" means that the examples of smaller datasets may be visited multiple times.

Returns:
[`Dataset`]: A [`Dataset`] that returns examples where each example is a dictionary
with keys corresponding to the keys of the input dictionary of datasets, and values corresponding
to the examples of the respective dataset.
"""
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(
f"{stopping_strategy} stopping strategy in `stack_datasets` is not implemented yet with a dict of {type(next(iter(datasets.values())))}"
)

if any(dataset.num_rows > 0 for dataset in datasets.values()):
datasets = {name: dataset for name, dataset in datasets.items() if dataset.num_rows > 0}
else:
# Return first dataset if all datasets are empty
return next(iter(datasets.values()))

# the strategy is: 1. pad or truncate -> 2. join by concatenating along axis=1 -> 3. nest the columns (stacking)

d2len = {name: len(dataset) for name, dataset in datasets.items()}
if stopping_strategy == "first_exhausted": # truncate all datasets to the length of the shortest one
min_len = min(d2len.values())
indices = range(min_len)
datasets = {name: dataset.select(indices) for name, dataset in datasets.items()}
else: # "all_exhausted" -> "pad" all datasets to the length of the longest one
max_len = max(d2len.values())
for name in list(datasets.keys()):
dataset = datasets[name]
rows_remaining = max_len - d2len[name]
if rows_remaining == 0:
continue
n_repreat = max_len // d2len[name]
extra_rows = max_len % d2len[name]
to_concat = [dataset] * n_repreat
if extra_rows != 0:
to_concat.append(dataset.select(range(extra_rows)))
datasets[name] = _concatenate_map_style_datasets(
to_concat, info=dataset.info, split=dataset.split
) # concat it with itself

# rename columns to avoid conflicts and to later distinguish the source (dataset) of each column
for name in list(datasets.keys()):
dataset = datasets[name]
datasets[name] = dataset.rename_columns({k: f"{name}.{k}" for k in dataset.column_names})

# create a joint dataset with all columns -> the columns will be named "dataset_name.column_name", so the structure is flattened
concatenated_dataset = _concatenate_map_style_datasets(list(datasets.values()), info=info, split=split, axis=1)

def structure_example(example): # here we nest the columns again -> this gives us the stacked structure
"""
```python
>>> example = {"dataset1.column1": 1, "dataset2.column2": {'a': 1, 'b': 2}}
>>> structure_example(example) == {"dataset1": {"column1": 1}, "dataset2": {"column2": {'a': 1, 'b': 2}}}
```
"""
new_example = {name: {} for name in datasets.keys()}
for key, value in example.items():
# key.split(".", 1): "dataset_name.column_name" -> ("dataset_name", "column_name"),
# "dataset_name.column_name.etc" -> ("dataset_name", "column_name.etc")
key_dataset, key_actual = key.split(".", 1)
new_example[key_dataset][key_actual] = value
return new_example

# because we generated new columns, we need to remove the old ones -> "remove_columns=concatenated_dataset.column_names"
return concatenated_dataset.map(structure_example, remove_columns=concatenated_dataset.column_names)
91 changes: 88 additions & 3 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from typing import List, Optional, TypeVar
from typing import Dict, List, Optional, TypeVar

from .arrow_dataset import Dataset, _concatenate_map_style_datasets, _interleave_map_style_datasets
from .arrow_dataset import (
Dataset,
_concatenate_map_style_datasets,
_interleave_map_style_datasets,
_stack_map_style_datasets,
)
from .dataset_dict import DatasetDict, IterableDatasetDict
from .info import DatasetInfo
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
from .iterable_dataset import (
IterableDataset,
_concatenate_iterable_datasets,
_interleave_iterable_datasets,
_stack_iterable_datasets,
)
from .splits import NamedSplit
from .utils import logging
from .utils.py_utils import Literal
Expand Down Expand Up @@ -213,3 +223,78 @@ def concatenate_datasets(
return _concatenate_map_style_datasets(dsets, info=info, split=split, axis=axis)
else:
return _concatenate_iterable_datasets(dsets, info=info, split=split, axis=axis)


def stack_datasets(
datasets: Dict[str, DatasetType],
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> DatasetType:
"""
Stack multiple datasets into a single dataset. Examples returned are meta-examples containing
one example from each dataset. Useful if each item (or later each batch)
should contain different (possibly complex) types from different sources that cannot simply be concatenated.
Inspired by torch.utils.data.StackDataset.

Args:
datasets (`Dict[str, IterableDataset]`): Dictionary of datasets to stack.
info ([`DatasetInfo`], *optional*, defaults to `None`):
Dataset information, like description, citation, etc.
split ([`NamedSplit`], *optional*, defaults to `None`):
Name of the dataset split.
stopping_strategy (`Literal["first_exhausted", "all_exhausted"]`, *optional*, defaults to `first_exhausted`):
If undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted.
If oversampling ("all_exhausted"), we stop as soon as every dataset is exhausted,
i.e as soon as every samples of every dataset has been visited at least once.
"all_exhausted" means that the examples of smaller datasets may be visited multiple times.

Returns:
[`Dataset`] or [`IterableDataset`]: Return type depends on the input `datasets`
parameter. Returns a `Dataset` if the input is a dict of `Dataset`,
or an `IterableDataset` if the input is a dict of `IterableDataset`.

Example:

```python
>>> datasets = {'d1': dataset1, 'd2': dataset2}
>>> stacked = stack_datasets(datasets)
>>> next(iter(stacked))
{'d1': <dataset1_example1>, 'd2': <dataset2_example1>}
```
"""

if not datasets:
raise ValueError("Unable to stack an empty dict of datasets.")
for i, (name, dataset) in enumerate(datasets.items()):
if not isinstance(dataset, (Dataset, IterableDataset)):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if not dataset:
raise ValueError(
f"Expected a dict of Dataset objects or a dict of IterableDataset objects, but value of key '{name}' "
"is an empty dataset dictionary."
)
raise ValueError(
f"Dataset of key '{name}' has at least one split: {list(dataset)}\n"
f"Please pick one to stack with the other datasets, for example: dataset['{next(iter(dataset))}']"
)
raise ValueError(
f"Expected a dict of Dataset objects or a dict of IterableDataset objects, but value of key '{name}' is a {type(dataset).__name__}."
)
if i == 0:
dataset_type, other_type = (
(Dataset, IterableDataset) if isinstance(dataset, Dataset) else (IterableDataset, Dataset)
)
dataset_key = name
elif not isinstance(dataset, dataset_type):
raise ValueError(
f"Unable to stack a {dataset_type.__name__} (key '{dataset_key}') with a {other_type.__name__} (key '{name}'). Expected a dict of Dataset objects or a dict of IterableDataset objects."
)
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
if dataset_type is Dataset:
return _stack_map_style_datasets(
datasets=datasets, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _stack_iterable_datasets(datasets=datasets, info=info, split=split, stopping_strategy=stopping_strategy)
Loading