Skip to content

Commit

Permalink
Support SubSplitInfo in MultiSplitInfo.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715420599
  • Loading branch information
The TensorFlow Datasets Authors committed Jan 15, 2025
1 parent 1322866 commit bc7ec10
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensorflow_datasets/core/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo):
This should only be used to read data and not when producing data.
"""

split_infos: list[SplitInfo] = dataclasses.field(default_factory=list)
split_infos: list[SplitInfo | SubSplitInfo] = dataclasses.field(
default_factory=list
)

def __init__(self, name: str, split_infos: list[SplitInfo]):
def __init__(self, name: str, split_infos: list[SplitInfo | SubSplitInfo]):
if not split_infos:
raise ValueError('Need to pass a non-empty list of SplitInfos')
object.__setattr__(self, 'split_infos', split_infos)
Expand Down

0 comments on commit bc7ec10

Please sign in to comment.