Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

added check for N=0 #429

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions projects/sandbox/train/train/augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,30 @@ def forward(self, X):
# sample waveforms and use them to compute
# interferometer responses
N = mask.sum().item()
responses = self.sample_responses(N, X.shape[-1], psds[mask])
responses.to(X.device)

# perform swapping and muting augmentations
# on those responses, and then inject them
responses, swap_indices = self.swapper(responses)
responses, mute_indices = self.muter(responses)
X[mask] += responses
if N > 0:
responses = self.sample_responses(N, X.shape[-1], psds[mask])
responses.to(X.device)

# perform swapping and muting augmentations
# on those responses, and then inject them
responses, swap_indices = self.swapper(responses)
responses, mute_indices = self.muter(responses)
X[mask] += responses

# set response augmentation labels to noise
idx = torch.where(mask)[0]
mask[idx[mute_indices]] = 0
mask[idx[swap_indices]] = 0

# set labels to positive for injected signals
y[mask] = -y[mask] + 1

# now that injections have been made,
# whiten _all_ the strain using the
# background psds computed up top
X = self.whitener(X, psds)

# set response augmentation labels to noise
idx = torch.where(mask)[0]
mask[idx[mute_indices]] = 0
mask[idx[swap_indices]] = 0

# set labels to positive for injected signals
y[mask] = -y[mask] + 1

# curriculum learning step
if self.snr is not None:
self.snr.step()
Expand Down
6 changes: 6 additions & 0 deletions projects/sandbox/train/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def main(
sample_length = window_length + psd_length
fftlength = fftlength or window_length

if psd_length < window_length:
raise ValueError(
"Can't have psd length shorter than "
"window length, {} < {}".format(psd_length, window_length)
)

# create objects that we'll use for whitening the data
fast = highpass is not None
psd_estimator = PsdEstimator(
Expand Down