Skip to content

Commit

Permalink
just about there!
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 25, 2024
1 parent 7837060 commit 0c26eb3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
8 changes: 2 additions & 6 deletions src/levanter/tracker/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def sharded_histogram(a: NamedArray, bins: int | ArrayLike = 10) -> tuple[jnp.nd
"""
edges = jnp.histogram_bin_edges(a.array, bins=bins)
return _shardmap_histogram(a, edges), edges
# return jnp.histogram(a.array, bins=edges)


def _single_shard_histogram(a, bins, reduce_mesh):
Expand All @@ -78,10 +77,8 @@ def _single_shard_histogram(a, bins, reduce_mesh):
"""
a = a.flatten()

bin_idx = jnp.searchsorted(bins, a, side="right", method="compare_all")
bin_idx = jnp.where(a == bins[-1], len(bins) - 1, bin_idx)
weights = jnp.ones_like(a)
counts = jnp.zeros(len(bins), weights.dtype).at[bin_idx].add(weights)[1:]
bin_idx = (a[..., None] >= bins[:-1]).astype(jnp.int32) & (a[..., None] < bins[1:]).astype(jnp.int32)
counts = bin_idx.sum(axis=0, dtype=jnp.int32)

if len(reduce_mesh):
counts = jax.lax.psum(counts, axis_name=reduce_mesh)
Expand All @@ -102,7 +99,6 @@ def _shardmap_histogram(a: NamedArray, bins):
)
res = shard_h(a.array, bins)
return res
# return res


def _flattened_spec(spec):
Expand Down
40 changes: 32 additions & 8 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,25 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
"""
Performs a single training step.
"""
# jit hooks impose a nontrivial cost even when they're not run (since they defeat some compiler optimizations)
# so we avoid running them when they're not needed
# this results in two compiles, but the cost of the second compile is worth it
hooks_this_time = any(state.step % h.every == 0 for h in self.hooks.stateful_hooks)

with capture_time() as step_time:
loss, new_state, cb_states = self._jit_train_step_fn(state, *batch, **batch_kwargs)
# force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
if hooks_this_time:
loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
# force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
else:
loss, new_state, cb_states = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
loss = loss.item() # type: ignore

info = StepInfo(new_state, loss, step_time())

with capture_time() as hook_time:
self.run_hooks(info)
self.hooks.run_jit_hooks_outside_step(info, cb_states)
if hooks_this_time:
self.hooks.run_jit_hooks_outside_step(info, cb_states)

levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step)

Expand Down Expand Up @@ -505,25 +514,40 @@ def _jit_train_step_fn(self):
donate_args=(True,),
)

def _train_step(self, state: S, *batch, **batch_kwargs) -> tuple[Scalar, S, Sequence[CBInfo]]:
@cached_property
def _jit_train_step_fn_no_hook(self):
return named_jit(
functools.partial(self._train_step, _no_hooks=True),
axis_resources=self.parameter_axis_mapping,
out_axis_resources=self.parameter_axis_mapping,
donate_args=(True,),
)

def _train_step(
self, state: S, batch, batch_kwargs, _no_hooks=False
) -> tuple[Scalar, S, Sequence[CBInfo]] | tuple[Scalar, S]:
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)

with hax.axis_mapping(self.parameter_axis_mapping):
if not _no_hooks:
hook_infos = self.hooks.run_jit_hooks(state, grads, force=False)

# Sophia needs to be able to access the loss function in the optimizer
def obj_fun(trainable_model):
model = eqx.combine(trainable_model, state.model)
with hax.axis_mapping(self.compute_axis_mapping):
model = self.mp.cast_to_compute(model)
return self._raw_loss_function(model, *batch, **batch_kwargs, key=key).scalar()

with hax.axis_mapping(self.parameter_axis_mapping):
hook_infos = self.hooks.run_jit_hooks(state, grads, force=False)

new_state = state.take_step(grads, obj_fun=obj_fun)
new_state = hax.shard(new_state, self.parameter_axis_mapping)
return loss, new_state, hook_infos
if _no_hooks:
return loss, new_state
else:
return loss, new_state, hook_infos

def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
Expand Down

0 comments on commit 0c26eb3

Please sign in to comment.