diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index edc39b886..3f60aa9a9 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -211,28 +211,30 @@ def forward(self, X): # sample waveforms and use them to compute # interferometer responses N = mask.sum().item() - responses = self.sample_responses(N, X.shape[-1], psds[mask]) - responses.to(X.device) - # perform swapping and muting augmentations - # on those responses, and then inject them - responses, swap_indices = self.swapper(responses) - responses, mute_indices = self.muter(responses) - X[mask] += responses + if N > 0: + responses = self.sample_responses(N, X.shape[-1], psds[mask]) + responses.to(X.device) + + # perform swapping and muting augmentations + # on those responses, and then inject them + responses, swap_indices = self.swapper(responses) + responses, mute_indices = self.muter(responses) + X[mask] += responses + + # set response augmentation labels to noise + idx = torch.where(mask)[0] + mask[idx[mute_indices]] = 0 + mask[idx[swap_indices]] = 0 + + # set labels to positive for injected signals + y[mask] = -y[mask] + 1 # now that injections have been made, # whiten _all_ the strain using the # background psds computed up top X = self.whitener(X, psds) - # set response augmentation labels to noise - idx = torch.where(mask)[0] - mask[idx[mute_indices]] = 0 - mask[idx[swap_indices]] = 0 - - # set labels to positive for injected signals - y[mask] = -y[mask] + 1 - # curriculum learning step if self.snr is not None: self.snr.step() diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index d43cdf0d8..80f4d1dfa 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -228,6 +228,12 @@ def main( sample_length = window_length + psd_length fftlength = fftlength or window_length + if psd_length < window_length: + raise ValueError( + "Can't have psd length shorter than " + "window length, {} < {}".format(psd_length, window_length) + ) + # create objects that we'll use for whitening the data fast = highpass is not None psd_estimator = PsdEstimator(