Skip to content

Commit

Permalink
new config
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Sep 11, 2024
1 parent 2f02a64 commit 4b97bba
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
import warnings
from dataclasses import dataclass
import dataclasses
from functools import cached_property
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -337,7 +338,7 @@ def init_state_and_model(model_init, training_key):

trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, model_init, training_key)
if self.config.reset_optimizer_state:
saveable_train_state = dataclass.replace(saveable_train_state, optimizer=False)
saveable_train_state = dataclasses.replace(saveable_train_state, optimizer=False)
saveable_train_state = saveable_training_mask(trainer_state_shape, is_trainable)

state = load_checkpoint_or_initialize(
Expand Down

0 comments on commit 4b97bba

Please sign in to comment.