Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Datagen bns changes #466

Open
wants to merge 34 commits into
base: bns
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8eb5b60
added check for N=0
rafia17 Oct 3, 2023
a70b426
added N>0
rafia17 Oct 3, 2023
068af79
add check for psd length >= window length
rafia17 Oct 5, 2023
708f137
pre-commit checks were failing
rafia17 Oct 5, 2023
433a6ac
pre-commit checks were failing
rafia17 Oct 5, 2023
a92ae7c
pre-commit checks were failing
rafia17 Oct 5, 2023
da26566
pre-commit checks were failing
rafia17 Oct 5, 2023
f18371f
pre-commit checks were failing
rafia17 Oct 5, 2023
f577295
pre-commit checks were failing
rafia17 Oct 5, 2023
faaaf51
fix for N=0 and proper psd length
rafia17 Oct 23, 2023
a4caaeb
added bns nonspin prior
Nov 24, 2023
f9d9720
Merge branch 'add_bns_prior' into bns
Nov 24, 2023
6955fc3
added imports for bns nonspin prior
Nov 27, 2023
f532652
Merge branch 'add_bns_prior' into bns
Nov 27, 2023
93ec5b2
datagen bns changes
Nov 27, 2023
4481b65
added signal_type variable
Nov 27, 2023
cb29e6c
datagen bns changes
Nov 28, 2023
611eb29
datagen changes
Nov 28, 2023
0c168d2
datagen changes
Nov 28, 2023
71fe8dc
injection changes
Nov 28, 2023
a25cf64
datagen changes with fixes
Nov 28, 2023
d832644
datagen bns changes
Nov 28, 2023
b95db3d
datagen bns changes
Nov 28, 2023
5499892
Fixes as per reviewer's comments
Nov 28, 2023
474ddaa
review changes
Nov 28, 2023
669e0db
Revert "review changes"
Nov 28, 2023
a1315bf
fixes based on review comments
Nov 29, 2023
93717f8
fixed unit test errors
Nov 30, 2023
e398992
resolving flake8 test issues
Nov 30, 2023
f048fc4
random merge conflicts
Nov 30, 2023
dedb7a3
datagen changes
Dec 3, 2023
ea8086d
datagen bns changes
Dec 3, 2023
1ed207a
datagen bns changes
Dec 3, 2023
6e7cc95
datagen changes
Dec 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion libs/priors/aframe/priors/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
Sine,
Uniform,
)
from bilby.gw.prior import UniformComovingVolume, UniformSourceFrame
from bilby.gw.prior import (
AlignedSpin,
UniformComovingVolume,
UniformInComponentsChirpMass,
UniformInComponentsMassRatio,
UniformSourceFrame,
)

from aframe.priors.utils import (
mass_condition_powerlaw,
Expand Down Expand Up @@ -133,6 +139,68 @@ def spin_bbh(cosmology: cosmo.Cosmology = COSMOLOGY) -> PriorDict:
return prior, detector_frame_prior


def nonspin_bns(cosmology: cosmo.Cosmology = COSMOLOGY) -> PriorDict:
"""
Define a Bilby `PriorDict` that describes a reasonable population
of non-spinning binary black holes

Masses are defined in the detector frame.

Args:
cosmology:
An `astropy` cosmology, used to determine redshift sampling

Returns:
prior:
`PriorDict` describing the binary black hole population
detector_frame_prior:
Boolean indicating which frame masses are defined in
"""
prior = PriorDict()
prior["mass1"] = Uniform(0.5, 5, unit=msun)
prior["mass2"] = Uniform(0.5, 5, unit=msun)
prior["mass_ratio"] = UniformInComponentsMassRatio(
name="mass_ratio", minimum=0.125, maximum=1
)

# tidal deformability parameter
prior["lambda_tilde"] = Uniform(0, 5000, name="lambda_tilde")
prior["delta_lambda"] = Uniform(-5000, 5000, name="delta_lambda")

prior["redshift"] = UniformSourceFrame(
0, 0.5, name="redshift", cosmology=cosmology
)
prior["chirp_mass"] = UniformInComponentsChirpMass(
name="chirp_mass", minimum=0.4, maximum=4.4
)
prior["distance"] = UniformSourceFrame(
name="luminosity_distance", minimum=1e2, maximum=5e3
)
prior["dec"] = Cosine(name="dec")
prior["ra"] = Uniform(
name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"
)
prior["theta_jn"] = Sine(name="theta_jn")
prior["phase"] = Uniform(
name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic"
)
prior["psi"] = Uniform(
name="psi", minimum=0, maximum=np.pi, boundary="periodic"
)
prior["chi_1"] = AlignedSpin(
name="chi_1", a_prior=Uniform(minimum=0, maximum=0.99)
)
prior["chi_2"] = AlignedSpin(
name="chi_2", a_prior=Uniform(minimum=0, maximum=0.99)
)

prior["phi_jl"] = 0

detector_frame_prior = True

return prior, detector_frame_prior


def end_o3_ratesandpops(
cosmology: cosmo.Cosmology = COSMOLOGY,
) -> ConditionalPriorDict:
Expand Down
39 changes: 28 additions & 11 deletions projects/sandbox/datagen/datagen/scripts/waveforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import random
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Callable, Optional

import h5py
import numpy as np
from datagen.utils.injection import generate_gw
from datagen.utils.injection import generate_gw, generate_gw_bns
from typeo import scriptify

from aframe.logging import configure_logging
Expand All @@ -22,6 +23,7 @@ def main(
sample_rate: float,
waveform_duration: float,
waveform_approximant: str = "IMRPhenomPv2",
signal_type: str = "bbh",
force_generation: bool = False,
verbose: bool = False,
seed: Optional[int] = None,
Expand Down Expand Up @@ -93,17 +95,32 @@ def main(
prior, detector_frame_prior = prior()
params = prior.sample(num_signals)

signals = generate_gw(
params,
minimum_frequency,
reference_frequency,
sample_rate,
waveform_duration,
waveform_approximant,
detector_frame_prior,
)
if signal_type == "bns":
with ProcessPoolExecutor(140) as exe:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of parallelizing this, but I don't think this is doing what you expect: This will submit a single job that will generate all the requested waveforms in one process.

Copy link
Contributor Author

@rafia17 rafia17 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this was a long discussion with Alec on a thread that incidentally did not include you. I am going to forward that thread to you and hopefully you can see the entire convo.
The crux was that, using concurrent.futures did reduce the waveform generation from days to under 46 min on the hanford box.
Let me know if you can access the following thread on slack:
https://fastml.slack.com/archives/C05EHNRU8AK/p1695772374684639

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ethan's right. This is only helpful if you submit a job for each choice of parameters, submitting one job that generates waveforms for all the parameters won't multiprocess anything, it will just generate all the waveforms in serial in one process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so I get your point. Not sure why it reduced the waveform generation time so drastically. Or could it be that the hanford box just behaved rather nicely at that particular run. It's kind of puzzling.
So I will remove the concurrent.futures for now. If we run into bottlenecks in generating BNS in future, we can revisit at that time.

future = exe.submit(
generate_gw_bns,
params,
minimum_frequency,
reference_frequency,
sample_rate,
waveform_duration,
waveform_approximant,
detector_frame_prior,
)
signals = future.result()
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is going into a dedicated bns branch, I think it might make sense to just assume we are generating bns waveforms, and not have this if else. We can move to generalizing everything down the line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to create a separate repo for BNS, it will be most ideal and efficient, if we take care of this now rather than later. And write code keeping in mind that it works seamlessly with the main pipeline. Otherwise it will be a huge issue later to merge bns branch into main aframe branch.

signals = generate_gw(
params,
minimum_frequency,
reference_frequency,
sample_rate,
waveform_duration,
waveform_approximant,
detector_frame_prior,
)

# Write params and similar to output file
logging.info("Writing waveforms to file....")
if np.isnan(signals).any():
raise ValueError("The signals contain NaN values")

Expand All @@ -125,7 +142,7 @@ def main(
"minimum_frequency": minimum_frequency,
}
)

logging.info("Writing waveforms to file finished!")
return signal_file


Expand Down
69 changes: 67 additions & 2 deletions projects/sandbox/datagen/datagen/utils/injection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
from typing import Dict, List, Tuple

import numpy as np
from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters
from bilby.gw.source import lal_binary_black_hole
from bilby.gw.conversion import (
convert_to_lal_binary_black_hole_parameters,
convert_to_lal_binary_neutron_star_parameters,
)
from bilby.gw.source import lal_binary_black_hole, lal_binary_neutron_star
from bilby.gw.waveform_generator import WaveformGenerator


Expand Down Expand Up @@ -91,6 +95,67 @@ def generate_gw(
return signals


def generate_gw_bns(
sample_params: Dict[List, str],
minimum_frequency: float,
reference_frequency: float,
sample_rate: float,
waveform_duration: float,
waveform_approximant: str,
detector_frame_prior: bool = False,
):
padding = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation for this? In general should avoid "magic numbers" - maybe make this a parameter with a default value

Copy link
Contributor Author

@rafia17 rafia17 Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there is the issue where thebilby generated waveforms have coalescence at timestamp=0 sec and that's on purpose I think. We need to move the coalescence to the end of the waveform at time stamp = 16 sec (lets say). To do this I played with several waveforms and found that if we roll the waveforms by 200 datapoints to the left, then we do get most of the coalescence at the very end. There are some ring-down remnants in some cases and to deal with that, we chop off the first sec of the waveform. So the padding is set to 1 sec and the waveform is generated "longer "by an amount equal to padding. After rolling and chopping off the first sec, the resultant waveform is of the intended length and the coalescence is nicely at the end of it.
I will add some comments to the code to explain this properly.
So in the specific context that it is used, I don't think its value will be changed or can be used elsewhere. Given its limited scope, I don't think that we need to put this in pyproject.toml

if not detector_frame_prior:
sample_params = convert_to_detector_frame(sample_params)

sample_params = [
dict(zip(sample_params, col)) for col in zip(*sample_params.values())
]

n_samples = len(sample_params)

waveform_generator = WaveformGenerator(
duration=waveform_duration + padding,
sampling_frequency=sample_rate,
frequency_domain_source_model=lal_binary_neutron_star,
parameter_conversion=convert_to_lal_binary_neutron_star_parameters,
waveform_arguments={
"waveform_approximant": waveform_approximant,
"reference_frequency": reference_frequency,
"minimum_frequency": minimum_frequency,
},
)

logging.info("Generating BNS waveforms : {}".format(n_samples))

waveform_size = int(sample_rate * waveform_duration)
num_pols = 2
signals = np.zeros((n_samples, num_pols, waveform_size))

for i, p in enumerate(sample_params):
polarizations = waveform_generator.time_domain_strain(p)
polarization_names = sorted(polarizations.keys())
polarizations = np.stack(
[polarizations[p] for p in polarization_names]
)

# just shift the coalescence to the left by 200 datapoints
# to cancel wraparound in the beginning
dt = -200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, another magic number: why is this 200? Is there a first principles motivation for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see above

polarizations = np.roll(polarizations, dt, axis=-1)

# cut off the first sec of the waveform where the wraparound occurs
padding_datapoints = padding * sample_rate
signals[i] = polarizations[:, int(padding_datapoints) :]
if i == (n_samples / 4):
logging.info("Generated polarizations : {}".format(i))
elif i == (n_samples / 2):
logging.info("Generated polarizations : {}".format(i))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would argue this reduces readability more than it helps with logging. If you wan't to keep track maybe something cleaner would be

# every 10th waveform
if not i % 10:
    # note the logging.debug so that it's only called if verbose=True.
    logging.debug(f"{i + 1} polarizations generated") 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, will change


logging.info("Finished Generated polarizations")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Finished Generated polarizations" --> "Finished generating polarizations"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, will change

return signals


def inject_waveforms(
background: Tuple[np.ndarray, np.ndarray],
waveforms: np.ndarray,
Expand Down
2 changes: 2 additions & 0 deletions projects/sandbox/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cosmology = "aframe.priors.cosmologies.planck"
streams_per_gpu = 3
waveform_approximant = "IMRPhenomPv2"
verbose = true
signal_type = "bbh"


[tool.typeo.scripts.deploy-background]
Expand Down Expand Up @@ -103,6 +104,7 @@ sample_rate = "${base.sample_rate}"
waveform_duration = "${base.waveform_duration}"
force_generation = "${base.force_generation}"
waveform_approximant = "${base.waveform_approximant}"
signal_type = "${base.signal_type}"

[tool.typeo.scripts.train]
# input and output paths
Expand Down
32 changes: 17 additions & 15 deletions projects/sandbox/train/train/augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,28 +211,30 @@ def forward(self, X):
# sample waveforms and use them to compute
# interferometer responses
N = mask.sum().item()
responses = self.sample_responses(N, X.shape[-1], psds[mask])
responses.to(X.device)

# perform swapping and muting augmentations
# on those responses, and then inject them
responses, swap_indices = self.swapper(responses)
responses, mute_indices = self.muter(responses)
X[mask] += responses
if N > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me when this edge case would be reached? In what instance would we not wan't to inject waveforms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was also something that Alec had proposed. We ran into this issue when we had to decrease the batch size to like 8 for BNS and that would lead to N coming to 0 on some instances, with waveform prob set to 0.277. For smaller batch sizes, the chances of N coming to zero, even for reasonable values of waveform probs is pretty high. So Alec proposed to add this check to aframe

responses = self.sample_responses(N, X.shape[-1], psds[mask])
responses.to(X.device)

# perform swapping and muting augmentations
# on those responses, and then inject them
responses, swap_indices = self.swapper(responses)
responses, mute_indices = self.muter(responses)
X[mask] += responses

# set response augmentation labels to noise
idx = torch.where(mask)[0]
mask[idx[mute_indices]] = 0
mask[idx[swap_indices]] = 0

# set labels to positive for injected signals
y[mask] = -y[mask] + 1

# now that injections have been made,
# whiten _all_ the strain using the
# background psds computed up top
X = self.whitener(X, psds)

# set response augmentation labels to noise
idx = torch.where(mask)[0]
mask[idx[mute_indices]] = 0
mask[idx[swap_indices]] = 0

# set labels to positive for injected signals
y[mask] = -y[mask] + 1

# curriculum learning step
if self.snr is not None:
self.snr.step()
Expand Down
6 changes: 6 additions & 0 deletions projects/sandbox/train/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ def main(
sample_length = window_length + psd_length
fftlength = fftlength or window_length

if psd_length < window_length:
raise ValueError(
"Can't have psd length shorter than "
"window length, {} < {}".format(psd_length, window_length)
)

# create objects that we'll use for whitening the data
fast = highpass is not None
psd_estimator = PsdEstimator(
Expand Down
Loading