diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index bdb09e4f1..d67997ea2 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -111,7 +111,6 @@ The following table lists some of the parameters that you might want to change. | Parameter | Description | Default | |----------------|-------------------------------------------------------------------------------|---------| | `log_dir` | Where to save logs (python logger). `$run_id` will be appended | `logs/` | -| `run_base_dir` | where to save run artifacts. not really used much. `$run_id` will be appended | `runs/` | diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index b102198d7..5bfb6be30 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -549,8 +549,12 @@ class CheckpointerConfig: default_factory=lambda: [dict(every=10000)] ) # list of dicts with two keys: every and until + append_run_id_to_base_path: bool = True + def expanded_path(self, run_id) -> str: - return os.path.expanduser(os.path.join(self.base_path, run_id)) + if self.append_run_id_to_base_path: + return os.path.expanduser(os.path.join(self.base_path, run_id)) + return os.path.expanduser(self.base_path) def create(self, run_id) -> Checkpointer: keeps = [CheckpointInterval(**k) for k in self.keep] diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 5e595b2a1..bcfcad397 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -577,6 +577,8 @@ def tagged_eval_sets( class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" + cache_dir: Optional[str] = "cache/" + def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None ) -> AsyncDataset[np.ndarray]: @@ -705,6 +707,8 @@ def _convert_id_to_token(self, index: int) -> str: class LMMixtureDatasetConfig(LMTaskConfig): """This class represents a mixture of datasets with their associated weights.""" + cache_dir: Optional[str] = "cache/" + # data source configs and weights configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) """ configuration of each dataset source (urls, hf dataset id, etc.) """ diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 6c96f8b62..c8316090a 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -188,7 +188,11 @@ def main(config: TrainLmConfig): callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + # bit gross to reach this far into the config, but it's fine + if config.trainer.checkpointer.append_run_id_to_base_path: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + else: + full_save_path = config.hf_save_path trainer.add_hook( save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 69c932cd9..8e98eaedb 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -519,7 +519,6 @@ class TrainerConfig: wandb: Optional[tracker.wandb.WandbConfig] = None log_dir: Path = Path("logs/") - run_base_dir: Path = Path("runs/") id: Optional[str] = None # run id. if None, will be set to a random string tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.wandb.WandbConfig)