Skip to content

Commit

Permalink
ok this one doesn't seem to crash
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 25, 2024
1 parent 78f98a3 commit 7837060
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
65 changes: 32 additions & 33 deletions src/levanter/tracker/histogram.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import equinox
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -58,13 +60,14 @@ def sharded_histogram(a: NamedArray, bins: int | ArrayLike = 10) -> tuple[jnp.nd
* It only works with NamedArrays
* It is more performant on TPUs
Credit to @aphoh for the original implementation.
Credit to @aphoh for the original implementation, though that one crashes on TPUs due to some kind of driver bug
"""
edges = jnp.histogram_bin_edges(a.array, bins=bins)
return _shardmap_histogram(a, edges), edges
# return jnp.histogram(a.array, bins=edges)


def _single_shard_histogram(a, bins):
def _single_shard_histogram(a, bins, reduce_mesh):
"""Modified version of jax.numpy.histogram that returns integer counts instead of using the datatype of the input.
Also avoids searchsorted, which is slow on TPUs.
Args:
Expand All @@ -74,46 +77,42 @@ def _single_shard_histogram(a, bins):
Array: counts. has length len(bins) - 1
"""
a = a.flatten()
prefix_sum = jnp.sum((a < bins[:, None]).astype(jnp.int32), axis=1)
last_count = jnp.sum(a <= bins[-1])
prefix_sum = prefix_sum.at[-1].set(last_count)
return jnp.expand_dims(jnp.diff(prefix_sum), 0)

bin_idx = jnp.searchsorted(bins, a, side="right", method="compare_all")
bin_idx = jnp.where(a == bins[-1], len(bins) - 1, bin_idx)
weights = jnp.ones_like(a)
counts = jnp.zeros(len(bins), weights.dtype).at[bin_idx].add(weights)[1:]

if len(reduce_mesh):
counts = jax.lax.psum(counts, axis_name=reduce_mesh)
return counts


@jax.jit
def _shardmap_histogram(a: NamedArray, bins):
mesh = hax.partitioning._get_mesh()
flattened_spec, spec = _flattened_spec(a)
spec = hax.partitioning.pspec_for_axis(a.axes)
flattened_spec = _flattened_spec(spec)
shard_h = shard_map(
_single_shard_histogram,
functools.partial(_single_shard_histogram, reduce_mesh=flattened_spec),
mesh=mesh,
in_specs=(
spec,
PartitionSpec(
None,
),
in_specs=(spec, PartitionSpec(None)),
out_specs=PartitionSpec(
None,
),
out_specs=(flattened_spec),
)
res = shard_h(a.array, bins)
return res.sum(axis=0)
return res
# return res


def _flattened_spec(a):
spec = hax.partitioning.pspec_for_axis(a.axes)
def _flattened_spec(spec):
out = []
for s in spec:
if isinstance(s, tuple):
out.extend(s)
elif s is None:
pass
else:
out.append(s)

def flatten_spec(spec):
# spec is a tuple None|str|tuple[str]
out = []
for s in spec:
if isinstance(s, tuple):
out.extend(s)
elif s is None:
pass
else:
out.append(s)

return tuple(out)

flattened_spec = PartitionSpec(flatten_spec(spec))
return flattened_spec, spec
return tuple(out)
8 changes: 4 additions & 4 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
def test_sharded_histogram_simple():
mesh = Mesh((jax.devices()), (ResourceAxis.DATA,))

Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)
Batch = hax.Axis("batch", 64)
Feature = hax.Axis("feature", 128)

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature))
Expand All @@ -31,8 +31,8 @@ def test_sharded_histogram_simple():
def test_sharded_histogram_tp():
mesh = Mesh(np.array(jax.devices()).reshape(-1, 2), (ResourceAxis.DATA, ResourceAxis.MODEL))

Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)
Batch = hax.Axis("batch", 64)
Feature = hax.Axis("feature", 128)

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}):
a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100
Expand Down

0 comments on commit 7837060

Please sign in to comment.