From 5aa7e237172a60f776f6434b61c4c29e032ad519 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 7 Nov 2024 09:33:40 -0800 Subject: [PATCH] fix epochs in type signature, fix type checker --- src/levanter/data/text.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index e0bf93466..0654d1dfa 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,6 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore +from levanter.utils import fsspec_utils from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -616,7 +617,12 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase: @abc.abstractmethod def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: pass @@ -717,7 +723,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) else: # Using local files - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)] if not validation_urls: raise ValueError("Must specify either hf_dataset_name or validation_urls") dataset = levanter.data.datasource_from_jsonl(validation_urls) @@ -735,12 +741,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain output_exemplar=output_exemplar, ) - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -811,14 +817,14 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken ) # Cache the processed data - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # Ensure padding token is set (needed by _prepare_supervised_example) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -833,18 +839,19 @@ def train_set( monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, - epochs: int = 0, + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: - ds = self.token_seq_dataset("train", seq_len, monitors) - if epochs: - logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) + ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) # add epoch flag here. if ds is None: raise ValueError("No training set!") + if epochs: + logger.info("Wrapping dataset in epoch dataset") + ds = EpochDataset(ds, max_epochs=epochs) + if self.shuffle is True: ds = ds.shuffle(key) elif isinstance(self.shuffle, int) and self.shuffle > 0: @@ -989,11 +996,19 @@ def __post_init__(self): ) def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: doc_caches = self.build_caches("train", monitors=monitors) token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + if epochs: + raise ValueError("Epochs are not supported for mixture datasets") + if key is None: key = jax.random.PRNGKey(0)