Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifgcrl committed Nov 13, 2024
1 parent 0503001 commit fe7a33d
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
X = TypeVar("X") # Input
S = TypeVar("S", bound=TrainerState)

DEFAULT_JAX_CONFIG = {
DEFAULT_JAX_CONFIG: Dict[str, JsonAtom] = {
"jax_threefry_partitionable": True,
"jax_softmax_custom_jvp": True,
}
Expand Down Expand Up @@ -331,7 +331,12 @@ def init_state_and_model(model_init, training_key):
model = model_init()
# only force trainable params to param precision. Other params are cast to compute precision
state = TrainerState.init(
self.optimizer, model, key=training_key, is_trainable=is_trainable, mp=self.mp, fp8=self.fp8
self.optimizer,
model,
key=training_key,
is_trainable=is_trainable,
mp=self.mp,
fp8=self.fp8,
)
return state

Expand Down Expand Up @@ -444,7 +449,10 @@ def eval_loss(model, *batch, **batch_kwargs):

self.add_hook(
callbacks.compute_validation_loss(
eval_loss, eval_loader, max_batches=self.config.max_eval_batches, name=name
eval_loss,
eval_loader,
max_batches=self.config.max_eval_batches,
name=name,
),
every=self.config.steps_per_eval,
)
Expand Down Expand Up @@ -497,7 +505,13 @@ def obj_fun(trainable_model):
def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]:
grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False)
mbs = self.config.microbatch_size
grad_fn = microbatched(grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping)
grad_fn = microbatched(
grad_fn,
self.TrainBatch,
mbs,
self.parameter_axis_mapping,
self.compute_axis_mapping,
)
with hax.axis_mapping(self.compute_axis_mapping):
return grad_fn(model, *batch, **batch_kwargs)

Expand Down Expand Up @@ -569,7 +583,7 @@ class TrainerConfig:
"""can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path."""
initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from

jax_config: Dict[str, JsonAtom] = field(
jax_config: Mapping[str, JsonAtom] = field(
default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG)
) # config to pass to jax.config.update

Expand Down Expand Up @@ -597,7 +611,10 @@ def microbatch_size(self):

def __post_init__(self):
if self.wandb is not None:
warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning)
warnings.warn(
"wandb is deprecated. use tracker with type wandb instead",
DeprecationWarning,
)
self.tracker = self.wandb

def initialize(self):
Expand Down

0 comments on commit fe7a33d

Please sign in to comment.