Skip to content

Commit

Permalink
increase tolerances
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 6, 2024
1 parent 5cf5f17 commit db08341
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 67 deletions.
3 changes: 2 additions & 1 deletion src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):

# ray doesn't merge the runtime envs properly, so we have to do it ourselves
# we need to do a deep merge
runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE)
sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None]
runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE)

remote_fn = remote_fn.options(
runtime_env=runtime_env,
Expand Down
118 changes: 52 additions & 66 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# test_cross_entropy.py
import math

import equinox
import jax.numpy as jnp
import jax.random
Expand All @@ -13,44 +15,36 @@
from levanter.utils.jax_utils import key_iterator


@pytest.fixture
def axes():
"""
Define and return Haliax axes for testing.
"""
Batch = hax.Axis("batch", size=2)
Seq = hax.Axis("seq", size=3)
Embed = hax.Axis("embed", size=8)
Vocab = hax.Axis("vocab", size=16)
return Batch, Seq, Embed, Vocab
Batch = hax.Axis("batch", size=2)
Seq = hax.Axis("seq", size=3)
Embed = hax.Axis("embed", size=8)
Vocab = hax.Axis("vocab", size=16)


@pytest.fixture
def test_data(axes):
def test_data():
"""
Create synthetic test data for cross-entropy loss computation.
"""
Batch, Seq, Embed, Vocab = axes

key = key_iterator(jax.random.PRNGKey(0))

# Initialize pred_embeddings with ones
pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed))
pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed), dtype=jnp.float32) / math.sqrt(Embed.size)

# Initialize pred_lm_head with ones
pred_lm_head = hax.random.normal(next(key), (Embed, Vocab))
pred_lm_head = hax.random.normal(next(key), (Vocab, Embed), dtype=jnp.float32) / math.sqrt(Embed.size)

# Define true_ids such that the target is always the first token in vocab
true_ids = hax.random.randint(next(key), (Batch, Seq), 0, Vocab.size)

return pred_embeddings, pred_lm_head, true_ids


def test_basic_equivalence(axes, test_data):
def test_basic_equivalence(test_data):
"""
Test that block-wise loss equals full loss when block_size perfectly divides vocab_size.
"""
Batch, Seq, Embed, Vocab = axes
pred_embeddings, pred_lm_head, true_ids = test_data

# Compute full loss
Expand All @@ -63,19 +57,20 @@ def test_basic_equivalence(axes, test_data):
Contract=Embed,
Label=Vocab,
labels_y=true_ids,
block_size=2,
block_size=8,
dtype=pred_embeddings.dtype,
)

# Assert that the losses are close
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Block-wise loss does not match full loss."
assert hax.all(
hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
), "Block-wise loss does not match full loss."


def test_single_block(axes, test_data):
def test_single_block(test_data):
"""
Test behavior when vocab_size equals block_size.
"""
Batch, Seq, Embed, Vocab = axes
pred_embeddings, pred_lm_head, true_ids = test_data

# Compute full loss
Expand All @@ -94,9 +89,11 @@ def test_single_block(axes, test_data):

# Assert that the losses are close
assert hax.all(
hax.isclose(sumexp_full, sumexp_block, atol=1e-4)
hax.isclose(sumexp_full, sumexp_block, atol=1e-3, rtol=1e-3)
), "Single block-wise sumexp does not match full sumexp."
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Single block-wise loss does not match full loss."
assert hax.all(
hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
), "Single block-wise loss does not match full loss."


def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids):
Expand All @@ -106,11 +103,10 @@ def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids):
return loss_full, sumexp_full


def test_multiple_blocks(axes, test_data):
def test_multiple_blocks(test_data):
"""
Test block-wise loss with multiple blocks.
"""
Batch, Seq, Embed, Vocab = axes
pred_embeddings, pred_lm_head, true_ids = test_data

# Compute full loss
Expand All @@ -127,12 +123,15 @@ def test_multiple_blocks(axes, test_data):
)

# Assert that the losses are close
assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "Multiple block-wise logz does not match full logz."
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Multiple block-wise loss does not match full loss."
assert hax.all(
hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
), "Multiple block-wise logz does not match full logz."
assert hax.all(
hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
), "Multiple block-wise loss does not match full loss."


def test_block_size_not_dividing_vocab(axes, test_data):
Batch, Seq, Embed, Vocab = axes
def test_block_size_not_dividing_vocab(test_data):
pred_embeddings, pred_lm_head, true_ids = test_data

# Set block_size that does not divide vocab_size
Expand All @@ -156,15 +155,18 @@ def test_block_size_not_dividing_vocab(axes, test_data):
)

# Assert that the losses are close
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Block-wise loss does not match full loss."
assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "Block-wise logz does not match full logz."
assert hax.all(
hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
), "Block-wise loss does not match full loss."
assert hax.all(
hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
), "Block-wise logz does not match full logz."


def test_vocab_size_less_than_block_size(axes, test_data):
def test_vocab_size_less_than_block_size(test_data):
"""
Test behavior when vocab_size is less than block_size.
"""
Batch, Seq, Embed, Vocab = axes
pred_embeddings, pred_lm_head, true_ids = test_data

# Set block_size greater than vocab_size
Expand All @@ -188,11 +190,11 @@ def test_vocab_size_less_than_block_size(axes, test_data):
)

# Assert that the losses are close
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "loss does not match full loss."
assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "logz does not match full logz."
assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)), "loss does not match full loss."
assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)), "logz does not match full logz."


def test_large_vocab(axes):
def test_large_vocab():
"""
Test block-wise loss with a larger vocabulary.
"""
Expand Down Expand Up @@ -233,29 +235,19 @@ def test_large_vocab(axes):

# Assert that the losses are close
assert hax.all(
hax.isclose(loss_full, loss_block, atol=1e-4)
hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)
), "Large vocab block-wise loss does not match full loss."
assert hax.all(
hax.isclose(logz_full, logz_block, atol=1e-4)
hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)
), "Large vocab block-wise logz does not match full logz."


@pytest.mark.parametrize("block_size", [1, 2, 3, 4, 5])
def test_gradient_block_cross_entropy(block_size):
def test_gradient_block_cross_entropy(block_size, test_data):
"""
Test the gradient of block-wise cross-entropy loss.
"""
# Define axes
Batch = hax.Axis("batch", size=2)
Seq = hax.Axis("seq", size=3)
Embed = hax.Axis("embed", size=8)
Vocab = hax.Axis("vocab", size=16)

# Define test data
key = jax.random.PRNGKey(0)
pred_embeddings = hax.random.normal(key, (Batch, Seq, Embed))
pred_lm_head = hax.random.normal(key, (Embed, Vocab))
true_ids = hax.random.randint(key, (Batch, Seq), 0, Vocab.size)
pred_embeddings, pred_lm_head, true_ids = test_data

# Compute block-wise loss
def custom_fn(pred):
Expand Down Expand Up @@ -286,25 +278,17 @@ def direct_fn(pred):

g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head))

assert hax.all(hax.isclose(g_embed, g_embed_direct, atol=1e-4)), "Gradient of embeddings does not match."
assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-4)), "Gradient of lm_head does not match."
assert hax.all(
hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3)
), "Gradient of embeddings does not match."
assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match."


def test_grad_loss_without_logz():
def test_grad_loss_without_logz(test_data):
"""
Test the gradient of block-wise cross-entropy loss without logz.
"""
# Define axes
Batch = hax.Axis("batch", size=2)
Seq = hax.Axis("seq", size=3)
Embed = hax.Axis("embed", size=8)
Vocab = hax.Axis("vocab", size=16)

# Define test data
key = jax.random.PRNGKey(0)
pred_embeddings = hax.random.normal(key, (Batch, Seq, Embed))
pred_lm_head = hax.random.normal(key, (Embed, Vocab))
true_ids = hax.random.randint(key, (Batch, Seq), 0, Vocab.size)
pred_embeddings, pred_lm_head, true_ids = test_data

# Compute block-wise loss
def custom_fn(pred):
Expand All @@ -328,12 +312,14 @@ def custom_fn(pred):

def direct_fn(pred):
pred_embeddings, pred_lm_head = pred
logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed")
logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed", preferred_element_type=pred_embeddings.dtype)
target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype)
loss, _ = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y)
return loss.mean().scalar()

g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head))

assert hax.all(hax.isclose(g_embed, g_embed_direct, atol=1e-4)), "Gradient of embeddings does not match."
assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-4)), "Gradient of lm_head does not match."
assert hax.all(
hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3)
), "Gradient of embeddings does not match."
assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match."

0 comments on commit db08341

Please sign in to comment.