diff --git a/tests/test_histogram.py b/tests/test_histogram.py index b2e008724..b47522134 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -11,7 +11,7 @@ def test_sharded_histogram_simple(): - mesh = Mesh((jax.devices()), (ResourceAxis.DATA)) + mesh = Mesh((jax.devices()), (ResourceAxis.DATA,)) Batch = hax.Axis("Batch", 64) Feature = hax.Axis("Feature", 128) @@ -19,11 +19,12 @@ def test_sharded_histogram_simple(): with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): a = hax.random.normal(PRNGKey(0), (Batch, Feature)) a = hax.shard(a) - hist = levanter.tracker.histogram.sharded_histogram(a, bins=10) + hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=10) - hist_normal = jax.numpy.histogram(a.array, bins=10)[0] + hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=10) assert jax.numpy.allclose(hist, hist_normal) + assert jax.numpy.allclose(bins, bins_normal) @skip_if_not_enough_devices(2)