Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 16, 2024
1 parent c2ed3ee commit 667a5a3
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 10 deletions.
4 changes: 2 additions & 2 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trainer:
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
Expand All @@ -36,4 +36,4 @@ optimizer:
min_lr_ratio: 0.1
warmup: 5000

epoch: False
epoch: False
2 changes: 1 addition & 1 deletion config/llama_7b_with_olmo_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ optimizer:
min_lr_ratio: 0.1
warmup: 0.01

data_shuffle: true
data_shuffle: true
2 changes: 1 addition & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
Expand Down
7 changes: 4 additions & 3 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Callable, Optional

Expand All @@ -18,16 +19,14 @@

import levanter.tracker
from levanter.data import DataLoader
from levanter.data.text import TokenSeqEpochDataset
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
from levanter.trainer import StepInfo
from levanter.utils import flop_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs
from levanter.data.text import TokenSeqEpochDataset
from concurrent.futures import ThreadPoolExecutor



logger = pylogging.getLogger(__name__)
Expand All @@ -53,6 +52,7 @@ def log_epoch(step_info: StepInfo):

return log_epoch


def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int):
def log_length():
# If ds.async_len() is the only option, run it in an event loop inside the thread
Expand All @@ -74,6 +74,7 @@ async def compute_length():
future = executor.submit(log_length)
return future


def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
Expand Down
9 changes: 8 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index


class TokenSeqEpochDataset(AsyncDataset[np.ndarray]):
def __init__(self, doc_cache: TreeCache[dict], seq_len: int):
self.doc_cache = doc_cache
Expand Down Expand Up @@ -115,6 +116,7 @@ async def wait_until_len_at_least(self, length: int) -> int:
self._cached_len = length
return length


class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from an underlying TreeCache.
Expand Down Expand Up @@ -691,7 +693,12 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: bool = False,
) -> AsyncDataset[np.ndarray]:

if epochs:
Expand Down
9 changes: 7 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,18 @@ def main(config: TrainLmConfig):
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch),
Pos,
KeyPos,
ignore_index=config.data.ignore_token_id,
)

if config.epoch:
# add epoch logging
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1)
trainer.add_hook(
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1
)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
Expand Down

0 comments on commit 667a5a3

Please sign in to comment.