Skip to content

Commit

Permalink
ok, histogram is much faster. still slower than not using it, but tha…
Browse files Browse the repository at this point in the history
…t's ok
  • Loading branch information
dlwh committed Nov 21, 2024
1 parent a886691 commit bc6897c
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tqdm_loggable.auto import tqdm

import haliax.nn
from haliax import NamedArray, is_named_array
from haliax.jax_utils import is_jax_array_like

import levanter.tracker
Expand Down Expand Up @@ -509,9 +510,9 @@ def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram])

def _generate_statistics_for(self, kind: str, tree: M) -> dict[str, float | Histogram]:
if self.split_scan_layers:
is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) # noqa: E731
is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) or is_named_array(n) # noqa: E731
else:
is_leaf = lambda n: False # noqa: E731
is_leaf = is_named_array

def _rec_log_magnitudes(norms, hists, path_prefix, tree):
leaf_key_paths = jax_utils.leaf_key_paths(tree, prefix=path_prefix, is_leaf=is_leaf)
Expand Down Expand Up @@ -539,7 +540,13 @@ def _rec_log_magnitudes(norms, hists, path_prefix, tree):
lambda x: x[i] if is_jax_array_like(x) else x, v
)

else:
elif isinstance(g, NamedArray):
# TODO: add linalg.norm to Haliax
norms[key_path] = jnp.linalg.norm(g.array)
if self.include_histogram:
hist = Histogram.from_named_array(g)
hists[key_path] = hist
elif is_jax_array_like(g):
norms[key_path] = jnp.linalg.norm(g)

if self.include_histogram:
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def main(config: TrainLmConfig):
every=config.hf_save_steps,
)

trainer.add_hook(callbacks.GradWatchCallback(include_histogram=False), every=5)
trainer.add_hook(callbacks.GradWatchCallback(include_histogram=True), every=5)

state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

Expand Down
89 changes: 88 additions & 1 deletion src/levanter/tracker/histogram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import equinox
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Scalar
from jax._src.partition_spec import PartitionSpec
from jax.experimental.shard_map import shard_map
from jaxtyping import ArrayLike, Scalar

import haliax as hax
from haliax import NamedArray


class Histogram(equinox.Module):
Expand All @@ -28,5 +34,86 @@ def from_array(array: jax.Array, num_bins: int = 64) -> "Histogram":
counts, edges = jax.numpy.histogram(array, bins=num_bins)
return Histogram(min, max, num, sum, sum_squares, edges, counts)

@staticmethod
def from_named_array(array: hax.NamedArray, num_bins: int = 64) -> "Histogram":
raw_array = array.array
min = raw_array.min()
max = raw_array.max()
num = array.size
sum = raw_array.sum()
sum_squares = (raw_array**2).sum()
counts, edges = sharded_histogram(array, bins=num_bins)
return Histogram(min, max, num, sum, sum_squares, edges, counts)

def to_numpy_histogram(self) -> tuple[np.ndarray, np.ndarray]:
return np.array(self.bucket_counts), np.array(self.bucket_limits)


def sharded_histogram(a: NamedArray, bins: int | ArrayLike = 10) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
As [jax.numpy.histogram](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.histogram.html#jax.numpy.histogram),
except:
* It preserves sharding
* It only works with NamedArrays
* It is more performant on TPUs
Credit to @aphoh for the original implementation.
"""
edges = jnp.histogram_bin_edges(a.array, bins=bins)
return _shardmap_histogram(a, edges), edges


def _single_shard_histogram(a, bins):
"""Modified version of jax.numpy.histogram that returns integer counts instead of using the datatype of the input.
Also avoids searchsorted, which is slow on TPUs.
Args:
a (Array): input array
bins (Array): bins to use for histogram
Returns:
Array: counts. has length len(bins) - 1
"""
a = a.flatten()
prefix_sum = jnp.sum((a < bins[:, None]).astype(jnp.int32), axis=1)
last_count = jnp.sum(a <= bins[-1])
prefix_sum = prefix_sum.at[-1].set(last_count)
return jnp.expand_dims(jnp.diff(prefix_sum), 0)


@jax.jit
def _shardmap_histogram(a: NamedArray, bins):
mesh = hax.partitioning._get_mesh()
flattened_spec, spec = _flattened_spec(a)
shard_h = shard_map(
_single_shard_histogram,
mesh=mesh,
in_specs=(
spec,
PartitionSpec(
None,
),
),
out_specs=(flattened_spec),
)
res = shard_h(a.array, bins)
return res.sum(axis=0)


def _flattened_spec(a):
spec = hax.partitioning.pspec_for_axis(a.axes)

def flatten_spec(spec):
# spec is a tuple None|str|tuple[str]
out = []
for s in spec:
if isinstance(s, tuple):
out.extend(s)
elif s is None:
pass
else:
out.append(s)

return tuple(out)

flattened_spec = PartitionSpec(flatten_spec(spec))
return flattened_spec, spec
11 changes: 10 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ def num_train_steps(self) -> int:
def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1):
...

@typing.overload
def add_hook(self, fn: JitCallback, *, every: int = 1):
...

@typing.overload
def add_hook(self, fn: Callback, *, every: int = 1):
...

@typing.overload
def add_hook(self, *, every: int = 1):
...
Expand Down Expand Up @@ -510,7 +518,8 @@ def obj_fun(trainable_model):
model = self.mp.cast_to_compute(model)
return self._raw_loss_function(model, *batch, **batch_kwargs, key=key).scalar()

hook_infos = self.hooks.run_jit_hooks(state, grads, force=False)
with hax.axis_mapping(self.parameter_axis_mapping):
hook_infos = self.hooks.run_jit_hooks(state, grads, force=False)

new_state = state.take_step(grads, obj_fun=obj_fun)
new_state = hax.shard(new_state, self.parameter_axis_mapping)
Expand Down
46 changes: 46 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import jax
import numpy as np
from jax.random import PRNGKey
from jax.sharding import Mesh

import haliax as hax
from haliax.partitioning import ResourceAxis

import levanter.tracker.histogram
from test_utils import skip_if_not_enough_devices


def test_sharded_histogram_simple():
mesh = Mesh((jax.devices()), (ResourceAxis.DATA))

Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)

a = hax.random.normal(PRNGKey(0), (Batch, Feature))

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}):
hist = levanter.tracker.histogram.sharded_histogram(a, bins=10)

hist_normal = jax.numpy.histogram(a.array, bins=10)[0]

assert jax.numpy.allclose(hist, hist_normal)


@skip_if_not_enough_devices(2)
def test_sharded_histogram_tp():
mesh = Mesh(np.array(jax.devices()).reshape(-1, 2), (ResourceAxis.DATA, ResourceAxis.MODEL))

Batch = hax.Axis("Batch", 64)
Feature = hax.Axis("Feature", 128)

a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100

with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}):
hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64)

jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64)

print(hist, jnp_hist)

assert jax.numpy.allclose(hist, jnp_hist)
assert jax.numpy.allclose(bins, jnp_bins)

0 comments on commit bc6897c

Please sign in to comment.