diff --git a/debug_jit_log.py b/debug_jit_log.py deleted file mode 100644 index 4e2a11633..000000000 --- a/debug_jit_log.py +++ /dev/null @@ -1,26 +0,0 @@ -import jax - - -def log(metrics, *, step): - """ - Log metrics to the global tracker. - - Args: - metrics: Metrics to log. We use LoggableValues just to give you a sense of what you can log. Backends may - support additional types. - step: Step to log at - commit: Whether to commit the metrics. If None, uses the default for the tracker. - """ - print(metrics, step) - - -def _do_jit_log(metrics, *, step=None): - try: - log(metrics, step=step) - except Exception as e: - raise e - - -def jit_log(metrics, *, step=None): - """uses jax effect callback to log to wandb from the host""" - jax.debug.callback(_do_jit_log, metrics, step=step) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 83e8b8779..fec2eb02e 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -169,9 +169,6 @@ def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, n if n > 0: total_loss /= n - # logger.info(f"eval loading time: {total_load_time / n:.3f} s/ba") - # logger.info(f"eval loss time: {total_loss_time / n:.3f} s/ba") - return total_loss @@ -523,19 +520,14 @@ def _rec_log_magnitudes(norms, hists, path_prefix, tree): strict=True, ): if self.split_scan_layers and isinstance(g, haliax.nn.Stacked): - # unstacked = g.unstacked() - # for i, layer in enumerate(unstacked): - # _rec_log_magnitudes(to_log, join_key(key_path, str(i)), layer) - # vmap over the layers - Block = g.Block - vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, Block)({}, {}, "", g.stacked) - # manual loop + vmapped_norms, vmapped_hists = haliax.vmap(_rec_log_magnitudes, g.Block)({}, {}, "", g.stacked) + for k, v in vmapped_norms.items(): - for i in range(Block.size): + for i in range(g.Block.size): norms[f"{key_path}/{i}/{k}"] = v[i] for k, v in vmapped_hists.items(): - for i in range(Block.size): + for i in range(g.Block.size): hists[f"{key_path}/{i}/{k}"] = jax.tree.map( lambda x: x[i] if is_jax_array_like(x) else x, v ) diff --git a/tests/test_histogram.py b/tests/test_histogram.py index a219427d0..b2e008724 100644 --- a/tests/test_histogram.py +++ b/tests/test_histogram.py @@ -16,9 +16,9 @@ def test_sharded_histogram_simple(): Batch = hax.Axis("Batch", 64) Feature = hax.Axis("Feature", 128) - a = hax.random.normal(PRNGKey(0), (Batch, Feature)) - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA}): + a = hax.random.normal(PRNGKey(0), (Batch, Feature)) + a = hax.shard(a) hist = levanter.tracker.histogram.sharded_histogram(a, bins=10) hist_normal = jax.numpy.histogram(a.array, bins=10)[0] @@ -33,14 +33,12 @@ def test_sharded_histogram_tp(): Batch = hax.Axis("Batch", 64) Feature = hax.Axis("Feature", 128) - a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100 - with mesh, hax.axis_mapping({"batch": ResourceAxis.DATA, "feature": ResourceAxis.MODEL}): + a = hax.random.normal(PRNGKey(0), (Batch, Feature)) * 100 + a = hax.shard(a) hist, bins = levanter.tracker.histogram.sharded_histogram(a, bins=64) jnp_hist, jnp_bins = jax.numpy.histogram(a.array, bins=64) - print(hist, jnp_hist) - assert jax.numpy.allclose(hist, jnp_hist) assert jax.numpy.allclose(bins, jnp_bins)