diff --git a/supirfactor_dynamical/_io/_args.py b/supirfactor_dynamical/_io/_args.py index 8548b44..6c9795a 100644 --- a/supirfactor_dynamical/_io/_args.py +++ b/supirfactor_dynamical/_io/_args.py @@ -5,6 +5,7 @@ 'time_dependent_decay', 'decay_k', 'decay_epoch_delay', + 'decay_loss_weight', 'n_genes', 'hidden_layer_width', 'n_peaks', diff --git a/supirfactor_dynamical/models/biophysical_model.py b/supirfactor_dynamical/models/biophysical_model.py index d78e343..88b473e 100644 --- a/supirfactor_dynamical/models/biophysical_model.py +++ b/supirfactor_dynamical/models/biophysical_model.py @@ -41,6 +41,7 @@ class SupirFactorBiophysical( separately_optimize_decay_model = False decay_epoch_delay = 0 decay_k = 20 + decay_loss_weight = None @property def has_decay(self): @@ -52,6 +53,7 @@ def __init__( decay_model=None, decay_epoch_delay=None, decay_k=20, + decay_loss_weight=None, separately_optimize_decay_model=False, use_prior_weights=False, input_dropout_rate=0.5, @@ -104,6 +106,7 @@ def __init__( self.decay_epoch_delay = decay_epoch_delay self.separately_optimize_decay_model = separately_optimize_decay_model + self.decay_loss_weight = decay_loss_weight if decay_model is False: output_activation = None @@ -433,32 +436,17 @@ def _training_step( decay_loss = 0. - loss = self._training_step_joint( - epoch_num, - train_x, - optimizer[2] if self._decay_optimize(epoch_num) else - optimizer[0], - loss_function - ) - - return loss, decay_loss - - def _training_step_joint( - self, - epoch_num, - train_x, - optimizer, - loss_function - ): - - return super()._training_step( + loss = super()._training_step( epoch_num, train_x, - optimizer, + optimizer[2] if self._decay_optimize(epoch_num) else + optimizer[0], loss_function, input_x=self._slice_data_and_forward(train_x) ) + return loss, decay_loss + def _training_step_decay( self, epoch_num, @@ -494,7 +482,8 @@ def _training_step_decay( optimizer, loss_function, input_x=neg, - target_x=_compare_x + target_x=_compare_x, + loss_weight=self.decay_loss_weight ) def _calculate_all_losses( diff --git a/supirfactor_dynamical/training/train_decoders.py b/supirfactor_dynamical/training/train_decoders.py index 73cc8ed..a32d6dc 100644 --- a/supirfactor_dynamical/training/train_decoders.py +++ b/supirfactor_dynamical/training/train_decoders.py @@ -191,6 +191,10 @@ def train_decoder_submodels( pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]") if post_epoch_hook is not None: - post_epoch_hook(model_ref) + + # If the hook returns True + # it's early stopping time + if post_epoch_hook(model_ref) is True: + break return model diff --git a/supirfactor_dynamical/training/train_simple_decoders.py b/supirfactor_dynamical/training/train_simple_decoders.py index 5780528..8c56ebc 100644 --- a/supirfactor_dynamical/training/train_simple_decoders.py +++ b/supirfactor_dynamical/training/train_simple_decoders.py @@ -205,6 +205,10 @@ def train_simple_multidecoder( pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]") if post_epoch_hook is not None: - post_epoch_hook(model_ref) + + # If the hook returns True + # it's early stopping time + if post_epoch_hook(model_ref) is True: + break return model diff --git a/supirfactor_dynamical/training/train_simple_models.py b/supirfactor_dynamical/training/train_simple_models.py index 8a349fb..113cef5 100644 --- a/supirfactor_dynamical/training/train_simple_models.py +++ b/supirfactor_dynamical/training/train_simple_models.py @@ -139,6 +139,10 @@ def train_simple_model( pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]") if post_epoch_hook is not None: - post_epoch_hook(model_ref) + + # If the hook returns True + # it's early stopping time + if post_epoch_hook(model_ref) is True: + break return model diff --git a/supirfactor_dynamical/training/train_standard_loop.py b/supirfactor_dynamical/training/train_standard_loop.py index 0ee3052..c39814a 100644 --- a/supirfactor_dynamical/training/train_standard_loop.py +++ b/supirfactor_dynamical/training/train_standard_loop.py @@ -132,7 +132,11 @@ def train_model( pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]") if post_epoch_hook is not None: - post_epoch_hook(model_ref) + + # If the hook returns True + # it's early stopping time + if post_epoch_hook(model_ref) is True: + break model_ref.eval()