Skip to content

Commit

Permalink
If there is only one config, load it by default
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684813556
  • Loading branch information
Alfonso Castaño authored and The TensorFlow Datasets Authors committed Oct 14, 2024
1 parent bc48d05 commit e839a2a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
9 changes: 0 additions & 9 deletions tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,15 +1982,6 @@ def _save_default_config_name(
tmp_config_path.write_text(json.dumps(data))


def load_default_config_name(builder_dir: epath.Path) -> str | None:
"""Load `builder_cls` metadata (common to all builder configs)."""
config_path = builder_dir / ".config" / constants.METADATA_FILENAME
if not config_path.exists():
return None
data = json.loads(config_path.read_text())
return data.get("default_config_name")


def canonical_version_for_config(
instance_or_cls: Union[DatasetBuilder, Type[DatasetBuilder]],
config: Optional[BuilderConfig] = None,
Expand Down
29 changes: 24 additions & 5 deletions tensorflow_datasets/core/read_only_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Sequence
import concurrent.futures
import functools
import json
import os
import typing
from typing import Any, Type
Expand All @@ -38,6 +39,8 @@
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
from tensorflow_datasets.core import constants

from tensorflow_datasets.core.features import feature as feature_lib
from tensorflow_datasets.core.proto import dataset_info_pb2
from tensorflow_datasets.core.utils import error_utils
Expand Down Expand Up @@ -436,10 +439,10 @@ def _list_possible_configs(
configs = []
for data_dir in all_data_dirs:
builder_dir = epath.Path(data_dir) / builder_name
if builder_dir.exists():
for path in builder_dir.iterdir():
if path.is_dir():
configs.append(path.name)
variants = file_utils.list_dataset_variants(
dataset_dir=builder_dir, include_versions=False
)
configs.extend(variant.config for variant in variants)
return configs


Expand Down Expand Up @@ -537,7 +540,7 @@ def _get_default_config_name(
return cls.default_builder_config.name

# Otherwise, try to load default config from common metadata
return dataset_builder.load_default_config_name(builder_dir)
return load_default_config_name(builder_dir)


def _get_version(
Expand Down Expand Up @@ -577,3 +580,19 @@ def _get_version(
)
error_utils.add_context(error_msg)
return None


def load_default_config_name(dataset_dir: epath.Path) -> str | None:
"""Load `builder_cls` metadata (common to all builder configs)."""
config_path = dataset_dir / '.config' / constants.METADATA_FILENAME
if config_path.exists():
data = json.loads(config_path.read_text())
return data.get('default_config_name')
variants = list(
file_utils.list_dataset_variants(
dataset_dir=dataset_dir, include_versions=False
)
)
if len(variants) == 1:
return variants[0].config
return None

0 comments on commit e839a2a

Please sign in to comment.