Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 24, 2024
1 parent c8553b4 commit aa70b88
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@


def test_sharded_histogram_simple():
mesh = Mesh((jax.devices()), (ResourceAxis.DATA))
mesh = Mesh((jax.devices()), (ResourceAxis.DATA,))

Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature))
a = hax.shard(a)
hist = levanter.tracker.histogram.sharded_histogram(a, bins=10)
hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=10)

hist_normal = jax.numpy.histogram(a.array, bins=10)[0]
hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=10)

assert jax.numpy.allclose(hist, hist_normal)
assert jax.numpy.allclose(bins, bins_normal)


@skip_if_not_enough_devices(2)
Expand Down

0 comments on commit aa70b88

Please sign in to comment.