From 2f625d3c952eb3e933a9834b8beef3a9bf9aafc0 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 9 Oct 2024 14:02:17 -0700 Subject: [PATCH] address david's comments --- src/levanter/data/text.py | 43 +++--------------------------- src/levanter/utils/fsspec_utils.py | 17 +++++++++++- 2 files changed, 19 insertions(+), 41 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 861a017b0..fdd935d82 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -10,7 +10,7 @@ from itertools import chain from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union -import braceexpand + import datasets import equinox as eqx import fsspec @@ -38,6 +38,7 @@ from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.fsspec_utils import fsspec_expand_glob silence_transformer_nag() # noqa @@ -378,20 +379,6 @@ def num_gpus(self) -> int: return 0 -def fsspec_expand_glob(url): - expanded_urls = braceexpand.braceexpand(url) - for expanded_url in expanded_urls: - if "*" in expanded_url: - fs = fsspec.core.url_to_fs(expanded_url)[0] - globbed = fs.glob(expanded_url) - # have to append the fs prefix back on - protocol, _ = fsspec.core.split_protocol(expanded_url) - if protocol is None: - yield from globbed - else: - yield from [f"{protocol}://{path}" for path in globbed] - else: - yield expanded_url def concatenate_and_group_texts( @@ -578,7 +565,7 @@ def tagged_eval_sets( @dataclass -class LMSupervisedDatasetConfig(LMDatasetSourceConfig): +class LMSupervisedDatasetConfig: """This class represents a dataset source with URLs or hf name/id.""" cache_dir: str = "cache/" @@ -589,30 +576,6 @@ class LMSupervisedDatasetConfig(LMDatasetSourceConfig): validation_urls: List[str] = () # type:ignore - # def token_seq_dataset( - # self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Optional[TokenSeqDataset]: - # cache = self.build_or_load_cache(split, monitors=monitors) - # if cache is None: - # return None - # return TokenSeqDataset(cache, seq_len) - - # def validation_set( - # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Optional[TokenSeqDataset]: - # return self.token_seq_dataset("validation", seq_len, monitors) - - # def validation_sets( - # self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - # ) -> Mapping[str, AsyncDataset[np.ndarray]]: - # validation_set = self.validation_set(seq_len, monitors) - # if validation_set is not None: - # return {"": validation_set} - # else: - # return {} - - # Add tagged eval set with split for auxiliary and validation dataset - def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase): sources = [example["input"] for example in batch] diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 896ea8450..6a1341bff 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,5 +1,5 @@ import fsspec - +import braceexpand def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" @@ -11,3 +11,18 @@ def mkdirs(path): """Create a directory and any necessary parent directories.""" fs, path = fsspec.core.url_to_fs(path) fs.makedirs(path, exist_ok=True) + +def fsspec_expand_glob(url): + expanded_urls = braceexpand.braceexpand(url) + for expanded_url in expanded_urls: + if "*" in expanded_url: + fs = fsspec.core.url_to_fs(expanded_url)[0] + globbed = fs.glob(expanded_url) + # have to append the fs prefix back on + protocol, _ = fsspec.core.split_protocol(expanded_url) + if protocol is None: + yield from globbed + else: + yield from [f"{protocol}://{path}" for path in globbed] + else: + yield expanded_url \ No newline at end of file