Skip to content

Commit

Permalink
update trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Sep 11, 2024
1 parent 43cb92d commit ca31208
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
import warnings
from dataclasses import dataclass
import dataclasses
from functools import cached_property
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -336,7 +337,10 @@ def init_state_and_model(model_init, training_key):
return state

trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, model_init, training_key)

saveable_train_state = saveable_training_mask(trainer_state_shape, is_trainable)
if self.config.reset_optimizer_state:
saveable_train_state = dataclasses.replace(saveable_train_state, optimizer=False)

state = load_checkpoint_or_initialize(
init_state_and_model,
Expand Down Expand Up @@ -481,6 +485,7 @@ def _jit_train_step_fn(self):
def _train_step(self, state: S, *batch, **batch_kwargs) -> tuple[Scalar, S]:
key, new_key = jax.random.split(state.training_key)
model = inference_mode(state.model, False)

loss, grads = self._compute_gradients_microbatched(self.loss_fn, model, *batch, **batch_kwargs, key=key)

# Sophia needs to be able to access the loss function in the optimizer
Expand Down Expand Up @@ -583,6 +588,8 @@ class TrainerConfig:
# whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes
shutdown_at_exit: Union[bool, float] = False

reset_optimizer_state: bool = False

@property
def TrainBatch(self):
return Axis("batch", self.train_batch_size)
Expand Down

0 comments on commit ca31208

Please sign in to comment.