diff --git a/libs/priors/aframe/priors/priors.py b/libs/priors/aframe/priors/priors.py index 2c82df496..ccd066843 100644 --- a/libs/priors/aframe/priors/priors.py +++ b/libs/priors/aframe/priors/priors.py @@ -12,7 +12,13 @@ Sine, Uniform, ) -from bilby.gw.prior import UniformComovingVolume, UniformSourceFrame +from bilby.gw.prior import ( + AlignedSpin, + UniformComovingVolume, + UniformInComponentsChirpMass, + UniformInComponentsMassRatio, + UniformSourceFrame, +) from aframe.priors.utils import ( mass_condition_powerlaw, @@ -133,6 +139,68 @@ def spin_bbh(cosmology: cosmo.Cosmology = COSMOLOGY) -> PriorDict: return prior, detector_frame_prior +def nonspin_bns(cosmology: cosmo.Cosmology = COSMOLOGY) -> PriorDict: + """ + Define a Bilby `PriorDict` that describes a reasonable population + of non-spinning binary black holes + + Masses are defined in the detector frame. + + Args: + cosmology: + An `astropy` cosmology, used to determine redshift sampling + + Returns: + prior: + `PriorDict` describing the binary black hole population + detector_frame_prior: + Boolean indicating which frame masses are defined in + """ + prior = PriorDict() + prior["mass1"] = Uniform(0.5, 5, unit=msun) + prior["mass2"] = Uniform(0.5, 5, unit=msun) + prior["mass_ratio"] = UniformInComponentsMassRatio( + name="mass_ratio", minimum=0.125, maximum=1 + ) + + # tidal deformability parameter + prior["lambda_tilde"] = Uniform(0, 5000, name="lambda_tilde") + prior["delta_lambda"] = Uniform(-5000, 5000, name="delta_lambda") + + prior["redshift"] = UniformSourceFrame( + 0, 0.5, name="redshift", cosmology=cosmology + ) + prior["chirp_mass"] = UniformInComponentsChirpMass( + name="chirp_mass", minimum=0.4, maximum=4.4 + ) + prior["distance"] = UniformSourceFrame( + name="luminosity_distance", minimum=1e2, maximum=5e3 + ) + prior["dec"] = Cosine(name="dec") + prior["ra"] = Uniform( + name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic" + ) + prior["theta_jn"] = Sine(name="theta_jn") + prior["phase"] = Uniform( + name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic" + ) + prior["psi"] = Uniform( + name="psi", minimum=0, maximum=np.pi, boundary="periodic" + ) + prior["chi_1"] = AlignedSpin( + name="chi_1", a_prior=Uniform(minimum=0, maximum=0.99) + ) + prior["chi_2"] = AlignedSpin( + name="chi_2", a_prior=Uniform(minimum=0, maximum=0.99) + ) + + prior["phi_jl"] = 0 + + detector_frame_prior = True + + return prior, detector_frame_prior + + def end_o3_ratesandpops( cosmology: cosmo.Cosmology = COSMOLOGY, ) -> ConditionalPriorDict: diff --git a/projects/sandbox/datagen/datagen/scripts/waveforms.py b/projects/sandbox/datagen/datagen/scripts/waveforms.py index c4b536d50..21d7b60bf 100644 --- a/projects/sandbox/datagen/datagen/scripts/waveforms.py +++ b/projects/sandbox/datagen/datagen/scripts/waveforms.py @@ -1,11 +1,12 @@ import logging import random +from concurrent.futures import ProcessPoolExecutor from pathlib import Path from typing import Callable, Optional import h5py import numpy as np -from datagen.utils.injection import generate_gw +from datagen.utils.injection import generate_gw, generate_gw_bns from typeo import scriptify from aframe.logging import configure_logging @@ -22,6 +23,7 @@ def main( sample_rate: float, waveform_duration: float, waveform_approximant: str = "IMRPhenomPv2", + signal_type: str = "bbh", force_generation: bool = False, verbose: bool = False, seed: Optional[int] = None, @@ -93,17 +95,29 @@ def main( prior, detector_frame_prior = prior() params = prior.sample(num_signals) - signals = generate_gw( - params, - minimum_frequency, - reference_frequency, - sample_rate, - waveform_duration, - waveform_approximant, - detector_frame_prior, - ) + if signal_type == "bns": + signals = generate_gw_bns( + params, + minimum_frequency, + reference_frequency, + sample_rate, + waveform_duration, + waveform_approximant, + detector_frame_prior, + ) + else: + signals = generate_gw( + params, + minimum_frequency, + reference_frequency, + sample_rate, + waveform_duration, + waveform_approximant, + detector_frame_prior, + ) # Write params and similar to output file + logging.info("Writing waveforms to file....") if np.isnan(signals).any(): raise ValueError("The signals contain NaN values") @@ -125,7 +139,7 @@ def main( "minimum_frequency": minimum_frequency, } ) - + logging.info("Writing waveforms to file finished!") return signal_file diff --git a/projects/sandbox/datagen/datagen/utils/injection.py b/projects/sandbox/datagen/datagen/utils/injection.py index 9793eea0b..7a8c0deeb 100644 --- a/projects/sandbox/datagen/datagen/utils/injection.py +++ b/projects/sandbox/datagen/datagen/utils/injection.py @@ -1,8 +1,12 @@ +import logging from typing import Dict, List, Tuple import numpy as np -from bilby.gw.conversion import convert_to_lal_binary_black_hole_parameters -from bilby.gw.source import lal_binary_black_hole +from bilby.gw.conversion import ( + convert_to_lal_binary_black_hole_parameters, + convert_to_lal_binary_neutron_star_parameters, +) +from bilby.gw.source import lal_binary_black_hole, lal_binary_neutron_star from bilby.gw.waveform_generator import WaveformGenerator @@ -91,6 +95,70 @@ def generate_gw( return signals +def generate_gw_bns( + sample_params: Dict[List, str], + minimum_frequency: float, + reference_frequency: float, + sample_rate: float, + waveform_duration: float, + waveform_approximant: str, + detector_frame_prior: bool = False, +): + # Generate a longer waveform by no of sec equal to padding + # After wrap-around effect is fixed the padded length would + # be chopped off leaving the waveform of intended length + padding = 1 + if not detector_frame_prior: + sample_params = convert_to_detector_frame(sample_params) + + sample_params = [ + dict(zip(sample_params, col)) for col in zip(*sample_params.values()) + ] + + n_samples = len(sample_params) + + waveform_generator = WaveformGenerator( + duration=waveform_duration + padding, + sampling_frequency=sample_rate, + frequency_domain_source_model=lal_binary_neutron_star, + parameter_conversion=convert_to_lal_binary_neutron_star_parameters, + waveform_arguments={ + "waveform_approximant": waveform_approximant, + "reference_frequency": reference_frequency, + "minimum_frequency": minimum_frequency, + }, + ) + + logging.info("Generating BNS waveforms : {}".format(n_samples)) + + waveform_size = int(sample_rate * waveform_duration) + num_pols = 2 + signals = np.zeros((n_samples, num_pols, waveform_size)) + + for i, p in enumerate(sample_params): + polarizations = waveform_generator.time_domain_strain(p) + polarization_names = sorted(polarizations.keys()) + polarizations = np.stack( + [polarizations[p] for p in polarization_names] + ) + + # just shift the coalescence to the left by 200 datapoints + # to cancel wraparound in the beginning + dt = -200 + polarizations = np.roll(polarizations, dt, axis=-1) + + # cut off the first sec of the waveform where the wraparound occurs + padding_length = padding * sample_rate + signals[i] = polarizations[:, int(padding_length) :] + + # every 1000th waveform + if not i % 1000: + # note the following is only called if verbose=True + logging.debug(f"{i + 1} polarizations generated") + logging.info("Finished generating polarizations") + return signals + + def inject_waveforms( background: Tuple[np.ndarray, np.ndarray], waveforms: np.ndarray, diff --git a/projects/sandbox/pyproject.toml b/projects/sandbox/pyproject.toml index 3f4bdb156..d1c2a9834 100644 --- a/projects/sandbox/pyproject.toml +++ b/projects/sandbox/pyproject.toml @@ -49,6 +49,7 @@ cosmology = "aframe.priors.cosmologies.planck" streams_per_gpu = 3 waveform_approximant = "IMRPhenomPv2" verbose = true +signal_type = "bbh" [tool.typeo.scripts.deploy-background] @@ -103,6 +104,7 @@ sample_rate = "${base.sample_rate}" waveform_duration = "${base.waveform_duration}" force_generation = "${base.force_generation}" waveform_approximant = "${base.waveform_approximant}" +signal_type = "${base.signal_type}" [tool.typeo.scripts.train] # input and output paths diff --git a/projects/sandbox/train/train/augmentor.py b/projects/sandbox/train/train/augmentor.py index edc39b886..d268e8ee7 100644 --- a/projects/sandbox/train/train/augmentor.py +++ b/projects/sandbox/train/train/augmentor.py @@ -211,28 +211,30 @@ def forward(self, X): # sample waveforms and use them to compute # interferometer responses N = mask.sum().item() - responses = self.sample_responses(N, X.shape[-1], psds[mask]) - responses.to(X.device) - # perform swapping and muting augmentations - # on those responses, and then inject them - responses, swap_indices = self.swapper(responses) - responses, mute_indices = self.muter(responses) - X[mask] += responses + if N > 0: + responses = self.sample_responses(N, X.shape[-1], psds[mask]) + responses.to(X.device) + + # perform swapping and muting augmentations + # on those responses, and then inject them + responses, swap_indices = self.swapper(responses) + responses, mute_indices = self.muter(responses) + X[mask] += responses + + # set response augmentation labels to noise + idx = torch.where(mask)[0] + mask[idx[mute_indices]] = 0 + mask[idx[swap_indices]] = 0 + + # set labels to positive for injected signals + y[mask] = -y[mask] + 1 # now that injections have been made, # whiten _all_ the strain using the # background psds computed up top X = self.whitener(X, psds) - # set response augmentation labels to noise - idx = torch.where(mask)[0] - mask[idx[mute_indices]] = 0 - mask[idx[swap_indices]] = 0 - - # set labels to positive for injected signals - y[mask] = -y[mask] + 1 - # curriculum learning step if self.snr is not None: self.snr.step() diff --git a/projects/sandbox/train/train/train.py b/projects/sandbox/train/train/train.py index d43cdf0d8..a5a204c85 100644 --- a/projects/sandbox/train/train/train.py +++ b/projects/sandbox/train/train/train.py @@ -217,10 +217,10 @@ def main( outdir.mkdir(exist_ok=True, parents=True) logdir.mkdir(exist_ok=True, parents=True) configure_logging(logdir / "train.log", verbose) + if seed is not None: logging.info(f"Setting global seed to {seed}") train_utils.seed_everything(seed) - # grab the names of the background files and determine the # length of data that will be handed to the preprocessor background_fnames = train_utils.get_background_fnames(background_dir) @@ -228,6 +228,12 @@ def main( sample_length = window_length + psd_length fftlength = fftlength or window_length + if psd_length < window_length: + raise ValueError( + "Can't have psd length shorter than " + "window length, {} < {}".format(psd_length, window_length) + ) + # create objects that we'll use for whitening the data fast = highpass is not None psd_estimator = PsdEstimator(