Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sequence packing for lm-eval-harness #850

Merged
merged 21 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
eval_harness:
task_spec: ["hellaswag"]
# task_spec: ["hellaswag"]
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 1
tokenizer: "gpt2"
model:
type: gpt2
Expand Down
12 changes: 1 addition & 11 deletions scripts/gcs_bulk_delete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import sys
import time
from datetime import datetime

import google.auth
from google.api_core import operations_v1
from google.cloud import storage_transfer_v1
from google.type.date_pb2 import Date
from google.type.timeofday_pb2 import TimeOfDay


EMPTY_BUCKET = "levanter-empty"
Expand All @@ -33,21 +30,14 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete):
gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET),
transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True),
),
schedule=storage_transfer_v1.types.Schedule(
schedule_start_date=Date(
year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day
),
start_time_of_day=TimeOfDay(
hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes
),
),
status=storage_transfer_v1.types.TransferJob.Status.ENABLED,
description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}",
)

# Create the transfer job
response = client.create_transfer_job(request={"transfer_job": transfer_job})
print(f"Created transfer job: {response.name}")
client.run_transfer_job({"job_name": response.name, "project_id": project_id})

# Wait for job completion
wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id)
Expand Down
8 changes: 4 additions & 4 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
import logging
import time
from collections import defaultdict
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

import equinox
import jax
from jax import Array
from jax import numpy as jnp
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_local_batch(begin: int, end: int) -> list:

# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
device_batch = stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)

cache[(begin, end)] = batch_leaves
Expand Down Expand Up @@ -267,8 +267,8 @@ def _fill_queue_with_batches(self):
super()._fill_queue_with_batches()


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
@equinox.filter_jit
def stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
if is_named_array(leaves[0]):
return hax.stack(batch_name, leaves)
Expand Down
213 changes: 213 additions & 0 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Implements sequence packing, mostly for doing evaluation on lots of short sequences.
Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens
(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and
yielding the packed examples when they are full.
This achieves about a 90% "real token" rate, compared to like 10% without packing.
"""
from dataclasses import dataclass
from typing import Iterable, Iterator

import jax.numpy as jnp
import numpy as np

import haliax as hax

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmExample
from levanter.utils.jax_utils import local_cpu_mesh


# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623

# todo should we use something like this: https://arxiv.org/pdf/2107.02027?


class SequencePacker:
"""
Packs sequences into a single LmExample.
"""

def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
self.Pos = Pos
self._ids: list[int] = []
self._segment_ids: list[int] = []
self._loss_mask: list[int] = []
self.num_segments = 0
self.pad_token = pad_token
self.max_pack_size = max_pack_size
assert pad_token is not None, "pad_token must be set"

def can_pack(self, ids: list[int]) -> bool:
return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size

def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None):
if len(ids) != len(loss_mask):
raise ValueError("ids and loss_mask must have the same length")

if len(ids) == 0:
return

if len(ids) + len(self._ids) > self.Pos.size:
raise ValueError("Too many tokens")

if self.num_segments >= self.max_pack_size:
raise ValueError("Too many segments")

self._ids.extend(ids)
if segment_id is None:
segment_id = self.num_segments

self.num_segments += 1

self._segment_ids.extend([segment_id] * len(ids))

self._loss_mask.extend(loss_mask)

def pack(self) -> LmExample:
ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids))

segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids))

loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask))

with local_cpu_mesh():
tokens = hax.named(ids, self.Pos).astype(jnp.int32)
segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32)
loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32)

attn_mask = AttentionMask.causal().with_segment_ids(segment_ids)

return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)


@dataclass(frozen=True)
class PromptCompletion:
ids: list[int]
prompt_length: int
segment_id: int | None = None


def pack_prompt_completions(
Pos: hax.Axis,
sequences: Iterable[PromptCompletion],
pad_token: int,
max_segments_per_example: int = 64,
max_buffered_examples: int = 64,
) -> Iterator[LmExample]:
"""
Packs a list of prompt completions into LmExamples using the SequencePacker
"""

packers = [SequencePacker(Pos, max_segments_per_example, pad_token)]

for sequence in sequences:
loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1
loss_mask[-1] = 0
assert np.any(loss_mask)

for packer in packers:
if packer.can_pack(sequence.ids):
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)

if packer.num_segments == max_segments_per_example:
yield packer.pack()
packers.remove(packer)
break
else:
# no packer could fit the example, create a new one
packer = SequencePacker(Pos, max_segments_per_example, pad_token)
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)
packers.append(packer)

while len(packers) >= max_buffered_examples:
yield packers.pop(0).pack()

for packer in packers:
yield packer.pack()


def per_segment_loss(
packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (Segments,), where:
* the first array is segment ids
* the second is loss per segment.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"
Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_losses = losses * packed_example.loss_mask

# sum the losses for each segment
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_losses.dtype)

segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos)

return unique_segment_ids, segment_losses


def _unique_segment_ids(max_Segments, segment_ids):
# Extract unique segment IDs with padding
# TODO: add unique to haliax
unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1)
unique_segment_ids = hax.named(unique_segment_ids, max_Segments)
return unique_segment_ids


def per_segment_correct(
packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (max_segments,), where:
* the first array is segment ids
* the second is whether all tokens in the segment are correct.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
correct is a boolean array of the same shape as the losses array indicating whether the token was correct
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"

Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask))

# sum the losses for each segment
# Extract unique segment IDs with padding
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_correct.dtype)

segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos)

return unique_segment_ids, segment_correct
Loading
Loading