Skip to content

Commit

Permalink
Make tfds.data_source pickable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636516631
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed May 23, 2024
1 parent e3d499f commit 6bbba45
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 109 deletions.
18 changes: 16 additions & 2 deletions tensorflow_datasets/core/data_sources/array_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
"""

import dataclasses
from typing import Any, Optional

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.data_sources import base
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source


Expand All @@ -37,9 +42,18 @@ class ArrayRecordDataSource(base.BaseDataSource):
source.
"""

dataset_info: dataset_info_lib.DatasetInfo
split: splits_lib.Split = None
decoders: Optional[type_utils.TreeDict[decode.partial_decode.DecoderArg]] = (
None
)
# In order to lazy load array_record, we don't load
# `array_record_data_source.ArrayRecordDataSource` here.
data_source: Any = dataclasses.field(init=False)
length: int = dataclasses.field(init=False)

def __post_init__(self):
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
file_instructions = base.file_instructions(self.dataset_info, self.split)
self.data_source = array_record_data_source.ArrayRecordDataSource(
file_instructions
)
32 changes: 8 additions & 24 deletions tensorflow_datasets/core/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

from collections.abc import MappingView, Sequence
import dataclasses
import functools
import typing
from typing import Any, Generic, Iterable, Protocol, SupportsIndex, TypeVar

from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core.features import top_level_feature
from tensorflow_datasets.core.utils import shard_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import tree
Expand Down Expand Up @@ -56,14 +54,6 @@ def file_instructions(
return split_dict[split].file_instructions


class _DatasetBuilder(Protocol):
"""Protocol for the DatasetBuilder to avoid cyclic imports."""

@property
def info(self) -> dataset_info_lib.DatasetInfo:
...


@dataclasses.dataclass
class BaseDataSource(MappingView, Sequence):
"""Base DataSource to override all dunder methods with the deserialization.
Expand All @@ -74,28 +64,22 @@ class BaseDataSource(MappingView, Sequence):
deserialization/decoding.
Attributes:
dataset_builder: The dataset builder.
dataset_info: The DatasetInfo of the
split: The split to load in the data source.
decoders: Optional decoders for decoding.
data_source: The underlying data source to initialize in the __post_init__.
"""

dataset_builder: _DatasetBuilder
dataset_info: dataset_info_lib.DatasetInfo
split: splits_lib.Split | None = None
decoders: type_utils.TreeDict[decode.partial_decode.DecoderArg] | None = None
data_source: DataSource[Any] = dataclasses.field(init=False)

@functools.cached_property
def _features(self) -> top_level_feature.TopLevelFeature:
"""Caches features because we log the use of dataset_builder.info."""
features = self.dataset_builder.info.features
if not features:
raise ValueError('No feature defined in the dataset buidler.')
return features

def __getitem__(self, key: SupportsIndex) -> Any:
record = self.data_source[key.__index__()]
return self._features.deserialize_example_np(record, decoders=self.decoders)
return self.dataset_info.features.deserialize_example_np(
record, decoders=self.decoders
)

def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
"""Retrieves items by batch.
Expand All @@ -114,24 +98,24 @@ def __getitems__(self, keys: Sequence[int]) -> Sequence[Any]:
if not keys:
return []
records = self.data_source.__getitems__(keys)
features = self.dataset_info.features
if len(keys) != len(records):
raise IndexError(
f'Requested {len(keys)} records but got'
f' {len(records)} records.'
f'{keys=}, {records=}'
)
return [
self._features.deserialize_example_np(record, decoders=self.decoders)
features.deserialize_example_np(record, decoders=self.decoders)
for record in records
]

def __repr__(self) -> str:
decoders_repr = (
tree.map_structure(type, self.decoders) if self.decoders else None
)
name = self.dataset_builder.info.name
return (
f'{self.__class__.__name__}(name={name}, '
f'{self.__class__.__name__}(name={self.dataset_info.name}, '
f'split={self.split!r}, '
f'decoders={decoders_repr})'
)
Expand Down
47 changes: 10 additions & 37 deletions tensorflow_datasets/core/data_sources/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

"""Tests for all data sources."""

import pickle
from unittest import mock

import cloudpickle
from etils import epath
import pytest
import tensorflow_datasets as tfds
from tensorflow_datasets import testing
from tensorflow_datasets.core import dataset_builder as dataset_builder_lib
from tensorflow_datasets.core import dataset_builder
from tensorflow_datasets.core import dataset_info as dataset_info_lib
from tensorflow_datasets.core import decode
from tensorflow_datasets.core import file_adapters
Expand Down Expand Up @@ -79,7 +77,7 @@ def mocked_parquet_dataset():
)
def test_read_write(
tmp_path: epath.Path,
builder_cls: dataset_builder_lib.DatasetBuilder,
builder_cls: dataset_builder.DatasetBuilder,
file_format: file_adapters.FileFormat,
):
builder = builder_cls(data_dir=tmp_path, file_format=file_format)
Expand Down Expand Up @@ -108,34 +106,28 @@ def test_read_write(
]


def create_dataset_builder(file_format: file_adapters.FileFormat):
def create_dataset_info(file_format: file_adapters.FileFormat):
with mock.patch.object(splits_lib, 'SplitInfo') as split_mock:
split_mock.return_value.name = 'train'
split_mock.return_value.file_instructions = _FILE_INSTRUCTIONS
dataset_info = mock.create_autospec(dataset_info_lib.DatasetInfo)
dataset_info.file_format = file_format
dataset_info.splits = {'train': split_mock()}
dataset_info.name = 'dataset_name'

dataset_builder = mock.create_autospec(dataset_builder_lib.DatasetBuilder)
dataset_builder.info = dataset_info

return dataset_builder
return dataset_info


@pytest.mark.parametrize(
'data_source_cls',
_DATA_SOURCE_CLS,
)
def test_missing_split_raises_error(data_source_cls):
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
with pytest.raises(
ValueError,
match="Unknown split 'doesnotexist'.",
):
data_source_cls(dataset_builder, split='doesnotexist')
data_source_cls(dataset_info, split='doesnotexist')


@pytest.mark.usefixtures(*_FIXTURES)
Expand All @@ -144,10 +136,8 @@ def test_missing_split_raises_error(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
source = data_source_cls(dataset_builder, split='train')
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
source = data_source_cls(dataset_info, split='train')
name = data_source_cls.__name__
assert (
repr(source) == f"{name}(name=dataset_name, split='train', decoders=None)"
Expand All @@ -160,11 +150,9 @@ def test_repr_returns_meaningful_string_without_decoders(data_source_cls):
_DATA_SOURCE_CLS,
)
def test_repr_returns_meaningful_string_with_decoders(data_source_cls):
dataset_builder = create_dataset_builder(
file_adapters.FileFormat.ARRAY_RECORD
)
dataset_info = create_dataset_info(file_adapters.FileFormat.ARRAY_RECORD)
source = data_source_cls(
dataset_builder,
dataset_info,
split='train',
decoders={'my_feature': decode.SkipDecoding()},
)
Expand Down Expand Up @@ -193,18 +181,3 @@ def test_data_source_is_sliceable():
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
assert file_instructions[0].skip == 0
assert file_instructions[0].take == 30000


# PyGrain requires that data sources are picklable.
@pytest.mark.parametrize(
'file_format',
file_adapters.FileFormat.with_random_access(),
)
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
def test_data_source_is_picklable_after_use(file_format, pickle_module):
with tfds.testing.tmp_dir() as data_dir:
builder = tfds.testing.DummyDataset(data_dir=data_dir)
builder.download_and_prepare(file_format=file_format)
data_source = builder.as_data_source(split='train')
assert data_source[0] == {'id': 0}
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}
3 changes: 1 addition & 2 deletions tensorflow_datasets/core/data_sources/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class ParquetDataSource(base.BaseDataSource):
"""ParquetDataSource to read from a ParquetDataset."""

def __post_init__(self):
dataset_info = self.dataset_builder.info
file_instructions = base.file_instructions(dataset_info, self.split)
file_instructions = base.file_instructions(self.dataset_info, self.split)
filenames = [
file_instruction.filename for file_instruction in file_instructions
]
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,13 @@ def build_single_data_source(
file_format = self.info.file_format
if file_format == file_adapters.FileFormat.ARRAY_RECORD:
return array_record.ArrayRecordDataSource(
self,
self.info,
split=split,
decoders=decoders,
)
elif file_format == file_adapters.FileFormat.PARQUET:
return parquet.ParquetDataSource(
self,
self.info,
split=split,
decoders=decoders,
)
Expand Down
65 changes: 32 additions & 33 deletions tensorflow_datasets/testing/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,13 @@ class PickableDataSourceMock(mock.MagicMock):
"""Makes MagicMock pickable in order to work with multiprocessing in Grain."""

def __getstate__(self):
return {
'num_examples': len(self),
'generator': self._generator,
'serialize_example': self._serialize_example,
}
return {'num_examples': len(self), 'generator': self._generator}

def __setstate__(self, state):
num_examples, generator, serialize_example = (
state['num_examples'],
state['generator'],
state['serialize_example'],
)
num_examples, generator = state['num_examples'], state['generator']
self.__len__.return_value = num_examples
self.__getitem__ = functools.partial(
_getitem, generator=generator, serialize_example=serialize_example
)
self.__getitems__ = functools.partial(
_getitems, generator=generator, serialize_example=serialize_example
)
self.__getitem__ = functools.partial(_getitem, generator=generator)
self.__getitems__ = functools.partial(_getitems, generator=generator)

def __reduce__(self):
return (PickableDataSourceMock, (), self.__getstate__())
Expand All @@ -111,33 +99,50 @@ def _getitem(
self,
record_key: int,
generator: RandomFakeGenerator,
serialize_example=None,
serialized: bool = False,
) -> Any:
"""Function to overwrite __getitem__ in data sources."""
del self
example = generator[record_key]
if serialize_example:
if serialized:
# Return serialized raw bytes
return serialize_example(example)
return self.dataset_info.features.serialize_example(example)
return example


def _getitems(
self,
record_keys: Sequence[int],
generator: RandomFakeGenerator,
serialize_example=None,
serialized: bool = False,
) -> Sequence[Any]:
"""Function to overwrite __getitems__ in data sources."""
items = [
_getitem(self, record_key, generator, serialize_example=serialize_example)
_getitem(self, record_key, generator, serialized=serialized)
for record_key in record_keys
]
if serialize_example:
if serialized:
return np.array(items)
return items


def _deserialize_example_np(serialized_example, *, decoders=None):
"""Function to overwrite dataset_info.features.deserialize_example_np.
Warning: this has to be defined in the outer scope in order for the function
to be pickable.
Args:
serialized_example: the example to deserialize.
decoders: optional decoders.
Returns:
The serialized example, because deserialization is taken care by
RandomFakeGenerator.
"""
del decoders
return serialized_example


class MockPolicy(enum.Enum):
"""Strategy to use with `tfds.testing.mock_data` to mock the dataset.
Expand Down Expand Up @@ -380,27 +385,21 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
# Force ARRAY_RECORD as the default file_format.
return_value=file_adapters.FileFormat.ARRAY_RECORD,
):
# Make mock_data_source pickable with a given len:
self.info.features.deserialize_example_np = _deserialize_example_np
mock_data_source.return_value.__len__.return_value = num_examples
# Make mock_data_source pickable with a given generator:
mock_data_source.return_value._generator = ( # pylint:disable=protected-access
generator
)
# Make mock_data_source pickable with a given serialize_example:
mock_data_source.return_value._serialize_example = ( # pylint:disable=protected-access
self.info.features.serialize_example
)
serialize_example = self.info.features.serialize_example
mock_data_source.return_value.__getitem__ = functools.partial(
_getitem, generator=generator, serialize_example=serialize_example
_getitem, generator=generator
)
mock_data_source.return_value.__getitems__ = functools.partial(
_getitems, generator=generator, serialize_example=serialize_example
_getitems, generator=generator
)

def build_single_data_source(split):
single_data_source = array_record.ArrayRecordDataSource(
dataset_builder=self, split=split, decoders=decoders
dataset_info=self.info, split=split, decoders=decoders
)
return single_data_source

Expand Down
9 changes: 0 additions & 9 deletions tensorflow_datasets/testing/mocking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,3 @@ def test_as_data_source_fn():
assert imagenet[0] == 'foo'
assert imagenet[1] == 'bar'
assert imagenet[2] == 'baz'


# PyGrain requires that data sources are picklable.
def test_mocked_data_source_is_pickable():
with tfds.testing.mock_data(num_examples=2):
data_source = tfds.data_source('imagenet2012', split='train')
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
assert len(pickled_and_unpickled_data_source) == 2
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 comments on commit 6bbba45

Please sign in to comment.