diff --git a/kauldron/data/py/__init__.py b/kauldron/data/py/__init__.py index c9e117cb..35b93260 100644 --- a/kauldron/data/py/__init__.py +++ b/kauldron/data/py/__init__.py @@ -19,6 +19,7 @@ from kauldron.data.py.base import DataSourceBase from kauldron.data.py.data_sources import DataSource from kauldron.data.py.data_sources import Tfds +from kauldron.data.py.data_sources import Json from kauldron.data.py.mixtures import Mix # ***************************************************************************** diff --git a/kauldron/data/py/data_sources.py b/kauldron/data/py/data_sources.py index 6bdf8681..9653304d 100644 --- a/kauldron/data/py/data_sources.py +++ b/kauldron/data/py/data_sources.py @@ -19,6 +19,7 @@ from collections.abc import Mapping import dataclasses import functools +import json from typing import Any, Optional from etils import epath @@ -55,3 +56,38 @@ def data_source(self) -> grain.RandomAccessDataSource: data_dir=self.data_dir, decoders=self.decoders, ) + + +# Should this be part of Grain ? +@dataclasses.dataclass(frozen=True) +class JsonDataSource(grain.RandomAccessDataSource): + """Json data source. + + Assumes that the json file is a list of examples. + """ + + path: str + + @functools.cached_property + def data(self) -> Mapping[str, Any]: + return json.loads(epath.Path(self.path).read_text()) + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, record_key): + return self.data[record_key] + + +@dataclasses.dataclass(frozen=True) +class Json(base.DataSourceBase): + """Json pipeline. + + Assumes that the json file is a list of examples. + """ + + path: str + + @functools.cached_property + def data_source(self) -> grain.RandomAccessDataSource: + return JsonDataSource(path=self.path)