Skip to content

Commit

Permalink
update testing
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Jan 24, 2025
1 parent 6bc5a17 commit 4c9cf9a
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 84 deletions.
44 changes: 5 additions & 39 deletions ml4gw/waveforms/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Dict, Optional, Tuple
import math
from typing import Callable, Dict, Tuple

import torch
from jaxtyping import Float
Expand All @@ -13,44 +14,6 @@
EXTRA_CYCLES = 3.0


class ParameterSampler(torch.nn.Module):
def __init__(
self,
conversion_function: Optional[Callable] = None,
**params: Callable,
):
"""
A class for sampling parameters from a prior distribution
Args:
conversion_function:
A callable that takes a dictionary of sampled parameters
and returns a dictionary of transformed parameters
**params:
A dictionary of parameter samplers that take an integer N
and return a tensor of shape (N, ...) representing
samples from the prior distribution
"""
super().__init__()
self.params = params
self.conversion_function = conversion_function or (lambda x: x)

def forward(
self,
N: int,
device: str = "cpu",
):
# sample parameters from prior
parameters = {
k: v.sample((N,)).to(device) for k, v in self.params.items()
}
# perform any necessary conversions
# to from sampled parameters to
# waveform generation parameters
parameters = self.conversion_function(parameters)
return parameters


class TimeDomainCBCWaveformGenerator(torch.nn.Module):
"""
Waveform generator that generates time-domain waveforms. Currently,
Expand Down Expand Up @@ -156,6 +119,9 @@ def forward(
(tchirp + tmerge + 2.0 * textra) * self.sample_rate
)

# pad to next power of 2
chirplen = 2 ** torch.ceil(torch.log(chirplen) / math.log(2))

# get smallest df corresponding to longest chirp length,
# which will make sure there is no wrap around effects.
df = min(1.0 / (chirplen.max() / self.sample_rate), self.delta_f)
Expand Down
30 changes: 30 additions & 0 deletions tests/Untitled-1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
101 changes: 101 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import torch
from scipy.special import erfinv
from torch.distributions import Uniform


@pytest.fixture
Expand Down Expand Up @@ -69,3 +70,103 @@ def validate(whitened, highpass, sample_rate, df):
torch.testing.assert_close(passed, target, rtol=0, atol=0.07)

return validate


# number of samples to draw from
# the distributions for testing
N_SAMPLES = 100


@pytest.fixture(params=[256, 1024, 2048])
def sample_rate(request):
return request.param


@pytest.fixture()
def chirp_mass(request):
dist = Uniform(5, 100)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def mass_ratio():
dist = Uniform(0.125, 0.99)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def a_1(request):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def a_2(request):
dist = Uniform(0, 0.90)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def tilt_1(request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def tilt_2(request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phi_12(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phi_jl(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance(request):
dist = Uniform(100, 3000)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance_far(request):
dist = Uniform(400, 3000)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def distance_close(request):
dist = Uniform(100, 400)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def theta_jn(request):
dist = Uniform(0, torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def phase(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def chi1(request):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))


@pytest.fixture()
def chi2(request):
dist = Uniform(-0.999, 0.999)
return dist.sample((N_SAMPLES,))
3 changes: 2 additions & 1 deletion tests/waveforms/cbc/test_cbc_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ def test_phenom_p(
assert np.allclose(
1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=1e-2
)

print(hp_lal.epoch)
assert False
# test batched outputs works as expected
hc_ml4gw, hp_ml4gw = waveforms.IMRPhenomPv2()(
torch_freqs,
Expand Down
File renamed without changes.
128 changes: 84 additions & 44 deletions tests/waveforms/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from math import pi

import astropy.units as u
import lal
import lalsimulation
import pytest
import torch

from ml4gw import distributions
from ml4gw.waveforms.generator import ParameterSampler, WaveformGenerator
from ml4gw.waveforms import IMRPhenomD, conversion
from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator


@pytest.fixture(params=[10, 100, 1000])
Expand All @@ -22,46 +23,85 @@ def sample_rate(request):
return request.param


def test_parameter_sampler(n_samples):
parameter_sampler = ParameterSampler(
phi=torch.distributions.Uniform(0, 2 * pi),
dec=distributions.Cosine(),
snr=distributions.LogNormal(6, 4, 3),
def test_cbc_waveform_generator(
chirp_mass,
mass_ratio,
chi1,
chi2,
phase,
distance,
theta_jn,
sample_rate,
):
sample_rate = 4096
duration = 1
f_min = 20
f_ref = 40
right_pad = 0.1

generator = TimeDomainCBCWaveformGenerator(
approximant=IMRPhenomD(),
sample_rate=sample_rate,
duration=duration,
f_min=f_min,
f_ref=f_ref,
right_pad=right_pad,
)

samples = parameter_sampler(
n_samples,
mass_1, mass_2 = conversion.chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

for k in ["phi", "dec", "snr"]:
assert len(samples[k]) == n_samples


def test_waveform_generator(sample_rate, duration, n_samples):
def waveform(amplitude, frequency, phase):
frequency = frequency.view(-1, 1)
amplitude = amplitude.view(-1, 1)
phase = phase.view(-1, 1)

strain = torch.arange(0, duration, 1 / sample_rate)
hplus = amplitude * torch.sin(2 * pi * frequency * strain + phase)
hcross = amplitude * torch.cos(2 * pi * frequency * strain + phase)

hplus = hplus.unsqueeze(1)
hcross = hcross.unsqueeze(1)

waveforms = torch.cat([hplus, hcross], dim=1)
return waveforms

parameter_sampler = ParameterSampler(
amplitude=torch.distributions.Uniform(0, 1),
frequency=torch.distributions.Uniform(0, 1),
phase=torch.distributions.Uniform(0, 2 * pi),
)

generator = WaveformGenerator(waveform, parameter_sampler)
waveforms, parameters = generator(n_samples)

for k in ["amplitude", "frequency", "phase"]:
assert len(parameters[k]) == n_samples
assert waveforms.shape == (n_samples, 2, duration * sample_rate)
s1x = torch.zeros_like(chi1)
s1y = torch.zeros_like(chi1)
s1z = chi1
s2x = torch.zeros_like(chi2)
s2y = torch.zeros_like(chi2)
s2z = chi2
parameters = {
"chirp_mass": chirp_mass,
"mass_ratio": mass_ratio,
"mass_1": mass_1,
"mass_2": mass_2,
"chi1": chi1,
"chi2": chi2,
"s1z": s1z,
"s2z": s2z,
"s1x": s1x,
"s1y": s1y,
"s2x": s2x,
"s2y": s2y,
"phic": phase,
"distance": distance,
"inclination": theta_jn,
}
hc, hp = generator(**parameters)

# now compare each waveform with lalsimulation SimInspiralTD
for i in range(len(chirp_mass)):

# test far (> 400 Mpc) waveforms (O(1e-3) agreement)

# construct lalinference params
params = dict(
m1=mass_1[i].item() * lal.MSUN_SI,
m2=mass_2[i].item() * lal.MSUN_SI,
S1x=s1x[i].item(),
S1y=s2y[i].item(),
S1z=s1z[i].item(),
S2x=s2x[i].item(),
S2y=s2y[i].item(),
S2z=s1z[i].item(),
distance=(distance[i].item() * u.Mpc).to("m").value,
inclination=theta_jn[i].item(),
phiRef=phase[i].item(),
longAscNodes=0.0,
eccentricity=0.0,
meanPerAno=0.0,
deltaT=1 / sample_rate,
f_min=f_min,
f_ref=f_ref,
approximant=lalsimulation.IMRPhenomD,
LALparams=lal.CreateDict(),
)
return params
# hp_lal, hc_lal = lalsimulation.SimInspiralTD(**params)

0 comments on commit 4c9cf9a

Please sign in to comment.