Skip to content

Commit

Permalink
Clean up optimizer code
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 10, 2024
1 parent ade582b commit 57abb7f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 24 deletions.
31 changes: 8 additions & 23 deletions supirfactor_dynamical/models/biophysical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,39 +372,24 @@ def forward_decay_model(
return_decay_constants=True
)

def train_model(
self,
training_dataloader,
epochs,
validation_dataloader=None,
loss_function=torch.nn.MSELoss(),
optimizer=None,
**kwargs
):
def process_optimizer(self, optimizer, params=None):

# Create separate optimizers for the decay and transcription
# models and pass them as tuple
optimizer = (
# Create separate optimizers for the decay, for the
# transcription, and for the combined models

return (
self._transcription_model.process_optimizer(
optimizer
),
self._decay_model.process_optimizer(
optimizer
) if self.has_decay else False,
self.process_optimizer(
optimizer
super().process_optimizer(
optimizer,
params=params
)
)

return super().train_model(
training_dataloader,
epochs,
validation_dataloader,
loss_function,
optimizer,
**kwargs
)

def _training_step(
self,
epoch_num,
Expand Down
2 changes: 1 addition & 1 deletion supirfactor_dynamical/tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def setUp(self) -> None:
_opt = None

self.opt = (
self.dynamical_model.process_optimizer(None),
self.dynamical_model.process_optimizer(None)[2],
self.dynamical_model._transcription_model.process_optimizer(None),
_opt
)
Expand Down

0 comments on commit 57abb7f

Please sign in to comment.