diff --git a/supirfactor_dynamical/models/biophysical_model.py b/supirfactor_dynamical/models/biophysical_model.py index f246d28..a92435e 100644 --- a/supirfactor_dynamical/models/biophysical_model.py +++ b/supirfactor_dynamical/models/biophysical_model.py @@ -422,10 +422,7 @@ def _training_step( :rtype: float, float, float """ - if ( - self.separately_optimize_decay_model and - self._decay_optimize(epoch_num) - ): + if self.separately_optimize_decay_model: # Call the training step and compare the negative velocity # to the decay output data @@ -450,14 +447,24 @@ def _training_step( decay_loss = 0. + # Call the main training step and compare velocity to the + # expected output velocity loss = super()._training_step( epoch_num, train_x, - optimizer[2] if self._decay_optimize(epoch_num) else - optimizer[0], - loss_function + None, + loss_function, + optimizer_step=False ) + if self._decay_optimize(epoch_num): + _optimizer = optimizer[2] + else: + _optimizer = optimizer[0] + + _optimizer.step() + _optimizer.zero_grad() + return loss, decay_loss def _calculate_all_losses(