Skip to content

Commit

Permalink
compile two train steps, one with jitcallbacks and one without
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 25, 2024
1 parent 0c26eb3 commit eaa977a
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,17 +394,17 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]:
loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs)
# force the loss so timing numbers are accurate. laziness isn't going to help here (i think?)
else:
loss, new_state, cb_states = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs)
loss = loss.item() # type: ignore

info = StepInfo(new_state, loss, step_time())
info = StepInfo(new_state, loss, step_time())

with capture_time() as hook_time:
self.run_hooks(info)
if hooks_this_time:
self.hooks.run_jit_hooks_outside_step(info, cb_states)
with capture_time() as hook_time:
self.run_hooks(info)
if hooks_this_time:
self.hooks.run_jit_hooks_outside_step(info, cb_states)

levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step)
levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step)

return info

Expand Down

0 comments on commit eaa977a

Please sign in to comment.