Skip to content

Commit

Permalink
fix fencepost
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 25, 2024
1 parent eaa977a commit a1c264e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/levanter/tracker/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def _shardmap_histogram(a: NamedArray, bins):
),
)
res = shard_h(a.array, bins)

# the filter misses the last bin, so we need to add it
if res.size >= 1:
res = res.at[-1].add(1)
return res


Expand Down
6 changes: 3 additions & 3 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def test_sharded_histogram_simple():
Feature = hax.Axis("feature", 128)

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature))
a = hax.random.normal(PRNGKey(1), (Batch, Feature))
a = hax.shard(a)
hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=10)
hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=32)

hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=10)
hist_normal, bins_normal = jax.numpy.histogram(a.array, bins=32)

assert jax.numpy.allclose(hist, hist_normal)
assert jax.numpy.allclose(bins, bins_normal)
Expand Down

0 comments on commit a1c264e

Please sign in to comment.