From 93a8aa95a06e312439bd457ae326e661fbfc52e2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 10:32:31 -0800 Subject: [PATCH] Add sequence packing for lm-eval-harness (#850) Speedup was less than I hoped, but I think that can be improved with better packing strategies (and also removing some batching overhead. the bottleneck is now data loading??) --------- Co-authored-by: Nikil Ravi <55033516+nikil-ravi@users.noreply.github.com> --- config/harness/harness_nano.yaml | 5 +- scripts/gcs_bulk_delete.py | 12 +- src/levanter/data/loader.py | 8 +- src/levanter/data/packing.py | 213 ++++++++ src/levanter/eval_harness.py | 611 +++++++++++++--------- src/levanter/models/lm_model.py | 46 +- src/levanter/utils/background_iterable.py | 11 +- src/levanter/utils/jax_utils.py | 62 ++- src/levanter/utils/py_utils.py | 32 ++ tests/test_packing.py | 223 ++++++++ 10 files changed, 940 insertions(+), 283 deletions(-) create mode 100644 src/levanter/data/packing.py create mode 100644 tests/test_packing.py diff --git a/config/harness/harness_nano.yaml b/config/harness/harness_nano.yaml index 833291e5c..109cfd152 100644 --- a/config/harness/harness_nano.yaml +++ b/config/harness/harness_nano.yaml @@ -1,5 +1,8 @@ eval_harness: - task_spec: ["hellaswag"] +# task_spec: ["hellaswag"] + task_spec: + - task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios + num_fewshot: 1 tokenizer: "gpt2" model: type: gpt2 diff --git a/scripts/gcs_bulk_delete.py b/scripts/gcs_bulk_delete.py index 564e3cd60..d8de46f63 100644 --- a/scripts/gcs_bulk_delete.py +++ b/scripts/gcs_bulk_delete.py @@ -1,13 +1,10 @@ import re import sys import time -from datetime import datetime import google.auth from google.api_core import operations_v1 from google.cloud import storage_transfer_v1 -from google.type.date_pb2 import Date -from google.type.timeofday_pb2 import TimeOfDay EMPTY_BUCKET = "levanter-empty" @@ -33,14 +30,6 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET), transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True), ), - schedule=storage_transfer_v1.types.Schedule( - schedule_start_date=Date( - year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day - ), - start_time_of_day=TimeOfDay( - hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes - ), - ), status=storage_transfer_v1.types.TransferJob.Status.ENABLED, description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}", ) @@ -48,6 +37,7 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): # Create the transfer job response = client.create_transfer_job(request={"transfer_job": transfer_job}) print(f"Created transfer job: {response.name}") + client.run_transfer_job({"job_name": response.name, "project_id": project_id}) # Wait for job completion wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 928c9456c..5db9b96b9 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -1,9 +1,9 @@ -import functools import logging import time from collections import defaultdict from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar +import equinox import jax from jax import Array from jax import numpy as jnp @@ -180,7 +180,7 @@ def get_local_batch(begin: int, end: int) -> list: # TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example # which will require support from the datastore (i.e. tensorstore) - device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) + device_batch = stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)]) batch_leaves = hax.tree_util.tree_leaves(device_batch) cache[(begin, end)] = batch_leaves @@ -267,8 +267,8 @@ def _fill_queue_with_batches(self): super()._fill_queue_with_batches() -@functools.partial(jax.jit, static_argnums=(0,)) -def _stack_tree(batch_name, individual_datums): +@equinox.filter_jit +def stack_tree(batch_name, individual_datums): def _stack_leaves_unchecked(*leaves): if is_named_array(leaves[0]): return hax.stack(batch_name, leaves) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py new file mode 100644 index 000000000..aa6028690 --- /dev/null +++ b/src/levanter/data/packing.py @@ -0,0 +1,213 @@ +""" +Implements sequence packing, mostly for doing evaluation on lots of short sequences. + +Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens +(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and +yielding the packed examples when they are full. + +This achieves about a 90% "real token" rate, compared to like 10% without packing. +""" +from dataclasses import dataclass +from typing import Iterable, Iterator + +import jax.numpy as jnp +import numpy as np + +import haliax as hax + +from levanter.models.attention import AttentionMask +from levanter.models.lm_model import LmExample +from levanter.utils.jax_utils import local_cpu_mesh + + +# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 + +# todo should we use something like this: https://arxiv.org/pdf/2107.02027? + + +class SequencePacker: + """ + Packs sequences into a single LmExample. + """ + + def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int): + self.Pos = Pos + self._ids: list[int] = [] + self._segment_ids: list[int] = [] + self._loss_mask: list[int] = [] + self.num_segments = 0 + self.pad_token = pad_token + self.max_pack_size = max_pack_size + assert pad_token is not None, "pad_token must be set" + + def can_pack(self, ids: list[int]) -> bool: + return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size + + def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None): + if len(ids) != len(loss_mask): + raise ValueError("ids and loss_mask must have the same length") + + if len(ids) == 0: + return + + if len(ids) + len(self._ids) > self.Pos.size: + raise ValueError("Too many tokens") + + if self.num_segments >= self.max_pack_size: + raise ValueError("Too many segments") + + self._ids.extend(ids) + if segment_id is None: + segment_id = self.num_segments + + self.num_segments += 1 + + self._segment_ids.extend([segment_id] * len(ids)) + + self._loss_mask.extend(loss_mask) + + def pack(self) -> LmExample: + ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids)) + + segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids)) + + loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask)) + + with local_cpu_mesh(): + tokens = hax.named(ids, self.Pos).astype(jnp.int32) + segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32) + loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32) + + attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) + + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + + +@dataclass(frozen=True) +class PromptCompletion: + ids: list[int] + prompt_length: int + segment_id: int | None = None + + +def pack_prompt_completions( + Pos: hax.Axis, + sequences: Iterable[PromptCompletion], + pad_token: int, + max_segments_per_example: int = 64, + max_buffered_examples: int = 64, +) -> Iterator[LmExample]: + """ + Packs a list of prompt completions into LmExamples using the SequencePacker + """ + + packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] + + for sequence in sequences: + loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 + loss_mask[-1] = 0 + assert np.any(loss_mask) + + for packer in packers: + if packer.can_pack(sequence.ids): + packer.add_example(sequence.ids, loss_mask, sequence.segment_id) + + if packer.num_segments == max_segments_per_example: + yield packer.pack() + packers.remove(packer) + break + else: + # no packer could fit the example, create a new one + packer = SequencePacker(Pos, max_segments_per_example, pad_token) + packer.add_example(sequence.ids, loss_mask, sequence.segment_id) + packers.append(packer) + + while len(packers) >= max_buffered_examples: + yield packers.pop(0).pack() + + for packer in packers: + yield packer.pack() + + +def per_segment_loss( + packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis +) -> tuple[hax.NamedArray, hax.NamedArray]: + """ + Returns a pair of arrays of shape (Segments,), where: + + * the first array is segment ids + * the second is loss per segment. + + This code is designed to run in a jit-compiled function, meaning we have to careful of shapes + """ + + assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" + + segment_ids = packed_example.attn_mask.segment_ids + assert ( + segment_ids.ndim == 1 + ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + Pos = packed_example.tokens.axes[0] + + # mask out padding etc + masked_losses = losses * packed_example.loss_mask + + # sum the losses for each segment + unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) + + # Create a mask matrix where each row corresponds to a unique segment + segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) + + segment_mask = segment_mask.astype(masked_losses.dtype) + + segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos) + + return unique_segment_ids, segment_losses + + +def _unique_segment_ids(max_Segments, segment_ids): + # Extract unique segment IDs with padding + # TODO: add unique to haliax + unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1) + unique_segment_ids = hax.named(unique_segment_ids, max_Segments) + return unique_segment_ids + + +def per_segment_correct( + packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis +) -> tuple[hax.NamedArray, hax.NamedArray]: + """ + Returns a pair of arrays of shape (max_segments,), where: + + * the first array is segment ids + * the second is whether all tokens in the segment are correct. + + This code is designed to run in a jit-compiled function, meaning we have to careful of shapes + + correct is a boolean array of the same shape as the losses array indicating whether the token was correct + """ + + assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" + + segment_ids = packed_example.attn_mask.segment_ids + assert ( + segment_ids.ndim == 1 + ), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" + + Pos = packed_example.tokens.axes[0] + + # mask out padding etc + masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask)) + + # sum the losses for each segment + # Extract unique segment IDs with padding + unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) + + # Create a mask matrix where each row corresponds to a unique segment + segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) + + segment_mask = segment_mask.astype(masked_correct.dtype) + + segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos) + + return unique_segment_ids, segment_correct diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index 4821f4d82..c3c33ac74 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -17,29 +17,35 @@ """ import dataclasses -import functools import json import logging import tempfile import typing from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Sequence, Tuple +from typing import Iterator, List, Optional, Tuple import equinox as eqx import jax import jax.numpy as jnp import jmp import numpy as np +from jax.sharding import PartitionSpec +from optax.tree_utils import tree_zeros_like import haliax from haliax import NamedArray import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer +from levanter.data.loader import stack_tree +from levanter.data.packing import PromptCompletion, pack_prompt_completions, per_segment_correct, per_segment_loss +from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss +from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.hf_utils import HfTokenizer +from levanter.utils.py_utils import set_global_rng_seeds try: @@ -58,116 +64,60 @@ import levanter.config from levanter.checkpoint import load_checkpoint -from levanter.data import AsyncDataset, DataLoader +from levanter.data import batched from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import StepInfo, TrainerConfig -from levanter.utils.jax_utils import use_cpu_device +from levanter.utils.jax_utils import broadcast_shard, use_cpu_device from levanter.utils.tree_utils import inference_mode logger = logging.getLogger(__name__) -class EvalDataset(AsyncDataset[LmExample]): - def __init__(self, Pos, tokenizer, examples: list[Instance]): - super().__init__() - self.examples = examples - self.Pos = Pos - self.tokenizer = tokenizer - - async def async_len(self) -> int: - return len(self.examples) - - async def final_length_is_known(self) -> bool: - return True - - def is_finite(self) -> bool: - return True - - async def current_len(self) -> Optional[int]: - return len(self.examples) - - async def get_batch(self, indices: Sequence[int]) -> List[LmExample]: - out = [] - pad_token_id = self.tokenizer.pad_token_id - - # lm-harness specs that args are (context, completion) - reqs = [(self.examples[i].args[0], self.examples[i].args[1]) for i in indices] +# OK, so LM-Eval-Harness is not deterministic. This means we can't just run it on different workers and expect the +# order of requests to be the same. Sorting doesn't even seem to be correct (?!?!?) so we need to only run it on one +# process. +# This is our design: +# 1. Process 0 creates an LevanterHarnessLM object. +# 2. On all processes, we start a loop that waits for a request using jnp_broadcast_one_to_all +# 3. When a request is received (and it's not STOP) we process the request. The results are broadcast to all +# devices, and process 0 records htem. +# 4. When a STOP request is received, we stop the loop and process 0 returns the results. - for context, completion in reqs: - # it's kinda annoying we run tokenization twice, but it's the easiest way to get the prompt length - # CF: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/api/model.py#L354 - whole_enc = self.tokenizer(context + completion) - context_enc = self.tokenizer(context) - - context_enc_len = len(context_enc["input_ids"]) - - tokens, length = self._truncate_or_pad(whole_enc, context_enc_len) - example = _jit_create_example(self.Pos, tokens, length, pad_token_id) - - out.append(example) - - return out - - def _truncate_or_pad(self, encoded: list[int], prompt_length: int): - """ - Truncate or pad the encoded sequence to the maximum length of the model. - Truncates from the beginning of the sequence, so that the completion is preserved. - - Returns: - Truncated or padded sequence and the prompt length. The prompt length can be shorter than the original - length if the input was truncated. - """ - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - - ex_pad = self.tokenizer.pad( - encoded, - padding="max_length", - max_length=self.Pos.size, - return_tensors="np", - ) - - truncated = ex_pad["input_ids"][-self.Pos.size :] - # if we truncated the prompt, we need to adjust the prompt length - if len(truncated) < len(encoded): - prompt_length -= len(encoded) - len(truncated) - if prompt_length < 0: - prompt_length = 0 - logger.warning("Prompt length is negative after truncation. Setting to 0.") - - return truncated, prompt_length +class _LmEvalHarnessWorker: + """ + Worker for running the LM Eval Harness. Each worker process will run a copy of this class. + The head process will run the main harness and dispatch requests to the workers while the + others run in a loop waiting for requests. + """ -class LevanterHarnessLM(LM): - def __init__( - self, - EvalBatch: hax.Axis, - EvalPos: hax.Axis, - model: LmHeadModel, - axis_resources, - tokenizer, - mp: jmp.Policy | None, - ): - super().__init__() + def __init__(self, EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments): + self.tokenizer = tokenizer + self.max_packed_segments = max_packed_segments self.EvalBatch = EvalBatch self.EvalPos = EvalPos self.model = model self.axis_resources = axis_resources - self.tokenizer = tokenizer self.mp = mp + self.max_packed_segments = max_packed_segments + + self._dummy_batch = _make_dummy_batch(EvalBatch, EvalPos) - def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedArray, NamedArray]: + def _eval_loglikelihood( + model: LmHeadModel, packed_example: LmExample + ) -> tuple[NamedArray, NamedArray, NamedArray]: """ Returns: - - loss: The negative log-likelihood of the completion. - - correct: Whether the completion is correct + - segments: The segment IDs of the completions. (shape: (Segments,)) + - loss: The log-likelihood of the completion. (shape: (Segments,)) + - correct: Whether the completion is correct or not. (shape: (Segments,)) """ if self.mp is not None: model = self.mp.cast_to_compute(model) - logits = model(example.tokens, attn_mask=example.attn_mask) + logits = model(packed_example.tokens, attn_mask=packed_example.attn_mask) logits = logits.astype(jnp.float32) Pos = logits.resolve_axis(self.EvalPos.name) @@ -175,61 +125,214 @@ def _eval_loglikelihood(model: LmHeadModel, example: LmExample) -> tuple[NamedAr Pos=Pos, Vocab=model.Vocab, logits=logits, - true_ids=example.tokens, - loss_mask=example.loss_mask, - reduction=hax.sum, - reduction_axis=Pos, + true_ids=packed_example.tokens, + loss_mask=packed_example.loss_mask, + reduction=None, ) - not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=bool) + # We need to compute losses and also whether or not the completion is correct + # (i.e. the greedy prediction is the target) pred_targets = hax.argmax(logits, axis=model.Vocab) - targets = hax.roll(example.tokens, -1, axis=Pos) - # "freebie" is the positions we don't need to predict (prompt or final token's next token) - freebie = hax.logical_not(example.loss_mask * not_last_loss_mask) - correct = hax.all(hax.equal(pred_targets, targets) + freebie, axis=Pos) + targets = hax.roll(packed_example.tokens, -1, axis=Pos) + is_correct = targets == pred_targets + + # we need + 1 because we use -1 as a padding value for segments + max_Segments = hax.Axis("Segments", size=self.max_packed_segments + 1) + + batched_segment_ids, batched_per_segment_losses = hax.vmap(per_segment_loss, self.EvalBatch)( + packed_example, loss, max_Segments + ) + + _, batched_per_segment_correct = hax.vmap(per_segment_correct, self.EvalBatch)( + packed_example, is_correct, max_Segments + ) - return -loss, correct + segments = hax.flatten(batched_segment_ids, "segment") + losses = hax.flatten(batched_per_segment_losses, "segment") + correct = hax.flatten(batched_per_segment_correct, "segment") + + return segments, -losses, correct # no sharded outputs self._jit_loglikelihood = hax.named_jit( _eval_loglikelihood, axis_resources=axis_resources, out_axis_resources={} ) + def make_harness_lm(self): + if jax.process_index() == 0: + return LevanterHarnessLM(self) + else: + raise ValueError("Only process 0 can create the harness") + + def worker_message_loop(self): + while True: + message = self._receive_message() + + if message == _Message.STOP: + return + elif message == _Message.LOGLIKELIHOOD: + payload = self._receive_payload() + self.process_loglikelihood(payload) + else: + raise ValueError(f"Unknown message type: {message}") + + def _receive_message(self): + stop_message = jnp.array(_Message.STOP) + message = broadcast_shard(stop_message, PartitionSpec()) + return message.item() + + def _receive_payload(self): + payload = broadcast_shard( + self._dummy_batch, + hax.partitioning.infer_resource_partitions(self._dummy_batch, preserve_existing_shardings=False), + ) + return payload + + def _send_message(self, message): + assert jax.process_index() == 0 + out = broadcast_shard(jnp.array(message), PartitionSpec()) + return out + + def _send_payload(self, payload): + assert jax.process_index() == 0 + out = broadcast_shard( + payload, hax.partitioning.infer_resource_partitions(payload, preserve_existing_shardings=False) + ) + return out + + def process_loglikelihood(self, packed_request): + out = self._jit_loglikelihood(self.model, packed_request) + return out + + def dispatch_loglikelihood(self, packed_request): + self._send_message(_Message.LOGLIKELIHOOD) + self._send_payload(packed_request) + return self.process_loglikelihood(packed_request) + + def stop(self): + self._send_message(_Message.STOP) + + +class _Message: + STOP = 0 + LOGLIKELIHOOD = 1 + + +def _get_segments_this_batch(batch, max_segments_per_ex): + unique_segs = np.unique(batch.attn_mask.segment_ids.array).tolist() + # + 1 because we use -1 as a padding value for segments and allow that + if len(unique_segs) > max_segments_per_ex + 1: + raise ValueError(f"Too many segments in batch: {len(unique_segs)}") + if -1 in unique_segs: + unique_segs.remove(-1) + + return unique_segs + + +def _get_padding_count(batch, pad_token_id): + # returns the total amount of padding in the batch + padding_count = np.sum(batch.tokens.array == pad_token_id) + total_tokens = batch.tokens.size + return padding_count, total_tokens + + +class LevanterHarnessLM(LM): + def __init__(self, leader: _LmEvalHarnessWorker): + super().__init__() + self.leader = leader + + tokenizer = property(lambda self: self.leader.tokenizer) + EvalBatch = property(lambda self: self.leader.EvalBatch) + EvalPos = property(lambda self: self.leader.EvalPos) + def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: """ Compute log-likelihood of generating a continuation from a context. Downstream tasks should attempt to use loglikelihood instead of other LM calls whenever possible. """ - # pad requests to be a multiple of the batch size - initial_length = len(requests) - dataset = self._pad_dataset_to_batch_size(requests) + if self.tokenizer.pad_token_id is None: + logger.warning("No pad token set. Setting to eos token.") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - mesh = haliax.partitioning._get_mesh() + packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.leader.max_packed_segments) + packed_iterator = self.stack_batches(packed_iterator, self.EvalBatch) + packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) - loader = DataLoader( - self.EvalBatch, dataset, max_buffered_batches=1024, mesh=mesh, axis_resources=self.axis_resources - ) + result_probs = np.zeros(len(requests)) + result_greedy = np.zeros(len(requests)) + covered_points = np.zeros(len(requests), dtype=bool) - result: list[tuple[float, bool]] = [] - for batch in tqdm(loader, desc="Loglikelihood", unit="ba"): - out_lls, out_correct = self._jit_loglikelihood(self.model, batch) - result.extend((ll.item(), correct.item()) for ll, correct in zip(out_lls.array, out_correct.array)) + total_padding = 0 + total_tokens = 0 + pbar = tqdm(total=len(requests), desc="Loglikelihood", unit="req") + for q, batch in enumerate(packed_iterator): + segments_this_batch = _get_segments_this_batch( + batch, self.leader.max_packed_segments * self.EvalBatch.size + ) + + padding_count, batch_tokens = _get_padding_count(batch, self.tokenizer.pad_token_id) + + out_ids, out_lls, out_correct = self.leader.dispatch_loglikelihood(batch) + + out_ids = np.array(out_ids.array) + out_lls = np.array(out_lls.array) + out_correct = np.array(out_correct.array) + # -1's are going to be where we had too few sequences to fill a batch + valid_indices = out_ids != -1 - assert len(result) >= initial_length - # skip padding - result = result[:initial_length] + out_ids_this_batch = out_ids[valid_indices].tolist() + missing_ids = set(segments_this_batch) - set(out_ids_this_batch) + extra_ids = set(out_ids_this_batch) - set(segments_this_batch) + assert len(missing_ids) == 0, f"Missing segments: {missing_ids}" + assert len(extra_ids) == 0, f"Extra segments: {extra_ids}" + + result_probs[out_ids[valid_indices]] = out_lls[valid_indices] + result_greedy[out_ids[valid_indices]] = out_correct[valid_indices] + covered_points[out_ids[valid_indices]] = True + + pbar.set_postfix( + padding=( + f"{total_padding + padding_count}/{total_tokens + batch_tokens} =" + f" {(total_padding + padding_count) / (total_tokens + batch_tokens):.2f}" + ), + this_padding=f"{padding_count}/{batch_tokens}= {padding_count / batch_tokens:.2f}", + ) + pbar.update(len(segments_this_batch)) + + missing_points = np.where(~covered_points)[0] + assert len(missing_points) == 0, f"Missing points: {missing_points}" + + result = list(zip(result_probs, result_greedy)) logger.info(f"Finished running {len(requests)} loglikelihoods.") return result - def _pad_dataset_to_batch_size(self, requests): - dummy_instance = dataclasses.replace(requests[0], arguments=("hello", " there"), idx=len(requests)) - requests = requests + [dummy_instance] * (self.EvalBatch.size - len(requests) % self.EvalBatch.size) - assert len(requests) % self.EvalBatch.size == 0 - dataset = EvalDataset(self.EvalPos, self.tokenizer, requests) - return dataset + def stack_batches(self, example_iterator, EvalBatch): + """ + Stack examples from an iterator into a batch. + + Args: + EvalBatch: The batch axis. + example_iterator: An iterator of examples. + + Returns: + A batch of examples. + """ + with use_cpu_device(): + for batch in batched(example_iterator, EvalBatch.size): + if len(batch) < EvalBatch.size: + dummy_instance = self._make_dummy_instance(batch) + batch.extend([dummy_instance] * (EvalBatch.size - len(batch))) + yield stack_tree(EvalBatch, batch) + + def _make_dummy_instance(self, batch): + dummy_instance: LmExample = tree_zeros_like(batch[0]) + dummy_segment_mask = hax.full(self.EvalPos, -1, dtype=jnp.int32) + dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) + dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) + return dummy_instance def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() @@ -238,12 +341,6 @@ def generate_until(self, requests) -> List[str]: raise NotImplementedError() -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _jit_create_example(Pos, tokens, prompt_len, pad_token_id): - tokens = hax.named(tokens, Pos) - return LmExample.from_prompt_and_completion(Pos, tokens, prompt_len, ignore_id=pad_token_id) - - @dataclass(frozen=True) class TaskConfig: """ @@ -373,7 +470,15 @@ def run_lm_eval_harness( EvalBatch, axis_resources, mp: jmp.Policy | None, -) -> dict: +) -> dict | None: + """ + Run the LM Eval Harness on the given model and tasks. + + Returns: + If running on process 0, returns the outputs of the LM Eval Harness with the following extra keys. + - "averages": A dictionary with macro and micro averages for all metrics. + Otherwise, returns None. + """ tasks_to_run = config.to_task_dict() outputs = _actually_run_eval_harness(config, model, tasks_to_run, tokenizer, EvalBatch, axis_resources, mp) @@ -389,7 +494,7 @@ def _actually_run_eval_harness( EvalBatch: haliax.Axis, axis_resources: ResourceMapping, mp: jmp.Policy | None, -): +) -> dict | None: """ Actually run the LM Eval Harness on the given model and tasks. This is a separate function so that it can be used by the main function and the callback function. @@ -408,21 +513,37 @@ def _actually_run_eval_harness( f"Evaluating with max eval length {EvalPos.size} and batch size {EvalBatch.size}. There are" f" {num_parameters} parameters in the model." ) - harness = LevanterHarnessLM(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp) logger.info("Running eval harness...") - outputs = evaluator.evaluate( - harness, - tasks_to_run, - limit=max_examples, - log_samples=config.log_samples, - bootstrap_iters=config.bootstrap_iters, - ) - logger.info("Finished running eval harness.") - averages = _compute_averages(outputs) - outputs["averages"] = averages + worker = _LmEvalHarnessWorker(EvalBatch, EvalPos, model, axis_resources, tokenizer, mp, max_packed_segments=64) + + if jax.process_index() == 0: + logger.info("Process 0 is running the eval harness.") + harness = worker.make_harness_lm() + + # eval_harness only sets seeds in simple_evaluate, which we can't use (I think?) + tasks_to_run = _adjust_config(tasks_to_run, 0) + with set_global_rng_seeds(0): + outputs = evaluator.evaluate( + harness, + tasks_to_run, + limit=max_examples, + log_samples=config.log_samples, + bootstrap_iters=config.bootstrap_iters, + ) + worker.stop() + + averages = _compute_averages(outputs) + outputs["averages"] = averages - return outputs + return outputs + else: + logger.info(f"Process {jax.process_index()} is waiting for eval harness requests from process 0.") + worker.worker_message_loop() + + logger.info("Finished running eval harness.") + + return None def _compute_averages(outputs): @@ -471,100 +592,6 @@ def _compute_averages(outputs): return averages -BITS_PER_NAT = 1 / np.log(2) - -# eval_harness isn't consistent enough for this to actually be workable -# def _compute_extra_metrics(samples): -# """ -# Compute a few "soft" measures of accuracy for each task, based on the outputs of the eval harness. -# -# Specifically, we compute: -# - "bpb": bits per byte of the correct completion -# - "logprob": log probability of the correct completion -# - "choice_logprob": log probability of the correct choice normalized w.r.t. the other choices -# - "choice_prob_norm": probability of the length-normalized correct choice normalized w.r.t. the other choices -# -# Args: -# samples: Dictionary with task data, where each task has a list of samples. Each sample contains: -# - "doc": The original task document (can include metadata such as the answer key) -# - "target": Index of the correct answer (0-indexed), or -# "doc.answer" for 1-indexed answers. -# - "arguments": List of [input, completion] pairs -# - "resps": List of [log probability, is_correct] pairs for completions -# -# Returns: -# A dictionary with per-task aggregated metrics. -# """ -# # TODO: move to eval harness and use more sane logic -# # uses the samples which has one of two structures (that I've seen) -# # { "": [ {"doc": {...,}, "target": <0-indexed answer>, "arguments": [[input, completion], "resps": [[score, is_correct], ...], ...}, ...] } -# # { "": [ {"doc": {..., "answer": "[1-indexed answer]"}, "target": "", "arguments": [input, completion], "resps": [[score, is_correct], ...], ...}, ...] } -# metrics = {} -# -# for task, samples in samples.items(): -# bpb_list = [] -# logprob_list = [] -# choice_logprob_list = [] -# choice_prob_norm_list = [] -# -# for sample in samples: -# # Extract the correct answer index (supporting both 0-indexed `target` and 1-indexed `doc.answer`) -# if "answer" in sample["doc"]: -# target = int(sample["doc"]["answer"]) - 1 # Convert 1-indexed to 0-indexed -# elif "label" in sample["doc"]: -# target = int(sample["doc"]["label"]) -# elif "target" in sample and isinstance(sample["target"], int): -# target = sample["target"] # 0-indexed target -# elif "target" in sample and isinstance(sample["target"], str): -# # see if it's A-Z: -# if len(sample["target"]) == 1 and "A" <= sample["target"] <= "Z": -# target = ord(sample["target"]) - ord("A") -# else: -# raise ValueError(f"Invalid target: {sample['target']}. {sample}") -# elif "target" in sample and isinstance(sample["target"], list): -# target = sample["target"][0] -# else: -# raise KeyError(f"Missing `target` or `doc.answer` in sample. doc id: {sample['doc_id']}. Hash: {sample['doc_hash']}\n\n{sample}") -# -# resps = sample["filtered_resps"] # List of [log probability, is_correct] -# arguments = sample["arguments"] # [input, completion] pairs -# -# # Compute byte lengths for each choice -# byte_lengths = [max(1, len(completion.encode("utf-8"))) for _, completion in arguments] -# -# # Compute log probabilities for each choice -# log_probs = np.array([resp[0] for resp in resps]) # Extract log probabilities -# assert log_probs.shape == (len(arguments),), f"Log probs shape: {log_probs.shape}, arguments: {len(arguments)}. doc: {sample}" -# normalized_log_probs = log_probs - np.logaddexp.reduce(log_probs) -# -# # Metrics for the correct answer -# correct_logprob = log_probs[target] -# correct_bpb = -correct_logprob / byte_lengths[target] * NAT_TO_BIT -# correct_choice_logprob = normalized_log_probs[target] -# -# # Compute length-normalized weights (w_i) -# bpb_values = -log_probs / np.array(byte_lengths) * NAT_TO_BIT -# bpb_weights = np.exp(-bpb_values) -# bpb_weights /= max(bpb_weights.sum(), 1e-8) # Avoid division by zero -# correct_choice_prob_norm = bpb_weights[target] -# -# # Append metrics -# bpb_list.append(correct_bpb) -# logprob_list.append(correct_logprob) -# choice_logprob_list.append(correct_choice_logprob) -# choice_prob_norm_list.append(correct_choice_prob_norm) -# -# # Aggregate metrics for the task -# metrics[task] = { -# "bpb": np.mean(bpb_list) if bpb_list else 0.0, -# "logprob": np.mean(logprob_list) if logprob_list else 0.0, -# "choice_logprob": np.mean(choice_logprob_list) if choice_logprob_list else 0.0, -# "choice_prob_norm": np.mean(choice_prob_norm_list) if choice_prob_norm_list else 0.0, -# } -# -# return metrics - - def run_eval_harness_main(config: EvalHarnessMainConfig): config.trainer.initialize() tokenizer = config.the_tokenizer @@ -611,20 +638,20 @@ def run_eval_harness_main(config: EvalHarnessMainConfig): logger.info("Finished running LM eval harness") # log the results - logger.info("Logging results to tracker") - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - logger.info("Finished logging results to tracker") - - # log the results as json - logger.info("uploading artifacts...") - with open("lm_eval_harness_results.json", "w") as f: - json.dump(outputs, f, indent=2) - f.flush() - f_path = f.name - levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") - - # also write to stdout if jax.process_index() == 0: + logger.info("Logging results to tracker") + assert outputs is not None + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Finished logging results to tracker") + + # log the results as json + logger.info("uploading artifacts...") + with open("lm_eval_harness_results.json", "w") as f: + json.dump(outputs, f, indent=2) + f.flush() + f_path = f.name + levanter.tracker.current_tracker().log_artifact(f_path, name="lm_eval_harness_results") + print(json.dumps(outputs, indent=2), flush=True) return outputs @@ -660,7 +687,7 @@ def lm_eval_harness(config: LmEvalHarnessConfig, tokenizer, EvalBatch, axis_reso def lm_eval_harness(step: StepInfo, force=False): if step.step == 0 and not force: - return # don't run eval on the first step + return model = inference_mode(step.model, True) logger.info("Running eval harness...") @@ -675,10 +702,11 @@ def lm_eval_harness(step: StepInfo, force=False): ) logger.info("Finished running eval harness.") - log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) - logger.info("Logged report to tracker") - if jax.process_index() == 0: + assert outputs is not None + log_report_to_tracker("lm_eval", outputs, levanter.tracker.current_tracker()) + logger.info("Logged report to tracker") + # don't delete b/c wandb will sometimes defer upload with tempfile.NamedTemporaryFile("w", delete=False, suffix=".json") as f: import json @@ -693,6 +721,91 @@ def lm_eval_harness(step: StepInfo, force=False): return lm_eval_harness +# lifted from lm-eval simple_evaluate +def _adjust_config(task_dict, fewshot_random_seed=0): + adjusted_task_dict = {} + for task_name, task_obj in task_dict.items(): + if isinstance(task_obj, dict): + adjusted_task_dict = { + **adjusted_task_dict, + **{task_name: _adjust_config(task_obj, fewshot_random_seed=fewshot_random_seed)}, + } + + else: + # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file) + task_obj.set_fewshot_seed(seed=fewshot_random_seed) + + adjusted_task_dict[task_name] = task_obj + + return adjusted_task_dict + + +def _iterate_tokenized_requests( + requests: list[Instance], tokenizer: HfTokenizer, max_len: int, batch_size: int +) -> Iterator[PromptCompletion]: + """ + Tokenize the requests and yield them as PromptCompletions, for packing into LmExamples. + """ + # Separate contexts and completions + contexts = [request.args[0] for request in requests] + completions = [request.args[1] for request in requests] + + # Combine contexts and completions for full tokenization + combined_texts = [context + completion for context, completion in zip(contexts, completions)] + + # Batch tokenization for combined and context separately + for batch_indices in batched(range(len(requests)), batch_size): + # Extract batch data + combined_batch = [combined_texts[i] for i in batch_indices] + context_batch = [contexts[i] for i in batch_indices] + + # Tokenize batched inputs + combined_encodings = tokenizer(combined_batch, truncation=False, padding=False) + context_encodings = tokenizer(context_batch, truncation=False, padding=False) + + for off in range(len(batch_indices)): + i = batch_indices[off] + context_enc = context_encodings["input_ids"][off] + whole_ids = combined_encodings["input_ids"][off] + + context_enc_len = len(context_enc) + + if len(whole_ids) > max_len: + logger.warning(f"Request {i} is too long. Truncating.") + # Truncate from the left + whole_ids = whole_ids[-max_len:] + context_enc_len = max_len - len(whole_ids) + context_enc_len + if context_enc_len < 0: + context_enc_len = 0 + logger.warning("Prompt length is negative after truncation. Setting to 0.") + + yield PromptCompletion(ids=whole_ids, prompt_length=context_enc_len, segment_id=i) + + +def _pack_requests( + requests: list[Instance], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int +) -> Iterator[LmExample]: + packed_iterator = _iterate_tokenized_requests(requests, tokenizer, Pos.size, batch_size=128) + # TODO: use a better packing algorithm? + yield from pack_prompt_completions( + Pos, + packed_iterator, + max_segments_per_example=max_pack_size, + pad_token=tokenizer.pad_token_id, + max_buffered_examples=16, + ) + + +def _make_dummy_batch(EvalBatch, EvalPos): + dummy_batch = hax.vmap(LmExample.causal, EvalBatch)( + hax.zeros(EvalPos, dtype=jnp.int32), + loss_mask=hax.zeros(EvalPos, dtype=jnp.int32), + segment_ids=hax.zeros(EvalPos, dtype=jnp.int32), + ) + out = hax.shard(dummy_batch, {}) + return out + + if __name__ == "__main__": levanter.config.main(run_eval_harness_main)() print("Done", flush=True) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7f5c0e3d8..7f92221ed 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -4,7 +4,6 @@ import draccus import equinox as eqx -import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -31,6 +30,7 @@ def causal( loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None, eos_id: Optional[int] = None, + segment_ids: Optional[hax.NamedArray] = None, ) -> "LmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -40,24 +40,31 @@ def causal( Pos = tokens.axes[0] - # don't predict the last token. - if loss_mask is None: - loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) + causal_loss_mask = LmExample.causal_loss_mask(Pos) + + if loss_mask is not None: + loss_mask = loss_mask & causal_loss_mask.astype(loss_mask.dtype) + else: + loss_mask = causal_loss_mask if ignore_id is not None: # we don't compute loss for any tokens matching the ignore index ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id loss_mask = loss_mask * ignore_mask + loss_mask = loss_mask.astype(jnp.int32) + attn_mask = AttentionMask.causal() - if eos_id is not None: + if eos_id is not None and segment_ids is None: # the next token after an eos token is in a new segment eos_mask = hax.roll(tokens, 1, Pos) == eos_id # first token is always in segment 0 eos_mask = eos_mask.at[Pos, 0].set(False).astype(jnp.int32) segment_ids = hax.cumsum(eos_mask, axis=Pos) attn_mask = attn_mask.with_segment_ids(segment_ids) + elif segment_ids is not None: + attn_mask = attn_mask.with_segment_ids(segment_ids) return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) @@ -70,24 +77,33 @@ def from_prompt_and_completion( ignore_id: Optional[int] = None, all_causal: bool = True, ) -> "LmExample": - # mask out the prompt tokens - loss_mask = hax.arange(Pos) >= prompt_length - 1 - # don't predict the padding - if ignore_id is not None: - targets = hax.roll(tokens, -1, Pos) - loss_mask = loss_mask & (targets != ignore_id) - - # don't predict the last token - loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - if all_causal: attn_mask = AttentionMask.causal() else: # causal just for the completion part. We don't have a special structured mask for this, so we just raise NotImplementedError("Not implemented yet") + # mask out the prompt tokens + loss_mask = LmExample.causal_loss_mask(Pos, prompt_length=prompt_length) + + if ignore_id is not None: + # we don't compute loss for any tokens matching the ignore index + ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id + loss_mask = loss_mask * ignore_mask + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def causal_loss_mask(Pos: Axis, prompt_length: Optional[int] = None) -> NamedArray: + loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) + + if prompt_length is not None: + # don't predict the prompt tokens + prompt_mask = hax.arange(Pos) >= prompt_length - 1 + loss_mask = loss_mask * prompt_mask + + return loss_mask + # TODO: for some reason, mypy doesn't like the discover_packages_path argument? @dataclass(frozen=True) diff --git a/src/levanter/utils/background_iterable.py b/src/levanter/utils/background_iterable.py index 11a80f8ec..7e51efb10 100644 --- a/src/levanter/utils/background_iterable.py +++ b/src/levanter/utils/background_iterable.py @@ -34,9 +34,16 @@ def __iter__(self): class BackgroundIterator(Iterator[Ex]): - def __init__(self, producer_fn: Callable[[], Union[Iterator[Ex], AsyncIterator[Ex]]], max_capacity: Optional[int]): + def __init__( + self, + producer_fn: Callable[[], Iterator[Ex] | AsyncIterator[Ex]] | Iterator[Ex] | AsyncIterator[Ex], + max_capacity: Optional[int], + ): self.max_capacity = max_capacity - self._producer_fn = producer_fn + if not callable(producer_fn): + self._producer_fn = lambda: producer_fn + else: + self._producer_fn = producer_fn self._stop_event = threading.Event() if self.max_capacity is None or self.max_capacity >= 0: diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index be77a1d99..5f7b7a798 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -15,6 +15,7 @@ import haliax as hax from haliax import is_named_array +from haliax._src.util import index_where from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceAxis, ResourceMapping @@ -42,7 +43,7 @@ def use_cpu_device(): @contextlib.contextmanager def local_cpu_mesh(): - """Temporarily sets the default device to CPU""" + """Temporarily sets the default device to CPU and creates a mesh with a single CPU device""" cpu = jax.local_devices(backend="cpu")[0] mesh = jax.sharding.Mesh( np.array([cpu]).reshape(1, 1, 1), (ResourceAxis.REPLICA, ResourceAxis.DATA, ResourceAxis.MODEL) @@ -369,3 +370,62 @@ def _zeros_like(mapping, dtype, n): return n - n else: return jnp.zeros((), dtype=dtype) + + +def broadcast_shard(x: T, out_axis_specs: Any, source: int = 0) -> T: + """ + Given a tree of arrays that are on a single source host, and other data (e.g. zeros) with + the same structure, broadcast and shard the data to all hosts, using the axis mapping provided. + + For some reason, I had a ton of trouble figuring this out. + + Our strategy is, for each leaf: + 1. create a host_local_array_to_global_array with the data if we're the source, or zeros if we're not. + This gives us an array [num_devices, ...] + 2. Then, inside jit, we select the source'th element of the array, then reshard with the out_axis_specs + + """ + if jax.process_count() == 1: + return x + + current_mesh: jax.sharding.Mesh = hax.partitioning._get_mesh() + + axis_names = current_mesh.axis_names + + valid_device_for_process = index_where(lambda d: d.host_id == source, current_mesh.devices.flatten()) + sharding = NamedSharding( + current_mesh, + PartitionSpec( + axis_names, + ), + ) + + def pre_jit(x): + if jax.process_index() == source: + inp = np.array(x) + else: + inp = jnp.zeros(x.shape, dtype=x.dtype) + + shape = (len(jax.devices()),) + inp.shape + inp = jnp.expand_dims(inp, axis=0) + out = jax.make_array_from_callback(shape, sharding, lambda _: inp) + + return out + + def in_jit(x, pspec): + if isinstance(x, hax.NamedArray): + arr = x.array + else: + arr = x + arr = jax.lax.with_sharding_constraint(arr[valid_device_for_process], pspec) + + if isinstance(x, hax.NamedArray): + return hax.named(arr, x.axes) + else: + return arr + + x = jax.tree.map(pre_jit, x) + # q = eqx.filter_jit(jax.tree.map).lower(in_jit, x, out_axis_specs, is_leaf=is_named_array).as_text() + out = eqx.filter_jit(jax.tree.map)(in_jit, x, out_axis_specs, is_leaf=is_named_array) + + return out diff --git a/src/levanter/utils/py_utils.py b/src/levanter/utils/py_utils.py index dab038452..f3fd4f631 100644 --- a/src/levanter/utils/py_utils.py +++ b/src/levanter/utils/py_utils.py @@ -1,3 +1,4 @@ +import contextlib import os import sys import time @@ -121,3 +122,34 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.stop() + + +@contextlib.contextmanager +def set_global_rng_seeds(seed): + import numpy as np + + current_np_seed = np.random.get_state() + np.random.seed(seed) + + import random + + current_random_seed = random.getstate() + random.seed(seed) + + try: + import torch + + current_torch_seed = torch.random.get_rng_state() + torch.manual_seed(seed) + except ImportError: + torch = None + current_torch_seed = None + pass + + try: + yield + finally: + np.random.set_state(current_np_seed) + random.setstate(current_random_seed) + if current_torch_seed is not None: + torch.random.set_rng_state(current_torch_seed) diff --git a/tests/test_packing.py b/tests/test_packing.py new file mode 100644 index 000000000..02e3fffd6 --- /dev/null +++ b/tests/test_packing.py @@ -0,0 +1,223 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +import haliax as hax + +from levanter.data.packing import ( + PromptCompletion, + SequencePacker, + pack_prompt_completions, + per_segment_correct, + per_segment_loss, +) +from levanter.models.attention import AttentionMask +from levanter.models.lm_model import LmExample + + +def test_per_segment_loss(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=10, pad_token=0) + + # Add two sequences + packer.add_example(ids=[1, 2, 3], loss_mask=[1, 1, 1], segment_id=None) + packer.add_example(ids=[4, 5], loss_mask=[1, 1], segment_id=None) + + # Pack into LmExample + packed = packer.pack() + + losses = hax.named(jnp.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0]), Pos) + + Segments = hax.Axis("segments", size=3) + + unique_ids, segment_losses = per_segment_loss(packed, losses, max_Segments=Segments) + + assert list(unique_ids.array) == [-1, 0, 1] + assert list(segment_losses.array) == [0.0, 0.6, 0.9] + + +def test_can_pack_simple_case(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + assert packer.can_pack([1, 2, 3]) is True + packer.add_example(ids=[1, 2, 3], loss_mask=[1, 1, 1]) + assert packer.can_pack([4, 5]) is True + assert packer.can_pack(list(range(6, 16))) is False # Exceeds Pos size + + +def test_add_example_and_pack(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + packer.add_example([1, 2, 3], [1, 1, 1]) + packed = packer.pack() + + expected_tokens = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_exceed_max_pack_size(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + packer.add_example([1, 2, 3], [1, 1, 1]) + packer.add_example([4, 5, 6], [1, 1, 1]) + + with pytest.raises(ValueError, match="Too many segments"): + packer.add_example([7, 8], [1, 1]) # Exceeds max pack size + + +def test_empty_sequence(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + with pytest.raises(ValueError, match="ids and loss_mask must have the same length"): + packer.add_example([], [1]) # Mismatched lengths + + packer.add_example([], []) # Adding an empty sequence is allowed but does nothing + packed = packer.pack() + + expected_tokens = [0] * 10 + expected_segment_ids = [-1] * 10 + expected_loss_mask = [0] * 10 + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_packing_multiple_examples(): + Pos = hax.Axis("pos", size=10) + packer = SequencePacker(Pos=Pos, max_pack_size=2, pad_token=0) + + # First example + packer.add_example([1, 2], [1, 1]) + # Second example + packer.add_example([3, 4, 5], [1, 1, 1]) + + packed = packer.pack() + + expected_tokens = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0] + expected_segment_ids = [0, 0, 1, 1, 1, -1, -1, -1, -1, -1] + expected_loss_mask = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed.tokens.array, expected_tokens) + np.testing.assert_array_equal(packed.attn_mask.segment_ids.array, expected_segment_ids) + np.testing.assert_array_equal(packed.loss_mask.array, expected_loss_mask) + + +def test_pack_prompt_completions_simple(): + Pos = hax.Axis("pos", size=10) + pad_token = 0 + max_pack_size = 2 + max_buffered_examples = 2 + + sequences = [ + PromptCompletion(ids=[1, 2, 3], prompt_length=2), + PromptCompletion(ids=[4, 5], prompt_length=1), + PromptCompletion(ids=[6, 7, 8], prompt_length=1), + ] + + results = list(pack_prompt_completions(Pos, sequences, pad_token, max_pack_size, max_buffered_examples)) + + assert len(results) == 2 # Expect two packed LmExamples + + # Check the first packed example + packed_1 = results[0] + expected_tokens_1 = [1, 2, 3, 4, 5, 0, 0, 0, 0, 0] + expected_segment_ids_1 = [0, 0, 0, 1, 1, -1, -1, -1, -1, -1] + expected_loss_mask_1 = [0, 1, 0, 1, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_1.tokens.array, expected_tokens_1) + np.testing.assert_array_equal(packed_1.attn_mask.segment_ids.array, expected_segment_ids_1) + np.testing.assert_array_equal(packed_1.loss_mask.array, expected_loss_mask_1) + + # Check the second packed example + packed_2 = results[1] + expected_tokens_2 = [6, 7, 8, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_2 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_2 = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_2.tokens.array, expected_tokens_2) + np.testing.assert_array_equal(packed_2.attn_mask.segment_ids.array, expected_segment_ids_2) + np.testing.assert_array_equal(packed_2.loss_mask.array, expected_loss_mask_2) + + +def test_pack_prompt_completions_exceed_max_buffered_examples(): + Pos = hax.Axis("pos", size=10) + pad_token = 0 + max_pack_size = 1 + max_buffered_examples = 1 + + sequences = [ + PromptCompletion(ids=[1, 2, 3], prompt_length=2), + PromptCompletion(ids=[4, 5], prompt_length=1), + PromptCompletion(ids=[6, 7, 8], prompt_length=1), + ] + + results = list(pack_prompt_completions(Pos, sequences, pad_token, max_pack_size, max_buffered_examples)) + + assert len(results) == 3 + + # Check the first packed example + packed_1 = results[0] + expected_tokens_1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_1 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_1 = [0, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_1.tokens.array, expected_tokens_1) + np.testing.assert_array_equal(packed_1.attn_mask.segment_ids.array, expected_segment_ids_1) + np.testing.assert_array_equal(packed_1.loss_mask.array, expected_loss_mask_1) + + # Check the second packed example + packed_2 = results[1] + expected_tokens_2 = [4, 5, 0, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_2 = [0, 0, -1, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_2 = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_2.tokens.array, expected_tokens_2) + np.testing.assert_array_equal(packed_2.attn_mask.segment_ids.array, expected_segment_ids_2) + np.testing.assert_array_equal(packed_2.loss_mask.array, expected_loss_mask_2) + + # Check the third packed example + packed_3 = results[2] + expected_tokens_3 = [6, 7, 8, 0, 0, 0, 0, 0, 0, 0] + expected_segment_ids_3 = [0, 0, 0, -1, -1, -1, -1, -1, -1, -1] + expected_loss_mask_3 = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] + + np.testing.assert_array_equal(packed_3.tokens.array, expected_tokens_3) + np.testing.assert_array_equal(packed_3.attn_mask.segment_ids.array, expected_segment_ids_3) + np.testing.assert_array_equal(packed_3.loss_mask.array, expected_loss_mask_3) + + +def test_segment_correct(): + # Mock segment_ids and loss_mask + Pos = hax.Axis("pos", size=10) + tokens = hax.named(jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), Pos) + segment_ids = hax.named(jnp.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2]), Pos) + loss_mask = hax.named(jnp.array([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), Pos) + + # Create a packed example + attn_mask = AttentionMask.causal().with_segment_ids(segment_ids=segment_ids) + packed_example = LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + + # Mock correctness array (True for correct, False for incorrect) + correct = hax.named(jnp.array([True, True, True, False, True, False, True, True, True, True]), Pos) + + max_Segments = hax.Axis("segments", size=4) + + # Call the function + unique_ids, segment_correct = per_segment_correct(packed_example, correct, max_Segments) + + assert list(unique_ids.array) == [0, 1, 2, -1] + assert list(segment_correct.array) == [True, False, True, True] + + +if __name__ == "__main__": + pytest.main()