From eaa977a2ea18ed0b940606ddc714b698fd3c985f Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 25 Nov 2024 10:07:15 -0800 Subject: [PATCH] compile two train steps, one with jitcallbacks and one without --- src/levanter/trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 72fe67747..82f32422a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -394,17 +394,17 @@ def train_step(self, state: S, *batch: X, **batch_kwargs) -> StepInfo[S]: loss, new_state, cb_states = self._jit_train_step_fn(state, batch, batch_kwargs) # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) else: - loss, new_state, cb_states = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) + loss, new_state = self._jit_train_step_fn_no_hook(state, batch, batch_kwargs) loss = loss.item() # type: ignore - info = StepInfo(new_state, loss, step_time()) + info = StepInfo(new_state, loss, step_time()) - with capture_time() as hook_time: - self.run_hooks(info) - if hooks_this_time: - self.hooks.run_jit_hooks_outside_step(info, cb_states) + with capture_time() as hook_time: + self.run_hooks(info) + if hooks_this_time: + self.hooks.run_jit_hooks_outside_step(info, cb_states) - levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step) + levanter.tracker.log({"throughput/hook_time": hook_time()}, step=info.step) return info