diff --git a/brainpy/_src/optimizers/optimizer.py b/brainpy/_src/optimizers/optimizer.py index c2aec25a..75dfef12 100644 --- a/brainpy/_src/optimizers/optimizer.py +++ b/brainpy/_src/optimizers/optimizer.py @@ -901,6 +901,7 @@ def __init__( amsgrad: bool = False, name: Optional[str] = None, ): + self.amsgrad = amsgrad super(AdamW, self).__init__(lr=lr, train_vars=train_vars, weight_decay=weight_decay, @@ -919,7 +920,6 @@ def __init__( self.beta2 = beta2 self.eps = eps self.weight_decay = weight_decay - self.amsgrad = amsgrad def __repr__(self): return (f"{self.__class__.__name__}(lr={self.lr}, "