diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index 1c059a509..2cd9bf5a2 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -27,7 +27,7 @@ trainer: train_batch_size: 256 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 - tensor_parallel_axes: ["mlp", "heads"] + tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" optimizer: @@ -36,4 +36,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 5000 -epoch: False \ No newline at end of file +epoch: False diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index e41f7dbc2..0b5bc4067 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -32,4 +32,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 0.01 - data_shuffle: true +data_shuffle: true diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 97ef8d7ef..e8f805cde 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -162,7 +162,7 @@ def _prepare_example(ex: dict) -> LmExample: # mask out padding and anything before the start of the target Pos = input_ids.resolve_axis("position") if config.mask_inputs: - loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1? + loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1? # don't predict the padding targets = hax.roll(input_ids, -1, Pos) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 9a5b290dc..a96f904ad 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -8,6 +8,7 @@ import threading import time import warnings +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import Callable, Optional @@ -18,6 +19,7 @@ import levanter.tracker from levanter.data import DataLoader +from levanter.data.text import TokenSeqEpochDataset from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -25,9 +27,6 @@ from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -from levanter.data.text import TokenSeqEpochDataset -from concurrent.futures import ThreadPoolExecutor - logger = pylogging.getLogger(__name__) @@ -53,6 +52,7 @@ def log_epoch(step_info: StepInfo): return log_epoch + def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread @@ -74,6 +74,7 @@ async def compute_length(): future = executor.submit(log_length) return future + def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 total_load_time = 0.0 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 9605ff74c..9f9a24a1b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -63,6 +63,7 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index + class TokenSeqEpochDataset(AsyncDataset[np.ndarray]): def __init__(self, doc_cache: TreeCache[dict], seq_len: int): self.doc_cache = doc_cache @@ -115,6 +116,7 @@ async def wait_until_len_at_least(self, length: int) -> int: self._cached_len = length return length + class TokenSeqDataset(AsyncDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. @@ -691,7 +693,12 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): cache_dir: Optional[str] = "cache/" def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray] = None, + epochs: bool = False, ) -> AsyncDataset[np.ndarray]: if epochs: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9134591f2..96323dc03 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,13 +121,18 @@ def main(config: TrainLmConfig): # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) train_dataset = CausalLmDataset( - config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), + Pos, + KeyPos, + ignore_index=config.data.ignore_token_id, ) if config.epoch: # add epoch logging total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) - trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1) + trainer.add_hook( + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of