Skip to content

Commit

Permalink
Change training to collect gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 16, 2024
1 parent 3cfbdc7 commit b6e3e2c
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions supirfactor_dynamical/models/biophysical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,7 @@ def _training_step(
:rtype: float, float, float
"""

if (
self.separately_optimize_decay_model and
self._decay_optimize(epoch_num)
):
if self.separately_optimize_decay_model:

# Call the training step and compare the negative velocity
# to the decay output data
Expand All @@ -450,14 +447,24 @@ def _training_step(

decay_loss = 0.

# Call the main training step and compare velocity to the
# expected output velocity
loss = super()._training_step(
epoch_num,
train_x,
optimizer[2] if self._decay_optimize(epoch_num) else
optimizer[0],
loss_function
None,
loss_function,
optimizer_step=False
)

if self._decay_optimize(epoch_num):
_optimizer = optimizer[2]
else:
_optimizer = optimizer[0]

_optimizer.step()
_optimizer.zero_grad()

return loss, decay_loss

def _calculate_all_losses(
Expand Down

0 comments on commit b6e3e2c

Please sign in to comment.