From 02e6d00dd89e4f77cdbf4b99c7ef756cfd17bbad Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 12 Dec 2024 13:12:18 -0800 Subject: [PATCH 01/18] switch to doing lm eval harness evals in bf16 --- src/levanter/eval_harness.py | 35 ++++++++++++++++++++++++++--------- src/levanter/main/train_lm.py | 4 +++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index b9f4381fe..c679eab66 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -29,6 +29,7 @@ import equinox as eqx import jax import jax.numpy as jnp +import jmp import numpy as np import haliax @@ -139,13 +140,22 @@ def _truncate_or_pad(self, encoded: list[int], prompt_length: int): class LevanterHarnessLM(LM): - def __init__(self, EvalBatch: hax.Axis, EvalPos: hax.Axis, model: LmHeadModel, axis_resources, tokenizer): + def __init__( + self, + EvalBatch: hax.Axis, + EvalPos: hax.Axis, + model: LmHeadModel, + axis_resources, + tokenizer, + mp: jmp.Policy | None, + ): super().__init__() self.EvalBatch = EvalBatch self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources self.tokenizer = tokenizer + self.mp = mp def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]: """ @@ -153,6 +163,10 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr - loss: The negative log-likelihood of the completion. - correct: Whether the completion is correct """ + + if self.mp is not None: + model = self.mp.cast_to_compute(model) + logits = model(example.tokens, attn_mask=example.attn_mask) logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) @@ -352,11 +366,11 @@ def run_lm_eval_harness( tokenizer, EvalBatch, axis_resources, + mp: jmp.Policy | None, ) -> dict: - # tasks_to_run = tasks.get_task_dict(config.task_spec_or_default(), tasks.TaskManager()) tasks_to_run = config.to_task_dict() - outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources) + outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources, mp) return outputs @@ -368,6 +382,7 @@ def _actually_run_eval_harness( tokenizer: HfTokenizer, EvalBatch: haliax.Axis, axis_resources: ResourceMapping, + mp: jmp.Policy | None, ): """ Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used @@ -382,7 +397,7 @@ def _actually_run_eval_harness( max_eval_length = config.max_eval_length EvalPos = model.Pos if max_eval_length is None else model.Pos.resize(max_eval_length) - harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer) + harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) # we always set log_samples here and filter out the samples later if we don't want them outputs = evaluator.evaluate(harness, tasks_to_run, limit=max_examples, log_samples=True) @@ -572,6 +587,7 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): tokenizer, config.EvalBatch, axis_resources=compute_axis_mapping, + mp=config.trainer.mp, ) logger.info("Finished running LM eval harness") @@ -615,12 +631,12 @@ def log_report_to_tracker(prefix: str, report: dict, tracker: Optional[levanter. tracker.log(to_log, step=None) -def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources): +def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_resources, mp: jmp.Policy | None): tasks_to_run = config.to_task_dict() def lm_eval_harness(step: StepInfo, force=False): - if step.step == 0 and not force: - return # don't run eval on the first step + # if step.step == 0 and not force: + # return # don't run eval on the first step model = inference_mode(step.model, True) outputs = _actually_run_eval_harness( @@ -630,9 +646,12 @@ def lm_eval_harness(step: StepInfo, force=False): tokenizer, EvalBatch, axis_resources, + mp, ) if jax.process_index() == 0: + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json @@ -642,8 +661,6 @@ def lm_eval_harness(step: StepInfo, force=False): f.name, name=f"lm_eval_harness_results.{step.step}.json", type="lm_eval_output" ) - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - return lm_eval_harness diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 75be8d206..15da92cfa 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -263,7 +263,9 @@ def main(config: TrainLmConfig): if config.eval_harness is not None: eval_harness = config.eval_harness trainer.add_hook( - levanter.eval_harness.lm_eval_harness(eval_harness, tokenizer, EvalBatch, compute_axis_mapping), + levanter.eval_harness.lm_eval_harness( + eval_harness, tokenizer, EvalBatch, compute_axis_mapping, trainer.mp + ), every=config.eval_harness_steps, ) From ef0df266b2402ed4a05bd4322f496019df052410 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Dec 2024 11:20:58 -0800 Subject: [PATCH 02/18] inital cut at sequence packing --- src/levanter/data/packing.py | 152 ++++++++++++++++++++++++++++++++ src/levanter/models/lm_model.py | 42 +++++---- tests/test_packing.py | 143 ++++++++++++++++++++++++++++++ 3 files changed, 322 insertions(+), 15 deletions(-) create mode 100644 src/levanter/data/packing.py create mode 100644 tests/test_packing.py diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py new file mode 100644 index 000000000..9709a98ba --- /dev/null +++ b/src/levanter/data/packing.py @@ -0,0 +1,152 @@ +# Implements sequence packing +from dataclasses import dataclass +from typing import Iterator + +import jax.numpy as jnp +import numpy as np + +import haliax as hax + +from levanter.models.attention import AttentionMask +from levanter.models.lm_model import LmExample + + +# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 + +# todo should we use something like this: https://arxiv.org/pdf/2107.02027? + + +class SequencePacker: + """ + Packs sequences into a single LmExample. + """ + + def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int): + self.Pos = Pos + self._ids: list[int] = [] + self._segment_ids: list[int] = [] + self._loss_mask: list[int] = [] + self.num_segments = 0 + self.pad_token = pad_token + self.max_pack_size = max_pack_size + + def can_pack(self, ids: list[int]) -> bool: + return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size + + def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None): + if len(ids) != len(loss_mask): + raise ValueError("ids and loss_mask must have the same length") + + if len(ids) == 0: + return + + if len(ids) + len(self._ids) > self.Pos.size: + raise ValueError("Too many tokens") + + if self.num_segments >= self.max_pack_size: + raise ValueError("Too many segments") + + self._ids.extend(ids) + if segment_id is None: + segment_id = self.num_segments + + self.num_segments += 1 + + self._segment_ids.extend([segment_id] * len(ids)) + + self._loss_mask.extend(loss_mask) + + def pack(self) -> LmExample: + ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids)) + + segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids)) + + loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask)) + + tokens = hax.named(ids, self.Pos) + segment_ids = hax.named(segment_ids, self.Pos) + loss_mask = hax.named(loss_mask, self.Pos) + + attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) + + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + + def per_segment_loss( + self, packed_example: LmExample, losses: hax.NamedArray, max_segments: int + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns a pair of arrays of shape (max_segments,), where the first array is segment ids + and the second is loss per segment. + + This code is designed to run in a jit-compiled function, meaning we have to careful of shapes + """ + + assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" + + segment_ids = packed_example.attn_mask.segment_ids.array + assert ( + segment_ids.ndim == 1 + ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + + # mask out padding etc + masked_losses = losses * packed_example.loss_mask + + # sum the losses for each segment + # Extract unique segment IDs with padding + unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) + + # Create a mask matrix where each row corresponds to a unique segment + segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] + + segment_mask = segment_mask.astype(masked_losses.dtype) + + # segment_losses = jnp.esum(losses * segment_mask, axis=1) # [segment] + segment_losses = jnp.einsum("ij,j->i", segment_mask, masked_losses.array) + + return unique_segment_ids, segment_losses + + +@dataclass +class PromptCompletion: + ids: list[int] + prompt_length: int + segment_id: int | None = None + + +def pack_prompt_completions( + Pos: hax.Axis, + sequences: list[PromptCompletion], + pad_token: int, + max_pack_size: int = 64, + max_buffered_examples: int = 64, +) -> Iterator[LmExample]: + """ + Packs a list of prompt completions into LmExamples using the SequencePacker + """ + + packers = [SequencePacker(Pos, max_pack_size, pad_token)] + + for sequence in sequences: + loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 + loss_mask[-1] = 0 + + for packer in packers: + if packer.can_pack(sequence.ids): + packer.add_example(sequence.ids, loss_mask, sequence.segment_id) + + if packer.num_segments == max_pack_size: + yield packer.pack() + packers.remove(packer) + break + else: + # no packer could fit the example, create a new one + packer = SequencePacker(Pos, max_pack_size, pad_token) + packer.add_example(sequence.ids, loss_mask, sequence.segment_id) + packers.append(packer) + + while len(packers) >= max_buffered_examples: + yield packer.pack() + packers.pop(0) + + for packer in packers: + yield packer.pack() diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7f5c0e3d8..495a0dfea 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -4,7 +4,6 @@ import draccus import equinox as eqx -import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -31,6 +30,7 @@ def causal( loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None, eos_id: Optional[int] = None, + segment_ids: Optional[hax.NamedArray] = None, ) -> "LmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -40,9 +40,12 @@ def causal( Pos = tokens.axes[0] - # don't predict the last token. - if loss_mask is None: - loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) + causal_loss_mask = LmExample.causal_loss_mask(Pos) + + if loss_mask is not None: + loss_mask = loss_mask & causal_loss_mask + else: + loss_mask = causal_loss_mask if ignore_id is not None: # we don't compute loss for any tokens matching the ignore index @@ -51,7 +54,7 @@ def causal( attn_mask = AttentionMask.causal() - if eos_id is not None: + if eos_id is not None and segment_ids is None: # the next token after an eos token is in a new segment eos_mask = hax.roll(tokens, 1, Pos) == eos_id # first token is always in segment 0 @@ -70,24 +73,33 @@ def from_prompt_and_completion( ignore_id: Optional[int] = None, all_causal: bool = True, ) -> "LmExample": - # mask out the prompt tokens - loss_mask = hax.arange(Pos) >= prompt_length - 1 - # don't predict the padding - if ignore_id is not None: - targets = hax.roll(tokens, -1, Pos) - loss_mask = loss_mask & (targets != ignore_id) - - # don't predict the last token - loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - if all_causal: attn_mask = AttentionMask.causal() else: # causal just for the completion part. We don't have a special structured mask for this, so we just raise NotImplementedError("Not implemented yet") + # mask out the prompt tokens + loss_mask = LmExample.causal_loss_mask(Pos, prompt_length=prompt_length) + + if ignore_id is not None: + # we don't compute loss for any tokens matching the ignore index + ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id + loss_mask = loss_mask * ignore_mask + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def causal_loss_mask(Pos: Axis, prompt_length: Optional[int] = None) -> NamedArray: + loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) + + if prompt_length is not None: + # don't predict the prompt tokens + prompt_mask = hax.arange(Pos) >= prompt_length - 1 + loss_mask = loss_mask * prompt_mask + + return loss_mask + # TODO: for some reason, mypy doesn't like the discover_packages_path argument? @dataclass(frozen=True) diff --git a/tests/test_packing.py b/tests/test_packing.py new file mode 100644 index 000000000..88b3572e6 --- /dev/null +++ b/tests/test_packing.py @@ -0,0 +1,143 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +import haliax as hax + +from levanter.data.packing import PromptCompletion, SequencePacker, pack_prompt_completions + + +def test_per_segment_loss(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=10, pad_token=0) + + # Add two sequences + packer.add_example(ids=[1, 2, 3], loss_mask=[1, 1, 1], segment_id=None) + packer.add_example(ids=[4, 5], loss_mask=[1, 1], segment_id=None) + + # Pack into LmExample + packed = packer.pack() + + losses = hax.named(jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0]), Pos) + + unique_ids, segment_losses = packer.per_segment_loss(packed, losses, max_segments=3) + + assert list(unique_ids) == [-1, 0, 1] + assert list(segment_losses) == [0.0, 0.6, 0.9] + + +def test_can_pack_simple_case(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + assert packer.can_pack([1, 2, 3]) is True + packer.add_example(ids=[1, 2, 3], loss_mask=[1, 1, 1]) + assert packer.can_pack([4, 5]) is True + assert packer.can_pack(list(range(6, 16))) is False # Exceeds Pos size + + +def test_add_example_and_pack(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + packer.add_example([1, 2, 3], [1, 1, 1]) + packed = packer.pack() + + expected_tokens = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_exceed_max_pack_size(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + packer.add_example([1, 2, 3], [1, 1, 1]) + packer.add_example([4, 5, 6], [1, 1, 1]) + + with pytest.raises(ValueError, match="Too many segments"): + packer.add_example([7, 8], [1, 1]) # Exceeds max pack size + + +def test_empty_sequence(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + with pytest.raises(ValueError, match="ids and loss_mask must have the same length"): + packer.add_example([], [1]) # Mismatched lengths + + packer.add_example([], []) # Adding an empty sequence is allowed but does nothing + packed = packer.pack() + + expected_tokens = [0] * 10 + expected_segment_ids = [-1] * 10 + expected_loss_mask = [0] * 10 + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_packing_multiple_examples(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + # First example + packer.add_example([1, 2], [1, 1]) + # Second example + packer.add_example([3, 4, 5], [1, 1, 1]) + + packed = packer.pack() + + expected_tokens = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0] + expected_segment_ids = [0, 0, 1, 1, 1, -1, -1, -1, -1, -1] + expected_loss_mask = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_pack_prompt_completions_simple(): + Pos = hax.Axis("pos", size=10) + pad_token = 0 + max_pack_size = 2 + max_buffered_examples = 2 + + sequences = [ + PromptCompletion(ids=[1, 2, 3], prompt_length=2), + PromptCompletion(ids=[4, 5], prompt_length=1), + PromptCompletion(ids=[6, 7, 8], prompt_length=1), + ] + + results = list(pack_prompt_completions(Pos, sequences, pad_token, max_pack_size, max_buffered_examples)) + + assert len(results) == 2 # Expect two packed LmExamples + + # Check the first packed example + packed_1 = results[0] + expected_tokens_1 = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0] + expected_segment_ids_1 = [0, 0, 0, 1, 1, -1, -1, -1, -1, -1] + expected_loss_mask_1 = [0, 1, 0, 1, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_1.tokens.array, expected_tokens_1) + np.testing.assert_array_equal(packed_1.attn_mask.segment_ids.array, expected_segment_ids_1) + np.testing.assert_array_equal(packed_1.loss_mask.array, expected_loss_mask_1) + + # Check the second packed example + packed_2 = results[1] + expected_tokens_2 = [6, 7, 8, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_2 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_2 = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_2.tokens.array, expected_tokens_2) + np.testing.assert_array_equal(packed_2.attn_mask.segment_ids.array, expected_segment_ids_2) + np.testing.assert_array_equal(packed_2.loss_mask.array, expected_loss_mask_2) + + +if __name__ == "__main__": + pytest.main() From c25ee9104b20c49ee421e3403b47fc87f8cf3d5c Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Dec 2024 11:22:01 -0800 Subject: [PATCH 03/18] inital cut at sequence packing --- src/levanter/data/packing.py | 69 ++++++++++++++++++------------------ tests/test_packing.py | 2 +- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 9709a98ba..47510536e 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -71,40 +71,6 @@ def pack(self) -> LmExample: return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) - def per_segment_loss( - self, packed_example: LmExample, losses: hax.NamedArray, max_segments: int - ) -> tuple[jnp.ndarray, jnp.ndarray]: - """ - Returns a pair of arrays of shape (max_segments,), where the first array is segment ids - and the second is loss per segment. - - This code is designed to run in a jit-compiled function, meaning we have to careful of shapes - """ - - assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" - - segment_ids = packed_example.attn_mask.segment_ids.array - assert ( - segment_ids.ndim == 1 - ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" - - # mask out padding etc - masked_losses = losses * packed_example.loss_mask - - # sum the losses for each segment - # Extract unique segment IDs with padding - unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) - - # Create a mask matrix where each row corresponds to a unique segment - segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] - - segment_mask = segment_mask.astype(masked_losses.dtype) - - # segment_losses = jnp.esum(losses * segment_mask, axis=1) # [segment] - segment_losses = jnp.einsum("ij,j->i", segment_mask, masked_losses.array) - - return unique_segment_ids, segment_losses - @dataclass class PromptCompletion: @@ -150,3 +116,38 @@ def pack_prompt_completions( for packer in packers: yield packer.pack() + + +def per_segment_loss( + packed_example: LmExample, losses: hax.NamedArray, max_segments: int +) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns a pair of arrays of shape (max_segments,), where the first array is segment ids + and the second is loss per segment. + + This code is designed to run in a jit-compiled function, meaning we have to careful of shapes + """ + + assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" + + segment_ids = packed_example.attn_mask.segment_ids.array + assert ( + segment_ids.ndim == 1 + ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + + # mask out padding etc + masked_losses = losses * packed_example.loss_mask + + # sum the losses for each segment + # Extract unique segment IDs with padding + unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) + + # Create a mask matrix where each row corresponds to a unique segment + segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] + + segment_mask = segment_mask.astype(masked_losses.dtype) + + # segment_losses = jnp.esum(losses * segment_mask, axis=1) # [segment] + segment_losses = jnp.einsum("ij,j->i", segment_mask, masked_losses.array) + + return unique_segment_ids, segment_losses diff --git a/tests/test_packing.py b/tests/test_packing.py index 88b3572e6..cc0341338 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -20,7 +20,7 @@ def test_per_segment_loss(): losses = hax.named(jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0]), Pos) - unique_ids, segment_losses = packer.per_segment_loss(packed, losses, max_segments=3) + unique_ids, segment_losses = per_segment_loss(packed, losses, max_segments=3) assert list(unique_ids) == [-1, 0, 1] assert list(segment_losses) == [0.0, 0.6, 0.9] From 9738a9362ee1d6f788ddad84d35dad2f12a721a7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Dec 2024 13:46:54 -0800 Subject: [PATCH 04/18] add is_correct checking --- src/levanter/data/packing.py | 44 ++++++++++++++++++++- tests/test_packing.py | 77 +++++++++++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 47510536e..aad48d031 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -122,8 +122,10 @@ def per_segment_loss( packed_example: LmExample, losses: hax.NamedArray, max_segments: int ) -> tuple[jnp.ndarray, jnp.ndarray]: """ - Returns a pair of arrays of shape (max_segments,), where the first array is segment ids - and the second is loss per segment. + Returns a pair of arrays of shape (max_segments,), where: + + * the first array is segment ids + * the second is loss per segment. This code is designed to run in a jit-compiled function, meaning we have to careful of shapes """ @@ -151,3 +153,41 @@ def per_segment_loss( segment_losses = jnp.einsum("ij,j->i", segment_mask, masked_losses.array) return unique_segment_ids, segment_losses + + +def per_segment_correct( + packed_example: LmExample, correct: hax.NamedArray, max_segments: int +) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + Returns a pair of arrays of shape (max_segments,), where: + + * the first array is segment ids + * the second is whether all tokens in the segment are correct. + + This code is designed to run in a jit-compiled function, meaning we have to careful of shapes + + correct is a boolean array of the same shape as the losses array indicating whether the token was correct + """ + + assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" + + segment_ids = packed_example.attn_mask.segment_ids.array + assert ( + segment_ids.ndim == 1 + ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + + # mask out padding etc + masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask)) + + # sum the losses for each segment + # Extract unique segment IDs with padding + unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) + + # Create a mask matrix where each row corresponds to a unique segment + segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] + + segment_mask = segment_mask.astype(masked_correct.dtype) + + segment_correct = jnp.all(jnp.where(segment_mask, masked_correct.array, True), axis=1) + + return unique_segment_ids, segment_correct diff --git a/tests/test_packing.py b/tests/test_packing.py index cc0341338..b916eb688 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -4,7 +4,15 @@ import haliax as hax -from levanter.data.packing import PromptCompletion, SequencePacker, pack_prompt_completions +from levanter.data.packing import ( + PromptCompletion, + SequencePacker, + pack_prompt_completions, + per_segment_correct, + per_segment_loss, +) +from levanter.models.attention import AttentionMask +from levanter.models.lm_model import LmExample def test_per_segment_loss(): @@ -139,5 +147,72 @@ def test_pack_prompt_completions_simple(): np.testing.assert_array_equal(packed_2.loss_mask.array, expected_loss_mask_2) +def test_pack_prompt_completions_exceed_max_buffered_examples(): + Pos = hax.Axis("pos", size=10) + pad_token = 0 + max_pack_size = 1 + max_buffered_examples = 1 + + sequences = [ + PromptCompletion(ids=[1, 2, 3], prompt_length=2), + PromptCompletion(ids=[4, 5], prompt_length=1), + PromptCompletion(ids=[6, 7, 8], prompt_length=1), + ] + + results = list(pack_prompt_completions(Pos, sequences, pad_token, max_pack_size, max_buffered_examples)) + + assert len(results) == 3 + + # Check the first packed example + packed_1 = results[0] + expected_tokens_1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_1 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_1 = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_1.tokens.array, expected_tokens_1) + np.testing.assert_array_equal(packed_1.attn_mask.segment_ids.array, expected_segment_ids_1) + np.testing.assert_array_equal(packed_1.loss_mask.array, expected_loss_mask_1) + + # Check the second packed example + packed_2 = results[1] + expected_tokens_2 = [4, 5, 0, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_2 = [0, 0, -1, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_2 = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_2.tokens.array, expected_tokens_2) + np.testing.assert_array_equal(packed_2.attn_mask.segment_ids.array, expected_segment_ids_2) + np.testing.assert_array_equal(packed_2.loss_mask.array, expected_loss_mask_2) + + # Check the third packed example + packed_3 = results[2] + expected_tokens_3 = [6, 7, 8, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_3 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_3 = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_3.tokens.array, expected_tokens_3) + np.testing.assert_array_equal(packed_3.attn_mask.segment_ids.array, expected_segment_ids_3) + np.testing.assert_array_equal(packed_3.loss_mask.array, expected_loss_mask_3) + + +def test_segment_correct(): + # Mock segment_ids and loss_mask + Pos = hax.Axis("pos", size=10) + segment_ids = hax.named(jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2]), Pos) + loss_mask = hax.named(jnp.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), Pos) + + # Create a packed example + attn_mask = AttentionMask.causal().with_segment_ids(segment_ids=segment_ids) + packed_example = LmExample(tokens=None, loss_mask=loss_mask, attn_mask=attn_mask) + + # Mock correctness array (True for correct, False for incorrect) + correct = hax.named(jnp.array([True, True, True, False, True, False, True, True, True, True]), Pos) + + # Call the function + unique_ids, segment_correct = per_segment_correct(packed_example, correct, max_segments=4) + + assert list(unique_ids) == [0, 1, 2, -1] + assert list(segment_correct) == [True, False, True, True] + + if __name__ == "__main__": pytest.main() From 3ef5dcb5e1b6f528fd664fd1dbc2b3e13c9edb52 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 17 Dec 2024 16:31:59 -0800 Subject: [PATCH 05/18] wip --- src/levanter/data/packing.py | 62 ++++++----- src/levanter/eval_harness.py | 119 +++++++++++++++++----- src/levanter/utils/background_iterable.py | 7 +- src/levanter/utils/jax_utils.py | 2 +- tests/test_packing.py | 19 ++-- 5 files changed, 145 insertions(+), 64 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index aad48d031..e3f385a1a 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -1,6 +1,6 @@ # Implements sequence packing from dataclasses import dataclass -from typing import Iterator +from typing import Iterable, Iterator import jax.numpy as jnp import numpy as np @@ -9,6 +9,7 @@ from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample +from levanter.utils.jax_utils import local_cpu_mesh # cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 @@ -29,6 +30,7 @@ def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int): self.num_segments = 0 self.pad_token = pad_token self.max_pack_size = max_pack_size + assert pad_token is not None, "pad_token must be set" def can_pack(self, ids: list[int]) -> bool: return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size @@ -63,13 +65,14 @@ def pack(self) -> LmExample: loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask)) - tokens = hax.named(ids, self.Pos) - segment_ids = hax.named(segment_ids, self.Pos) - loss_mask = hax.named(loss_mask, self.Pos) + with local_cpu_mesh(): + tokens = hax.named(ids, self.Pos) + segment_ids = hax.named(segment_ids, self.Pos) + loss_mask = hax.named(loss_mask, self.Pos) - attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) + attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) - return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) @dataclass @@ -81,16 +84,16 @@ class PromptCompletion: def pack_prompt_completions( Pos: hax.Axis, - sequences: list[PromptCompletion], + sequences: Iterable[PromptCompletion], pad_token: int, - max_pack_size: int = 64, + max_segments_per_example: int = 64, max_buffered_examples: int = 64, ) -> Iterator[LmExample]: """ Packs a list of prompt completions into LmExamples using the SequencePacker """ - packers = [SequencePacker(Pos, max_pack_size, pad_token)] + packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] for sequence in sequences: loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 @@ -100,13 +103,13 @@ def pack_prompt_completions( if packer.can_pack(sequence.ids): packer.add_example(sequence.ids, loss_mask, sequence.segment_id) - if packer.num_segments == max_pack_size: + if packer.num_segments == max_segments_per_example: yield packer.pack() packers.remove(packer) break else: # no packer could fit the example, create a new one - packer = SequencePacker(Pos, max_pack_size, pad_token) + packer = SequencePacker(Pos, max_segments_per_example, pad_token) packer.add_example(sequence.ids, loss_mask, sequence.segment_id) packers.append(packer) @@ -119,10 +122,10 @@ def pack_prompt_completions( def per_segment_loss( - packed_example: LmExample, losses: hax.NamedArray, max_segments: int -) -> tuple[jnp.ndarray, jnp.ndarray]: + packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis +) -> tuple[hax.NamedArray, hax.NamedArray]: """ - Returns a pair of arrays of shape (max_segments,), where: + Returns a pair of arrays of shape (Segments,), where: * the first array is segment ids * the second is loss per segment. @@ -132,32 +135,37 @@ def per_segment_loss( assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" - segment_ids = packed_example.attn_mask.segment_ids.array + segment_ids = packed_example.attn_mask.segment_ids assert ( segment_ids.ndim == 1 ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + Pos = packed_example.tokens.axes[0] # mask out padding etc masked_losses = losses * packed_example.loss_mask # sum the losses for each segment - # Extract unique segment IDs with padding - unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) + unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) # Create a mask matrix where each row corresponds to a unique segment - segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] + segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) segment_mask = segment_mask.astype(masked_losses.dtype) - # segment_losses = jnp.esum(losses * segment_mask, axis=1) # [segment] - segment_losses = jnp.einsum("ij,j->i", segment_mask, masked_losses.array) + segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos) return unique_segment_ids, segment_losses +def _unique_segment_ids(max_Segments, segment_ids): + # Extract unique segment IDs with padding + # TODO: add unique to haliax + unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1) + unique_segment_ids = hax.named(unique_segment_ids, max_Segments) + return unique_segment_ids def per_segment_correct( - packed_example: LmExample, correct: hax.NamedArray, max_segments: int -) -> tuple[jnp.ndarray, jnp.ndarray]: + packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis +) -> tuple[hax.NamedArray, hax.NamedArray]: """ Returns a pair of arrays of shape (max_segments,), where: @@ -171,23 +179,25 @@ def per_segment_correct( assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" - segment_ids = packed_example.attn_mask.segment_ids.array + segment_ids = packed_example.attn_mask.segment_ids assert ( segment_ids.ndim == 1 ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + Pos = packed_example.tokens.axes[0] + # mask out padding etc masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask)) # sum the losses for each segment # Extract unique segment IDs with padding - unique_segment_ids = jnp.unique(segment_ids, size=max_segments, fill_value=-1) + unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) # Create a mask matrix where each row corresponds to a unique segment - segment_mask = unique_segment_ids[:, None] == segment_ids[None, :] # [segment, len] + segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) segment_mask = segment_mask.astype(masked_correct.dtype) - segment_correct = jnp.all(jnp.where(segment_mask, masked_correct.array, True), axis=1) + segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos) return unique_segment_ids, segment_correct diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index c679eab66..407d312df 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -24,7 +24,7 @@ import typing from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Sequence, Tuple +from typing import Iterator, List, Optional, Sequence, Tuple import equinox as eqx import jax @@ -37,8 +37,10 @@ import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer +from levanter.data.packing import per_segment_correct, per_segment_loss, PromptCompletion, pack_prompt_completions from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss +from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.hf_utils import HfTokenizer @@ -67,6 +69,39 @@ logger = logging.getLogger(__name__) +def _iterate_tokenized_requests(requests: list[Instance], tokenizer: HfTokenizer, max_len: int) -> Iterator[PromptCompletion]: + """ + Tokenize the requests and yield them as PromptCompletions, for packing into LmExamples. + """ + for i, request in enumerate(requests): + # it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length + # CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354 + context, completion = request.args + whole_enc = tokenizer(context + completion) + context_enc = tokenizer(context) + + context_enc_len = len(context_enc["input_ids"]) + whole_ids = whole_enc["input_ids"] + if len(whole_ids) > max_len: + logger.warning(f"Request {i} is too long. Truncating.") + # truncate from the left + whole_ids = whole_ids[-max_len:] + context_enc_len = max_len - len(completion) + if context_enc_len < 0: + context_enc_len = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) + + +def _pack_requests(requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int) -> Iterator[LmExample]: + packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size) + yield from pack_prompt_completions( + Pos, packed_iterator, + max_segments_per_example=max_pack_size, + pad_token=tokenizer.pad_token_id + ) + class EvalDataset(AsyncDataset[LmExample]): def __init__(self, Pos, tokenizer, examples: list[Instance]): @@ -148,6 +183,7 @@ def __init__( axis_resources, tokenizer, mp: jmp.Policy | None, + max_packed_segments: int = 64, ): super().__init__() self.EvalBatch = EvalBatch @@ -156,18 +192,20 @@ def __init__( self.axis_resources = axis_resources self.tokenizer = tokenizer self.mp = mp + self.max_packed_segments = max_packed_segments - def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]: + def _eval_loglikelihood(model: LmHeadModel, packed_example: LmExample) -> tuple[NamedArray, NamedArray, NamedArray]: """ Returns: - - loss: The negative log-likelihood of the completion. - - correct: Whether the completion is correct + - segments: The segment IDs of the completions. (shape: (Segments,)) + - loss: The log-likelihood of the completion. (shape: (Segments,)) + - correct: Whether the completion is correct or not. (shape: (Segments,)) """ if self.mp is not None: model = self.mp.cast_to_compute(model) - logits = model(example.tokens, attn_mask=example.attn_mask) + logits = model(packed_example.tokens, attn_mask=packed_example.attn_mask) logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) @@ -175,20 +213,32 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr Pos=Pos, Vocab=model.Vocab, logits=logits, - true_ids=example.tokens, - loss_mask=example.loss_mask, - reduction=hax.sum, - reduction_axis=Pos, + true_ids=packed_example.tokens, + loss_mask=packed_example.loss_mask, + reduction=None, ) - not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) + # We need to compute losses and also whether or not the completion is correct + # (i.e. the greedy prediction is the target) pred_targets = hax.argmax(logits, axis=model.Vocab) - targets = hax.roll(example.tokens, -1, axis=Pos) - # "freebie" is the positions we don't need to predict (prompt or final token's next token) - freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) - correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) + targets = hax.roll(packed_example.tokens, -1, axis=Pos) + is_correct = targets == pred_targets + + max_Segments = hax.Axis("Segments", size=self.max_packed_segments) - return -loss, correct + batched_segment_ids, batched_per_segment_losses = ( + hax.vmap(per_segment_loss, self.EvalBatch)(packed_example, loss, max_Segments) + ) + + _, batched_per_segment_correct = ( + hax.vmap(per_segment_correct, self.EvalBatch)(packed_example, is_correct, max_Segments) + ) + + segments = hax.flatten(batched_segment_ids, "segment") + losses = hax.flatten(batched_per_segment_losses, "segment") + correct = hax.flatten(batched_per_segment_correct, "segment") + + return segments, -losses, correct # no sharded outputs self._jit_loglikelihood = hax.named_jit( @@ -203,22 +253,34 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ # pad requests to be a multiple of the batch size initial_length = len(requests) - dataset = self._pad_dataset_to_batch_size(requests) + # mesh = haliax.partitioning._get_mesh() - mesh = haliax.partitioning._get_mesh() + if self.tokenizer.pad_token_id is None: + logger.warning("No pad token set. Setting to eos token.") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - loader = DataLoader( - self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources - ) + packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.max_packed_segments) + # packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) - result: list[tuple[float, bool]] = [] - for batch in tqdm(loader, desc="Loglikelihood", unit="ba"): - out_lls, out_correct = self._jit_loglikelihood(self.model, batch) - result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) + # loader = DataLoader( + # self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources + # ) - assert len(result) >= initial_length - # skip padding - result = result[:initial_length] + result_probs = np.zeros(len(requests)) + result_greedy = np.zeros(len(requests)) + + for batch in tqdm(packed_iterator, desc="Loglikelihood", unit="packed"): + out_ids, out_lls, out_correct = self._jit_loglikelihood(self.model, batch) + # result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) + # -1's are going to be where we had too few sequences to fill a batch + out_ids = np.array(out_ids.array) + out_lls = np.array(out_lls.array) + out_correct = np.array(out_correct.array) + valid_indices = out_ids != -1 + result_probs[out_ids[valid_indices]] = out_lls[valid_indices] + result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] + + result = list(zip(result_probs[:initial_length], result_greedy[:initial_length])) return result @@ -663,7 +725,8 @@ def lm_eval_harness(step: StepInfo, force=False): return lm_eval_harness - if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() print("Done", flush=True) + + diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 11a80f8ec..1a2ec53df 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -34,9 +34,12 @@ def __iter__(self): class BackgroundIterator(Iterator[Ex]): - def __init__(self, producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[Ex]]], max_capacity: Optional[int]): + def __init__(self, producer_fn: Callable[[], Iterator[Ex]|AsyncIterator[Ex]]| Iterator[Ex] | AsyncIterator[Ex], max_capacity: Optional[int]): self.max_capacity = max_capacity - self._producer_fn = producer_fn + if not callable(producer_fn): + self._producer_fn = lambda: producer_fn + else: + self._producer_fn = producer_fn self._stop_event = threading.Event() if self.max_capacity is None or self.max_capacity >= 0: diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index be77a1d99..1540353bf 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -42,7 +42,7 @@ def use_cpu_device(): @contextlib.contextmanager def local_cpu_mesh(): - """Temporarily sets the default device to CPU""" + """Temporarily sets the default device to CPU and creates a mesh with a single CPU device""" cpu = jax.local_devices(backend="cpu")[0] mesh = jax.sharding.Mesh( np.array([cpu]).reshape(1, 1, 1), (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL) diff --git a/tests/test_packing.py b/tests/test_packing.py index b916eb688..02e3fffd6 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -28,10 +28,12 @@ def test_per_segment_loss(): losses = hax.named(jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0]), Pos) - unique_ids, segment_losses = per_segment_loss(packed, losses, max_segments=3) + Segments = hax.Axis("segments", size=3) - assert list(unique_ids) == [-1, 0, 1] - assert list(segment_losses) == [0.0, 0.6, 0.9] + unique_ids, segment_losses = per_segment_loss(packed, losses, max_Segments=Segments) + + assert list(unique_ids.array) == [-1, 0, 1] + assert list(segment_losses.array) == [0.0, 0.6, 0.9] def test_can_pack_simple_case(): @@ -197,21 +199,24 @@ def test_pack_prompt_completions_exceed_max_buffered_examples(): def test_segment_correct(): # Mock segment_ids and loss_mask Pos = hax.Axis("pos", size=10) + tokens = hax.named(jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), Pos) segment_ids = hax.named(jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2]), Pos) loss_mask = hax.named(jnp.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), Pos) # Create a packed example attn_mask = AttentionMask.causal().with_segment_ids(segment_ids=segment_ids) - packed_example = LmExample(tokens=None, loss_mask=loss_mask, attn_mask=attn_mask) + packed_example = LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) # Mock correctness array (True for correct, False for incorrect) correct = hax.named(jnp.array([True, True, True, False, True, False, True, True, True, True]), Pos) + max_Segments = hax.Axis("segments", size=4) + # Call the function - unique_ids, segment_correct = per_segment_correct(packed_example, correct, max_segments=4) + unique_ids, segment_correct = per_segment_correct(packed_example, correct, max_Segments) - assert list(unique_ids) == [0, 1, 2, -1] - assert list(segment_correct) == [True, False, True, True] + assert list(unique_ids.array) == [0, 1, 2, -1] + assert list(segment_correct.array) == [True, False, True, True] if __name__ == "__main__": From d4c2d2bfcea52087362ebcaafd3442890997ea01 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 19 Dec 2024 10:47:06 -0800 Subject: [PATCH 06/18] wip --- config/harness/eval_llama3.yaml | 54 +++++++++++------------ config/harness/harness_nano.yaml | 5 ++- src/levanter/data/loader.py | 4 +- src/levanter/data/packing.py | 3 ++ src/levanter/eval_harness.py | 73 +++++++++++++++++++++++++++++--- 5 files changed, 102 insertions(+), 37 deletions(-) diff --git a/config/harness/eval_llama3.yaml b/config/harness/eval_llama3.yaml index 260620102..d15b84b74 100644 --- a/config/harness/eval_llama3.yaml +++ b/config/harness/eval_llama3.yaml @@ -2,32 +2,32 @@ eval_harness: task_spec: - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios num_fewshot: 10 - - task: agieval_lsat_ar # 3-shot tests in legal domain - num_fewshot: 3 - - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science - num_fewshot: 10 - - task: arc_challenge # a (harder) version of arc_easy - num_fewshot: 10 - - task: boolq # answer yes/no questions based on a passage - num_fewshot: 10 - - task: copa # use causal reasoning to predict the correct outcome of a given scenario - num_fewshot: 0 - - task: hellaswag # 4-way multiple choice commonsense reasoning dataset - num_fewshot: 0 - task_alias: hellaswag_0shot - - task: hellaswag # 4-way multiple choice commonsense reasoning dataset - num_fewshot: 10 - task_alias: hellaswag_10shot - - task: lambada # predict the endings of text passages - num_fewshot: 0 - - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning - num_fewshot: 0 - - task: piqa # answer questions based on a passage - num_fewshot: 10 - - task: wsc273 # Winograd Schema Challenge - num_fewshot: 0 - - task: winogrande # Winograd challenge, extended to more domains - num_fewshot: 0 +# - task: agieval_lsat_ar # 3-shot tests in legal domain +# num_fewshot: 3 +# - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science +# num_fewshot: 10 +# - task: arc_challenge # a (harder) version of arc_easy +# num_fewshot: 10 +# - task: boolq # answer yes/no questions based on a passage +# num_fewshot: 10 +# - task: copa # use causal reasoning to predict the correct outcome of a given scenario +# num_fewshot: 0 +# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset +# num_fewshot: 0 +# task_alias: hellaswag_0shot +# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset +# num_fewshot: 10 +# task_alias: hellaswag_10shot +# - task: lambada # predict the endings of text passages +# num_fewshot: 0 +# - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning +# num_fewshot: 0 +# - task: piqa # answer questions based on a passage +# num_fewshot: 10 +# - task: wsc273 # Winograd Schema Challenge +# num_fewshot: 0 +# - task: winogrande # Winograd challenge, extended to more domains +# num_fewshot: 0 # requires generation ## - task: squadv2 # reading comprehension benchmark # num_fewshot: 10 @@ -39,7 +39,7 @@ model: checkpoint_path: meta-llama/Meta-Llama-3-8B checkpoint_is_hf: true trainer: - mp: f32 + mp: p=f32,c=bfloat16 profiler: true per_device_parallelism: -1 diff --git a/config/harness/harness_nano.yaml b/config/harness/harness_nano.yaml index 833291e5c..109cfd152 100644 --- a/config/harness/harness_nano.yaml +++ b/config/harness/harness_nano.yaml @@ -1,5 +1,8 @@ eval_harness: - task_spec: ["hellaswag"] +# task_spec: ["hellaswag"] + task_spec: + - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios + num_fewshot: 1 tokenizer: "gpt2" model: type: gpt2 diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 928c9456c..dc87e549d 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -180,7 +180,7 @@ def get_local_batch(begin: int, end: int) -> list: # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example # which will require support from the datastore (i.e. tensorstore) - device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) + device_batch = stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) batch_leaves = hax.tree_util.tree_leaves(device_batch) cache[(begin, end)] = batch_leaves @@ -268,7 +268,7 @@ def _fill_queue_with_batches(self): @functools.partial(jax.jit, static_argnums=(0,)) -def _stack_tree(batch_name, individual_datums): +def stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves): if is_named_array(leaves[0]): return hax.stack(batch_name, leaves) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index e3f385a1a..71bc8b35e 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -98,6 +98,7 @@ def pack_prompt_completions( for sequence in sequences: loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 + assert np.any(loss_mask) for packer in packers: if packer.can_pack(sequence.ids): @@ -156,6 +157,7 @@ def per_segment_loss( return unique_segment_ids, segment_losses + def _unique_segment_ids(max_Segments, segment_ids): # Extract unique segment IDs with padding # TODO: add unique to haliax @@ -163,6 +165,7 @@ def _unique_segment_ids(max_Segments, segment_ids): unique_segment_ids = hax.named(unique_segment_ids, max_Segments) return unique_segment_ids + def per_segment_correct( packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis ) -> tuple[hax.NamedArray, hax.NamedArray]: diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index b495607e6..49941e5c8 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -21,6 +21,7 @@ import json import logging import tempfile +import time import typing from dataclasses import dataclass from functools import cached_property @@ -31,13 +32,16 @@ import jax.numpy as jnp import jmp import numpy as np +from optax.tree_utils import tree_zeros_like import haliax from haliax import NamedArray import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer +from levanter.data.loader import stack_tree from levanter.data.packing import PromptCompletion, pack_prompt_completions, per_segment_correct, per_segment_loss +from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss from levanter.utils.background_iterable import BackgroundIterator @@ -60,7 +64,7 @@ import levanter.config from levanter.checkpoint import load_checkpoint -from levanter.data import AsyncDataset +from levanter.data import AsyncDataset, batched from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import StepInfo, TrainerConfig from levanter.utils.jax_utils import use_cpu_device @@ -243,6 +247,11 @@ def _eval_loglikelihood( losses = hax.flatten(batched_per_segment_losses, "segment") correct = hax.flatten(batched_per_segment_correct, "segment") + jax.debug.inspect_array_sharding( + batched_segment_ids, callback=lambda x: print(f"batched Segment ids: {x}") + ) + jax.debug.inspect_array_sharding(batched_segment_ids, callback=lambda x: print(f"Segment ids: {x}")) + return segments, -losses, correct # no sharded outputs @@ -258,7 +267,6 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ # pad requests to be a multiple of the batch size initial_length = len(requests) - # mesh = haliax.partitioning._get_mesh() if self.tokenizer.pad_token_id is None: logger.warning("No pad token set. Setting to eos token.") @@ -269,17 +277,61 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: result_probs = np.zeros(len(requests)) result_greedy = np.zeros(len(requests)) + covered_points = np.zeros(len(requests), dtype=bool) + + time_in = time.time() + for q, batch in enumerate( + tqdm(batched(packed_iterator, self.EvalBatch.size), desc="Loglikelihood", unit="ba") + ): + segments_this_batch = set() + for i in range(len(batch)): + segments_this_batch.update(np.unique(batch[i].attn_mask.segment_ids.array).tolist()) + + try: + segments_this_batch.remove(-1) + except KeyError: + pass + + orig_batch_len = len(batch) + print(f"{q} {jax.process_index()} tokens: {np.array(batch[0].tokens.array).tolist()}") + # print(f"{q} {jax.process_index()} mask: {np.array(batch[0].loss_mask.array).tolist()}") + print( + f"{q} {jax.process_index()} attn: {np.unique(np.array(batch[0].attn_mask.segment_ids.array)).tolist()}" + ) + if len(batch) < self.EvalBatch.size: + dummy_instance = self._make_dummy_instance(batch) + batch.extend([dummy_instance] * (self.EvalBatch.size - len(batch))) + + stacked = stack_tree(self.EvalBatch, batch) + stacked = hax.shard(stacked, self.axis_resources) + time_batch = time.time() - for batch in tqdm(packed_iterator, desc="Loglikelihood", unit="packed"): - out_ids, out_lls, out_correct = self._jit_loglikelihood(self.model, batch) + out_ids, out_lls, out_correct = self._jit_loglikelihood(self.model, stacked) # result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # -1's are going to be where we had too few sequences to fill a batch - out_ids = np.array(out_ids.array) - out_lls = np.array(out_lls.array) - out_correct = np.array(out_correct.array) + out_ids = np.array(out_ids.array)[0 : orig_batch_len * self.max_packed_segments] + out_lls = np.array(out_lls.array)[0 : orig_batch_len * self.max_packed_segments] + out_correct = np.array(out_correct.array)[0 : orig_batch_len * self.max_packed_segments] valid_indices = out_ids != -1 + + out_ids_this_batch = out_ids[valid_indices].tolist() + + assert len(out_ids_this_batch) == len( + segments_this_batch + ), f"Batch {q} had {len(segments_this_batch)} segments, but {len(out_ids_this_batch)} loglikelihoods" + result_probs[out_ids[valid_indices]] = out_lls[valid_indices] result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] + covered_points[out_ids[valid_indices]] = True + + time_ll = time.time() + + if jax.process_index() == 0: + print(f"Batch time: {time_batch - time_in}, LL time: {time_ll - time_batch}") + time_in = time.time() + + missing_points = np.where(~covered_points)[0] + assert len(missing_points) == 0, f"Missing points: {missing_points}" result = list(zip(result_probs[:initial_length], result_greedy[:initial_length])) @@ -287,6 +339,13 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: return result + def _make_dummy_instance(self, batch): + dummy_instance: LmExample = tree_zeros_like(batch[0]) + dummy_segment_mask = hax.full(self.EvalPos, -1, dtype=jnp.int32) + dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) + dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) + return dummy_instance + def _pad_dataset_to_batch_size(self, requests): dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) From 3103c1df40dbe53deee10da57c71845b1772b2d2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 23 Dec 2024 15:55:58 -0800 Subject: [PATCH 07/18] wip --- scripts/gcs_bulk_delete.py | 17 +++++++------ src/levanter/eval_harness.py | 47 +++++++++++++++++++++++++++--------- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/scripts/gcs_bulk_delete.py b/scripts/gcs_bulk_delete.py index 564e3cd60..4a3b2d546 100644 --- a/scripts/gcs_bulk_delete.py +++ b/scripts/gcs_bulk_delete.py @@ -33,14 +33,14 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET), transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True), ), - schedule=storage_transfer_v1.types.Schedule( - schedule_start_date=Date( - year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day - ), - start_time_of_day=TimeOfDay( - hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes - ), - ), + # schedule=storage_transfer_v1.types.Schedule( + # schedule_start_date=Date( + # year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day + # ), + # start_time_of_day=TimeOfDay( + # hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes + # ), + # ), status=storage_transfer_v1.types.TransferJob.Status.ENABLED, description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}", ) @@ -48,6 +48,7 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): # Create the transfer job response = client.create_transfer_job(request={"transfer_job": transfer_job}) print(f"Created transfer job: {response.name}") + client.run_transfer_job({"job_name": response.name, "project_id": project_id}) # Wait for job completion wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 49941e5c8..7c7016154 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -250,7 +250,7 @@ def _eval_loglikelihood( jax.debug.inspect_array_sharding( batched_segment_ids, callback=lambda x: print(f"batched Segment ids: {x}") ) - jax.debug.inspect_array_sharding(batched_segment_ids, callback=lambda x: print(f"Segment ids: {x}")) + jax.debug.inspect_array_sharding(segments, callback=lambda x: print(f"Segment ids: {x}")) return segments, -losses, correct @@ -268,11 +268,26 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: # pad requests to be a multiple of the batch size initial_length = len(requests) + # so, infuriatingly, lm_eval_harness (or maybe it's hf datasets) isn't deterministic + # and so when this gets called from different workers, we get different request orderings. + # (Requests should be the same?!?) + # So we need to sort them to make sure they're in the same order. + # we can't trust the idx, so we hash the args. we have to be able to unsort, so we need an argsort + # built-in hash isn't deterministic, so we have to hash it ourselves + import hashlib + indices = np.argsort([hashlib.md5(json.dumps(req.args).encode()).digest() for req in requests]) + inverse = np.argsort(indices) + print(f"{jax.process_index()} {indices}") + + requests_for_packing = [requests[i] for i in indices] + + print(f"{jax.process_index()} {indices}") + if self.tokenizer.pad_token_id is None: logger.warning("No pad token set. Setting to eos token.") self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.max_packed_segments) + packed_iterator = _pack_requests(requests_for_packing, self.tokenizer, self.EvalPos, self.max_packed_segments) packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) result_probs = np.zeros(len(requests)) @@ -292,33 +307,40 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: except KeyError: pass + # compute a checksum of this batch so we can compare to the other workers + batch_checksum = hashlib.md5(json.dumps([np.array(batch[i].tokens.array).tolist() for i in range(len(batch))]).encode()).digest() + print(f"check {q} {jax.process_index()} {batch_checksum}") + orig_batch_len = len(batch) - print(f"{q} {jax.process_index()} tokens: {np.array(batch[0].tokens.array).tolist()}") + # print(f"{q} {jax.process_index()} tokens: {np.array(batch[0].tokens.array).tolist()}") # print(f"{q} {jax.process_index()} mask: {np.array(batch[0].loss_mask.array).tolist()}") print( - f"{q} {jax.process_index()} attn: {np.unique(np.array(batch[0].attn_mask.segment_ids.array)).tolist()}" + f"{q} {jax.process_index()} seg: {segments_this_batch} {len(segments_this_batch)} {len(batch)} {len(requests)}" ) if len(batch) < self.EvalBatch.size: dummy_instance = self._make_dummy_instance(batch) batch.extend([dummy_instance] * (self.EvalBatch.size - len(batch))) stacked = stack_tree(self.EvalBatch, batch) - stacked = hax.shard(stacked, self.axis_resources) + # stacked = hax.shard(stacked, self.axis_resources) time_batch = time.time() out_ids, out_lls, out_correct = self._jit_loglikelihood(self.model, stacked) # result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # -1's are going to be where we had too few sequences to fill a batch - out_ids = np.array(out_ids.array)[0 : orig_batch_len * self.max_packed_segments] - out_lls = np.array(out_lls.array)[0 : orig_batch_len * self.max_packed_segments] - out_correct = np.array(out_correct.array)[0 : orig_batch_len * self.max_packed_segments] + out_ids = np.array(out_ids.array) + out_lls = np.array(out_lls.array) + out_correct = np.array(out_correct.array) valid_indices = out_ids != -1 out_ids_this_batch = out_ids[valid_indices].tolist() - assert len(out_ids_this_batch) == len( - segments_this_batch - ), f"Batch {q} had {len(segments_this_batch)} segments, but {len(out_ids_this_batch)} loglikelihoods" + missing_ids = set(segments_this_batch) - set(out_ids_this_batch) + + # assert len(out_ids_this_batch) == len( + # segments_this_batch + # ), f"Batch {q} had {len(segments_this_batch)} segments, but {len(out_ids_this_batch)} loglikelihoods" + assert len(missing_ids) == 0, f"Missing segments: {missing_ids}" result_probs[out_ids[valid_indices]] = out_lls[valid_indices] result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] @@ -335,6 +357,9 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: result = list(zip(result_probs[:initial_length], result_greedy[:initial_length])) + # unsort the results + result = [result[i] for i in inverse] + logger.info(f"Finished running {len(requests)} loglikelihoods.") return result From 721a5aa428003cd7a9b60f6eaea51b98de0177b3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 26 Dec 2024 10:09:06 -0800 Subject: [PATCH 08/18] wip got a crash from jax --- src/levanter/eval_harness.py | 343 +++++++++++------------------------ 1 file changed, 109 insertions(+), 234 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 7c7016154..135dde6e0 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -17,7 +17,6 @@ """ import dataclasses -import functools import json import logging import tempfile @@ -25,7 +24,7 @@ import typing from dataclasses import dataclass from functools import cached_property -from typing import Iterator, List, Optional, Sequence, Tuple +from typing import Iterator, List, Optional, Tuple import equinox as eqx import jax @@ -64,10 +63,10 @@ import levanter.config from levanter.checkpoint import load_checkpoint -from levanter.data import AsyncDataset, batched +from levanter.data import batched from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import StepInfo, TrainerConfig -from levanter.utils.jax_utils import use_cpu_device +from levanter.utils.jax_utils import jnp_broadcast_one_to_all, use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -110,97 +109,37 @@ def _pack_requests( ) -class EvalDataset(AsyncDataset[LmExample]): - def __init__(self, Pos, tokenizer, examples: list[Instance]): - super().__init__() - self.examples = examples - self.Pos = Pos - self.tokenizer = tokenizer - - async def async_len(self) -> int: - return len(self.examples) - - async def final_length_is_known(self) -> bool: - return True - - def is_finite(self) -> bool: - return True - - async def current_len(self) -> Optional[int]: - return len(self.examples) - - async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: - out = [] - pad_token_id = self.tokenizer.pad_token_id - - # lm-harness specs that args are (context, completion) - reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] - - for context, completion in reqs: - # it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length - # CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354 - whole_enc = self.tokenizer(context + completion) - context_enc = self.tokenizer(context) - - context_enc_len = len(context_enc["input_ids"]) - - tokens, length = self._truncate_or_pad(whole_enc, context_enc_len) - example = _jit_create_example(self.Pos, tokens, length, pad_token_id) - - out.append(example) - - return out - - def _truncate_or_pad(self, encoded: list[int], prompt_length: int): - """ - Truncate or pad the encoded sequence to the maximum length of the model. - Truncates from the beginning of the sequence, so that the completion is preserved. - - Returns: - Truncated or padded sequence and the prompt length. The prompt length can be shorter than the original - length if the input was truncated. - """ - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id +# OK, so LM-Eval-Harness is not deterministic. This means we can't just run it on different workers and expect the +# order of requests to be the same. Sorting doesn't even seem to be correct (?!?!?) so we need to only run it on one +# process. +# This is our design: +# 1. Process 0 creates an LevanterHarnessLM object. +# 2. On all processes, we start a loop that waits for a request using jnp_broadcast_one_to_all +# 3. When a request is received (and it's not STOP) we process the request. The results are broadcast to all +# devices, and process 0 records htem. +# 4. When a STOP request is received, we stop the loop and process 0 returns the results. - ex_pad = self.tokenizer.pad( - encoded, - padding="max_length", - max_length=self.Pos.size, - return_tensors="np", - ) - - truncated = ex_pad["input_ids"][-self.Pos.size :] - # if we truncated the prompt, we need to adjust the prompt length - if len(truncated) < len(encoded): - prompt_length -= len(encoded) - len(truncated) - if prompt_length < 0: - prompt_length = 0 - logger.warning("Prompt length is negative after truncation. Setting to 0.") - return truncated, prompt_length +def _make_dummy_batch(EvalBatch, EvalPos): + dummy_batch = hax.vmap(LmExample.causal, EvalBatch)(hax.zeros(EvalPos, dtype=jnp.int32), + loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), + segment_ids=hax.zeros(EvalPos, dtype=jnp.int32)) + return dummy_batch -class LevanterHarnessLM(LM): - def __init__( - self, - EvalBatch: hax.Axis, - EvalPos: hax.Axis, - model: LmHeadModel, - axis_resources, - tokenizer, - mp: jmp.Policy | None, - max_packed_segments: int = 64, - ): - super().__init__() +class _LmEvalHarnessWorker: + def __init__(self, EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments): + self.tokenizer = tokenizer + self.max_packed_segments = max_packed_segments self.EvalBatch = EvalBatch self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources - self.tokenizer = tokenizer self.mp = mp self.max_packed_segments = max_packed_segments + self._dummy_batch = _make_dummy_batch(EvalBatch, EvalPos) + def _eval_loglikelihood( model: LmHeadModel, packed_example: LmExample ) -> tuple[NamedArray, NamedArray, NamedArray]: @@ -254,40 +193,89 @@ def _eval_loglikelihood( return segments, -losses, correct + # def _do_message(model, message, maybe_batch): + # dummy_result = eqx.filter_eval_shape(_eval_loglikelihood, self._dummy_batch) + # dummy_result = tree_zeros_like(dummy_result) + # + # message, my_batch = + # + # jax.lax.cond(message == _Message.LOGLIKELIHOOD, + # self._jit_loglikelihood, + # lambda *args: dummy_result, + # model, maybe_batch + # ) + + + # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} ) + def make_harness_lm(self): + if jax.process_index() == 0: + return LevanterHarnessLM(self) + else: + raise ValueError("Only process 0 can create the harness") + + def worker_message_loop(self): + while True: + message, payload = self._receive_message() + + if message == _Message.STOP: + return + elif message == _Message.LOGLIKELIHOOD: + self.process_loglikelihood(payload) + else: + raise ValueError(f"Unknown message type: {message}") + + def _receive_message(self): + stop_message = jnp.array(_Message.STOP) + mesh = hax.partitioning._get_mesh() + stop_message = jax.device_put(stop_message, jax.NamedSharding(mesh, jax.sharding.PartitionSpec())) + message, payload = jnp_broadcast_one_to_all((stop_message, self._dummy_batch)) + return message.item(), payload + + def _send_message(self, message, payload): + assert jax.process_index() == 0 + return jnp_broadcast_one_to_all((message, payload)) + + def process_loglikelihood(self, packed_request): + return self._jit_loglikelihood(self.model, packed_request) + + def dispatch_loglikelihood(self, packed_request): + self._send_message(_Message.LOGLIKELIHOOD, packed_request) + return self.process_loglikelihood(packed_request) + + def stop(self): + self._send_message(_Message.STOP, self._dummy_batch) + + +class _Message: + STOP = 0 + LOGLIKELIHOOD = 1 + + +class LevanterHarnessLM(LM): + def __init__(self, leader: _LmEvalHarnessWorker): + super().__init__() + self.leader = leader + + tokenizer = property(lambda self: self.leader.tokenizer) + EvalBatch = property(lambda self: self.leader.EvalBatch) + EvalPos = property(lambda self: self.leader.EvalPos) + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. """ - # pad requests to be a multiple of the batch size - initial_length = len(requests) - - # so, infuriatingly, lm_eval_harness (or maybe it's hf datasets) isn't deterministic - # and so when this gets called from different workers, we get different request orderings. - # (Requests should be the same?!?) - # So we need to sort them to make sure they're in the same order. - # we can't trust the idx, so we hash the args. we have to be able to unsort, so we need an argsort - # built-in hash isn't deterministic, so we have to hash it ourselves - import hashlib - indices = np.argsort([hashlib.md5(json.dumps(req.args).encode()).digest() for req in requests]) - inverse = np.argsort(indices) - print(f"{jax.process_index()} {indices}") - - requests_for_packing = [requests[i] for i in indices] - - print(f"{jax.process_index()} {indices}") - if self.tokenizer.pad_token_id is None: logger.warning("No pad token set. Setting to eos token.") self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - packed_iterator = _pack_requests(requests_for_packing, self.tokenizer, self.EvalPos, self.max_packed_segments) + packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.leader.max_packed_segments) packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) result_probs = np.zeros(len(requests)) @@ -307,16 +295,6 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: except KeyError: pass - # compute a checksum of this batch so we can compare to the other workers - batch_checksum = hashlib.md5(json.dumps([np.array(batch[i].tokens.array).tolist() for i in range(len(batch))]).encode()).digest() - print(f"check {q} {jax.process_index()} {batch_checksum}") - - orig_batch_len = len(batch) - # print(f"{q} {jax.process_index()} tokens: {np.array(batch[0].tokens.array).tolist()}") - # print(f"{q} {jax.process_index()} mask: {np.array(batch[0].loss_mask.array).tolist()}") - print( - f"{q} {jax.process_index()} seg: {segments_this_batch} {len(segments_this_batch)} {len(batch)} {len(requests)}" - ) if len(batch) < self.EvalBatch.size: dummy_instance = self._make_dummy_instance(batch) batch.extend([dummy_instance] * (self.EvalBatch.size - len(batch))) @@ -325,7 +303,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: # stacked = hax.shard(stacked, self.axis_resources) time_batch = time.time() - out_ids, out_lls, out_correct = self._jit_loglikelihood(self.model, stacked) + out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(stacked) # result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) # -1's are going to be where we had too few sequences to fill a batch out_ids = np.array(out_ids.array) @@ -355,11 +333,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: missing_points = np.where(~covered_points)[0] assert len(missing_points) == 0, f"Missing points: {missing_points}" - result = list(zip(result_probs[:initial_length], result_greedy[:initial_length])) - - # unsort the results - result = [result[i] for i in inverse] - + result = list(zip(result_probs, result_greedy)) logger.info(f"Finished running {len(requests)} loglikelihoods.") return result @@ -371,13 +345,6 @@ def _make_dummy_instance(self, batch): dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) return dummy_instance - def _pad_dataset_to_batch_size(self, requests): - dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) - requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) - assert len(requests) % self.EvalBatch.size == 0 - dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) - return dataset - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -385,12 +352,6 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): - tokens = hax.named(tokens, Pos) - return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) - - @dataclass(frozen=True) class TaskConfig: """ @@ -555,15 +516,23 @@ def _actually_run_eval_harness( f"Evaluating with max eval length {EvalPos.size} and batch size {EvalBatch.size}. There are" f" {num_parameters} parameters in the model." ) - harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) logger.info("Running eval harness...") - outputs = evaluator.evaluate( - harness, - tasks_to_run, - limit=max_examples, - log_samples=config.log_samples, - bootstrap_iters=config.bootstrap_iters, - ) + + worker = _LmEvalHarnessWorker(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments=64) + + if jax.process_index() == 0: + harness = worker.make_harness_lm() + outputs = evaluator.evaluate( + harness, + tasks_to_run, + limit=max_examples, + log_samples=config.log_samples, + bootstrap_iters=config.bootstrap_iters, + ) + worker.stop() + else: + worker.worker_message_loop() + logger.info("Finished running eval harness.") averages = _compute_averages(outputs) @@ -618,100 +587,6 @@ def _compute_averages(outputs): return averages -BITS_PER_NAT = 1 / np.log(2) - -# eval_harness isn't consistent enough for this to actually be workable -# def _compute_extra_metrics(samples): -# """ -# Compute a few "soft" measures of accuracy for each task, based on the outputs of the eval harness. -# -# Specifically, we compute: -# - "bpb": bits per byte of the correct completion -# - "logprob": log probability of the correct completion -# - "choice_logprob": log probability of the correct choice normalized w.r.t. the other choices -# - "choice_prob_norm": probability of the length-normalized correct choice normalized w.r.t. the other choices -# -# Args: -# samples: Dictionary with task data, where each task has a list of samples. Each sample contains: -# - "doc": The original task document (can include metadata such as the answer key) -# - "target": Index of the correct answer (0-indexed), or -# "doc.answer" for 1-indexed answers. -# - "arguments": List of [input, completion] pairs -# - "resps": List of [log probability, is_correct] pairs for completions -# -# Returns: -# A dictionary with per-task aggregated metrics. -# """ -# # TODO: move to eval harness and use more sane logic -# # uses the samples which has one of two structures (that I've seen) -# # { "": [ {"doc": {...,}, "target": <0-indexed answer>, "arguments": [[input, completion], "resps": [[score, is_correct], ...], ...}, ...] } -# # { "": [ {"doc": {..., "answer": "[1-indexed answer]"}, "target": "", "arguments": [input, completion], "resps": [[score, is_correct], ...], ...}, ...] } -# metrics = {} -# -# for task, samples in samples.items(): -# bpb_list = [] -# logprob_list = [] -# choice_logprob_list = [] -# choice_prob_norm_list = [] -# -# for sample in samples: -# # Extract the correct answer index (supporting both 0-indexed `target` and 1-indexed `doc.answer`) -# if "answer" in sample["doc"]: -# target = int(sample["doc"]["answer"]) - 1 # Convert 1-indexed to 0-indexed -# elif "label" in sample["doc"]: -# target = int(sample["doc"]["label"]) -# elif "target" in sample and isinstance(sample["target"], int): -# target = sample["target"] # 0-indexed target -# elif "target" in sample and isinstance(sample["target"], str): -# # see if it's A-Z: -# if len(sample["target"]) == 1 and "A" <= sample["target"] <= "Z": -# target = ord(sample["target"]) - ord("A") -# else: -# raise ValueError(f"Invalid target: {sample['target']}. {sample}") -# elif "target" in sample and isinstance(sample["target"], list): -# target = sample["target"][0] -# else: -# raise KeyError(f"Missing `target` or `doc.answer` in sample. doc id: {sample['doc_id']}. Hash: {sample['doc_hash']}\n\n{sample}") -# -# resps = sample["filtered_resps"] # List of [log probability, is_correct] -# arguments = sample["arguments"] # [input, completion] pairs -# -# # Compute byte lengths for each choice -# byte_lengths = [max(1, len(completion.encode("utf-8"))) for _, completion in arguments] -# -# # Compute log probabilities for each choice -# log_probs = np.array([resp[0] for resp in resps]) # Extract log probabilities -# assert log_probs.shape == (len(arguments),), f"Log probs shape: {log_probs.shape}, arguments: {len(arguments)}. doc: {sample}" -# normalized_log_probs = log_probs - np.logaddexp.reduce(log_probs) -# -# # Metrics for the correct answer -# correct_logprob = log_probs[target] -# correct_bpb = -correct_logprob / byte_lengths[target] * NAT_TO_BIT -# correct_choice_logprob = normalized_log_probs[target] -# -# # Compute length-normalized weights (w_i) -# bpb_values = -log_probs / np.array(byte_lengths) * NAT_TO_BIT -# bpb_weights = np.exp(-bpb_values) -# bpb_weights /= max(bpb_weights.sum(), 1e-8) # Avoid division by zero -# correct_choice_prob_norm = bpb_weights[target] -# -# # Append metrics -# bpb_list.append(correct_bpb) -# logprob_list.append(correct_logprob) -# choice_logprob_list.append(correct_choice_logprob) -# choice_prob_norm_list.append(correct_choice_prob_norm) -# -# # Aggregate metrics for the task -# metrics[task] = { -# "bpb": np.mean(bpb_list) if bpb_list else 0.0, -# "logprob": np.mean(logprob_list) if logprob_list else 0.0, -# "choice_logprob": np.mean(choice_logprob_list) if choice_logprob_list else 0.0, -# "choice_prob_norm": np.mean(choice_prob_norm_list) if choice_prob_norm_list else 0.0, -# } -# -# return metrics - - def run_eval_harness_main(config: EvalHarnessMainConfig): config.trainer.initialize() tokenizer = config.the_tokenizer From a4d1df9c31790ff619dcacb4b580440ba35c9867 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 29 Dec 2024 22:02:42 -0800 Subject: [PATCH 09/18] ok maybe this works? --- src/levanter/data/packing.py | 10 +- src/levanter/eval_harness.py | 156 +++++++++++++++++++------------- src/levanter/models/lm_model.py | 4 +- src/levanter/utils/jax_utils.py | 60 ++++++++++++ 4 files changed, 164 insertions(+), 66 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 71bc8b35e..612afec15 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -92,6 +92,7 @@ def pack_prompt_completions( """ Packs a list of prompt completions into LmExamples using the SequencePacker """ + in_ids = set() packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] @@ -99,12 +100,14 @@ def pack_prompt_completions( loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 assert np.any(loss_mask) + in_ids.add(sequence.segment_id) for packer in packers: if packer.can_pack(sequence.ids): packer.add_example(sequence.ids, loss_mask, sequence.segment_id) if packer.num_segments == max_segments_per_example: + in_ids -= set(packers[0]._segment_ids) yield packer.pack() packers.remove(packer) break @@ -115,12 +118,15 @@ def pack_prompt_completions( packers.append(packer) while len(packers) >= max_buffered_examples: - yield packer.pack() - packers.pop(0) + in_ids -= set(packers[0]._segment_ids) + yield packers.pop(0).pack() for packer in packers: + in_ids -= set(packer._segment_ids) yield packer.pack() + assert not in_ids, "Some segments were not packed" + def per_segment_loss( packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 135dde6e0..c4baa518d 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -31,6 +31,7 @@ import jax.numpy as jnp import jmp import numpy as np +from jax.sharding import PartitionSpec from optax.tree_utils import tree_zeros_like import haliax @@ -66,7 +67,7 @@ from levanter.data import batched from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import StepInfo, TrainerConfig -from levanter.utils.jax_utils import jnp_broadcast_one_to_all, use_cpu_device +from levanter.utils.jax_utils import broadcast_shard, use_cpu_device from levanter.utils.tree_utils import inference_mode @@ -121,10 +122,13 @@ def _pack_requests( def _make_dummy_batch(EvalBatch, EvalPos): - dummy_batch = hax.vmap(LmExample.causal, EvalBatch)(hax.zeros(EvalPos, dtype=jnp.int32), - loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), - segment_ids=hax.zeros(EvalPos, dtype=jnp.int32)) - return dummy_batch + dummy_batch = hax.vmap(LmExample.causal, EvalBatch)( + hax.zeros(EvalPos, dtype=jnp.int32), + loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), + segment_ids=hax.zeros(EvalPos, dtype=jnp.int32), + ) + out = hax.shard(dummy_batch, {}) + return out class _LmEvalHarnessWorker: @@ -186,11 +190,6 @@ def _eval_loglikelihood( losses = hax.flatten(batched_per_segment_losses, "segment") correct = hax.flatten(batched_per_segment_correct, "segment") - jax.debug.inspect_array_sharding( - batched_segment_ids, callback=lambda x: print(f"batched Segment ids: {x}") - ) - jax.debug.inspect_array_sharding(segments, callback=lambda x: print(f"Segment ids: {x}")) - return segments, -losses, correct # def _do_message(model, message, maybe_batch): @@ -205,8 +204,6 @@ def _eval_loglikelihood( # model, maybe_batch # ) - - # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} @@ -220,35 +217,51 @@ def make_harness_lm(self): def worker_message_loop(self): while True: - message, payload = self._receive_message() + message = self._receive_message() if message == _Message.STOP: return elif message == _Message.LOGLIKELIHOOD: + payload = self._receive_payload() self.process_loglikelihood(payload) else: raise ValueError(f"Unknown message type: {message}") def _receive_message(self): stop_message = jnp.array(_Message.STOP) - mesh = hax.partitioning._get_mesh() - stop_message = jax.device_put(stop_message, jax.NamedSharding(mesh, jax.sharding.PartitionSpec())) - message, payload = jnp_broadcast_one_to_all((stop_message, self._dummy_batch)) - return message.item(), payload + message = broadcast_shard(stop_message, PartitionSpec()) + return message.item() + + def _receive_payload(self): + payload = broadcast_shard( + self._dummy_batch, + hax.partitioning.infer_resource_partitions(self._dummy_batch, preserve_existing_shardings=False), + ) + return payload + + def _send_message(self, message): + assert jax.process_index() == 0 + out = broadcast_shard(jnp.array(message), PartitionSpec()) + return out - def _send_message(self, message, payload): + def _send_payload(self, payload): assert jax.process_index() == 0 - return jnp_broadcast_one_to_all((message, payload)) + out = broadcast_shard( + payload, hax.partitioning.infer_resource_partitions(payload, preserve_existing_shardings=False) + ) + return out def process_loglikelihood(self, packed_request): - return self._jit_loglikelihood(self.model, packed_request) + out = self._jit_loglikelihood(self.model, packed_request) + return out def dispatch_loglikelihood(self, packed_request): - self._send_message(_Message.LOGLIKELIHOOD, packed_request) + self._send_message(_Message.LOGLIKELIHOOD) + self._send_payload(packed_request) return self.process_loglikelihood(packed_request) def stop(self): - self._send_message(_Message.STOP, self._dummy_batch) + self._send_message(_Message.STOP) class _Message: @@ -256,6 +269,17 @@ class _Message: LOGLIKELIHOOD = 1 +def _get_segments_this_batch(batch): + segments_this_batch = set() + for i in range(len(batch)): + segments_this_batch.update(np.unique(batch[i].attn_mask.segment_ids.array).tolist()) + try: + segments_this_batch.remove(-1) + except KeyError: + pass + return segments_this_batch + + class LevanterHarnessLM(LM): def __init__(self, leader: _LmEvalHarnessWorker): super().__init__() @@ -283,42 +307,32 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: covered_points = np.zeros(len(requests), dtype=bool) time_in = time.time() - for q, batch in enumerate( - tqdm(batched(packed_iterator, self.EvalBatch.size), desc="Loglikelihood", unit="ba") - ): - segments_this_batch = set() - for i in range(len(batch)): - segments_this_batch.update(np.unique(batch[i].attn_mask.segment_ids.array).tolist()) - - try: - segments_this_batch.remove(-1) - except KeyError: - pass + pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req") + for q, batch in enumerate(batched(packed_iterator, self.EvalBatch.size)): + segments_this_batch = _get_segments_this_batch(batch) if len(batch) < self.EvalBatch.size: dummy_instance = self._make_dummy_instance(batch) batch.extend([dummy_instance] * (self.EvalBatch.size - len(batch))) - stacked = stack_tree(self.EvalBatch, batch) - # stacked = hax.shard(stacked, self.axis_resources) + with use_cpu_device(): + stacked = stack_tree(self.EvalBatch, batch) time_batch = time.time() out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(stacked) - # result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) - # -1's are going to be where we had too few sequences to fill a batch + out_ids = np.array(out_ids.array) out_lls = np.array(out_lls.array) out_correct = np.array(out_correct.array) + # -1's are going to be where we had too few sequences to fill a batch valid_indices = out_ids != -1 out_ids_this_batch = out_ids[valid_indices].tolist() missing_ids = set(segments_this_batch) - set(out_ids_this_batch) - - # assert len(out_ids_this_batch) == len( - # segments_this_batch - # ), f"Batch {q} had {len(segments_this_batch)} segments, but {len(out_ids_this_batch)} loglikelihoods" + extra_ids = set(out_ids_this_batch) - set(segments_this_batch) assert len(missing_ids) == 0, f"Missing segments: {missing_ids}" + assert len(extra_ids) == 0, f"Extra segments: {extra_ids}" result_probs[out_ids[valid_indices]] = out_lls[valid_indices] result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] @@ -326,8 +340,11 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: time_ll = time.time() + pbar.update(len(segments_this_batch)) + if jax.process_index() == 0: print(f"Batch time: {time_batch - time_in}, LL time: {time_ll - time_batch}") + time_in = time.time() missing_points = np.where(~covered_points)[0] @@ -481,7 +498,15 @@ def run_lm_eval_harness( EvalBatch, axis_resources, mp: jmp.Policy | None, -) -> dict: +) -> dict | None: + """ + Run the LM Eval Harness on the given model and tasks. + + Returns: + If running on process 0, returns the outputs of the LM Eval Harness with the following extra keys. + - "averages": A dictionary with macro and micro averages for all metrics. + Otherwise, returns None. + """ tasks_to_run = config.to_task_dict() outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources, mp) @@ -497,7 +522,7 @@ def _actually_run_eval_harness( EvalBatch: haliax.Axis, axis_resources: ResourceMapping, mp: jmp.Policy | None, -): +) -> dict | None: """ Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used by the main function and the callback function. @@ -521,6 +546,7 @@ def _actually_run_eval_harness( worker = _LmEvalHarnessWorker(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments=64) if jax.process_index() == 0: + print("Running eval harness on process 0", flush=True) harness = worker.make_harness_lm() outputs = evaluator.evaluate( harness, @@ -530,15 +556,18 @@ def _actually_run_eval_harness( bootstrap_iters=config.bootstrap_iters, ) worker.stop() + + averages = _compute_averages(outputs) + outputs["averages"] = averages + + return outputs else: + print("Running worker message loop", flush=True) worker.worker_message_loop() - logger.info("Finished running eval harness.") - - averages = _compute_averages(outputs) - outputs["averages"] = averages + logger.info("Finished running eval harness.") - return outputs + return None def _compute_averages(outputs): @@ -633,20 +662,20 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): logger.info("Finished running LM eval harness") # log the results - logger.info("Logging results to tracker") - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - logger.info("Finished logging results to tracker") - - # log the results as json - logger.info("uploading artifacts...") - with open("lm_eval_harness_results.json", "w") as f: - json.dump(outputs, f, indent=2) - f.flush() - f_path = f.name - levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") - - # also write to stdout if jax.process_index() == 0: + logger.info("Logging results to tracker") + assert outputs is not None + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Finished logging results to tracker") + + # log the results as json + logger.info("uploading artifacts...") + with open("lm_eval_harness_results.json", "w") as f: + json.dump(outputs, f, indent=2) + f.flush() + f_path = f.name + levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") + print(json.dumps(outputs, indent=2), flush=True) return outputs @@ -681,8 +710,8 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso tasks_to_run = config.to_task_dict() def lm_eval_harness(step: StepInfo, force=False): - # if step.step == 0 and not force: - # return # don't run eval on the first step + if step.step == 0 and not force: + return model = inference_mode(step.model, True) logger.info("Running eval harness...") @@ -698,6 +727,7 @@ def lm_eval_harness(step: StepInfo, force=False): logger.info("Finished running eval harness.") if jax.process_index() == 0: + assert outputs is not None log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) logger.info("Logged report to tracker") diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 495a0dfea..3a9797190 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -43,7 +43,7 @@ def causal( causal_loss_mask = LmExample.causal_loss_mask(Pos) if loss_mask is not None: - loss_mask = loss_mask & causal_loss_mask + loss_mask = loss_mask & causal_loss_mask.astype(loss_mask.dtype) else: loss_mask = causal_loss_mask @@ -61,6 +61,8 @@ def causal( eos_mask = eos_mask.at[Pos, 0].set(False).astype(jnp.int32) segment_ids = hax.cumsum(eos_mask, axis=Pos) attn_mask = attn_mask.with_segment_ids(segment_ids) + elif segment_ids is not None: + attn_mask = attn_mask.with_segment_ids(segment_ids) return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 1540353bf..5f7b7a798 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -15,6 +15,7 @@ import haliax as hax from haliax import is_named_array +from haliax._src.util import index_where from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceAxis, ResourceMapping @@ -369,3 +370,62 @@ def _zeros_like(mapping, dtype, n): return n - n else: return jnp.zeros((), dtype=dtype) + + +def broadcast_shard(x: T, out_axis_specs: Any, source: int = 0) -> T: + """ + Given a tree of arrays that are on a single source host, and other data (e.g. zeros) with + the same structure, broadcast and shard the data to all hosts, using the axis mapping provided. + + For some reason, I had a ton of trouble figuring this out. + + Our strategy is, for each leaf: + 1. create a host_local_array_to_global_array with the data if we're the source, or zeros if we're not. + This gives us an array [num_devices, ...] + 2. Then, inside jit, we select the source'th element of the array, then reshard with the out_axis_specs + + """ + if jax.process_count() == 1: + return x + + current_mesh: jax.sharding.Mesh = hax.partitioning._get_mesh() + + axis_names = current_mesh.axis_names + + valid_device_for_process = index_where(lambda d: d.host_id == source, current_mesh.devices.flatten()) + sharding = NamedSharding( + current_mesh, + PartitionSpec( + axis_names, + ), + ) + + def pre_jit(x): + if jax.process_index() == source: + inp = np.array(x) + else: + inp = jnp.zeros(x.shape, dtype=x.dtype) + + shape = (len(jax.devices()),) + inp.shape + inp = jnp.expand_dims(inp, axis=0) + out = jax.make_array_from_callback(shape, sharding, lambda _: inp) + + return out + + def in_jit(x, pspec): + if isinstance(x, hax.NamedArray): + arr = x.array + else: + arr = x + arr = jax.lax.with_sharding_constraint(arr[valid_device_for_process], pspec) + + if isinstance(x, hax.NamedArray): + return hax.named(arr, x.axes) + else: + return arr + + x = jax.tree.map(pre_jit, x) + # q = eqx.filter_jit(jax.tree.map).lower(in_jit, x, out_axis_specs, is_leaf=is_named_array).as_text() + out = eqx.filter_jit(jax.tree.map)(in_jit, x, out_axis_specs, is_leaf=is_named_array) + + return out From 0616db0cafec06babe6419e55a9a3994d709745b Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 29 Dec 2024 22:40:07 -0800 Subject: [PATCH 10/18] solved nondeterminism i think --- src/levanter/eval_harness.py | 42 ++++++++++++++++++++++++++-------- src/levanter/utils/py_utils.py | 32 ++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index c4baa518d..8549576c0 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -46,6 +46,7 @@ from levanter.models.loss import next_token_loss from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.hf_utils import HfTokenizer +from levanter.utils.py_utils import set_global_rng_seeds try: @@ -546,15 +547,19 @@ def _actually_run_eval_harness( worker = _LmEvalHarnessWorker(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments=64) if jax.process_index() == 0: - print("Running eval harness on process 0", flush=True) + logger.info("Process 0 is running the eval harness.") harness = worker.make_harness_lm() - outputs = evaluator.evaluate( - harness, - tasks_to_run, - limit=max_examples, - log_samples=config.log_samples, - bootstrap_iters=config.bootstrap_iters, - ) + + # eval_harness only sets seeds in simple_evaluate, which we can't use (I think?) + tasks_to_run = _adjust_config(tasks_to_run, 0) + with set_global_rng_seeds(0): + outputs = evaluator.evaluate( + harness, + tasks_to_run, + limit=max_examples, + log_samples=config.log_samples, + bootstrap_iters=config.bootstrap_iters, + ) worker.stop() averages = _compute_averages(outputs) @@ -562,7 +567,7 @@ def _actually_run_eval_harness( return outputs else: - print("Running worker message loop", flush=True) + logger.info(f"Process {jax.process_index()} is waiting for eval harness requests from process 0.") worker.worker_message_loop() logger.info("Finished running eval harness.") @@ -745,6 +750,25 @@ def lm_eval_harness(step: StepInfo, force=False): return lm_eval_harness +# lifted from lm-eval simple_evaluate +def _adjust_config(task_dict, fewshot_random_seed=0): + adjusted_task_dict = {} + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + adjusted_task_dict = { + **adjusted_task_dict, + **{task_name: _adjust_config(task_obj, fewshot_random_seed=fewshot_random_seed)}, + } + + else: + # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) + task_obj.set_fewshot_seed(seed=fewshot_random_seed) + + adjusted_task_dict[task_name] = task_obj + + return adjusted_task_dict + + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() print("Done", flush=True) diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index dab038452..f3fd4f631 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,3 +1,4 @@ +import contextlib import os import sys import time @@ -121,3 +122,34 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() + + +@contextlib.contextmanager +def set_global_rng_seeds(seed): + import numpy as np + + current_np_seed = np.random.get_state() + np.random.seed(seed) + + import random + + current_random_seed = random.getstate() + random.seed(seed) + + try: + import torch + + current_torch_seed = torch.random.get_rng_state() + torch.manual_seed(seed) + except ImportError: + torch = None + current_torch_seed = None + pass + + try: + yield + finally: + np.random.set_state(current_np_seed) + random.setstate(current_random_seed) + if current_torch_seed is not None: + torch.random.set_rng_state(current_torch_seed) From a350b86443622f0845ad3ceef6fcaef44548ba29 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 30 Dec 2024 22:39:17 -0800 Subject: [PATCH 11/18] ok this feels pretty good --- config/harness/eval_llama3.yaml | 54 ++++++++++++++++----------------- src/levanter/data/packing.py | 7 ----- src/levanter/eval_harness.py | 33 ++++++++++---------- 3 files changed, 43 insertions(+), 51 deletions(-) diff --git a/config/harness/eval_llama3.yaml b/config/harness/eval_llama3.yaml index d15b84b74..260620102 100644 --- a/config/harness/eval_llama3.yaml +++ b/config/harness/eval_llama3.yaml @@ -2,32 +2,32 @@ eval_harness: task_spec: - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios num_fewshot: 10 -# - task: agieval_lsat_ar # 3-shot tests in legal domain -# num_fewshot: 3 -# - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science -# num_fewshot: 10 -# - task: arc_challenge # a (harder) version of arc_easy -# num_fewshot: 10 -# - task: boolq # answer yes/no questions based on a passage -# num_fewshot: 10 -# - task: copa # use causal reasoning to predict the correct outcome of a given scenario -# num_fewshot: 0 -# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset -# num_fewshot: 0 -# task_alias: hellaswag_0shot -# - task: hellaswag # 4-way multiple choice commonsense reasoning dataset -# num_fewshot: 10 -# task_alias: hellaswag_10shot -# - task: lambada # predict the endings of text passages -# num_fewshot: 0 -# - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning -# num_fewshot: 0 -# - task: piqa # answer questions based on a passage -# num_fewshot: 10 -# - task: wsc273 # Winograd Schema Challenge -# num_fewshot: 0 -# - task: winogrande # Winograd challenge, extended to more domains -# num_fewshot: 0 + - task: agieval_lsat_ar # 3-shot tests in legal domain + num_fewshot: 3 + - task: arc_easy # 10-shot, four-way MCQ questions involving grade 3-9 basic science + num_fewshot: 10 + - task: arc_challenge # a (harder) version of arc_easy + num_fewshot: 10 + - task: boolq # answer yes/no questions based on a passage + num_fewshot: 10 + - task: copa # use causal reasoning to predict the correct outcome of a given scenario + num_fewshot: 0 + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 0 + task_alias: hellaswag_0shot + - task: hellaswag # 4-way multiple choice commonsense reasoning dataset + num_fewshot: 10 + task_alias: hellaswag_10shot + - task: lambada # predict the endings of text passages + num_fewshot: 0 + - task: openbookqa # 4-way multiple choice question answering task that requires multi-step reasoning + num_fewshot: 0 + - task: piqa # answer questions based on a passage + num_fewshot: 10 + - task: wsc273 # Winograd Schema Challenge + num_fewshot: 0 + - task: winogrande # Winograd challenge, extended to more domains + num_fewshot: 0 # requires generation ## - task: squadv2 # reading comprehension benchmark # num_fewshot: 10 @@ -39,7 +39,7 @@ model: checkpoint_path: meta-llama/Meta-Llama-3-8B checkpoint_is_hf: true trainer: - mp: p=f32,c=bfloat16 + mp: f32 profiler: true per_device_parallelism: -1 diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 612afec15..557bf4665 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -92,7 +92,6 @@ def pack_prompt_completions( """ Packs a list of prompt completions into LmExamples using the SequencePacker """ - in_ids = set() packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] @@ -100,14 +99,12 @@ def pack_prompt_completions( loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 assert np.any(loss_mask) - in_ids.add(sequence.segment_id) for packer in packers: if packer.can_pack(sequence.ids): packer.add_example(sequence.ids, loss_mask, sequence.segment_id) if packer.num_segments == max_segments_per_example: - in_ids -= set(packers[0]._segment_ids) yield packer.pack() packers.remove(packer) break @@ -118,15 +115,11 @@ def pack_prompt_completions( packers.append(packer) while len(packers) >= max_buffered_examples: - in_ids -= set(packers[0]._segment_ids) yield packers.pop(0).pack() for packer in packers: - in_ids -= set(packer._segment_ids) yield packer.pack() - assert not in_ids, "Some segments were not packed" - def per_segment_loss( packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 8549576c0..80f8bb7d4 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -106,8 +106,14 @@ def _pack_requests( requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int ) -> Iterator[LmExample]: packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size) + # max_capacity shouln't be too big or we spend all our time lookign for packing + # TODO: use a better packing algorithm yield from pack_prompt_completions( - Pos, packed_iterator, max_segments_per_example=max_pack_size, pad_token=tokenizer.pad_token_id + Pos, + packed_iterator, + max_segments_per_example=max_pack_size, + pad_token=tokenizer.pad_token_id, + max_buffered_examples=16, ) @@ -177,7 +183,8 @@ def _eval_loglikelihood( targets = hax.roll(packed_example.tokens, -1, axis=Pos) is_correct = targets == pred_targets - max_Segments = hax.Axis("Segments", size=self.max_packed_segments) + # we need + 1 because we use -1 as a padding value for segments + max_Segments = hax.Axis("Segments", size=self.max_packed_segments + 1) batched_segment_ids, batched_per_segment_losses = hax.vmap(per_segment_loss, self.EvalBatch)( packed_example, loss, max_Segments @@ -193,18 +200,6 @@ def _eval_loglikelihood( return segments, -losses, correct - # def _do_message(model, message, maybe_batch): - # dummy_result = eqx.filter_eval_shape(_eval_loglikelihood, self._dummy_batch) - # dummy_result = tree_zeros_like(dummy_result) - # - # message, my_batch = - # - # jax.lax.cond(message == _Message.LOGLIKELIHOOD, - # self._jit_loglikelihood, - # lambda *args: dummy_result, - # model, maybe_batch - # ) - # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} @@ -270,10 +265,14 @@ class _Message: LOGLIKELIHOOD = 1 -def _get_segments_this_batch(batch): +def _get_segments_this_batch(batch, max_segments_per_ex): segments_this_batch = set() for i in range(len(batch)): - segments_this_batch.update(np.unique(batch[i].attn_mask.segment_ids.array).tolist()) + unique_segs = np.unique(batch[i].attn_mask.segment_ids.array).tolist() + # + 1 because we use -1 as a padding value for segments and allow that + if len(unique_segs) > max_segments_per_ex + 1: + raise ValueError(f"Too many segments in example {i}: {len(unique_segs)}") + segments_this_batch.update(unique_segs) try: segments_this_batch.remove(-1) except KeyError: @@ -310,7 +309,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: time_in = time.time() pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req") for q, batch in enumerate(batched(packed_iterator, self.EvalBatch.size)): - segments_this_batch = _get_segments_this_batch(batch) + segments_this_batch = _get_segments_this_batch(batch, self.leader.max_packed_segments) if len(batch) < self.EvalBatch.size: dummy_instance = self._make_dummy_instance(batch) From ec7b0411e4377e175f1d0ec6cba2dc7aa2c124e5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 6 Jan 2025 13:28:03 -0800 Subject: [PATCH 12/18] batch together tokenization to improve throughput somewhat --- src/levanter/eval_harness.py | 60 ++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 80f8bb7d4..8920825dc 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -76,38 +76,52 @@ def _iterate_tokenized_requests( - requests: list[Instance], tokenizer: HfTokenizer, max_len: int + requests: list[Instance], tokenizer: HfTokenizer, max_len: int, batch_size: int ) -> Iterator[PromptCompletion]: """ Tokenize the requests and yield them as PromptCompletions, for packing into LmExamples. """ - for i, request in enumerate(requests): - # it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length - # CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354 - context, completion = request.args - whole_enc = tokenizer(context + completion) - context_enc = tokenizer(context) - - context_enc_len = len(context_enc["input_ids"]) - whole_ids = whole_enc["input_ids"] - if len(whole_ids) > max_len: - logger.warning(f"Request {i} is too long. Truncating.") - # truncate from the left - whole_ids = whole_ids[-max_len:] - context_enc_len = max_len - len(completion) - if context_enc_len < 0: - context_enc_len = 0 - logger.warning("Prompt length is negative after truncation. Setting to 0.") - - yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) + # Separate contexts and completions + contexts = [request.args[0] for request in requests] + completions = [request.args[1] for request in requests] + + # Combine contexts and completions for full tokenization + combined_texts = [context + completion for context, completion in zip(contexts, completions)] + + # Batch tokenization for combined and context separately + for batch_indices in batched(range(len(requests)), batch_size): + # Extract batch data + combined_batch = [combined_texts[i] for i in batch_indices] + context_batch = [contexts[i] for i in batch_indices] + + # Tokenize batched inputs + combined_encodings = tokenizer(combined_batch, truncation=False, padding=False) + context_encodings = tokenizer(context_batch, truncation=False, padding=False) + + for off in range(len(batch_indices)): + i = batch_indices[off] + context_enc = context_encodings["input_ids"][off] + whole_ids = combined_encodings["input_ids"][off] + + context_enc_len = len(context_enc) + + if len(whole_ids) > max_len: + logger.warning(f"Request {i} is too long. Truncating.") + # Truncate from the left + whole_ids = whole_ids[-max_len:] + context_enc_len = max_len - len(whole_ids) + context_enc_len + if context_enc_len < 0: + context_enc_len = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) def _pack_requests( requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int ) -> Iterator[LmExample]: - packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size) - # max_capacity shouln't be too big or we spend all our time lookign for packing - # TODO: use a better packing algorithm + packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size, batch_size=128) + # TODO: use a better packing algorithm? yield from pack_prompt_completions( Pos, packed_iterator, From 605e7438210eaeb0cd42cead953a8db2d0fb4993 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Jan 2025 09:14:49 -0800 Subject: [PATCH 13/18] precommit --- scripts/gcs_bulk_delete.py | 11 ----------- src/levanter/utils/background_iterable.py | 6 +++++- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/scripts/gcs_bulk_delete.py b/scripts/gcs_bulk_delete.py index 4a3b2d546..d8de46f63 100644 --- a/scripts/gcs_bulk_delete.py +++ b/scripts/gcs_bulk_delete.py @@ -1,13 +1,10 @@ import re import sys import time -from datetime import datetime import google.auth from google.api_core import operations_v1 from google.cloud import storage_transfer_v1 -from google.type.date_pb2 import Date -from google.type.timeofday_pb2 import TimeOfDay EMPTY_BUCKET = "levanter-empty" @@ -33,14 +30,6 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET), transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True), ), - # schedule=storage_transfer_v1.types.Schedule( - # schedule_start_date=Date( - # year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day - # ), - # start_time_of_day=TimeOfDay( - # hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes - # ), - # ), status=storage_transfer_v1.types.TransferJob.Status.ENABLED, description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}", ) diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 1a2ec53df..7e51efb10 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -34,7 +34,11 @@ def __iter__(self): class BackgroundIterator(Iterator[Ex]): - def __init__(self, producer_fn: Callable[[], Iterator[Ex]|AsyncIterator[Ex]]| Iterator[Ex] | AsyncIterator[Ex], max_capacity: Optional[int]): + def __init__( + self, + producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]] | Iterator[Ex] | AsyncIterator[Ex], + max_capacity: Optional[int], + ): self.max_capacity = max_capacity if not callable(producer_fn): self._producer_fn = lambda: producer_fn From 90be3f3a4660b6b4355c5155aa578baffbc1e075 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 7 Jan 2025 17:29:57 -0800 Subject: [PATCH 14/18] wip --- src/levanter/data/loader.py | 3 +- src/levanter/data/packing.py | 6 +- src/levanter/eval_harness.py | 207 +++++++++++++++++++------------- src/levanter/main/train_lm.py | 12 +- src/levanter/models/lm_model.py | 2 + 5 files changed, 139 insertions(+), 91 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index dc87e549d..5aa8516ed 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar +import equinox import jax from jax import Array from jax import numpy as jnp @@ -267,7 +268,7 @@ def _fill_queue_with_batches(self): super()._fill_queue_with_batches() -@functools.partial(jax.jit, static_argnums=(0,)) +@equinox.filter_jit def stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves): if is_named_array(leaves[0]): diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 557bf4665..cd959fd58 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -66,9 +66,9 @@ def pack(self) -> LmExample: loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask)) with local_cpu_mesh(): - tokens = hax.named(ids, self.Pos) - segment_ids = hax.named(segment_ids, self.Pos) - loss_mask = hax.named(loss_mask, self.Pos) + tokens = hax.named(ids, self.Pos).astype(jnp.int32) + segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32) + loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32) attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 8920825dc..f5f4ace7f 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -75,62 +75,6 @@ logger = logging.getLogger(__name__) -def _iterate_tokenized_requests( - requests: list[Instance], tokenizer: HfTokenizer, max_len: int, batch_size: int -) -> Iterator[PromptCompletion]: - """ - Tokenize the requests and yield them as PromptCompletions, for packing into LmExamples. - """ - # Separate contexts and completions - contexts = [request.args[0] for request in requests] - completions = [request.args[1] for request in requests] - - # Combine contexts and completions for full tokenization - combined_texts = [context + completion for context, completion in zip(contexts, completions)] - - # Batch tokenization for combined and context separately - for batch_indices in batched(range(len(requests)), batch_size): - # Extract batch data - combined_batch = [combined_texts[i] for i in batch_indices] - context_batch = [contexts[i] for i in batch_indices] - - # Tokenize batched inputs - combined_encodings = tokenizer(combined_batch, truncation=False, padding=False) - context_encodings = tokenizer(context_batch, truncation=False, padding=False) - - for off in range(len(batch_indices)): - i = batch_indices[off] - context_enc = context_encodings["input_ids"][off] - whole_ids = combined_encodings["input_ids"][off] - - context_enc_len = len(context_enc) - - if len(whole_ids) > max_len: - logger.warning(f"Request {i} is too long. Truncating.") - # Truncate from the left - whole_ids = whole_ids[-max_len:] - context_enc_len = max_len - len(whole_ids) + context_enc_len - if context_enc_len < 0: - context_enc_len = 0 - logger.warning("Prompt length is negative after truncation. Setting to 0.") - - yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) - - -def _pack_requests( - requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int -) -> Iterator[LmExample]: - packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size, batch_size=128) - # TODO: use a better packing algorithm? - yield from pack_prompt_completions( - Pos, - packed_iterator, - max_segments_per_example=max_pack_size, - pad_token=tokenizer.pad_token_id, - max_buffered_examples=16, - ) - - # OK, so LM-Eval-Harness is not deterministic. This means we can't just run it on different workers and expect the # order of requests to be the same. Sorting doesn't even seem to be correct (?!?!?) so we need to only run it on one # process. @@ -142,17 +86,12 @@ def _pack_requests( # 4. When a STOP request is received, we stop the loop and process 0 returns the results. -def _make_dummy_batch(EvalBatch, EvalPos): - dummy_batch = hax.vmap(LmExample.causal, EvalBatch)( - hax.zeros(EvalPos, dtype=jnp.int32), - loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), - segment_ids=hax.zeros(EvalPos, dtype=jnp.int32), - ) - out = hax.shard(dummy_batch, {}) - return out - - class _LmEvalHarnessWorker: + """ + Worker for running the LM Eval Harness. Each worker process will run a copy of this class. + The head process will run the main harness and dispatch requests to the workers while the + others run in a loop waiting for requests. + """ def __init__(self, EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments): self.tokenizer = tokenizer self.max_packed_segments = max_packed_segments @@ -280,18 +219,21 @@ class _Message: def _get_segments_this_batch(batch, max_segments_per_ex): - segments_this_batch = set() - for i in range(len(batch)): - unique_segs = np.unique(batch[i].attn_mask.segment_ids.array).tolist() - # + 1 because we use -1 as a padding value for segments and allow that - if len(unique_segs) > max_segments_per_ex + 1: - raise ValueError(f"Too many segments in example {i}: {len(unique_segs)}") - segments_this_batch.update(unique_segs) - try: - segments_this_batch.remove(-1) - except KeyError: - pass - return segments_this_batch + unique_segs = np.unique(batch.attn_mask.segment_ids.array).tolist() + # + 1 because we use -1 as a padding value for segments and allow that + if len(unique_segs) > max_segments_per_ex + 1: + raise ValueError(f"Too many segments in batch: {len(unique_segs)}") + if -1 in unique_segs: + unique_segs.remove(-1) + + return unique_segs + + +def _get_padding_count(batch, pad_token_id): + # returns the total amount of padding in the batch + padding_count = np.sum(batch.tokens.array == pad_token_id) + total_tokens = batch.tokens.size + return padding_count, total_tokens class LevanterHarnessLM(LM): @@ -314,26 +256,27 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.leader.max_packed_segments) + packed_iterator = self.stack_batches(packed_iterator, self.EvalBatch) packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) result_probs = np.zeros(len(requests)) result_greedy = np.zeros(len(requests)) covered_points = np.zeros(len(requests), dtype=bool) + total_padding = 0 + total_tokens = 0 time_in = time.time() pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req") - for q, batch in enumerate(batched(packed_iterator, self.EvalBatch.size)): - segments_this_batch = _get_segments_this_batch(batch, self.leader.max_packed_segments) + for q, batch in enumerate(packed_iterator): + time_data_available = time.time() + segments_this_batch = _get_segments_this_batch(batch, self.leader.max_packed_segments * self.EvalBatch.size) + time_segments = time.time() - if len(batch) < self.EvalBatch.size: - dummy_instance = self._make_dummy_instance(batch) - batch.extend([dummy_instance] * (self.EvalBatch.size - len(batch))) + padding_count, batch_tokens = _get_padding_count(batch, self.tokenizer.pad_token_id) - with use_cpu_device(): - stacked = stack_tree(self.EvalBatch, batch) time_batch = time.time() - out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(stacked) + out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(batch) out_ids = np.array(out_ids.array) out_lls = np.array(out_lls.array) @@ -354,10 +297,15 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: time_ll = time.time() + pbar.set_postfix( + padding=f"{total_padding + padding_count}/{total_tokens + batch_tokens} = {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}", + this_padding=f"{padding_count}/{batch_tokens}= {padding_count / batch_tokens:.2f}", + ) pbar.update(len(segments_this_batch)) if jax.process_index() == 0: print(f"Batch time: {time_batch - time_in}, LL time: {time_ll - time_batch}") + print(f"Data available: {time_data_available - time_in}, Segments: {time_segments - time_data_available}, Stack: {time_batch - time_segments}") time_in = time.time() @@ -369,6 +317,24 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: return result + def stack_batches(self, example_iterator, EvalBatch): + """ + Stack examples from an iterator into a batch. + + Args: + EvalBatch: The batch axis. + example_iterator: An iterator of examples. + + Returns: + A batch of examples. + """ + with use_cpu_device(): + for batch in batched(example_iterator, EvalBatch.size): + if len(batch) < EvalBatch.size: + dummy_instance = self._make_dummy_instance(batch) + batch.extend([dummy_instance] * (EvalBatch.size - len(batch))) + yield stack_tree(EvalBatch, batch) + def _make_dummy_instance(self, batch): dummy_instance: LmExample = tree_zeros_like(batch[0]) dummy_segment_mask = hax.full(self.EvalPos, -1, dtype=jnp.int32) @@ -782,6 +748,75 @@ def _adjust_config(task_dict, fewshot_random_seed=0): return adjusted_task_dict +def _iterate_tokenized_requests( + requests: list[Instance], tokenizer: HfTokenizer, max_len: int, batch_size: int +) -> Iterator[PromptCompletion]: + """ + Tokenize the requests and yield them as PromptCompletions, for packing into LmExamples. + """ + # Separate contexts and completions + contexts = [request.args[0] for request in requests] + completions = [request.args[1] for request in requests] + + # Combine contexts and completions for full tokenization + combined_texts = [context + completion for context, completion in zip(contexts, completions)] + + # Batch tokenization for combined and context separately + for batch_indices in batched(range(len(requests)), batch_size): + # Extract batch data + combined_batch = [combined_texts[i] for i in batch_indices] + context_batch = [contexts[i] for i in batch_indices] + + # Tokenize batched inputs + combined_encodings = tokenizer(combined_batch, truncation=False, padding=False) + context_encodings = tokenizer(context_batch, truncation=False, padding=False) + + for off in range(len(batch_indices)): + i = batch_indices[off] + context_enc = context_encodings["input_ids"][off] + whole_ids = combined_encodings["input_ids"][off] + + context_enc_len = len(context_enc) + + if len(whole_ids) > max_len: + logger.warning(f"Request {i} is too long. Truncating.") + # Truncate from the left + whole_ids = whole_ids[-max_len:] + context_enc_len = max_len - len(whole_ids) + context_enc_len + if context_enc_len < 0: + context_enc_len = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) + + +def _pack_requests( + requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int +) -> Iterator[LmExample]: + packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size, batch_size=128) + # TODO: use a better packing algorithm? + yield from pack_prompt_completions( + Pos, + packed_iterator, + max_segments_per_example=max_pack_size, + pad_token=tokenizer.pad_token_id, + max_buffered_examples=16, + ) + + +def _make_dummy_batch(EvalBatch, EvalPos): + dummy_batch = hax.vmap(LmExample.causal, EvalBatch)( + hax.zeros(EvalPos, dtype=jnp.int32), + loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), + segment_ids=hax.zeros(EvalPos, dtype=jnp.int32), + ) + out = hax.shard(dummy_batch, {}) + return out + + + + + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() print("Done", flush=True) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 15da92cfa..69275e871 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -165,7 +165,17 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + def init_model(): + return config.model.build(Vocab, key=model_key) + + if use_mup: + def mupify_model(): + model = init_model() + return mupify(model, Pos, KeyPos) + + init_model = mupify_model + + state = trainer.initial_state(training_key, model_init=lambda: init_model) seek_dataloader = True if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 3a9797190..7f92221ed 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -52,6 +52,8 @@ def causal( ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id loss_mask = loss_mask * ignore_mask + loss_mask = loss_mask.astype(jnp.int32) + attn_mask = AttentionMask.causal() if eos_id is not None and segment_ids is None: From 2b5d52c0b9ab42875aab07c8fa92ae9865881416 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 08:40:45 -0800 Subject: [PATCH 15/18] Apply suggestions from code review Co-authored-by: Nikil Ravi <55033516+nikil-ravi@users.noreply.github.com> --- src/levanter/data/packing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 557bf4665..2c690563b 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -75,7 +75,7 @@ def pack(self) -> LmExample: return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) -@dataclass +@dataclass(frozen=True) class PromptCompletion: ids: list[int] prompt_length: int From af5fb7e13d737d501fe55859f9f4d5f067d56af3 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 08:56:26 -0800 Subject: [PATCH 16/18] comment --- src/levanter/data/packing.py | 10 +++++++++- src/levanter/eval_harness.py | 27 ++++++++------------------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index eba3fd71c..aa6028690 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -1,4 +1,12 @@ -# Implements sequence packing +""" +Implements sequence packing, mostly for doing evaluation on lots of short sequences. + +Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens +(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and +yielding the packed examples when they are full. + +This achieves about a 90% "real token" rate, compared to like 10% without packing. +""" from dataclasses import dataclass from typing import Iterable, Iterator diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index f5f4ace7f..c3c33ac74 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -20,7 +20,6 @@ import json import logging import tempfile -import time import typing from dataclasses import dataclass from functools import cached_property @@ -92,6 +91,7 @@ class _LmEvalHarnessWorker: The head process will run the main harness and dispatch requests to the workers while the others run in a loop waiting for requests. """ + def __init__(self, EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments): self.tokenizer = tokenizer self.max_packed_segments = max_packed_segments @@ -265,17 +265,14 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: total_padding = 0 total_tokens = 0 - time_in = time.time() pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req") for q, batch in enumerate(packed_iterator): - time_data_available = time.time() - segments_this_batch = _get_segments_this_batch(batch, self.leader.max_packed_segments * self.EvalBatch.size) - time_segments = time.time() + segments_this_batch = _get_segments_this_batch( + batch, self.leader.max_packed_segments * self.EvalBatch.size + ) padding_count, batch_tokens = _get_padding_count(batch, self.tokenizer.pad_token_id) - time_batch = time.time() - out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(batch) out_ids = np.array(out_ids.array) @@ -295,20 +292,15 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] covered_points[out_ids[valid_indices]] = True - time_ll = time.time() - pbar.set_postfix( - padding=f"{total_padding + padding_count}/{total_tokens + batch_tokens} = {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}", + padding=( + f"{total_padding + padding_count}/{total_tokens + batch_tokens} =" + f" {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}" + ), this_padding=f"{padding_count}/{batch_tokens}= {padding_count / batch_tokens:.2f}", ) pbar.update(len(segments_this_batch)) - if jax.process_index() == 0: - print(f"Batch time: {time_batch - time_in}, LL time: {time_ll - time_batch}") - print(f"Data available: {time_data_available - time_in}, Segments: {time_segments - time_data_available}, Stack: {time_batch - time_segments}") - - time_in = time.time() - missing_points = np.where(~covered_points)[0] assert len(missing_points) == 0, f"Missing points: {missing_points}" @@ -814,9 +806,6 @@ def _make_dummy_batch(EvalBatch, EvalPos): return out - - - if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() print("Done", flush=True) From 5c0fca2da148a4d4f4e37d11b1edd51e17bc1a1e Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 10:28:27 -0800 Subject: [PATCH 17/18] dumb --- src/levanter/main/train_lm.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 69275e871..15da92cfa 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -165,17 +165,7 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) - def init_model(): - return config.model.build(Vocab, key=model_key) - - if use_mup: - def mupify_model(): - model = init_model() - return mupify(model, Pos, KeyPos) - - init_model = mupify_model - - state = trainer.initial_state(training_key, model_init=lambda: init_model) + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) seek_dataloader = True if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: From f87a8d7020414b5de5dc85d25be442f859939d94 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 10:31:52 -0800 Subject: [PATCH 18/18] sigh --- src/levanter/data/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 5aa8516ed..5db9b96b9 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -1,4 +1,3 @@ -import functools import logging import time from collections import defaultdict