Skip to content

Commit

Permalink
Fix issue with EMA and normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
dedeswim committed Apr 26, 2022
1 parent d812be3 commit b30fe8f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/setup_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SoftTargetCrossEntropy)
from timm.models import convert_splitbn_model, create_model, safe_model_name
from timm.optim import optimizer_kwargs
from timm.utils.model_ema import ModelEmaV2
from timm.scheduler import create_scheduler
from torchvision import transforms

Expand Down Expand Up @@ -440,12 +441,12 @@ def update_state_with_norm_model(dev_env: DeviceEnv, train_state: TrainState,
mean=data_config["mean"],
std=data_config["std"]))
train_state = replace(train_state, model=dev_env.to_device(train_state.model))

if train_state.model_ema is not None:
train_state = replace(train_state,
model_ema=utils.normalize_model(train_state.model_ema,
mean=data_config["mean"],
std=data_config["std"]))
train_state = replace(train_state, model_ema=dev_env.to_device(train_state.model_ema))
assert isinstance(train_state.model_ema, ModelEmaV2)
new_model_ema = ModelEmaV2(train_state.model, decay=train_state.model_ema.decay)
train_state = replace(train_state, model_ema=dev_env.to_device(new_model_ema))

return train_state


Expand Down

0 comments on commit b30fe8f

Please sign in to comment.