diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 2ced8591c..629b556c2 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -97,6 +97,8 @@ def train(config: SFTConfig): # Create supervised dataset using generic machinery logger.info("Creating supervised dataset") if config.dataset_type == DatasetType.CHAT_JSONL: + assert config.chat_train_urls is not None + assert config.supervised_data is not None chat_config = ChatSFTDatasetConfig( cache_dir=config.supervised_data.cache_dir, train_urls=config.chat_train_urls, # No validation in this config @@ -106,6 +108,7 @@ def train(config: SFTConfig): ) train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) else: + assert config.supervised_data is not None train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) @@ -122,7 +125,7 @@ def train(config: SFTConfig): # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: + with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore parameter_axis_mapping = trainer.parameter_axis_mapping # We have two axis_mappings: one for storing the model and optimizer states, and one for compute @@ -141,7 +144,7 @@ def train(config: SFTConfig): logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") model: LmHeadModel = converter.load_pretrained( model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype - ) + ) # type: ignore model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) state = trainer.initial_state(training_key, model=model) else: @@ -163,10 +166,14 @@ def train(config: SFTConfig): next(loader) if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + # bit gross to reach this far into the config, but it's fine + if config.trainer.checkpointer.append_run_id_to_base_path: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + else: + full_save_path = config.hf_save_path trainer.add_hook( - save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload), + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), every=config.hf_save_steps, ) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index e0bf93466..0654d1dfa 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,6 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore +from levanter.utils import fsspec_utils from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -616,7 +617,12 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase: @abc.abstractmethod def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: pass @@ -717,7 +723,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) else: # Using local files - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)] if not validation_urls: raise ValueError("Must specify either hf_dataset_name or validation_urls") dataset = levanter.data.datasource_from_jsonl(validation_urls) @@ -735,12 +741,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain output_exemplar=output_exemplar, ) - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -811,14 +817,14 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken ) # Cache the processed data - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # Ensure padding token is set (needed by _prepare_supervised_example) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -833,18 +839,19 @@ def train_set( monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, - epochs: int = 0, + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: - ds = self.token_seq_dataset("train", seq_len, monitors) - if epochs: - logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) + ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) # add epoch flag here. if ds is None: raise ValueError("No training set!") + if epochs: + logger.info("Wrapping dataset in epoch dataset") + ds = EpochDataset(ds, max_epochs=epochs) + if self.shuffle is True: ds = ds.shuffle(key) elif isinstance(self.shuffle, int) and self.shuffle > 0: @@ -989,11 +996,19 @@ def __post_init__(self): ) def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: doc_caches = self.build_caches("train", monitors=monitors) token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + if epochs: + raise ValueError("Epochs are not supported for mixture datasets") + if key is None: key = jax.random.PRNGKey(0)