-
Notifications
You must be signed in to change notification settings - Fork 17
Datagen bns changes #466
base: bns
Are you sure you want to change the base?
Datagen bns changes #466
Changes from 23 commits
8eb5b60
a70b426
068af79
708f137
433a6ac
a92ae7c
da26566
f18371f
f577295
faaaf51
a4caaeb
f9d9720
6955fc3
f532652
93ec5b2
4481b65
cb29e6c
611eb29
0c168d2
71fe8dc
a25cf64
d832644
b95db3d
5499892
474ddaa
669e0db
a1315bf
93717f8
e398992
f048fc4
dedb7a3
ea8086d
1ed207a
6e7cc95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
import logging | ||
import random | ||
from concurrent.futures import ProcessPoolExecutor | ||
from pathlib import Path | ||
from typing import Callable, Optional | ||
|
||
import h5py | ||
import numpy as np | ||
from datagen.utils.injection import generate_gw | ||
from datagen.utils.injection import generate_gw, generate_gw_bns | ||
from typeo import scriptify | ||
|
||
from aframe.logging import configure_logging | ||
|
@@ -22,6 +23,7 @@ def main( | |
sample_rate: float, | ||
waveform_duration: float, | ||
waveform_approximant: str = "IMRPhenomPv2", | ||
signal_type: str = "bbh", | ||
force_generation: bool = False, | ||
verbose: bool = False, | ||
seed: Optional[int] = None, | ||
|
@@ -93,17 +95,32 @@ def main( | |
prior, detector_frame_prior = prior() | ||
params = prior.sample(num_signals) | ||
|
||
signals = generate_gw( | ||
params, | ||
minimum_frequency, | ||
reference_frequency, | ||
sample_rate, | ||
waveform_duration, | ||
waveform_approximant, | ||
detector_frame_prior, | ||
) | ||
if signal_type == "bns": | ||
with ProcessPoolExecutor(140) as exe: | ||
future = exe.submit( | ||
generate_gw_bns, | ||
params, | ||
minimum_frequency, | ||
reference_frequency, | ||
sample_rate, | ||
waveform_duration, | ||
waveform_approximant, | ||
detector_frame_prior, | ||
) | ||
signals = future.result() | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is going into a dedicated There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't want to create a separate repo for BNS, it will be most ideal and efficient, if we take care of this now rather than later. And write code keeping in mind that it works seamlessly with the main pipeline. Otherwise it will be a huge issue later to merge bns branch into main aframe branch. |
||
signals = generate_gw( | ||
params, | ||
minimum_frequency, | ||
reference_frequency, | ||
sample_rate, | ||
waveform_duration, | ||
waveform_approximant, | ||
detector_frame_prior, | ||
) | ||
|
||
# Write params and similar to output file | ||
logging.info("Writing waveforms to file....") | ||
if np.isnan(signals).any(): | ||
raise ValueError("The signals contain NaN values") | ||
|
||
|
@@ -125,7 +142,7 @@ def main( | |
"minimum_frequency": minimum_frequency, | ||
} | ||
) | ||
|
||
logging.info("Writing waveforms to file finished!") | ||
return signal_file | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
import logging | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters | ||
from bilby.gw.source import lal_binary_black_hole | ||
from bilby.gw.conversion import ( | ||
convert_to_lal_binary_black_hole_parameters, | ||
convert_to_lal_binary_neutron_star_parameters, | ||
) | ||
from bilby.gw.source import lal_binary_black_hole, lal_binary_neutron_star | ||
from bilby.gw.waveform_generator import WaveformGenerator | ||
|
||
|
||
|
@@ -91,6 +95,67 @@ def generate_gw( | |
return signals | ||
|
||
|
||
def generate_gw_bns( | ||
sample_params: Dict[List, str], | ||
minimum_frequency: float, | ||
reference_frequency: float, | ||
sample_rate: float, | ||
waveform_duration: float, | ||
waveform_approximant: str, | ||
detector_frame_prior: bool = False, | ||
): | ||
padding = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the motivation for this? In general should avoid "magic numbers" - maybe make this a parameter with a default value There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So there is the issue where thebilby generated waveforms have coalescence at timestamp=0 sec and that's on purpose I think. We need to move the coalescence to the end of the waveform at time stamp = 16 sec (lets say). To do this I played with several waveforms and found that if we roll the waveforms by 200 datapoints to the left, then we do get most of the coalescence at the very end. There are some ring-down remnants in some cases and to deal with that, we chop off the first sec of the waveform. So the padding is set to 1 sec and the waveform is generated "longer "by an amount equal to padding. After rolling and chopping off the first sec, the resultant waveform is of the intended length and the coalescence is nicely at the end of it. |
||
if not detector_frame_prior: | ||
sample_params = convert_to_detector_frame(sample_params) | ||
|
||
sample_params = [ | ||
dict(zip(sample_params, col)) for col in zip(*sample_params.values()) | ||
] | ||
|
||
n_samples = len(sample_params) | ||
|
||
waveform_generator = WaveformGenerator( | ||
duration=waveform_duration + padding, | ||
sampling_frequency=sample_rate, | ||
frequency_domain_source_model=lal_binary_neutron_star, | ||
parameter_conversion=convert_to_lal_binary_neutron_star_parameters, | ||
waveform_arguments={ | ||
"waveform_approximant": waveform_approximant, | ||
"reference_frequency": reference_frequency, | ||
"minimum_frequency": minimum_frequency, | ||
}, | ||
) | ||
|
||
logging.info("Generating BNS waveforms : {}".format(n_samples)) | ||
|
||
waveform_size = int(sample_rate * waveform_duration) | ||
num_pols = 2 | ||
signals = np.zeros((n_samples, num_pols, waveform_size)) | ||
|
||
for i, p in enumerate(sample_params): | ||
polarizations = waveform_generator.time_domain_strain(p) | ||
polarization_names = sorted(polarizations.keys()) | ||
polarizations = np.stack( | ||
[polarizations[p] for p in polarization_names] | ||
) | ||
|
||
# just shift the coalescence to the left by 200 datapoints | ||
# to cancel wraparound in the beginning | ||
dt = -200 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, another magic number: why is this 200? Is there a first principles motivation for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please see above |
||
polarizations = np.roll(polarizations, dt, axis=-1) | ||
|
||
# cut off the first sec of the waveform where the wraparound occurs | ||
padding_datapoints = padding * sample_rate | ||
signals[i] = polarizations[:, int(padding_datapoints) :] | ||
if i == (n_samples / 4): | ||
logging.info("Generated polarizations : {}".format(i)) | ||
elif i == (n_samples / 2): | ||
logging.info("Generated polarizations : {}".format(i)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would argue this reduces readability more than it helps with logging. If you wan't to keep track maybe something cleaner would be # every 10th waveform
if not i % 10:
# note the logging.debug so that it's only called if verbose=True.
logging.debug(f"{i + 1} polarizations generated") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, will change |
||
|
||
logging.info("Finished Generated polarizations") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Finished Generated polarizations" --> "Finished generating polarizations" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree, will change |
||
return signals | ||
|
||
|
||
def inject_waveforms( | ||
background: Tuple[np.ndarray, np.ndarray], | ||
waveforms: np.ndarray, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,28 +211,30 @@ def forward(self, X): | |
# sample waveforms and use them to compute | ||
# interferometer responses | ||
N = mask.sum().item() | ||
responses = self.sample_responses(N, X.shape[-1], psds[mask]) | ||
responses.to(X.device) | ||
|
||
# perform swapping and muting augmentations | ||
# on those responses, and then inject them | ||
responses, swap_indices = self.swapper(responses) | ||
responses, mute_indices = self.muter(responses) | ||
X[mask] += responses | ||
if N > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remind me when this edge case would be reached? In what instance would we not wan't to inject waveforms? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was also something that Alec had proposed. We ran into this issue when we had to decrease the batch size to like 8 for BNS and that would lead to N coming to 0 on some instances, with waveform prob set to 0.277. For smaller batch sizes, the chances of N coming to zero, even for reasonable values of waveform probs is pretty high. So Alec proposed to add this check to aframe |
||
responses = self.sample_responses(N, X.shape[-1], psds[mask]) | ||
responses.to(X.device) | ||
|
||
# perform swapping and muting augmentations | ||
# on those responses, and then inject them | ||
responses, swap_indices = self.swapper(responses) | ||
responses, mute_indices = self.muter(responses) | ||
X[mask] += responses | ||
|
||
# set response augmentation labels to noise | ||
idx = torch.where(mask)[0] | ||
mask[idx[mute_indices]] = 0 | ||
mask[idx[swap_indices]] = 0 | ||
|
||
# set labels to positive for injected signals | ||
y[mask] = -y[mask] + 1 | ||
|
||
# now that injections have been made, | ||
# whiten _all_ the strain using the | ||
# background psds computed up top | ||
X = self.whitener(X, psds) | ||
|
||
# set response augmentation labels to noise | ||
idx = torch.where(mask)[0] | ||
mask[idx[mute_indices]] = 0 | ||
mask[idx[swap_indices]] = 0 | ||
|
||
# set labels to positive for injected signals | ||
y[mask] = -y[mask] + 1 | ||
|
||
# curriculum learning step | ||
if self.snr is not None: | ||
self.snr.step() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of parallelizing this, but I don't think this is doing what you expect: This will submit a single job that will generate all the requested waveforms in one process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this was a long discussion with Alec on a thread that incidentally did not include you. I am going to forward that thread to you and hopefully you can see the entire convo.
The crux was that, using concurrent.futures did reduce the waveform generation from days to under 46 min on the hanford box.
Let me know if you can access the following thread on slack:
https://fastml.slack.com/archives/C05EHNRU8AK/p1695772374684639
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ethan's right. This is only helpful if you submit a job for each choice of parameters, submitting one job that generates waveforms for all the parameters won't multiprocess anything, it will just generate all the waveforms in serial in one process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, so I get your point. Not sure why it reduced the waveform generation time so drastically. Or could it be that the hanford box just behaved rather nicely at that particular run. It's kind of puzzling.
So I will remove the concurrent.futures for now. If we run into bottlenecks in generating BNS in future, we can revisit at that time.