From db083412f10d0cea0e0d236c23d8930a4803c009 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 6 Nov 2024 09:30:22 -0800 Subject: [PATCH] increase tolerances --- src/levanter/infra/ray_tpu.py | 3 +- tests/test_loss.py | 118 +++++++++++++++------------------- 2 files changed, 54 insertions(+), 67 deletions(-) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index b04648079..1a9342c54 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): # ray doesn't merge the runtime envs properly, so we have to do it ourselves # we need to do a deep merge - runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None] + runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE) remote_fn = remote_fn.options( runtime_env=runtime_env, diff --git a/tests/test_loss.py b/tests/test_loss.py index d7960becd..30d140ede 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,4 +1,6 @@ # test_cross_entropy.py +import math + import equinox import jax.numpy as jnp import jax.random @@ -13,32 +15,25 @@ from levanter.utils.jax_utils import key_iterator -@pytest.fixture -def axes(): - """ - Define and return Haliax axes for testing. - """ - Batch = hax.Axis("batch", size=2) - Seq = hax.Axis("seq", size=3) - Embed = hax.Axis("embed", size=8) - Vocab = hax.Axis("vocab", size=16) - return Batch, Seq, Embed, Vocab +Batch = hax.Axis("batch", size=2) +Seq = hax.Axis("seq", size=3) +Embed = hax.Axis("embed", size=8) +Vocab = hax.Axis("vocab", size=16) @pytest.fixture -def test_data(axes): +def test_data(): """ Create synthetic test data for cross-entropy loss computation. """ - Batch, Seq, Embed, Vocab = axes key = key_iterator(jax.random.PRNGKey(0)) # Initialize pred_embeddings with ones - pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed)) + pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) # Initialize pred_lm_head with ones - pred_lm_head = hax.random.normal(next(key), (Embed, Vocab)) + pred_lm_head = hax.random.normal(next(key), (Vocab, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) # Define true_ids such that the target is always the first token in vocab true_ids = hax.random.randint(next(key), (Batch, Seq), 0, Vocab.size) @@ -46,11 +41,10 @@ def test_data(axes): return pred_embeddings, pred_lm_head, true_ids -def test_basic_equivalence(axes, test_data): +def test_basic_equivalence(test_data): """ Test that block-wise loss equals full loss when block_size perfectly divides vocab_size. """ - Batch, Seq, Embed, Vocab = axes pred_embeddings, pred_lm_head, true_ids = test_data # Compute full loss @@ -63,19 +57,20 @@ def test_basic_equivalence(axes, test_data): Contract=Embed, Label=Vocab, labels_y=true_ids, - block_size=2, + block_size=8, dtype=pred_embeddings.dtype, ) # Assert that the losses are close - assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Block-wise loss does not match full loss." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." -def test_single_block(axes, test_data): +def test_single_block(test_data): """ Test behavior when vocab_size equals block_size. """ - Batch, Seq, Embed, Vocab = axes pred_embeddings, pred_lm_head, true_ids = test_data # Compute full loss @@ -94,9 +89,11 @@ def test_single_block(axes, test_data): # Assert that the losses are close assert hax.all( - hax.isclose(sumexp_full, sumexp_block, atol=1e-4) + hax.isclose(sumexp_full, sumexp_block, atol=1e-3, rtol=1e-3) ), "Single block-wise sumexp does not match full sumexp." - assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Single block-wise loss does not match full loss." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Single block-wise loss does not match full loss." def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids): @@ -106,11 +103,10 @@ def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids): return loss_full, sumexp_full -def test_multiple_blocks(axes, test_data): +def test_multiple_blocks(test_data): """ Test block-wise loss with multiple blocks. """ - Batch, Seq, Embed, Vocab = axes pred_embeddings, pred_lm_head, true_ids = test_data # Compute full loss @@ -127,12 +123,15 @@ def test_multiple_blocks(axes, test_data): ) # Assert that the losses are close - assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "Multiple block-wise logz does not match full logz." - assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Multiple block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise logz does not match full logz." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise loss does not match full loss." -def test_block_size_not_dividing_vocab(axes, test_data): - Batch, Seq, Embed, Vocab = axes +def test_block_size_not_dividing_vocab(test_data): pred_embeddings, pred_lm_head, true_ids = test_data # Set block_size that does not divide vocab_size @@ -156,15 +155,18 @@ def test_block_size_not_dividing_vocab(axes, test_data): ) # Assert that the losses are close - assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "Block-wise loss does not match full loss." - assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "Block-wise logz does not match full logz." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Block-wise logz does not match full logz." -def test_vocab_size_less_than_block_size(axes, test_data): +def test_vocab_size_less_than_block_size(test_data): """ Test behavior when vocab_size is less than block_size. """ - Batch, Seq, Embed, Vocab = axes pred_embeddings, pred_lm_head, true_ids = test_data # Set block_size greater than vocab_size @@ -188,11 +190,11 @@ def test_vocab_size_less_than_block_size(axes, test_data): ) # Assert that the losses are close - assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-4)), "loss does not match full loss." - assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-4)), "logz does not match full logz." + assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)), "loss does not match full loss." + assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)), "logz does not match full logz." -def test_large_vocab(axes): +def test_large_vocab(): """ Test block-wise loss with a larger vocabulary. """ @@ -233,29 +235,19 @@ def test_large_vocab(axes): # Assert that the losses are close assert hax.all( - hax.isclose(loss_full, loss_block, atol=1e-4) + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) ), "Large vocab block-wise loss does not match full loss." assert hax.all( - hax.isclose(logz_full, logz_block, atol=1e-4) + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) ), "Large vocab block-wise logz does not match full logz." @pytest.mark.parametrize("block_size", [1, 2, 3, 4, 5]) -def test_gradient_block_cross_entropy(block_size): +def test_gradient_block_cross_entropy(block_size, test_data): """ Test the gradient of block-wise cross-entropy loss. """ - # Define axes - Batch = hax.Axis("batch", size=2) - Seq = hax.Axis("seq", size=3) - Embed = hax.Axis("embed", size=8) - Vocab = hax.Axis("vocab", size=16) - - # Define test data - key = jax.random.PRNGKey(0) - pred_embeddings = hax.random.normal(key, (Batch, Seq, Embed)) - pred_lm_head = hax.random.normal(key, (Embed, Vocab)) - true_ids = hax.random.randint(key, (Batch, Seq), 0, Vocab.size) + pred_embeddings, pred_lm_head, true_ids = test_data # Compute block-wise loss def custom_fn(pred): @@ -286,25 +278,17 @@ def direct_fn(pred): g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) - assert hax.all(hax.isclose(g_embed, g_embed_direct, atol=1e-4)), "Gradient of embeddings does not match." - assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-4)), "Gradient of lm_head does not match." + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match." -def test_grad_loss_without_logz(): +def test_grad_loss_without_logz(test_data): """ Test the gradient of block-wise cross-entropy loss without logz. """ - # Define axes - Batch = hax.Axis("batch", size=2) - Seq = hax.Axis("seq", size=3) - Embed = hax.Axis("embed", size=8) - Vocab = hax.Axis("vocab", size=16) - - # Define test data - key = jax.random.PRNGKey(0) - pred_embeddings = hax.random.normal(key, (Batch, Seq, Embed)) - pred_lm_head = hax.random.normal(key, (Embed, Vocab)) - true_ids = hax.random.randint(key, (Batch, Seq), 0, Vocab.size) + pred_embeddings, pred_lm_head, true_ids = test_data # Compute block-wise loss def custom_fn(pred): @@ -328,12 +312,14 @@ def custom_fn(pred): def direct_fn(pred): pred_embeddings, pred_lm_head = pred - logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed", preferred_element_type=pred_embeddings.dtype) target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) loss, _ = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y) return loss.mean().scalar() g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) - assert hax.all(hax.isclose(g_embed, g_embed_direct, atol=1e-4)), "Gradient of embeddings does not match." - assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-4)), "Gradient of lm_head does not match." + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match."