Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/stanford-crfm/levanter into…
Browse files Browse the repository at this point in the history
… nikil/group_task_fixes
  • Loading branch information
nikil-ravi committed Jan 9, 2025
2 parents b92c090 + ab7dc65 commit bb708c4
Show file tree
Hide file tree
Showing 22 changed files with 973 additions and 301 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/launch_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Install locally
run: |
python -m pip install --upgrade pip
pip install -e .[test] "jax[cpu]==0.4.30"
pip install -e .[test] "jax[cpu]==0.4.38"
- name: Launch Small Fast TPU Train LM job
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.14"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_ray_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
5 changes: 4 additions & 1 deletion config/harness/harness_nano.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
eval_harness:
task_spec: ["hellaswag"]
# task_spec: ["hellaswag"]
task_spec:
- task: commonsense_qa # 5-way multiple-choice questions based on common-sense, everyday scenarios
num_fewshot: 1
tokenizer: "gpt2"
model:
type: gpt2
Expand Down
3 changes: 1 addition & 2 deletions docker/tpu/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ RUN pip install virtualenv
# venv binaries encode their directory, so we need to setup the venv in the final location
RUN virtualenv -p python3.10 /opt/levanter/.venv
ENV PATH /opt/levanter/.venv/bin:$PATH
#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install package dependencies to make incremental builds faster.
WORKDIR /tmp/
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ dependencies = [
"tensorstore>=0.1.65",
"pytimeparse>=1.1.8",
"humanfriendly==10.0",
"safetensors[numpy]~=0.4.2",
"safetensors[numpy]>=0.4.2,<0.6.0",
"matplotlib>=3.7.0",
"tblib>=1.7.0,<4.0.0",
"dataclasses-json~=0.6.4",
Expand Down
12 changes: 1 addition & 11 deletions scripts/gcs_bulk_delete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import re
import sys
import time
from datetime import datetime

import google.auth
from google.api_core import operations_v1
from google.cloud import storage_transfer_v1
from google.type.date_pb2 import Date
from google.type.timeofday_pb2 import TimeOfDay


EMPTY_BUCKET = "levanter-empty"
Expand All @@ -33,21 +30,14 @@ def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete):
gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET),
transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True),
),
schedule=storage_transfer_v1.types.Schedule(
schedule_start_date=Date(
year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day
),
start_time_of_day=TimeOfDay(
hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes
),
),
status=storage_transfer_v1.types.TransferJob.Status.ENABLED,
description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}",
)

# Create the transfer job
response = client.create_transfer_job(request={"transfer_job": transfer_job})
print(f"Created transfer job: {response.name}")
client.run_transfer_job({"job_name": response.name, "project_id": project_id})

# Wait for job completion
wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id)
Expand Down
8 changes: 4 additions & 4 deletions src/levanter/data/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
import logging
import time
from collections import defaultdict
from typing import AsyncIterator, Callable, Iterable, Iterator, Optional, Tuple, TypeVar

import equinox
import jax
from jax import Array
from jax import numpy as jnp
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_local_batch(begin: int, end: int) -> list:

# TODO: if we ever do "big data" (i.e. huge examples) we might want to be able to load part of an example
# which will require support from the datastore (i.e. tensorstore)
device_batch = _stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
device_batch = stack_tree(self.dl.Batch.name, [data_for_this_batch[i] for i in range(begin, end)])
batch_leaves = hax.tree_util.tree_leaves(device_batch)

cache[(begin, end)] = batch_leaves
Expand Down Expand Up @@ -267,8 +267,8 @@ def _fill_queue_with_batches(self):
super()._fill_queue_with_batches()


@functools.partial(jax.jit, static_argnums=(0,))
def _stack_tree(batch_name, individual_datums):
@equinox.filter_jit
def stack_tree(batch_name, individual_datums):
def _stack_leaves_unchecked(*leaves):
if is_named_array(leaves[0]):
return hax.stack(batch_name, leaves)
Expand Down
213 changes: 213 additions & 0 deletions src/levanter/data/packing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Implements sequence packing, mostly for doing evaluation on lots of short sequences.
Our strategy is basically to maintain a pool of SequencePackers, each of which can hold a fixed number of tokens
(and a maximum number of segments). We then iterate over the sequences, adding them to the packers if they fit, and
yielding the packed examples when they are full.
This achieves about a 90% "real token" rate, compared to like 10% without packing.
"""
from dataclasses import dataclass
from typing import Iterable, Iterator

import jax.numpy as jnp
import numpy as np

import haliax as hax

from levanter.models.attention import AttentionMask
from levanter.models.lm_model import LmExample
from levanter.utils.jax_utils import local_cpu_mesh


# cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623

# todo should we use something like this: https://arxiv.org/pdf/2107.02027?


class SequencePacker:
"""
Packs sequences into a single LmExample.
"""

def __init__(self, Pos: hax.Axis, max_pack_size: int, pad_token: int):
self.Pos = Pos
self._ids: list[int] = []
self._segment_ids: list[int] = []
self._loss_mask: list[int] = []
self.num_segments = 0
self.pad_token = pad_token
self.max_pack_size = max_pack_size
assert pad_token is not None, "pad_token must be set"

def can_pack(self, ids: list[int]) -> bool:
return len(ids) + len(self._ids) <= self.Pos.size and self.num_segments < self.max_pack_size

def add_example(self, ids: list[int], loss_mask: list[int] | np.ndarray, segment_id: int | None = None):
if len(ids) != len(loss_mask):
raise ValueError("ids and loss_mask must have the same length")

if len(ids) == 0:
return

if len(ids) + len(self._ids) > self.Pos.size:
raise ValueError("Too many tokens")

if self.num_segments >= self.max_pack_size:
raise ValueError("Too many segments")

self._ids.extend(ids)
if segment_id is None:
segment_id = self.num_segments

self.num_segments += 1

self._segment_ids.extend([segment_id] * len(ids))

self._loss_mask.extend(loss_mask)

def pack(self) -> LmExample:
ids = self._ids + [self.pad_token] * (self.Pos.size - len(self._ids))

segment_ids = self._segment_ids + [-1] * (self.Pos.size - len(self._segment_ids))

loss_mask = self._loss_mask + [0] * (self.Pos.size - len(self._loss_mask))

with local_cpu_mesh():
tokens = hax.named(ids, self.Pos).astype(jnp.int32)
segment_ids = hax.named(segment_ids, self.Pos).astype(jnp.int32)
loss_mask = hax.named(loss_mask, self.Pos).astype(jnp.int32)

attn_mask = AttentionMask.causal().with_segment_ids(segment_ids)

return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask)


@dataclass(frozen=True)
class PromptCompletion:
ids: list[int]
prompt_length: int
segment_id: int | None = None


def pack_prompt_completions(
Pos: hax.Axis,
sequences: Iterable[PromptCompletion],
pad_token: int,
max_segments_per_example: int = 64,
max_buffered_examples: int = 64,
) -> Iterator[LmExample]:
"""
Packs a list of prompt completions into LmExamples using the SequencePacker
"""

packers = [SequencePacker(Pos, max_segments_per_example, pad_token)]

for sequence in sequences:
loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1
loss_mask[-1] = 0
assert np.any(loss_mask)

for packer in packers:
if packer.can_pack(sequence.ids):
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)

if packer.num_segments == max_segments_per_example:
yield packer.pack()
packers.remove(packer)
break
else:
# no packer could fit the example, create a new one
packer = SequencePacker(Pos, max_segments_per_example, pad_token)
packer.add_example(sequence.ids, loss_mask, sequence.segment_id)
packers.append(packer)

while len(packers) >= max_buffered_examples:
yield packers.pop(0).pack()

for packer in packers:
yield packer.pack()


def per_segment_loss(
packed_example: LmExample, losses: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (Segments,), where:
* the first array is segment ids
* the second is loss per segment.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"
Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_losses = losses * packed_example.loss_mask

# sum the losses for each segment
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_losses.dtype)

segment_losses = hax.dot(segment_mask, masked_losses, axis=Pos)

return unique_segment_ids, segment_losses


def _unique_segment_ids(max_Segments, segment_ids):
# Extract unique segment IDs with padding
# TODO: add unique to haliax
unique_segment_ids = jnp.unique(segment_ids.array, size=max_Segments.size, fill_value=-1)
unique_segment_ids = hax.named(unique_segment_ids, max_Segments)
return unique_segment_ids


def per_segment_correct(
packed_example: LmExample, correct: hax.NamedArray, max_Segments: hax.Axis
) -> tuple[hax.NamedArray, hax.NamedArray]:
"""
Returns a pair of arrays of shape (max_segments,), where:
* the first array is segment ids
* the second is whether all tokens in the segment are correct.
This code is designed to run in a jit-compiled function, meaning we have to careful of shapes
correct is a boolean array of the same shape as the losses array indicating whether the token was correct
"""

assert packed_example.attn_mask.segment_ids is not None, "segment_ids must be set in the AttentionMask"

segment_ids = packed_example.attn_mask.segment_ids
assert (
segment_ids.ndim == 1
), f"Expected segment_ids to be 1D, got {segment_ids.ndim}. Use vmap if you have multiple examples"

Pos = packed_example.tokens.axes[0]

# mask out padding etc
masked_correct = hax.logical_or(correct, hax.logical_not(packed_example.loss_mask))

# sum the losses for each segment
# Extract unique segment IDs with padding
unique_segment_ids = _unique_segment_ids(max_Segments, segment_ids)

# Create a mask matrix where each row corresponds to a unique segment
segment_mask = unique_segment_ids == segment_ids.broadcast_axis(max_Segments)

segment_mask = segment_mask.astype(masked_correct.dtype)

segment_correct = hax.all(hax.where(segment_mask, masked_correct, True), axis=Pos)

return unique_segment_ids, segment_correct
Loading

0 comments on commit bb708c4

Please sign in to comment.