diff --git a/src/levanter/tracker/histogram.py b/src/levanter/tracker/histogram.py index 07b0e507e..9ab983a12 100644 --- a/src/levanter/tracker/histogram.py +++ b/src/levanter/tracker/histogram.py @@ -98,6 +98,10 @@ def _shardmap_histogram(a: NamedArray, bins): ), ) res = shard_h(a.array, bins) + + # the filter misses the last bin, so we need to add it + if res.size >= 1: + res = res.at[-1].add(1) return res diff --git a/tests/test_histogram.py b/tests/test_histogram.py index 0e3a1e843..f2ef4fd0a 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -17,11 +17,11 @@ def test_sharded_histogram_simple(): Feature = hax.Axis("feature", 128) with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): - a = hax.random.normal(PRNGKey(0), (Batch, Feature)) + a = hax.random.normal(PRNGKey(1), (Batch, Feature)) a = hax.shard(a) - hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=10) + hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=32) - hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=10) + hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=32) assert jax.numpy.allclose(hist, hist_normal) assert jax.numpy.allclose(bins, bins_normal)