Skip to content

Commit

Permalink
actually shard
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 24, 2024
1 parent bc6897c commit ae99971
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 44 deletions.
26 changes: 0 additions & 26 deletions debug_jit_log.py

This file was deleted.

16 changes: 4 additions & 12 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n
if n > 0:
total_loss /= n

# logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba")
# logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba")

return total_loss


Expand Down Expand Up @@ -523,19 +520,14 @@ def _rec_log_magnitudes(norms, hists, path_prefix, tree):
strict=True,
):
if self.split_scan_layers and isinstance(g, haliax.nn.Stacked):
# unstacked = g.unstacked()
# for i, layer in enumerate(unstacked):
# _rec_log_magnitudes(to_log, join_key(key_path, str(i)), layer)
# vmap over the layers
Block = g.Block
vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, Block)({}, {}, "", g.stacked)
# manual loop
vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, g.Block)({}, {}, "", g.stacked)

for k, v in vmapped_norms.items():
for i in range(Block.size):
for i in range(g.Block.size):
norms[f"{key_path}/{i}/{k}"] = v[i]

for k, v in vmapped_hists.items():
for i in range(Block.size):
for i in range(g.Block.size):
hists[f"{key_path}/{i}/{k}"] = jax.tree.map(
lambda x: x[i] if is_jax_array_like(x) else x, v
)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def test_sharded_histogram_simple():
Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)

a = hax.random.normal(PRNGKey(0), (Batch, Feature))

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature))
a = hax.shard(a)
hist = levanter.tracker.histogram.sharded_histogram(a, bins=10)

hist_normal = jax.numpy.histogram(a.array, bins=10)[0]
Expand All @@ -33,14 +33,12 @@ def test_sharded_histogram_tp():
Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)

a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100
a = hax.shard(a)
hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64)

jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64)

print(hist, jnp_hist)

assert jax.numpy.allclose(hist, jnp_hist)
assert jax.numpy.allclose(bins, jnp_bins)

0 comments on commit ae99971

Please sign in to comment.