Skip to content

Commit

Permalink
removed get_xy_batches, simplified tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 20, 2024
1 parent 6770514 commit 2fa7396
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 125 deletions.
34 changes: 7 additions & 27 deletions src/delphi/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
import time
from collections.abc import Generator
from collections.abc import Iterator
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Type, cast
Expand Down Expand Up @@ -139,22 +139,6 @@ def get_indices_for_epoch(
return indices


def get_xy_batch(
dataset: Dataset,
indices: list[int],
batch_size: int,
batch_num: int,
feature_name: str,
device: torch.device,
) -> torch.Tensor:
"""Get a batch of data from a dataset given a batch number and indices"""
start = batch_num * batch_size
end = (batch_num + 1) * batch_size
batch_indices = indices[start:end]
data = dataset[batch_indices][feature_name].to(device)
return data


def gen_minibatches(
dataset: Dataset,
batch_size: int,
Expand All @@ -163,21 +147,17 @@ def gen_minibatches(
indices: list[int],
device: torch.device,
feature_name: str,
) -> Generator[torch.Tensor, None, None]:
) -> Iterator[torch.Tensor]:
"""
Generate minibatches from a dataset given a step and indices
"""
minibatch_size = batch_size // num_minibatches
first_minibatch_num = num_minibatches * step
for i in range(num_minibatches):
yield get_xy_batch(
dataset=dataset,
indices=indices,
batch_num=first_minibatch_num + i,
batch_size=minibatch_size,
feature_name=feature_name,
device=device,
)
for batch_num in range(first_minibatch_num, first_minibatch_num + num_minibatches):
start = batch_num * minibatch_size
end = (batch_num + 1) * minibatch_size
batch_indices = indices[start:end]
yield dataset[batch_indices][feature_name].to(device)


@torch.no_grad()
Expand Down
143 changes: 45 additions & 98 deletions tests/train/test_train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import torch
from datasets import Dataset
from jaxtyping import Float
from transformers import PreTrainedModel

from delphi.constants import TEST_CONFIGS_DIR
from delphi.train.config import TrainingConfig
from delphi.train.config.utils import build_config_from_files_and_overrides
from delphi.train.train_step import accumulate_gradients, train_step
from delphi.train.utils import (
ModelTrainingState,
get_xy_batch,
gen_minibatches,
init_model,
setup_determinism,
)
Expand Down Expand Up @@ -90,69 +91,44 @@ def test_basic_reproducibility(dataset, model):
).all()


def get_grads(model: PreTrainedModel) -> Float[torch.Tensor, "grads"]:
grads = [
param.grad.flatten() for param in model.parameters() if param.grad is not None
]
return torch.cat(grads)


def test_accumulate_gradients_accumulates(dataset, model):
"""
check that gradient accumulation works as expected and doesn't reset on each microstep
"""
# setup
indices_set_a = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]
# different batch but idential last batch;
indices_set_a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
# different batch but idential last batch (with batches of 3);
# this should result in a different accumulated gradient
indices_set_b = [7, 8, 9, 7, 8, 9, 7, 8, 9]
batch_size = 3
num_batches = len(indices_set_a) // batch_size

batches_a = [
get_xy_batch(
dataset=dataset,
indices=indices_set_a,
batch_size=3,
batch_num=microstep,
feature_name="tokens",
device=torch.device("cpu"),
)
for microstep in range(num_batches)
]
batches_b = [
get_xy_batch(
dataset=dataset,
indices=indices_set_b,
batch_size=3,
batch_num=microstep,
feature_name="tokens",
device=torch.device("cpu"),
)
for microstep in range(num_batches)
]
num_batches = 3
# first 2 mini-batches different, last mini-batch the same
indices_set_a = [1, 2, 3] + [4, 5, 6] + [7, 8, 9]
indices_set_b = [7, 8, 9] * 3

kwargs = dict(
dataset=dataset,
batch_size=batch_size,
num_minibatches=num_batches,
step=0,
device=torch.device("cpu"),
feature_name="tokens",
)
batches_a = gen_minibatches(indices=indices_set_a, **kwargs) # type: ignore
batches_b = gen_minibatches(indices=indices_set_b, **kwargs) # type: ignore

# accumulate
_total_loss = accumulate_gradients(model, batches_a, len(batches_a))

grads_a = torch.cat(
[
param.grad.clone().detach().flatten()
for param in model.parameters()
if param.grad is not None
]
)
_total_loss = accumulate_gradients(model, batches_a, num_batches)

grads_a = get_grads(model)

# reset grad on model
model.zero_grad()

_total_loss = accumulate_gradients(model, batches_b, len(batches_b))
grads_b = torch.cat(
[
param.grad.clone().detach().flatten()
for param in model.parameters()
if param.grad is not None
]
)
_total_loss = accumulate_gradients(model, batches_b, num_batches)
grads_b = get_grads(model)

# test
assert not torch.isclose(grads_a, grads_b).all()
Expand All @@ -163,59 +139,30 @@ def test_accumulate_gradients_consistent(dataset, model):
Validate that the gradients are consistent when the same batch is passed to accumulate_gradients
"""
# setup
indices_set = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
]
indices_set = list(range(1, 10))
num_batches = 3
batch_size = 3
batches_a = [
get_xy_batch(
dataset=dataset,
indices=indices_set,
batch_size=batch_size,
batch_num=microstep,
feature_name="tokens",
device=torch.device("cpu"),
)
for microstep in range(num_batches)
]
batches_aa = [
get_xy_batch(
dataset=dataset,
indices=indices_set,
batch_size=batch_size,
batch_num=microstep,
feature_name="tokens",
device=torch.device("cpu"),
)
for microstep in range(num_batches)
]
kwargs = dict(
indices=list(range(1, 10)),
dataset=dataset,
batch_size=batch_size,
num_minibatches=num_batches,
step=0,
device=torch.device("cpu"),
feature_name="tokens",
)
batches_a = gen_minibatches(**kwargs) # type: ignore
batches_aa = gen_minibatches(**kwargs) # type: ignore

# accumulate
total_loss = accumulate_gradients(model, batches_a, num_batches)

grads_a = torch.cat(
[
param.grad.clone().detach().flatten()
for param in model.parameters()
if param.grad is not None
]
)
_total_loss = accumulate_gradients(model, batches_a, num_batches)

grads_a = get_grads(model)

# reset grad on model
model.zero_grad()

total_loss = accumulate_gradients(model, batches_aa, num_batches)
grads_aa = torch.cat(
[
param.grad.clone().detach().flatten()
for param in model.parameters()
if param.grad is not None
]
)
_total_loss = accumulate_gradients(model, batches_aa, num_batches)
grads_aa = get_grads(model)

# test
assert torch.isclose(grads_a, grads_aa).all()
Expand Down

0 comments on commit 2fa7396

Please sign in to comment.