From 8eb5b608ea6be0eb500b2846ef1e7a6235ade898 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Tue, 3 Oct 2023 12:08:25 -0500 Subject: [PATCH 01/10] added check for N=0 --- projects/sandbox/train/train/augmentor.py | 32 ++++++++++++----------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index edc39b886..e5d18bb17 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -211,28 +211,30 @@ def forward(self, X): # sample waveforms and use them to compute # interferometer responses N = mask.sum().item() - responses = self.sample_responses(N, X.shape[-1], psds[mask]) - responses.to(X.device) - # perform swapping and muting augmentations - # on those responses, and then inject them - responses, swap_indices = self.swapper(responses) - responses, mute_indices = self.muter(responses) - X[mask] += responses + if(N != 0): + responses = self.sample_responses(N, X.shape[-1], psds[mask]) + responses.to(X.device) + + # perform swapping and muting augmentations + # on those responses, and then inject them + responses, swap_indices = self.swapper(responses) + responses, mute_indices = self.muter(responses) + X[mask] += responses + + # set response augmentation labels to noise + idx = torch.where(mask)[0] + mask[idx[mute_indices]] = 0 + mask[idx[swap_indices]] = 0 + + # set labels to positive for injected signals + y[mask] = -y[mask] + 1 # now that injections have been made, # whiten _all_ the strain using the # background psds computed up top X = self.whitener(X, psds) - # set response augmentation labels to noise - idx = torch.where(mask)[0] - mask[idx[mute_indices]] = 0 - mask[idx[swap_indices]] = 0 - - # set labels to positive for injected signals - y[mask] = -y[mask] + 1 - # curriculum learning step if self.snr is not None: self.snr.step() From a70b4268e47420755f9a9ee66e906b0f5ca4a5b5 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Tue, 3 Oct 2023 12:41:17 -0500 Subject: [PATCH 02/10] added N>0 --- projects/sandbox/train/train/augmentor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index e5d18bb17..4aa31f739 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -212,7 +212,7 @@ def forward(self, X): # interferometer responses N = mask.sum().item() - if(N != 0): + if(N > 0): responses = self.sample_responses(N, X.shape[-1], psds[mask]) responses.to(X.device) From 068af79ed570006afed7363a7204dbf45d359676 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 13:13:16 -0500 Subject: [PATCH 03/10] add check for psd length >= window length --- projects/sandbox/train/train/train.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index d43cdf0d8..fb0496e8a 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -228,6 +228,14 @@ def main( sample_length = window_length + psd_length fftlength = fftlength or window_length + + if(psd_length < window_length): + raise ValueError( + "Can't have psd length {} longer than window length {}".format( + psd_length, window_length + ) + ) + # create objects that we'll use for whitening the data fast = highpass is not None psd_estimator = PsdEstimator( From 708f137a7fd0fe75560789f7625e22619fa4f45d Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 13:27:21 -0500 Subject: [PATCH 04/10] pre-commit checks were failing --- projects/sandbox/train/train/augmentor.py | 2 +- projects/sandbox/train/train/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index 4aa31f739..5d8a4b09e 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -212,7 +212,7 @@ def forward(self, X): # interferometer responses N = mask.sum().item() - if(N > 0): + if N > 0 : responses = self.sample_responses(N, X.shape[-1], psds[mask]) responses.to(X.device) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index fb0496e8a..590fc5deb 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -229,7 +229,7 @@ def main( fftlength = fftlength or window_length - if(psd_length < window_length): + if psd_length < window_length : raise ValueError( "Can't have psd length {} longer than window length {}".format( psd_length, window_length From 433a6acafa9281b1a9aec836f1c173662ce099f1 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 13:46:18 -0500 Subject: [PATCH 05/10] pre-commit checks were failing --- projects/sandbox/train/train/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index 590fc5deb..86c7200e2 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -229,12 +229,12 @@ def main( fftlength = fftlength or window_length - if psd_length < window_length : + if psd_length < window_length: raise ValueError( - "Can't have psd length {} longer than window length {}".format( + "Can't have psd length {} longer than window length {}".format( psd_length, window_length - ) ) + ) # create objects that we'll use for whitening the data fast = highpass is not None From a92ae7c71e6d768ece686f9aa2f85fa88af17d15 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 13:53:55 -0500 Subject: [PATCH 06/10] pre-commit checks were failing --- projects/sandbox/train/train/augmentor.py | 2 +- projects/sandbox/train/train/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index 5d8a4b09e..3f60aa9a9 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -212,7 +212,7 @@ def forward(self, X): # interferometer responses N = mask.sum().item() - if N > 0 : + if N > 0: responses = self.sample_responses(N, X.shape[-1], psds[mask]) responses.to(X.device) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index 86c7200e2..a54cd58a6 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -232,7 +232,7 @@ def main( if psd_length < window_length: raise ValueError( "Can't have psd length {} longer than window length {}".format( - psd_length, window_length + psd_length, window_length ) ) From da265662758828bf4cafafa1291eeceb019a5666 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 14:03:06 -0500 Subject: [PATCH 07/10] pre-commit checks were failing --- projects/sandbox/train/train/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index a54cd58a6..444f315fa 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -231,9 +231,10 @@ def main( if psd_length < window_length: raise ValueError( - "Can't have psd length {} longer than window length {}".format( - psd_length, window_length - ) + "Can't have psd length {} longer than " + "window length {}".format( + psd_length, window_length + ) ) # create objects that we'll use for whitening the data From f18371f748a357662dcdf130a3bd0fa8b792da5a Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 14:06:00 -0500 Subject: [PATCH 08/10] pre-commit checks were failing --- projects/sandbox/train/train/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index 444f315fa..4a54d3083 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -232,9 +232,7 @@ def main( if psd_length < window_length: raise ValueError( "Can't have psd length {} longer than " - "window length {}".format( - psd_length, window_length - ) + "window length {}".format(psd_length, window_length) ) # create objects that we'll use for whitening the data From f577295c3d7c0483ed8a19789ba19c3adc0d043d Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Thu, 5 Oct 2023 14:08:21 -0500 Subject: [PATCH 09/10] pre-commit checks were failing --- projects/sandbox/train/train/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index 4a54d3083..1b70542b6 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -231,8 +231,8 @@ def main( if psd_length < window_length: raise ValueError( - "Can't have psd length {} longer than " - "window length {}".format(psd_length, window_length) + "Can't have psd length {} longer than " + "window length {}".format(psd_length, window_length) ) # create objects that we'll use for whitening the data From faaaf512adb95a667c9706f1f71aa439bb6d5278 Mon Sep 17 00:00:00 2001 From: Rafia Omer Date: Mon, 23 Oct 2023 13:46:51 -0500 Subject: [PATCH 10/10] fix for N=0 and proper psd length --- projects/sandbox/train/train/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index 1b70542b6..80f4d1dfa 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -228,11 +228,10 @@ def main( sample_length = window_length + psd_length fftlength = fftlength or window_length - if psd_length < window_length: raise ValueError( - "Can't have psd length {} longer than " - "window length {}".format(psd_length, window_length) + "Can't have psd length shorter than " + "window length, {} < {}".format(psd_length, window_length) ) # create objects that we'll use for whitening the data