From 3ea327761c445e42df32c697a801e75f3144bde1 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 20 Nov 2024 14:16:25 -0800 Subject: [PATCH 01/10] fix for token bug that skips EOS --- src/levanter/data/text.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0c9c6e9de..c0a86d830 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -196,7 +196,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") - offsets = np.array(indices, dtype=np.int64) * self.seq_len + offsets = np.array(indices) * self.seq_len with ts.Batch(): out = [] for offset in offsets: @@ -929,15 +929,24 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: ) -def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool) -> dict: """ Preprocess chat examples to match the format of preprocess_supervised_example. Returns a dict with input_ids and sources_len like the supervised case. + + Args: + batch: List of dicts with input/output pairs + tokenizer: HuggingFace tokenizer + should_append_eos: Whether we need to manually add EOS (True if tokenizer doesn't do it automatically) """ # Get sources (inputs) and targets (outputs) from the batch sources = [example["input"] for example in batch] targets = [example["output"] for example in batch] + # Add EOS only if needed (tokenizer doesn't do it automatically) + if should_append_eos: + targets = [t + tokenizer.eos_token for t in targets] + # Tokenize sources alone first to get the source lengths sources_tokenized = tokenizer(sources, padding=False, truncation=True) @@ -945,7 +954,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) - # Get source lengths to mask loss appropriately + # Get source lengths to mask loss appropriately source_lens = [len(s) for s in sources_tokenized["input_ids"]] return { @@ -965,9 +974,13 @@ def mk_chat_sft_dataset( # Set up example structure matching supervised case output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id + logger.info(f"Manual EOS Needed: {should_append_eos}") + # Process the dataset dataset = source.map_batches( - lambda ex: preprocess_chat_example(ex, tokenizer), + lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar, From 41b43c78090b69ec7058bec055c21967733e8093 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 20 Nov 2024 15:03:49 -0800 Subject: [PATCH 02/10] add back np.int64 --- src/levanter/data/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c0a86d830..9843877e2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -196,7 +196,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: len = await self.wait_until_len_at_least(max(indices) + 1) if len is not None and len < max(indices) + 1: raise ValueError("Requested indices beyond the end of the dataset") - offsets = np.array(indices) * self.seq_len + offsets = np.array(indices, dtype=np.int64) * self.seq_len with ts.Batch(): out = [] for offset in offsets: From 4eb4281be8e546e540818cd477c189afc9025fc5 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 20 Nov 2024 15:14:15 -0800 Subject: [PATCH 03/10] forgot to run precommit --- src/levanter/data/text.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 9843877e2..3e74c96b7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -933,7 +933,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_ap """ Preprocess chat examples to match the format of preprocess_supervised_example. Returns a dict with input_ids and sources_len like the supervised case. - + Args: batch: List of dicts with input/output pairs tokenizer: HuggingFace tokenizer @@ -954,7 +954,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_ap full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) - # Get source lengths to mask loss appropriately + # Get source lengths to mask loss appropriately source_lens = [len(s) for s in sources_tokenized["input_ids"]] return { From f44e5c8e53b6eed5ca476407fc920a6d2c1b614e Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 20 Nov 2024 16:50:16 -0800 Subject: [PATCH 04/10] jit and batch supervised data loading to speed it up (a lot) (#816) --- src/levanter/data/dataset.py | 76 ++++++++++++++++++++++++++++++++++-- src/levanter/data/text.py | 45 ++++++++++++--------- tests/test_supervised.py | 4 +- 3 files changed, 102 insertions(+), 23 deletions(-) diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index 4d71241d4..f448ed83b 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -2,7 +2,7 @@ import asyncio import logging from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Generic, Optional, Sequence, TypeVar +from typing import Callable, Generic, Optional, Sequence, TypeAlias, TypeVar import jax.random import numpy as np @@ -18,6 +18,11 @@ T = TypeVar("T") U = TypeVar("U") +# When we decide to standardize on 3.12, we can use fancier things +# P = ParamSpec("P") + +MapFunction: TypeAlias = Callable[..., U] + _executor = ThreadPoolExecutor(max_workers=10) @@ -111,9 +116,12 @@ def as_sync_dataset(self): def as_async_dataset(self) -> "AsyncDataset[T_co]": return self - def map(self, fn: Callable[[T_co], U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]": + def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDataset[T_co, U]": return MappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]": + return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs) + def shuffle(self, key: PRNGKey): import levanter.data.permutation as permutation @@ -321,7 +329,7 @@ class MappedAsyncDataset(AsyncDataset[U], Generic[T, U]): def __init__( self, dataset: AsyncDataset[T], - fn: Callable[[T], U] | Callable[[T, Optional[PRNGKey]], U], + fn: MapFunction[U], *extra_args, **extra_kwargs, ): @@ -365,3 +373,65 @@ def _call_fn(self, index, item): else: kwargs = self._extra_kwargs return self.fn(item, *self._extra_args, **kwargs) + + +class BatchMappedAsyncDataset(AsyncDataset[U]): + """ + A dataset that applies a function to each batch of items in the dataset. + You can pass extra arguments to the function using `*extra_args` and `**extra_kwargs`. + If a kwarg called `key` is passed, it will be treated as a PRNGKey and folded in with the index of the item + for each call to the function. The key will be split into a key for each item in the batch. + """ + + def __init__( + self, + dataset: AsyncDataset[T], + fn: MapFunction[Sequence[U]], + *extra_args, + **extra_kwargs, + ): + super().__init__() + self.dataset = dataset + self.fn = fn + self._extra_args = extra_args + self._extra_kwargs = extra_kwargs + + async def async_len(self) -> int: + return await self.dataset.async_len() + + async def final_length_is_known(self) -> bool: + return await self.dataset.final_length_is_known() + + def is_finite(self) -> bool: + return self.dataset.is_finite() + + async def current_len(self) -> Optional[int]: + return await self.dataset.current_len() + + def _maybe_fold_in_key(self, key, indices: Sequence[int]): + if key is not None: + key = _fold_in_key_vmap(key, np.array(indices)) + return key + + async def get_batch(self, indices: Sequence[int]) -> Sequence[U]: + items = await self.dataset.get_batch(indices) + return self._call_fn(indices, items) + + async def getitem_async(self, index: int) -> U: + return self._call_fn([index], [await self.dataset.getitem_async(index)])[0] + + async def wait_until_len_at_least(self, length: int) -> int: + return await self.dataset.wait_until_len_at_least(length) + + def _call_fn(self, indices: Sequence[int], items): + if "key" in self._extra_kwargs: + key = self._maybe_fold_in_key(self._extra_kwargs["key"], indices) + kwargs = {**self._extra_kwargs, "key": key} + else: + kwargs = self._extra_kwargs + return self.fn(items, *self._extra_args, **kwargs) + + +@jax.jit +def _fold_in_key_vmap(key, indices): + return jax.vmap(lambda i: jax.random.fold_in(key, i))(indices) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 3e74c96b7..053372207 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -779,9 +779,9 @@ def _preprocess_supervised_example( } -def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample: +def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> list[LmExample]: """ - Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. + Prepare examples for training. This function converts the (cached) encodings into an LmExample. It goes through the following steps: @@ -789,26 +789,35 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po 2. Mask out the input and prompt if requested. 3. Create an LmExample with the input_ids as the input and the next token as the target. """ - # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = tokenizer.pad( - {k: np.expand_dims(v, 0) for k, v in ex.items()}, - return_tensors="np", + lens = np.array([ex["sources_len"] for ex in ex]) + + ex_pad = tokenizer.pad( + ex, padding="max_length", max_length=Pos.size, ) - ex = {k: v[0] for k, v in ex.items()} - # padding doesn't do truncation, so we have to do it ourselves. - # Truncate from the left since we want to predict the last tokens - input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) - # mask out padding and anything before the start of the target - loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 + input_ids = ex_pad["input_ids"] + truncated = [ids[-Pos.size :] for ids in input_ids] + + out = [] + for ids, len in zip(truncated, lens): + causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id) + + out.append(causal) + + return out + + +@functools.partial(jax.jit, static_argnums=(0, 3)) +def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id): + # mask out padding and anything before the start of the target + loss_mask = hax.arange(Pos) >= sources_len - 1 # don't predict the padding targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + loss_mask = loss_mask & (targets != pad_token_id) loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) - return lm_ex + return LmExample.causal(input_ids, loss_mask=loss_mask) def mk_supervised_datasets( @@ -884,7 +893,7 @@ def mk_supervised_dataset( if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): @@ -899,7 +908,7 @@ def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output output_exemplar=output_exemplar, ) cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) - ds = cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + ds = cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) return ds @@ -994,7 +1003,7 @@ def mk_chat_sft_dataset( tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) @dataclass diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 23f9e240c..40a3d927b 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -4,7 +4,7 @@ import haliax from haliax import Axis -from levanter.data.text import _prepare_supervised_example, _preprocess_supervised_example +from levanter.data.text import _prepare_supervised_examples, _preprocess_supervised_example def test_supervised_eval(): @@ -77,7 +77,7 @@ def test_supervised_eval(): "sources_len": np.array(45, dtype=np.int32), } - lm_ex = _prepare_supervised_example(ex, tokenizer, Axis("position", 128)) + lm_ex = _prepare_supervised_examples([ex], tokenizer, Axis("position", 128))[0] assert lm_ex.loss_mask["position", 44] assert haliax.sum(lm_ex.loss_mask) == 1 From dcd487212276dc8411948b99900d4c66bb1cff4d Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Thu, 21 Nov 2024 23:29:32 -0800 Subject: [PATCH 05/10] Optim config drop stable and add decay (#818) --- docs/Configuration-Guide.md | 2 +- src/levanter/optim/config.py | 12 ++++++++---- tests/test_optimizer_config.py | 35 +++++++++++++++------------------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/docs/Configuration-Guide.md b/docs/Configuration-Guide.md index f20488ee2..0b00c0800 100644 --- a/docs/Configuration-Guide.md +++ b/docs/Configuration-Guide.md @@ -302,7 +302,7 @@ which are common to all optimizers (and most have to do with learning rate sched | `lr_schedule` | The type of learning rate schedule for decay. See below. | `cosine` | | `min_lr_ratio` | The minimum learning rate ratio. | `0.1` | | `warmup` | Warmup fraction or number of steps | `0.01` | -| `stable` | Stable fraction or number of steps | `0.0` | +| `decay` | Decay fraction or number of steps | `None` | | `cycles` | The number of cycles for the learning rate, or steps where cycles end | `None` | | `rewarmup` | The learning rate re-warmup, if using cycles. | `0.0` | diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index d814a6b64..7b684efeb 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -26,8 +26,8 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): """The lr scheduler operates on 4 stages: [warmup] - {[stable] - [decay]} x haps - [cooldown]""" warmup: float = 0.01 """fraction of training steps to use as warmup, or steps to use. 0.0 means no warmup""" - stable: float = 0.00 - """fraction of training steps to use as cooldown, or steps to use. 0.0 means no cooldown""" + decay: Optional[float] = None + """fraction of training steps to use as decay, or steps to use. None means full decay""" rewarmup: float = 0.0 "If using a cycle, how much of the cycle to use as re-warmup. 0.0 means no re-warmup." cooldown: Optional[float] = None @@ -174,8 +174,12 @@ def lr_scheduler(self, num_train_steps): schedules.append(warmup) boundaries.append(start + warmup_steps) - stable_steps = _convert_ratio_or_steps(self.stable, cycle_steps) - lr_decay_steps = cycle_steps - stable_steps - warmup_steps + lr_decay_steps = ( + _convert_ratio_or_steps(self.decay, cycle_steps) + if self.decay is not None + else cycle_steps - warmup_steps + ) + stable_steps = cycle_steps - warmup_steps - lr_decay_steps if stable_steps != 0: stable = optax.constant_schedule(self.learning_rate) diff --git a/tests/test_optimizer_config.py b/tests/test_optimizer_config.py index 9c5b91d7c..70737df7c 100644 --- a/tests/test_optimizer_config.py +++ b/tests/test_optimizer_config.py @@ -8,11 +8,10 @@ def test_no_stable_weirdness(): learning_rate=2e-6, # 2x10^-6 weight_decay=0.0, warmup=0.03, - stable=0.0, min_lr_ratio=0.0, lr_schedule="linear", max_grad_norm=None, - haps=None, + cycles=None, weight_decay_modules=None, default_weight_decay_mask=None, ) @@ -33,10 +32,8 @@ def test_constant_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, min_lr_ratio=1.0, # No decay lr_schedule="constant", - haps=None, cycles=None, ) @@ -52,10 +49,8 @@ def test_warmup_and_cosine_decay(): learning_rate=1e-2, weight_decay=0.0, warmup=0.1, # 10% of steps - stable=0.0, min_lr_ratio=0.1, lr_schedule="cosine", - haps=None, cycles=None, ) @@ -75,7 +70,6 @@ def test_linear_schedule_with_cycles(): learning_rate=5e-4, weight_decay=0.0, warmup=50, - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2, @@ -105,30 +99,33 @@ def test_linear_schedule_with_cycles(): assert np.isclose(sched_fn(999), 0.2 * 5e-4, atol=1e-5) -def test_haps_schedule(): +def test_wsds_schedule(): optimizer = AdamConfig( learning_rate=1e-3, weight_decay=0.0, warmup=0.0, - stable=0.0, + decay=0.1, min_lr_ratio=0.1, lr_schedule="cosine", - haps=[300, 700], + cycles=[300, 700], ) sched_fn = optimizer.lr_scheduler(1000) - # Before first haps + # First cycle assert np.isclose(sched_fn(0), 1e-3) + assert np.isclose(sched_fn(269), 1e-3) + assert sched_fn(271) < 1e-3 - # First haps + # Second cycle assert np.isclose(sched_fn(300), 1e-3) + assert np.isclose(sched_fn(659), 1e-3) + assert sched_fn(661) < 1e-3 - # After first haps - assert sched_fn(301) < 1e-3 - - # Before second haps - assert sched_fn(699) < sched_fn(301) + # Thrid cycle + assert np.isclose(sched_fn(701), 1e-3) + assert np.isclose(sched_fn(969), 1e-3) + assert sched_fn(971) < 1e-3 def test_inv_sqrt_decay_schedule(): @@ -136,10 +133,9 @@ def test_inv_sqrt_decay_schedule(): learning_rate=1e-3, weight_decay=0.0, warmup=0.1, - stable=0.0, min_lr_ratio=0.1, lr_schedule="inv_sqrt", - haps=None, + cycles=None, ) sched_fn = optimizer.lr_scheduler(100_000) @@ -157,7 +153,6 @@ def test_rewarmup_schedule(): learning_rate=1e-2, weight_decay=0.0, warmup=0.2, # 20% of cycle - stable=0.0, min_lr_ratio=0.2, lr_schedule="linear", cycles=2, From 5a4d3a685b3b3fc52279162423d63c74a1df9275 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 00:31:24 -0800 Subject: [PATCH 06/10] Bump fsspec (#824) --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 29e54b9a3..a2f89cdeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,17 +25,17 @@ dependencies = [ "equinox>=0.11.7", "jaxtyping>=0.2.34", "tokenizers>=0.15.2", - "transformers>=4.41.2,<4.46.0", + "transformers>=4.41.2,<4.47.0", "optax>=0.1.9", "wandb>=0.17.8", "draccus>=0.9.3", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", - "gcsfs>=2024.2,<2024.10", + "gcsfs>=2024.2,<2025", "braceexpand>=0.1.7", "jmp>=0.0.3", - "fsspec[http]>=2024.2,<2024.10", + "fsspec[http]>=2024.2,<2025", "tensorstore>=0.1.65", "pytimeparse>=1.1.8", "humanfriendly==10.0", From 4873d3ca38173bc6d136394eb92c912599149d69 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 00:31:48 -0800 Subject: [PATCH 07/10] rename maybe_fused_next_token_loss (#823) --- src/levanter/models/lm_model.py | 4 +- src/levanter/models/loss.py | 77 +++++++++++++++++++++++++-------- tests/test_text.py | 10 +++-- 3 files changed, 68 insertions(+), 23 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 63cb2d4e3..7e51ecadd 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -11,7 +11,7 @@ from haliax import Axis, NamedArray from levanter.models.attention import AttentionMask -from levanter.models.loss import next_token_loss +from levanter.models.loss import maybe_fused_next_token_loss LmConfigT = TypeVar("LmConfigT", bound="LmConfig") @@ -190,7 +190,7 @@ def compute_next_token_loss( """ activations = model.activations(example.tokens, example.attn_mask, key=key) - loss = next_token_loss( + loss = maybe_fused_next_token_loss( model.Pos, model.Embed, model.Vocab, diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index d705eda4d..bf0bd380e 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -10,7 +10,7 @@ from haliax.nn import cross_entropy_loss_and_log_normalizers -def next_token_loss( +def maybe_fused_next_token_loss( Pos: hax.AxisSelector, Embed: hax.AxisSelector, Vocab: hax.AxisSelector, @@ -46,6 +46,15 @@ def next_token_loss( Pos = pred_embeddings.resolve_axis(Pos) Vocab = pred_lm_head.resolve_axis(Vocab) + if block_size is None: + # Full softmax computation + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) + if dtype is not None: + logits = logits.astype(dtype) + + # Shift target tokens to predict the next token + return next_token_loss(Pos, Vocab, logits, true_ids, loss_mask, reduction, reduction_axis, logsumexp_weight) + # Shift target tokens to predict the next token target_y = hax.roll(true_ids, -1, Pos) @@ -56,22 +65,6 @@ def next_token_loss( else: loss_mask = not_last_loss_mask - if block_size is None: - # Full softmax computation - logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) - if dtype is not None: - logits = logits.astype(dtype) - target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) - return cross_entropy_and_logsumexp_penalty( - logits, - Vocab, - target_y_full, - reduction=reduction, - reduction_axis=reduction_axis, - where=loss_mask, - logsumexp_weight=logsumexp_weight, - ) - # Compute the loss with optional block-wise processing return fused_cross_entropy_loss_and_logsumexp_penalty( pred_embeddings, @@ -88,9 +81,57 @@ def next_token_loss( ) +def next_token_loss( + Pos: hax.AxisSelector, + Vocab: hax.AxisSelector, + logits: NamedArray, + true_ids: NamedArray, + loss_mask: Optional[NamedArray] = None, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + logsumexp_weight: Optional[float] = None, +): + """ + Compute the next token loss with optional logsumexp penalty. + + Args: + Pos: axis selector for the position axis + Vocab: axis selector for the vocabulary axis + logits: predicted logits + true_ids: true token IDs (not shifted) + loss_mask: mask to apply to the loss + reduction: reduction function or None to disable reduction + reduction_axis: axis to apply reduction. None means all axes + logsumexp_weight: weight for the logsumexp penalty + Returns: + NamedArray: computed loss + """ + Pos = logits.resolve_axis(Pos) + + target_y = hax.roll(true_ids, -1, Pos) + target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=logits.dtype) + + # Create a mask that excludes the last token + not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore + if loss_mask is not None: + loss_mask = loss_mask * not_last_loss_mask + else: + loss_mask = not_last_loss_mask + + return cross_entropy_and_logsumexp_penalty( + Vocab=Vocab, + pred_y=logits, + target_y=target_y_full, + reduction=reduction, + reduction_axis=reduction_axis, + where=loss_mask, + logsumexp_weight=logsumexp_weight, + ) + + def cross_entropy_and_logsumexp_penalty( - pred_y: NamedArray, Vocab: hax.Axis, + pred_y: NamedArray, target_y: NamedArray, *, reduction: Optional[hax.ReductionFunction] = hax.mean, diff --git a/tests/test_text.py b/tests/test_text.py index e4e51acbc..f293a9429 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -7,7 +7,7 @@ from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.models.lm_model import LmExample -from levanter.models.loss import next_token_loss +from levanter.models.loss import maybe_fused_next_token_loss from tests.test_utils import skip_if_hf_model_not_accessible @@ -39,8 +39,12 @@ def test_lm_example_handles_ignore_id(): lm_head = hax.zeros((Embed, Vocab)) lm_head = lm_head.at[Vocab, ignore_id].set(-100) - ignored_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask) - no_ignore_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask) + ignored_loss = maybe_fused_next_token_loss( + Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask + ) + no_ignore_loss = maybe_fused_next_token_loss( + Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask + ) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size From 290ab806a4c696703a78ceb2408749c06fc7d278 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 00:32:07 -0800 Subject: [PATCH 08/10] =?UTF-8?q?move=20logging=20and=20types=20to=20util?= =?UTF-8?q?=20to=20make=20python's=20module=20resolution=20hap=E2=80=A6=20?= =?UTF-8?q?(#820)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …pierq --- src/levanter/__init__.py | 1 - src/levanter/callbacks.py | 2 +- src/levanter/checkpoint.py | 2 +- src/levanter/compat/hf_checkpoints.py | 2 +- src/levanter/data/audio.py | 2 +- src/levanter/data/text.py | 6 +++--- src/levanter/doremi.py | 2 +- src/levanter/eval.py | 2 +- src/levanter/lora.py | 2 +- src/levanter/main/cache_dataset.py | 2 +- src/levanter/models/backpack.py | 2 +- src/levanter/models/gemma.py | 4 ++-- src/levanter/models/gpt2.py | 2 +- src/levanter/models/llama.py | 4 ++-- src/levanter/models/mistral.py | 2 +- src/levanter/models/mpt.py | 2 +- src/levanter/models/qwen.py | 4 ++-- src/levanter/models/whisper.py | 2 +- src/levanter/trainer.py | 6 +++--- src/levanter/trainer_state.py | 2 +- src/levanter/utils/hf_utils.py | 2 +- src/levanter/{ => utils}/logging.py | 0 src/levanter/{ => utils}/types.py | 0 23 files changed, 27 insertions(+), 28 deletions(-) rename src/levanter/{ => utils}/logging.py (100%) rename src/levanter/{ => utils}/types.py (100%) diff --git a/src/levanter/__init__.py b/src/levanter/__init__.py index b969828bc..f9570aaf7 100644 --- a/src/levanter/__init__.py +++ b/src/levanter/__init__.py @@ -3,7 +3,6 @@ import levanter.data as data import levanter.distributed as distributed import levanter.eval as eval -import levanter.logging as logging import levanter.models as models import levanter.optim as optim import levanter.tracker as tracker diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 897109ffc..983750685 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -19,12 +19,12 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader -from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig from levanter.trainer import StepInfo from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python +from levanter.utils.logging import save_xla_dumps_to_wandb from levanter.visualization import compute_and_visualize_log_probs as viz_probs diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 38b039f20..ba684b8e5 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -26,7 +26,7 @@ from haliax.jax_utils import is_in_jit, is_jax_array_like from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore -from levanter.types import FilterSpec +from levanter.utils.types import FilterSpec logger = logging.getLogger(__name__) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 5822c3fba..f4ad33757 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -33,7 +33,6 @@ from haliax.partitioning import ResourceMapping from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRMixin from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.trainer import StepInfo @@ -41,6 +40,7 @@ from levanter.utils.cloud_utils import temp_dir_before_upload from levanter.utils.hf_utils import HfTokenizer from levanter.utils.jax_utils import best_effort_sharding, local_cpu_mesh, use_cpu_device +from levanter.utils.logging import silence_transformer_nag from levanter.utils.py_utils import dataclass_with_default_init, logical_cpu_memory_size diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 9bfc1e142..b2235e863 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -30,10 +30,10 @@ from levanter.data.text import BatchTokenizer # intercept the logging nonsense here -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import AudioTextExample from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache from levanter.utils.jax_utils import key_iterator +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() # noqa diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 053372207..1532a7d06 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -30,9 +30,6 @@ from levanter.data import AsyncDataset from levanter.data.dataset import MappedAsyncDataset from levanter.data.mixture import MixtureDataset, StopStrategy - -# intercept the logging nonsense here -from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample from levanter.store.cache import CacheOptions, TreeCache @@ -41,6 +38,9 @@ from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import HfTokenizer, num_cpus_used_by_tokenizer +# intercept the logging nonsense here +from levanter.utils.logging import silence_transformer_nag # noqa + silence_transformer_nag() # noqa from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 6d9165cfc..c066e55d5 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -18,8 +18,8 @@ from levanter.data.mixture import MixtureDataset from levanter.tracker import capture_time from levanter.trainer import M, StepInfo, Trainer, TrainerConfig, TrainerState -from levanter.types import ComputeLossFunction from levanter.utils.tree_utils import inference_mode +from levanter.utils.types import ComputeLossFunction logger = logging.getLogger(__name__) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 9fe9ab0d7..6f40888cd 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -17,10 +17,10 @@ import levanter.tracker from levanter.data import AsyncDataset, DataLoader -from levanter.logging import LoadingTimeTrackerIterator from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.trainer import StepInfo from levanter.utils.hf_utils import HfTokenizer, byte_length_of_token +from levanter.utils.logging import LoadingTimeTrackerIterator from levanter.utils.stat_utils import Arrayish, RunningMean from levanter.utils.tree_utils import inference_mode diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 1e0f37d67..cdabee3a5 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -64,10 +64,10 @@ ) from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef, upload_to_hub -from levanter.logging import silence_transformer_nag from levanter.trainer import StepInfo from levanter.utils.cloud_utils import temp_dir_before_upload from levanter.utils.jax_utils import join_key, key_iterator, leaf_key_paths +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index 92471e997..73eb518b2 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -6,9 +6,9 @@ from levanter.data.metrics_monitor import LoggingMetricsMonitor, RichMetricsMonitor from levanter.data.text import BatchTokenizer, LMDatasetConfig from levanter.distributed import RayConfig -from levanter.logging import init_logging from levanter.store.cache import build_or_load_cache from levanter.tracker import NoopConfig, TrackerConfig +from levanter.utils.logging import init_logging logger = logging.getLogger(__name__) diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 715706f8e..42157f947 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -15,10 +15,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization, StateDict, with_prefix from levanter.compat.hf_checkpoints import HFCheckpointConverter, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, materialize_mask from levanter.models.gpt2 import ACT2FN, Gpt2Config, Gpt2Transformer from levanter.models.lm_model import LmConfig +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index 23e2bf6dc..93c360792 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -14,7 +14,6 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import ( # Gemma attention and MLP is identical to LLama LlamaAttention, @@ -23,8 +22,9 @@ ) from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 1d2fe5892..db2ed693c 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -17,10 +17,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 6b04ec540..76a786fd9 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -15,13 +15,13 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.gpt2 import ACT2FN from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index b9f19ef41..d7ac00b83 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -11,11 +11,11 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index e77e967d7..8a2d6a1c5 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -19,11 +19,11 @@ import levanter.models.attention from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmConfig from levanter.utils.flop_utils import lm_flops_per_token from levanter.utils.jax_utils import use_cpu_device +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py index 807a768ad..7f8afa951 100644 --- a/src/levanter/models/qwen.py +++ b/src/levanter/models/qwen.py @@ -13,13 +13,13 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, dot_product_attention from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import RotaryEmbeddingsConfig -from levanter.types import BlockFoldable from levanter.utils.flop_utils import lm_flops_per_token +from levanter.utils.logging import silence_transformer_nag +from levanter.utils.types import BlockFoldable silence_transformer_nag() diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index 7239626f7..a9c5d528b 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -17,10 +17,10 @@ from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, ModelWithHfSerializationMixin -from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRConfig, ASRMixin from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig +from levanter.utils.logging import silence_transformer_nag silence_transformer_nag() diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index eee27cdeb..fb353592d 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -42,9 +42,9 @@ from haliax.types import Scalar import levanter.checkpoint -import levanter.logging import levanter.tracker import levanter.tracker.wandb +import levanter.utils.logging from levanter import tracker from levanter.checkpoint import CheckpointerConfig, load_checkpoint_or_initialize from levanter.config import JsonAtom @@ -53,10 +53,10 @@ from levanter.grad_accum import microbatched from levanter.tracker import TrackerConfig, capture_time from levanter.trainer_state import TrainerState, saveable_training_mask -from levanter.types import ComputeLossFunction, FilterSpec from levanter.utils import cloud_utils, fsspec_utils from levanter.utils.jax_utils import create_fsdp_mesh from levanter.utils.tree_utils import inference_mode +from levanter.utils.types import ComputeLossFunction, FilterSpec logger = pylogging.getLogger(__name__) @@ -626,7 +626,7 @@ def initialize(self): self._validate_and_set_defaults() id = self._maybe_set_id() - levanter.logging.init_logging(self.log_dir, f"{id}.log") + levanter.utils.logging.init_logging(self.log_dir, f"{id}.log") _initialize_global_tracker(self.tracker, id) self.ray.initialize() diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index 15800bd17..549267681 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -12,8 +12,8 @@ from haliax.quantization import Fp8Config, apply_updates, fp8_linear_layers, partition_for_grad_overwrite from haliax.types import IntScalar, Scalar -from levanter.types import FilterTree from levanter.utils.jax_utils import is_inexact_arrayish +from levanter.utils.types import FilterTree M = TypeVar("M", bound=PyTree) diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 41e4488d4..08205edf6 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -4,7 +4,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from levanter.logging import silence_transformer_nag +from levanter.utils.logging import silence_transformer_nag from levanter.utils.py_utils import logical_cpu_core_count diff --git a/src/levanter/logging.py b/src/levanter/utils/logging.py similarity index 100% rename from src/levanter/logging.py rename to src/levanter/utils/logging.py diff --git a/src/levanter/types.py b/src/levanter/utils/types.py similarity index 100% rename from src/levanter/types.py rename to src/levanter/utils/types.py From a472cd4730d010278f0e0c432595a553eec19e74 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 00:33:17 -0800 Subject: [PATCH 09/10] hijack HF's download so it works with gcs etc. (#819) --- src/levanter/compat/hf_checkpoints.py | 201 ++++++++++++++------------ 1 file changed, 107 insertions(+), 94 deletions(-) diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index f4ad33757..dc2f0e16d 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -1,4 +1,5 @@ import abc +import contextlib import dataclasses import json import logging @@ -10,7 +11,6 @@ from dataclasses import dataclass from functools import cached_property from typing import Generic, Optional, Tuple, Type, TypeVar, Union, cast -from urllib.parse import urlparse import draccus import equinox as eqx @@ -21,8 +21,9 @@ import mergedeep import safetensors import safetensors.numpy +import transformers.utils.hub from huggingface_hub import HfApi, hf_hub_download, repo_exists, snapshot_download -from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError +from huggingface_hub.utils import EntryNotFoundError, GatedRepoError, HFValidationError, RepositoryNotFoundError from jax.experimental.multihost_utils import sync_global_devices from jax.random import PRNGKey from jaxtyping import Array @@ -324,11 +325,8 @@ def _infer_config_class(hf_config_class, ref, trust_remote_code): if ref is None: raise ValueError("Must provide either config class or reference_checkpoint") path, rev = ref.model_name_or_path, ref.revision - config = AutoConfig.from_pretrained( - path, - revision=rev, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=trust_remote_code) clss = type(config) elif isinstance(hf_config_class, str): if ref is None: @@ -423,7 +421,9 @@ def config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) - def hf_config_from_hf_checkpoint(self, ref: Optional[Union[str, RepoRef]] = None) -> HfConfig: path, rev = self._get_ref(ref) - config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code) + + with _patch_hf_hub_download(): + config = AutoConfig.from_pretrained(path, revision=rev, trust_remote_code=self.trust_remote_code) return config def _get_ref(self, ref) -> Tuple[str, Optional[str]]: @@ -450,49 +450,51 @@ def load_state_dict(self, ref: Optional[Union[str, RepoRef]] = None, dtype: Opti except HFValidationError: pass - # TODO: load models from gcs etc. - if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)): - state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype) - elif os.path.exists(os.path.join(id, PYTORCH_MODEL)): - state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype) - else: - try: - model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) - state_dict = _load_safe_tensors(model_path, dtype) - except (EntryNotFoundError, HFValidationError): - model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev) - state_dict = _load_torch(model_path, dtype) + with _patch_hf_hub_download() as hf_hub_download: + # TODO: load models from gcs etc. + if os.path.exists(os.path.join(id, SAFE_TENSORS_MODEL)): + state_dict = _load_safe_tensors(os.path.join(id, SAFE_TENSORS_MODEL), dtype) + elif os.path.exists(os.path.join(id, PYTORCH_MODEL)): + state_dict = _load_torch(os.path.join(id, PYTORCH_MODEL), dtype) + else: + try: + model_path = hf_hub_download(id, SAFE_TENSORS_MODEL, revision=rev) + state_dict = _load_safe_tensors(model_path, dtype) + except (EntryNotFoundError, HFValidationError): + model_path = hf_hub_download(id, PYTORCH_MODEL, revision=rev) + state_dict = _load_torch(model_path, dtype) - return state_dict + return state_dict def _load_shards(self, id: str, index_file: str, rev: Optional[str], dtype) -> dict: """Load model from sharded files based on the provided index.""" - index_path = os.path.join(id, index_file) - if not os.path.exists(index_path): - # Download the index file if not found locally - index_path = hf_hub_download(id, index_file, revision=rev) - - with open(index_path, "r", encoding="utf-8") as f: - index = json.load(f) - - shard_files = list(set(index["weight_map"].values())) - final_state_dict = {} - - # right now we do safe tensors thing - # where we load into memory then update some dict - if "safetensors" in index_file: - loader = _load_safe_tensors - else: - loader = _load_torch + with _patch_hf_hub_download() as hf_hub_download: + index_path = os.path.join(id, index_file) + if not os.path.exists(index_path): + # Download the index file if not found locally + index_path = hf_hub_download(id, index_file, revision=rev) + + with open(index_path, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + final_state_dict = {} + + # right now we do safe tensors thing + # where we load into memory then update some dict + if "safetensors" in index_file: + loader = _load_safe_tensors + else: + loader = _load_torch - for shard_file in shard_files: - shard_path = os.path.join(id, shard_file) - if not os.path.exists(shard_path): - # Download the shard if not found locally - shard_path = hf_hub_download(id, shard_file, revision=rev) + for shard_file in shard_files: + shard_path = os.path.join(id, shard_file) + if not os.path.exists(shard_path): + # Download the shard if not found locally + shard_path = hf_hub_download(id, shard_file, revision=rev) - shard_state_dict = loader(shard_path, dtype) - final_state_dict.update(shard_state_dict) + shard_state_dict = loader(shard_path, dtype) + final_state_dict.update(shard_state_dict) return final_state_dict @@ -588,22 +590,6 @@ def load_from_state_dict(template, state_dict): lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0)) lev_model = load_from_state_dict(lev_model, state_dict) - # all_arrays: list[jax.Array] = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - # gc.collect() # sometimes takes a while to free buffers otherwise - # try: - # get_backend().defragment() - # except Exception as e: - # warnings.warn(f"Could not defragment because {e}") - # pass - # all_arrays = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - # all_arrays = get_backend().live_arrays() - # total_size = sum(a.size * a.itemsize for a in all_arrays) - # print(f"Total size of live arrays: {total_size / 1e9:.2f} GB") - return lev_model def _save_pretrained_local( @@ -874,45 +860,20 @@ def cb(step: StepInfo): return cb -def arbitrary_load_from_hf( - model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True -) -> Union[HfTokenizer | ProcessorMixin]: - is_url_like = urlparse(model_name_or_path).scheme != "" - if is_url_like: - if revision is not None: - raise ValueError("revision is not supported for URLs") - # tokenizers are directories, so we have to copy them locally - if local_cache_dir is None: - local_cache_dir = tempfile.mkdtemp() - - fs, path = fsspec.core.url_to_fs(model_name_or_path) - fs.get(path, local_cache_dir, recursive=True) - base_path = os.path.basename(path) - return from_pretrained_lambda(os.path.join(local_cache_dir, base_path), trust_remote_code=trust_remote_code) - else: - return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code) - - def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer: """Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec""" - return arbitrary_load_from_hf( - model_name_or_path, - AutoTokenizer.from_pretrained, - revision=revision, - local_cache_dir=local_cache_dir, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + return AutoTokenizer.from_pretrained( + model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code + ) def load_processor(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> ProcessorMixin: """Like AutoProcessor.from_pretrained, but works with gs:// paths or anything on fsspec""" - return arbitrary_load_from_hf( - model_name_or_path, - AutoProcessor.from_pretrained, - revision=revision, - local_cache_dir=local_cache_dir, - trust_remote_code=trust_remote_code, - ) + with _patch_hf_hub_download(): + return AutoProcessor.from_pretrained( + model_name_or_path, revision=revision, cache_dir=local_cache_dir, trust_remote_code=trust_remote_code + ) _sync_count = 0 @@ -1111,3 +1072,55 @@ def _should_use_cpu_for_checkpoint_loading(): return False if sum(accel_memory) < cpu_memory: return True + + +def _is_hf_hub_model(ref: RepoRef): + api = HfApi() + + try: + api.model_info(repo_id=ref.model_name_or_path) + return True + except RepositoryNotFoundError: + return False + + +@contextlib.contextmanager +def _patch_hf_hub_download(): + """ + Temporarily monkeypatch `hf_hub_download` to handle fsspec URLs, ensuring the temporary directory + persists for the lifetime of the context manager. + """ + original_hf_hub_download = transformers.utils.hub.hf_hub_download + + # Create a temporary directory that persists through the context manager + with tempfile.TemporaryDirectory() as tmpdir: + + def custom_hf_hub_download(*args, **kwargs): + """ + Custom implementation of hf_hub_download to handle fsspec URLs. + """ + repo_id = kwargs.get("repo_id", args[0] if len(args) > 0 else None) + filename = kwargs.get("filename", args[1] if len(args) > 1 else None) + + if repo_id and filename and _is_url_like(repo_id): + fs, path = fsspec.core.url_to_fs(repo_id) + remote_path = os.path.join(path, filename) + local_path = os.path.join(tmpdir, filename) + + if not fs.exists(remote_path): + raise EntryNotFoundError(f"File {remote_path} not found") + + fs.get(remote_path, local_path) + return local_path + + # Fallback to the original implementation + return original_hf_hub_download(*args, **kwargs) + + # Monkeypatch hf_hub_download + transformers.utils.hub.hf_hub_download = custom_hf_hub_download + + try: + yield custom_hf_hub_download + finally: + # Restore the original implementation + transformers.utils.hub.hf_hub_download = original_hf_hub_download From 574f9333e354c88d35608d21aa8ed8dfb8109e1d Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 24 Nov 2024 01:00:08 -0800 Subject: [PATCH 10/10] bump jax version (#822) --- docker/tpu/Dockerfile.base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/tpu/Dockerfile.base b/docker/tpu/Dockerfile.base index d276c974d..e2e032e82 100644 --- a/docker/tpu/Dockerfile.base +++ b/docker/tpu/Dockerfile.base @@ -5,7 +5,7 @@ RUN pip install virtualenv # venv binaries encode their directory, so we need to setup the venv in the final location RUN virtualenv -p python3.10 /opt/levanter/.venv ENV PATH /opt/levanter/.venv/bin:$PATH -#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install package dependencies to make incremental builds faster.