Skip to content

Commit

Permalink
add conversion funtion (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Dec 17, 2024
1 parent 4edecb5 commit 58c3b74
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
8 changes: 6 additions & 2 deletions ml4gw/waveforms/cbc/taylorf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
from ml4gw.constants import EulerGamma as GAMMA
from ml4gw.types import BatchTensor, FrequencySeries1d
from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components


class TaylorF2(torch.nn.Module):
Expand Down Expand Up @@ -60,8 +61,11 @@ def forward(
or phic.shape[0] != inclination.shape[0]
):
raise RuntimeError("Tensors should have same batch size")
mass2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
mass1 = mass_ratio * mass2

mass1, mass2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

cfac = torch.cos(inclination)
pfac = 0.5 * (1.0 + cfac * cfac)

Expand Down
17 changes: 17 additions & 0 deletions ml4gw/waveforms/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def XLALSimInspiralL_2PN(eta: BatchTensor):
return 1.5 + eta / 6.0


def chirp_mass_and_mass_ratio_to_components(
chirp_mass: BatchTensor, mass_ratio: BatchTensor
):
"""
Compute component masses from chirp mass and mass ratio.
Args:
chirp_mass: Tensor of chirp mass values
mass_ratio:
Tensor of mass ratio values, `m2 / m1`,
where m1 >= m2, so that mass_ratio <= 1
"""
total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
mass_1 = total_mass / (1 + mass_ratio)
mass_2 = mass_1 * mass_ratio
return mass_1, mass_2


def bilby_spins_to_lalsim(
theta_jn: BatchTensor,
phi_jl: BatchTensor,
Expand Down
12 changes: 7 additions & 5 deletions tests/waveforms/cbc/test_cbc_waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributions import Uniform

import ml4gw.waveforms as waveforms
from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components


@pytest.fixture(params=[256, 1024, 2048])
Expand Down Expand Up @@ -108,8 +109,9 @@ def test_taylor_f2(
theta_jn,
sample_rate,
):
mass_2 = chirp_mass * (1 + mass_ratio) ** 0.2 / mass_ratio**0.6
mass_1 = mass_ratio * mass_2
mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

# compare each waveform with lalsimulation
for i in range(len(chirp_mass)):
Expand Down Expand Up @@ -235,9 +237,9 @@ def test_phenom_d(
sample_rate,
f_ref,
):
total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
mass_1 = total_mass / (1 + mass_ratio)
mass_2 = mass_1 * mass_ratio
mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

# compare each waveform with lalsimulation
for i in range(len(chirp_mass)):
Expand Down

0 comments on commit 58c3b74

Please sign in to comment.