Skip to content

Commit

Permalink
Add some loss weighting and provisioning for early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 11, 2024
1 parent 57abb7f commit df64349
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 25 deletions.
1 change: 1 addition & 0 deletions supirfactor_dynamical/_io/_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
'time_dependent_decay',
'decay_k',
'decay_epoch_delay',
'decay_loss_weight',
'n_genes',
'hidden_layer_width',
'n_peaks',
Expand Down
31 changes: 10 additions & 21 deletions supirfactor_dynamical/models/biophysical_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SupirFactorBiophysical(
separately_optimize_decay_model = False
decay_epoch_delay = 0
decay_k = 20
decay_loss_weight = None

@property
def has_decay(self):
Expand All @@ -52,6 +53,7 @@ def __init__(
decay_model=None,
decay_epoch_delay=None,
decay_k=20,
decay_loss_weight=None,
separately_optimize_decay_model=False,
use_prior_weights=False,
input_dropout_rate=0.5,
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(

self.decay_epoch_delay = decay_epoch_delay
self.separately_optimize_decay_model = separately_optimize_decay_model
self.decay_loss_weight = decay_loss_weight

if decay_model is False:
output_activation = None
Expand Down Expand Up @@ -433,32 +436,17 @@ def _training_step(

decay_loss = 0.

loss = self._training_step_joint(
epoch_num,
train_x,
optimizer[2] if self._decay_optimize(epoch_num) else
optimizer[0],
loss_function
)

return loss, decay_loss

def _training_step_joint(
self,
epoch_num,
train_x,
optimizer,
loss_function
):

return super()._training_step(
loss = super()._training_step(
epoch_num,
train_x,
optimizer,
optimizer[2] if self._decay_optimize(epoch_num) else
optimizer[0],
loss_function,
input_x=self._slice_data_and_forward(train_x)
)

return loss, decay_loss

def _training_step_decay(
self,
epoch_num,
Expand Down Expand Up @@ -494,7 +482,8 @@ def _training_step_decay(
optimizer,
loss_function,
input_x=neg,
target_x=_compare_x
target_x=_compare_x,
loss_weight=self.decay_loss_weight
)

def _calculate_all_losses(
Expand Down
6 changes: 5 additions & 1 deletion supirfactor_dynamical/training/train_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def train_decoder_submodels(
pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]")

if post_epoch_hook is not None:
post_epoch_hook(model_ref)

# If the hook returns True
# it's early stopping time
if post_epoch_hook(model_ref) is True:
break

return model
6 changes: 5 additions & 1 deletion supirfactor_dynamical/training/train_simple_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def train_simple_multidecoder(
pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]")

if post_epoch_hook is not None:
post_epoch_hook(model_ref)

# If the hook returns True
# it's early stopping time
if post_epoch_hook(model_ref) is True:
break

return model
6 changes: 5 additions & 1 deletion supirfactor_dynamical/training/train_simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def train_simple_model(
pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]")

if post_epoch_hook is not None:
post_epoch_hook(model_ref)

# If the hook returns True
# it's early stopping time
if post_epoch_hook(model_ref) is True:
break

return model
6 changes: 5 additions & 1 deletion supirfactor_dynamical/training/train_standard_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def train_model(
pbar.set_description(f"[{epoch_num} n={np.sum(_batch_n)}]")

if post_epoch_hook is not None:
post_epoch_hook(model_ref)

# If the hook returns True
# it's early stopping time
if post_epoch_hook(model_ref) is True:
break

model_ref.eval()

Expand Down

0 comments on commit df64349

Please sign in to comment.