diff --git a/tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py b/tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py index ef94c337c85..91912b75c98 100644 --- a/tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py +++ b/tensorflow_datasets/datasets/pg19/pg19_dataset_builder.py @@ -17,6 +17,7 @@ from __future__ import annotations +from collections.abc import Mapping import os import numpy as np @@ -44,13 +45,27 @@ def _info(self): homepage='https://github.com/deepmind/pg19', ) + def _get_paths(self, data_dir: str) -> Mapping[str, str]: + return { + 'metadata': tf.io.gfile.join(data_dir, 'metadata.csv'), + 'train': tf.io.gfile.join(data_dir, 'train'), + 'validation': tf.io.gfile.join(data_dir, 'validation'), + 'test': tf.io.gfile.join(data_dir, 'test'), + } + def _split_generators(self, dl_manager): """Returns SplitGenerators.""" del dl_manager # Unused metadata_dict = dict() - metadata_path = os.path.join(_DATA_DIR, 'metadata.csv') - metadata = tf.io.gfile.GFile(metadata_path).read().splitlines() + if self.data_dir and all( + map(tf.io.gfile.exists, self._get_paths(self.data_dir).values()) + ): + data_dir = self._data_dir + else: + data_dir = _DATA_DIR + paths = self._get_paths(data_dir) + metadata = tf.io.gfile.GFile(paths['metadata']).read().splitlines() for row in metadata: row_split = row.split(',') @@ -62,21 +77,21 @@ def _split_generators(self, dl_manager): name=tfds.Split.TRAIN, gen_kwargs={ 'metadata': metadata_dict, - 'filepath': os.path.join(_DATA_DIR, 'train'), + 'filepath': paths['train'], }, ), tfds.core.SplitGenerator( name=tfds.Split.VALIDATION, gen_kwargs={ 'metadata': metadata_dict, - 'filepath': os.path.join(_DATA_DIR, 'validation'), + 'filepath': paths['validation'], }, ), tfds.core.SplitGenerator( name=tfds.Split.TEST, gen_kwargs={ 'metadata': metadata_dict, - 'filepath': os.path.join(_DATA_DIR, 'test'), + 'filepath': paths['test'], }, ), ] diff --git a/tensorflow_datasets/robotics/dataset_importer_builder.py b/tensorflow_datasets/robotics/dataset_importer_builder.py index 162a8b69500..c93eae37e2c 100644 --- a/tensorflow_datasets/robotics/dataset_importer_builder.py +++ b/tensorflow_datasets/robotics/dataset_importer_builder.py @@ -72,6 +72,8 @@ def get_relative_dataset_location(self): pass def get_dataset_location(self): + if self._data_dir and tf.io.gfile.exists(self._data_dir): + return self._data_dir return os.path.join( str(self._GCS_BUCKET), self.get_relative_dataset_location() )