From 57abb7fb8f6c273d869a05a0992807de728e21ea Mon Sep 17 00:00:00 2001 From: asistradition Date: Wed, 10 Jul 2024 15:09:03 -0400 Subject: [PATCH] Clean up optimizer code --- .../models/biophysical_model.py | 31 +++++-------------- .../tests/test_full_model.py | 2 +- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/supirfactor_dynamical/models/biophysical_model.py b/supirfactor_dynamical/models/biophysical_model.py index 88506cd..d78e343 100644 --- a/supirfactor_dynamical/models/biophysical_model.py +++ b/supirfactor_dynamical/models/biophysical_model.py @@ -372,39 +372,24 @@ def forward_decay_model( return_decay_constants=True ) - def train_model( - self, - training_dataloader, - epochs, - validation_dataloader=None, - loss_function=torch.nn.MSELoss(), - optimizer=None, - **kwargs - ): + def process_optimizer(self, optimizer, params=None): - # Create separate optimizers for the decay and transcription - # models and pass them as tuple - optimizer = ( + # Create separate optimizers for the decay, for the + # transcription, and for the combined models + + return ( self._transcription_model.process_optimizer( optimizer ), self._decay_model.process_optimizer( optimizer ) if self.has_decay else False, - self.process_optimizer( - optimizer + super().process_optimizer( + optimizer, + params=params ) ) - return super().train_model( - training_dataloader, - epochs, - validation_dataloader, - loss_function, - optimizer, - **kwargs - ) - def _training_step( self, epoch_num, diff --git a/supirfactor_dynamical/tests/test_full_model.py b/supirfactor_dynamical/tests/test_full_model.py index 072ec1e..efbcae7 100644 --- a/supirfactor_dynamical/tests/test_full_model.py +++ b/supirfactor_dynamical/tests/test_full_model.py @@ -102,7 +102,7 @@ def setUp(self) -> None: _opt = None self.opt = ( - self.dynamical_model.process_optimizer(None), + self.dynamical_model.process_optimizer(None)[2], self.dynamical_model._transcription_model.process_optimizer(None), _opt )