Skip to content

Commit

Permalink
Add sequence packing for lm-eval-harness (#850)
Browse files Browse the repository at this point in the history
Speedup was less than I hoped, but I think that can be improved with
better packing strategies (and also removing some batching overhead. the
bottleneck is now data loading??)

---------

Co-authored-by: Nikil Ravi <[email protected]>
  • Loading branch information
dlwh and nikil-ravi authored Jan 8, 2025
1 parent 101324f commit 93a8aa9
Show file tree
Hide file tree
Showing 10 changed files with 940 additions and 283 deletions.
5 changes: 4 additions & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
eval_harness:
task_spec: ["hellaswag"]
# task_spec: ["hellaswag"]
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 1
tokenizer: "gpt2"
model:
type: gpt2
Expand Down
12 changes: 1 addition & 11 deletions scripts/gcs_bulk_delete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import sys
import time
from datetime import datetime

import google.auth
from google.api_core import operations_v1
from google.cloud import storage_transfer_v1
from google.type.date_pb2 import Date
from google.type.timeofday_pb2 import TimeOfDay


EMPTY_BUCKET = "levanter-empty"
Expand All @@ -33,21 +30,14 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete):
gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET),
transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True),
),
schedule=storage_transfer_v1.types.Schedule(
schedule_start_date=Date(
year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day
),
start_time_of_day=TimeOfDay(
hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes
),
),
status=storage_transfer_v1.types.TransferJob.Status.ENABLED,
description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}",
)

# Create the transfer job
response = client.create_transfer_job(request={"transfer_job": transfer_job})
print(f"Created transfer job: {response.name}")
client.run_transfer_job({"job_name": response.name, "project_id": project_id})

# Wait for job completion
wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id)
Expand Down
8 changes: 4 additions & 4 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
import logging
import time
from collections import defaultdict
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

import equinox
import jax
from jax import Array
from jax import numpy as jnp
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_local_batch(begin: int, end: int) -> list:

# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
device_batch = stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)

cache[(begin, end)] = batch_leaves
Expand Down Expand Up @@ -267,8 +267,8 @@ def _fill_queue_with_batches(self):
super()._fill_queue_with_batches()


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
@equinox.filter_jit
def stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
if is_named_array(leaves[0]):
return hax.stack(batch_name, leaves)
Expand Down
213 changes: 213 additions & 0 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Implements sequence packing, mostly for doing evaluation on lots of short sequences.
Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens
(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and
yielding the packed examples when they are full.
This achieves about a 90% "real token" rate, compared to like 10% without packing.
"""
from dataclasses import dataclass
from typing import Iterable, Iterator

import jax.numpy as jnp
import numpy as np

import haliax as hax

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmExample
from levanter.utils.jax_utils import local_cpu_mesh


# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623

# todo should we use something like this: https://arxiv.org/pdf/2107.02027?


class SequencePacker:
"""
Packs sequences into a single LmExample.
"""

def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
self.Pos = Pos
self._ids: list[int] = []
self._segment_ids: list[int] = []
self._loss_mask: list[int] = []
self.num_segments = 0
self.pad_token = pad_token
self.max_pack_size = max_pack_size
assert pad_token is not None, "pad_token must be set"

def can_pack(self, ids: list[int]) -> bool:
return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size

def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None):
if len(ids) != len(loss_mask):
raise ValueError("ids and loss_mask must have the same length")

if len(ids) == 0:
return

if len(ids) + len(self._ids) > self.Pos.size:
raise ValueError("Too many tokens")

if self.num_segments >= self.max_pack_size:
raise ValueError("Too many segments")

self._ids.extend(ids)
if segment_id is None:
segment_id = self.num_segments

self.num_segments += 1

self._segment_ids.extend([segment_id] * len(ids))

self._loss_mask.extend(loss_mask)

def pack(self) -> LmExample:
ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids))

segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids))

loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask))

with local_cpu_mesh():
tokens = hax.named(ids, self.Pos).astype(jnp.int32)
segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32)
loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32)

attn_mask = AttentionMask.causal().with_segment_ids(segment_ids)

return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)


@dataclass(frozen=True)
class PromptCompletion:
ids: list[int]
prompt_length: int
segment_id: int | None = None


def pack_prompt_completions(
Pos: hax.Axis,
sequences: Iterable[PromptCompletion],
pad_token: int,
max_segments_per_example: int = 64,
max_buffered_examples: int = 64,
) -> Iterator[LmExample]:
"""
Packs a list of prompt completions into LmExamples using the SequencePacker
"""

packers = [SequencePacker(Pos, max_segments_per_example, pad_token)]

for sequence in sequences:
loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1
loss_mask[-1] = 0
assert np.any(loss_mask)

for packer in packers:
if packer.can_pack(sequence.ids):
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)

if packer.num_segments == max_segments_per_example:
yield packer.pack()
packers.remove(packer)
break
else:
# no packer could fit the example, create a new one
packer = SequencePacker(Pos, max_segments_per_example, pad_token)
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)
packers.append(packer)

while len(packers) >= max_buffered_examples:
yield packers.pop(0).pack()

for packer in packers:
yield packer.pack()


def per_segment_loss(
packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (Segments,), where:
* the first array is segment ids
* the second is loss per segment.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"
Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_losses = losses * packed_example.loss_mask

# sum the losses for each segment
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_losses.dtype)

segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos)

return unique_segment_ids, segment_losses


def _unique_segment_ids(max_Segments, segment_ids):
# Extract unique segment IDs with padding
# TODO: add unique to haliax
unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1)
unique_segment_ids = hax.named(unique_segment_ids, max_Segments)
return unique_segment_ids


def per_segment_correct(
packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (max_segments,), where:
* the first array is segment ids
* the second is whether all tokens in the segment are correct.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
correct is a boolean array of the same shape as the losses array indicating whether the token was correct
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"

Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask))

# sum the losses for each segment
# Extract unique segment IDs with padding
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_correct.dtype)

segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos)

return unique_segment_ids, segment_correct
Loading

0 comments on commit 93a8aa9

Please sign in to comment.