Skip to content

Commit

Permalink
address david's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 9, 2024
1 parent 1063fd8 commit 2f625d3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 41 deletions.
43 changes: 3 additions & 40 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from itertools import chain
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union

import braceexpand

import datasets
import equinox as eqx
import fsspec
Expand Down Expand Up @@ -38,6 +38,7 @@
from levanter.store.jagged_array import JaggedArrayStore
from levanter.store.tree_store import TreeStore
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer
from levanter.utils.fsspec_utils import fsspec_expand_glob


silence_transformer_nag() # noqa
Expand Down Expand Up @@ -378,20 +379,6 @@ def num_gpus(self) -> int:
return 0


def fsspec_expand_glob(url):
expanded_urls = braceexpand.braceexpand(url)
for expanded_url in expanded_urls:
if "*" in expanded_url:
fs = fsspec.core.url_to_fs(expanded_url)[0]
globbed = fs.glob(expanded_url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(expanded_url)
if protocol is None:
yield from globbed
else:
yield from [f"{protocol}://{path}" for path in globbed]
else:
yield expanded_url


def concatenate_and_group_texts(
Expand Down Expand Up @@ -578,7 +565,7 @@ def tagged_eval_sets(


@dataclass
class LMSupervisedDatasetConfig(LMDatasetSourceConfig):
class LMSupervisedDatasetConfig:
"""This class represents a dataset source with URLs or hf name/id."""

cache_dir: str = "cache/"
Expand All @@ -589,30 +576,6 @@ class LMSupervisedDatasetConfig(LMDatasetSourceConfig):

validation_urls: List[str] = () # type:ignore

# def token_seq_dataset(
# self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Optional[TokenSeqDataset]:
# cache = self.build_or_load_cache(split, monitors=monitors)
# if cache is None:
# return None
# return TokenSeqDataset(cache, seq_len)

# def validation_set(
# self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Optional[TokenSeqDataset]:
# return self.token_seq_dataset("validation", seq_len, monitors)

# def validation_sets(
# self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
# ) -> Mapping[str, AsyncDataset[np.ndarray]]:
# validation_set = self.validation_set(seq_len, monitors)
# if validation_set is not None:
# return {"": validation_set}
# else:
# return {}

# Add tagged eval set with split for auxiliary and validation dataset


def preprocess_supervised_example(batch, tokenizer: PreTrainedTokenizerBase):
sources = [example["input"] for example in batch]
Expand Down
17 changes: 16 additions & 1 deletion src/levanter/utils/fsspec_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import fsspec

import braceexpand

def exists(url, **kwargs) -> bool:
"""Check if a file exists on a remote filesystem."""
Expand All @@ -11,3 +11,18 @@ def mkdirs(path):
"""Create a directory and any necessary parent directories."""
fs, path = fsspec.core.url_to_fs(path)
fs.makedirs(path, exist_ok=True)

def fsspec_expand_glob(url):
expanded_urls = braceexpand.braceexpand(url)
for expanded_url in expanded_urls:
if "*" in expanded_url:
fs = fsspec.core.url_to_fs(expanded_url)[0]
globbed = fs.glob(expanded_url)
# have to append the fs prefix back on
protocol, _ = fsspec.core.split_protocol(expanded_url)
if protocol is None:
yield from globbed
else:
yield from [f"{protocol}://{path}" for path in globbed]
else:
yield expanded_url

0 comments on commit 2f625d3

Please sign in to comment.