-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/stanford-crfm/levanter into…
… nikil/group_task_fixes
- Loading branch information
Showing
22 changed files
with
973 additions
and
301 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
""" | ||
Implements sequence packing, mostly for doing evaluation on lots of short sequences. | ||
Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens | ||
(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and | ||
yielding the packed examples when they are full. | ||
This achieves about a 90% "real token" rate, compared to like 10% without packing. | ||
""" | ||
from dataclasses import dataclass | ||
from typing import Iterable, Iterator | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
import haliax as hax | ||
|
||
from levanter.models.attention import AttentionMask | ||
from levanter.models.lm_model import LmExample | ||
from levanter.utils.jax_utils import local_cpu_mesh | ||
|
||
|
||
# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 | ||
|
||
# todo should we use something like this: https://arxiv.org/pdf/2107.02027? | ||
|
||
|
||
class SequencePacker: | ||
""" | ||
Packs sequences into a single LmExample. | ||
""" | ||
|
||
def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int): | ||
self.Pos = Pos | ||
self._ids: list[int] = [] | ||
self._segment_ids: list[int] = [] | ||
self._loss_mask: list[int] = [] | ||
self.num_segments = 0 | ||
self.pad_token = pad_token | ||
self.max_pack_size = max_pack_size | ||
assert pad_token is not None, "pad_token must be set" | ||
|
||
def can_pack(self, ids: list[int]) -> bool: | ||
return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size | ||
|
||
def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None): | ||
if len(ids) != len(loss_mask): | ||
raise ValueError("ids and loss_mask must have the same length") | ||
|
||
if len(ids) == 0: | ||
return | ||
|
||
if len(ids) + len(self._ids) > self.Pos.size: | ||
raise ValueError("Too many tokens") | ||
|
||
if self.num_segments >= self.max_pack_size: | ||
raise ValueError("Too many segments") | ||
|
||
self._ids.extend(ids) | ||
if segment_id is None: | ||
segment_id = self.num_segments | ||
|
||
self.num_segments += 1 | ||
|
||
self._segment_ids.extend([segment_id] * len(ids)) | ||
|
||
self._loss_mask.extend(loss_mask) | ||
|
||
def pack(self) -> LmExample: | ||
ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids)) | ||
|
||
segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids)) | ||
|
||
loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask)) | ||
|
||
with local_cpu_mesh(): | ||
tokens = hax.named(ids, self.Pos).astype(jnp.int32) | ||
segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32) | ||
loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32) | ||
|
||
attn_mask = AttentionMask.causal().with_segment_ids(segment_ids) | ||
|
||
return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PromptCompletion: | ||
ids: list[int] | ||
prompt_length: int | ||
segment_id: int | None = None | ||
|
||
|
||
def pack_prompt_completions( | ||
Pos: hax.Axis, | ||
sequences: Iterable[PromptCompletion], | ||
pad_token: int, | ||
max_segments_per_example: int = 64, | ||
max_buffered_examples: int = 64, | ||
) -> Iterator[LmExample]: | ||
""" | ||
Packs a list of prompt completions into LmExamples using the SequencePacker | ||
""" | ||
|
||
packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] | ||
|
||
for sequence in sequences: | ||
loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 | ||
loss_mask[-1] = 0 | ||
assert np.any(loss_mask) | ||
|
||
for packer in packers: | ||
if packer.can_pack(sequence.ids): | ||
packer.add_example(sequence.ids, loss_mask, sequence.segment_id) | ||
|
||
if packer.num_segments == max_segments_per_example: | ||
yield packer.pack() | ||
packers.remove(packer) | ||
break | ||
else: | ||
# no packer could fit the example, create a new one | ||
packer = SequencePacker(Pos, max_segments_per_example, pad_token) | ||
packer.add_example(sequence.ids, loss_mask, sequence.segment_id) | ||
packers.append(packer) | ||
|
||
while len(packers) >= max_buffered_examples: | ||
yield packers.pop(0).pack() | ||
|
||
for packer in packers: | ||
yield packer.pack() | ||
|
||
|
||
def per_segment_loss( | ||
packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis | ||
) -> tuple[hax.NamedArray, hax.NamedArray]: | ||
""" | ||
Returns a pair of arrays of shape (Segments,), where: | ||
* the first array is segment ids | ||
* the second is loss per segment. | ||
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes | ||
""" | ||
|
||
assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" | ||
|
||
segment_ids = packed_example.attn_mask.segment_ids | ||
assert ( | ||
segment_ids.ndim == 1 | ||
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" | ||
Pos = packed_example.tokens.axes[0] | ||
|
||
# mask out padding etc | ||
masked_losses = losses * packed_example.loss_mask | ||
|
||
# sum the losses for each segment | ||
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) | ||
|
||
# Create a mask matrix where each row corresponds to a unique segment | ||
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) | ||
|
||
segment_mask = segment_mask.astype(masked_losses.dtype) | ||
|
||
segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos) | ||
|
||
return unique_segment_ids, segment_losses | ||
|
||
|
||
def _unique_segment_ids(max_Segments, segment_ids): | ||
# Extract unique segment IDs with padding | ||
# TODO: add unique to haliax | ||
unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1) | ||
unique_segment_ids = hax.named(unique_segment_ids, max_Segments) | ||
return unique_segment_ids | ||
|
||
|
||
def per_segment_correct( | ||
packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis | ||
) -> tuple[hax.NamedArray, hax.NamedArray]: | ||
""" | ||
Returns a pair of arrays of shape (max_segments,), where: | ||
* the first array is segment ids | ||
* the second is whether all tokens in the segment are correct. | ||
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes | ||
correct is a boolean array of the same shape as the losses array indicating whether the token was correct | ||
""" | ||
|
||
assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask" | ||
|
||
segment_ids = packed_example.attn_mask.segment_ids | ||
assert ( | ||
segment_ids.ndim == 1 | ||
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples" | ||
|
||
Pos = packed_example.tokens.axes[0] | ||
|
||
# mask out padding etc | ||
masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask)) | ||
|
||
# sum the losses for each segment | ||
# Extract unique segment IDs with padding | ||
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids) | ||
|
||
# Create a mask matrix where each row corresponds to a unique segment | ||
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments) | ||
|
||
segment_mask = segment_mask.astype(masked_correct.dtype) | ||
|
||
segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos) | ||
|
||
return unique_segment_ids, segment_correct |
Oops, something went wrong.