Skip to content

Commit

Permalink
add old phenom p tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Dec 16, 2024
1 parent 962424d commit d0c44af
Showing 1 changed file with 65 additions and 98 deletions.
163 changes: 65 additions & 98 deletions tests/waveforms/cbc/test_cbc_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from torch.distributions import Uniform

import ml4gw.waveforms as waveforms
from ml4gw.waveforms.conversion import (
bilby_spins_to_lalsim,
chirp_mass_and_mass_ratio_to_components,
)
from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components


@pytest.fixture(params=[256, 1024, 2048])
Expand All @@ -21,79 +18,79 @@ def sample_rate(request):
@pytest.fixture()
def chirp_mass(request):
dist = Uniform(5, 100)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def mass_ratio():
dist = Uniform(0.125, 0.99)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def a_1(request):
dist = Uniform(0, 0.90)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def a_2(request):
dist = Uniform(0, 0.90)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def tilt_1(request):
dist = Uniform(0, torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def tilt_2(request):
dist = Uniform(0, torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def phi_12(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def phi_jl(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def distance(request):
dist = Uniform(100, 3000)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def theta_jn(request):
dist = Uniform(0, torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def phase(request):
dist = Uniform(0, 2 * torch.pi)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def chi_1(request):
dist = Uniform(-0.999, 0.999)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture()
def chi_2(request):
dist = Uniform(-0.999, 0.999)
return dist.sample((1000,))
return dist.sample((100,))


@pytest.fixture(params=[20, 40])
Expand Down Expand Up @@ -273,7 +270,7 @@ def test_phenom_d(
hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params)

# reconstruct frequencies generated by
# lal and filter based on fmin and fmax
# lal and filter based on fmin and fmax
lal_freqs = np.array(
[
hp_lal.f0 + ii * hp_lal.deltaF
Expand Down Expand Up @@ -302,6 +299,7 @@ def test_phenom_d(
)

hc_ml4gw = hc_ml4gw[0]

hp_ml4gw = hp_ml4gw[0]

hp_lal_data = hp_lal.data.data[lal_mask]
Expand All @@ -324,61 +322,46 @@ def test_phenom_d(
)


def test_phenom_p(
chirp_mass,
mass_ratio,
tilt_1,
tilt_2,
a_1,
a_2,
theta_jn,
phase,
distance,
phi_jl,
phi_12,
sample_rate,
f_ref,
):
# m1 > m2
mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)
def test_phenom_p():
chirp_mass = torch.tensor([15.0, 30.0])
mass_ratio = torch.tensor([0.99, 0.5])
chi1z = torch.tensor([0.0, 0.5])
chi2z = torch.tensor([-0.1, 0.1])
distance = torch.tensor([100.0, 1000.0])
sample_rate = 2048

inclination, s1x, s1y, s1z, s2x, s2y, s2z = bilby_spins_to_lalsim(
theta_jn,
phi_jl,
tilt_1,
tilt_2,
phi_12,
a_1,
a_2,
mass_1,
mass_2,
f_ref,
phase,
)
mass_2 = chirp_mass * (1 + mass_ratio) ** 0.2 / mass_ratio**0.6
mass_1 = mass_2 * mass_ratio

f_ref = 20.0
phic = 0.0
tc = 0.0
inclination = 0.0

for i in range(chirp_mass.shape[0]):
m1, m2 = mass_1[i], mass_2[i]
mr = mass_ratio[i]
if m2 > m1:
m1, m2 = m2, m1
mr = 1 / mr

# compare each waveform with lalsimulation
for i in range(2): # (len(chirp_mass)):
print(i, mass_ratio)
# construct lalinference params
params = dict(
m1=mass_1[i].item() * lal.MSUN_SI,
m2=mass_2[i].item() * lal.MSUN_SI,
S1x=s1x[i].item(),
S1y=s1y[i].item(),
S1z=s1z[i].item(),
S2x=s2x[i].item(),
S2y=s2y[i].item(),
S2z=s2z[i].item(),
m1=m1.item() * lal.MSUN_SI,
m2=m2.item() * lal.MSUN_SI,
S1x=0,
S1y=0,
S1z=chi1z[i].item(),
S2x=0,
S2y=0,
S2z=chi2z[i].item(),
distance=(distance[i].item() * u.Mpc).to("m").value,
inclination=inclination[i].item(),
phiRef=phase[i].item(),
inclination=inclination,
phiRef=phic,
longAscNodes=0.0,
eccentricity=0.0,
meanPerAno=0.0,
deltaF=1.0 / sample_rate,
f_min=20,
f_min=10.0,
f_ref=f_ref,
f_max=300,
approximant=lalsimulation.IMRPhenomPv2,
Expand All @@ -402,57 +385,41 @@ def test_phenom_p(
lal_freqs = lal_freqs[lal_mask]
torch_freqs = torch.tensor(lal_freqs, dtype=torch.float32)

# generate waveforms using ml4gw
hc_ml4gw, hp_ml4gw = waveforms.IMRPhenomPv2()(
torch_freqs,
chirp_mass[i][None],
mass_ratio[i][None],
s1x[i][None],
s1y[i][None],
s1z[i][None],
s2x[i][None],
s2y[i][None],
s2z[i][None],
torch.tensor([mr]),
torch.tensor([0.0]),
torch.tensor([0.0]),
chi1z[i][None],
torch.tensor([0.0]),
torch.tensor([0.0]),
chi2z[i][None],
distance[i][None],
phase[i][None],
inclination[i][None],
torch.tensor([phic]),
torch.tensor([inclination]),
f_ref,
torch.tensor([tc]),
)

hc_ml4gw = hc_ml4gw[0]
hp_ml4gw = hp_ml4gw[0]
hc_ml4gw = hc_ml4gw[0]

hp_lal_data = hp_lal.data.data[lal_mask]
hc_lal_data = hc_lal.data.data[lal_mask]

np.savetxt("lal_hp.txt", hp_lal_data.real)
np.savetxt("ml4gw_hp.txt", hp_ml4gw.real.numpy())

np.savetxt("lal_hp_imag.txt", hp_lal_data.imag)
np.savetxt("ml4gw_hp_imag.txt", hp_ml4gw.imag.numpy())

print(
i,
s1x[i],
s1y[i],
s1z[i],
s2x[i],
s2y[i],
s2z[i],
inclination[i],
phase[i],
)
np.savetxt("hp_lal_data", hp_lal_data.real)
np.savetxt("hp_ml4gw_data", hp_ml4gw.real.numpy())

assert np.allclose(
1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=1e-2
1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=2e-3
)
assert np.allclose(
1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=1e-3
1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=2e-3
)
assert np.allclose(
1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=1e-3
1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=2e-3
)
assert np.allclose(
1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=1e-3
1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=2e-3
)
# assert False

0 comments on commit d0c44af

Please sign in to comment.