From bc6897c3b9b42104b4cb84a69210666faf87a48a Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 21 Nov 2024 12:18:31 -0800 Subject: [PATCH] ok, histogram is much faster. still slower than not using it, but that's ok --- src/levanter/callbacks.py | 13 +++-- src/levanter/main/train_lm.py | 2 +- src/levanter/tracker/histogram.py | 89 ++++++++++++++++++++++++++++++- src/levanter/trainer.py | 11 +++- tests/test_histogram.py | 46 ++++++++++++++++ 5 files changed, 155 insertions(+), 6 deletions(-) create mode 100644 tests/test_histogram.py diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 26bf880ec..83e8b8779 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -23,6 +23,7 @@ from tqdm_loggable.auto import tqdm import haliax.nn +from haliax import NamedArray, is_named_array from haliax.jax_utils import is_jax_array_like import levanter.tracker @@ -509,9 +510,9 @@ def on_step(self, step_info: StepInfo[S], cb_info: dict[str, float | Histogram]) def _generate_statistics_for(self, kind: str, tree: M) -> dict[str, float | Histogram]: if self.split_scan_layers: - is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) # noqa: E731 + is_leaf = lambda n: isinstance(n, haliax.nn.Stacked) or is_named_array(n) # noqa: E731 else: - is_leaf = lambda n: False # noqa: E731 + is_leaf = is_named_array def _rec_log_magnitudes(norms, hists, path_prefix, tree): leaf_key_paths = jax_utils.leaf_key_paths(tree, prefix=path_prefix, is_leaf=is_leaf) @@ -539,7 +540,13 @@ def _rec_log_magnitudes(norms, hists, path_prefix, tree): lambda x: x[i] if is_jax_array_like(x) else x, v ) - else: + elif isinstance(g, NamedArray): + # TODO: add linalg.norm to Haliax + norms[key_path] = jnp.linalg.norm(g.array) + if self.include_histogram: + hist = Histogram.from_named_array(g) + hists[key_path] = hist + elif is_jax_array_like(g): norms[key_path] = jnp.linalg.norm(g) if self.include_histogram: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 59b493f11..3fb1c6afa 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -170,7 +170,7 @@ def main(config: TrainLmConfig): every=config.hf_save_steps, ) - trainer.add_hook(callbacks.GradWatchCallback(include_histogram=False), every=5) + trainer.add_hook(callbacks.GradWatchCallback(include_histogram=True), every=5) state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) diff --git a/src/levanter/tracker/histogram.py b/src/levanter/tracker/histogram.py index cc331041e..0cada5551 100644 --- a/src/levanter/tracker/histogram.py +++ b/src/levanter/tracker/histogram.py @@ -1,7 +1,13 @@ import equinox import jax +import jax.numpy as jnp import numpy as np -from jaxtyping import Scalar +from jax._src.partition_spec import PartitionSpec +from jax.experimental.shard_map import shard_map +from jaxtyping import ArrayLike, Scalar + +import haliax as hax +from haliax import NamedArray class Histogram(equinox.Module): @@ -28,5 +34,86 @@ def from_array(array: jax.Array, num_bins: int = 64) -> "Histogram": counts, edges = jax.numpy.histogram(array, bins=num_bins) return Histogram(min, max, num, sum, sum_squares, edges, counts) + @staticmethod + def from_named_array(array: hax.NamedArray, num_bins: int = 64) -> "Histogram": + raw_array = array.array + min = raw_array.min() + max = raw_array.max() + num = array.size + sum = raw_array.sum() + sum_squares = (raw_array**2).sum() + counts, edges = sharded_histogram(array, bins=num_bins) + return Histogram(min, max, num, sum, sum_squares, edges, counts) + def to_numpy_histogram(self) -> tuple[np.ndarray, np.ndarray]: return np.array(self.bucket_counts), np.array(self.bucket_limits) + + +def sharded_histogram(a: NamedArray, bins: int | ArrayLike = 10) -> tuple[jnp.ndarray, jnp.ndarray]: + """ + As [jax.numpy.histogram](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.histogram.html#jax.numpy.histogram), + except: + + * It preserves sharding + * It only works with NamedArrays + * It is more performant on TPUs + + Credit to @aphoh for the original implementation. + """ + edges = jnp.histogram_bin_edges(a.array, bins=bins) + return _shardmap_histogram(a, edges), edges + + +def _single_shard_histogram(a, bins): + """Modified version of jax.numpy.histogram that returns integer counts instead of using the datatype of the input. + Also avoids searchsorted, which is slow on TPUs. + Args: + a (Array): input array + bins (Array): bins to use for histogram + Returns: + Array: counts. has length len(bins) - 1 + """ + a = a.flatten() + prefix_sum = jnp.sum((a < bins[:, None]).astype(jnp.int32), axis=1) + last_count = jnp.sum(a <= bins[-1]) + prefix_sum = prefix_sum.at[-1].set(last_count) + return jnp.expand_dims(jnp.diff(prefix_sum), 0) + + +@jax.jit +def _shardmap_histogram(a: NamedArray, bins): + mesh = hax.partitioning._get_mesh() + flattened_spec, spec = _flattened_spec(a) + shard_h = shard_map( + _single_shard_histogram, + mesh=mesh, + in_specs=( + spec, + PartitionSpec( + None, + ), + ), + out_specs=(flattened_spec), + ) + res = shard_h(a.array, bins) + return res.sum(axis=0) + + +def _flattened_spec(a): + spec = hax.partitioning.pspec_for_axis(a.axes) + + def flatten_spec(spec): + # spec is a tuple None|str|tuple[str] + out = [] + for s in spec: + if isinstance(s, tuple): + out.extend(s) + elif s is None: + pass + else: + out.append(s) + + return tuple(out) + + flattened_spec = PartitionSpec(flatten_spec(spec)) + return flattened_spec, spec diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 191cdef4e..83c8d7e45 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -225,6 +225,14 @@ def num_train_steps(self) -> int: def add_hook(self, fn: Callable[[StepInfo], Any], *, every: int = 1): ... + @typing.overload + def add_hook(self, fn: JitCallback, *, every: int = 1): + ... + + @typing.overload + def add_hook(self, fn: Callback, *, every: int = 1): + ... + @typing.overload def add_hook(self, *, every: int = 1): ... @@ -510,7 +518,8 @@ def obj_fun(trainable_model): model = self.mp.cast_to_compute(model) return self._raw_loss_function(model, *batch, **batch_kwargs, key=key).scalar() - hook_infos = self.hooks.run_jit_hooks(state, grads, force=False) + with hax.axis_mapping(self.parameter_axis_mapping): + hook_infos = self.hooks.run_jit_hooks(state, grads, force=False) new_state = state.take_step(grads, obj_fun=obj_fun) new_state = hax.shard(new_state, self.parameter_axis_mapping) diff --git a/tests/test_histogram.py b/tests/test_histogram.py new file mode 100644 index 000000000..a219427d0 --- /dev/null +++ b/tests/test_histogram.py @@ -0,0 +1,46 @@ +import jax +import numpy as np +from jax.random import PRNGKey +from jax.sharding import Mesh + +import haliax as hax +from haliax.partitioning import ResourceAxis + +import levanter.tracker.histogram +from test_utils import skip_if_not_enough_devices + + +def test_sharded_histogram_simple(): + mesh = Mesh((jax.devices()), (ResourceAxis.DATA)) + + Batch = hax.Axis("Batch", 64) + Feature = hax.Axis("Feature", 128) + + a = hax.random.normal(PRNGKey(0), (Batch, Feature)) + + with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + hist = levanter.tracker.histogram.sharded_histogram(a, bins=10) + + hist_normal = jax.numpy.histogram(a.array, bins=10)[0] + + assert jax.numpy.allclose(hist, hist_normal) + + +@skip_if_not_enough_devices(2) +def test_sharded_histogram_tp(): + mesh = Mesh(np.array(jax.devices()).reshape(-1, 2), (ResourceAxis.DATA, ResourceAxis.MODEL)) + + Batch = hax.Axis("Batch", 64) + Feature = hax.Axis("Feature", 128) + + a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100 + + with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}): + hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64) + + jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64) + + print(hist, jnp_hist) + + assert jax.numpy.allclose(hist, jnp_hist) + assert jax.numpy.allclose(bins, jnp_bins)