Skip to content

Commit

Permalink
add taylorF2 3.5PN in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
deepchatterjeeligo committed Dec 7, 2023
1 parent ab9205d commit 00b7e99
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 0 deletions.
1 change: 1 addition & 0 deletions ml4gw/waveforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .sine_gaussian import SineGaussian
from .taylorf2 import TaylorF2
165 changes: 165 additions & 0 deletions ml4gw/waveforms/taylorf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch
from torchtyping import TensorType

GAMMA = 0.577215664901532860606512090082402431
"""Euler-Mascheroni constant. Same as lal.GAMMA"""

MSUN_SI = 1.988409870698050731911960804878414216e30
"""Solar mass in kg. Same as lal.MSUN_SI"""

MTSUN_SI = 4.925490947641266978197229498498379006e-6
"""1 solar mass in seconds. Same value as lal.MTSUN_SI"""

PI = 3.141592653589793238462643383279502884
"""Archimedes constant. Same as lal.PI"""

MPC_SEC = 1.02927125e14
"""
1 Mpc in seconds.
"""


def taylorf2_phase(
f: TensorType,
mass1: TensorType,
mass2: TensorType,
) -> TensorType:
"""
Calculate the inspiral phase for the IMRPhenomD waveform.
"""
mass1_s = mass1 * MTSUN_SI
mass2_s = mass2 * MTSUN_SI
M_s = mass1_s + mass2_s
eta = mass1_s * mass2_s / M_s / M_s

Mf = (f.T * M_s).T

v0 = torch.ones_like(Mf)
v1 = (PI * Mf) ** (1.0 / 3.0)
v2 = v1 * v1
v3 = v2 * v1
v4 = v3 * v1
v5 = v4 * v1
v6 = v5 * v1
v7 = v6 * v1
logv = torch.log(v1)
v5_logv = v5 * logv
v6_logv = v6 * logv

# Phase coeffeciencts from https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c # noqa E501
pfaN = 3.0 / (128.0 * eta)
pfa_v0 = 1.0
pfa_v0 *= pfaN
pfa_v1 = 0.0
pfa_v1 *= pfaN
pfa_v2 = 5.0 * (74.3 / 8.4 + 11.0 * eta) / 9.0
pfa_v2 *= pfaN
pfa_v3 = -16.0 * PI
pfa_v3 *= pfaN
pfa_v4 = (
5.0
* (3058.673 / 7.056 + 5429.0 / 7.0 * eta + 617.0 * eta * eta)
/ 72.0
)
pfa_v4 *= pfaN
pfa_v5logv = 5.0 / 3.0 * (772.9 / 8.4 - 13.0 * eta) * PI
pfa_v5logv *= pfaN
pfa_v5 = 5.0 / 9.0 * (772.9 / 8.4 - 13.0 * eta) * PI
pfa_v5 *= pfaN
pfa_v6logv = -684.8 / 2.1
pfa_v6 = (
11583.231236531 / 4.694215680
- 640.0 / 3.0 * PI * PI
- 684.8 / 2.1 * GAMMA
+ eta * (-15737.765635 / 3.048192 + 225.5 / 1.2 * PI * PI)
+ eta * eta * 76.055 / 1.728
- eta * eta * eta * 127.825 / 1.296
+ pfa_v6logv * torch.log(torch.tensor(4.0))
)
pfa_v6logv *= pfaN
pfa_v6 *= pfaN
pfa_v7 = PI * (
770.96675 / 2.54016 + 378.515 / 1.512 * eta - 740.45 / 7.56 * eta * eta
)
pfa_v7 *= pfaN
# construct power series
phasing = (v7.T * pfa_v7).T
phasing += (v6.T * pfa_v6 + v6_logv.T * pfa_v6logv).T
phasing += (v5.T * pfa_v5 + v5_logv.T * pfa_v5logv).T
phasing += (v4.T * pfa_v4).T
phasing += (v3.T * pfa_v3).T
phasing += (v2.T * pfa_v2).T
phasing += (v1.T * pfa_v1).T
phasing += (v0.T * pfa_v0).T
# Divide by 0PN v-dependence
phasing /= v5

return phasing


def taylorf2_amplitude(f: TensorType, mass1, mass2, distance) -> TensorType:
mass1_s = mass1 * MTSUN_SI
mass2_s = mass2 * MTSUN_SI
M_s = mass1_s + mass2_s
eta = mass1_s * mass2_s / M_s / M_s
Mf = (f.T * M_s).T
v = (PI * Mf) ** (1.0 / 3.0)
v10 = v**10

# Flux and energy coefficient at newtonian
FTaN = 32.0 * eta * eta / 5.0
dETaN = 2 * (-eta / 2.0)

amp0 = -4.0 * mass1_s * mass2_s * (PI / 12.0) ** 0.5

amp0 /= distance * MPC_SEC
flux = (v10.T * FTaN).T
dEnergy = (v.T * dETaN).T
amp = torch.sqrt(-dEnergy / flux) * v
amp = (amp.T * amp0).T

return amp


def taylorf2_htilde(f: TensorType, params: TensorType, f_ref: float):
mass1 = params[:, 0]
mass2 = params[:, 1]
distance = params[:, 2]
phic = params[:, 3]

# repeat freq across batch size
f = f.repeat([mass1.shape[0], 1])
f_ref = torch.tensor(f_ref).repeat([mass1.shape[0], 1])

Psi = taylorf2_phase(f, mass1, mass2)
Psi_ref = taylorf2_phase(f_ref, mass1, mass2)

Psi = (Psi.T - 2 * phic).T
Psi -= Psi_ref

amp0 = taylorf2_amplitude(f, mass1, mass2, distance)
h0 = amp0 * torch.exp(-1j * (Psi - PI / 4))
return h0


def TaylorF2(f: TensorType, params: TensorType, f_ref: float):
"""
TaylorF2 up to 3.5 PN in phase. SPA amplitude.
params = [mass1, mass2, chi1, chi2, D, phic, inclination]
Returns:
--------
hp, hc
"""
# shape assumed (n_batch, params)
# frequency array is repeated along batch
inclination = params[:, 4]
cfac = torch.cos(inclination)
pfac = 0.5 * (1.0 + cfac * cfac)

htilde = taylorf2_htilde(f, params, f_ref)

hp = (htilde.T * pfac).T
hc = -1j * (htilde.T * cfac).T

return hp, hc
96 changes: 96 additions & 0 deletions tests/waveforms/test_taylorf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import lal
import lalsimulation
import numpy as np
import pytest
import torch
from astropy import units as u

import ml4gw.waveforms as waveforms


@pytest.fixture(params=[2048, 4096])
def sample_rate(request):
return request.param


@pytest.fixture(params=[20.0, 30.0, 40.0])
def mass_1(request):
return request.param


@pytest.fixture(params=[15.0, 25.0, 35.0])
def mass_2(request):
return request.param


@pytest.fixture(params=[100.0, 1000.0])
def distance(request):
return request.param


@pytest.fixture(params=[100.0, 1000.0])
def inclination(request):
return request.param


def test_taylor_f2(mass_1, mass_2, distance, inclination, sample_rate):
# Fix spins and coal. phase, ref, freq.
phic, f_ref = 0.0, 15
params = dict(
m1=mass_1 * lal.MSUN_SI,
m2=mass_2 * lal.MSUN_SI,
S1x=0,
S1y=0,
S1z=0,
S2x=0,
S2y=0,
S2z=0,
distance=(distance * u.Mpc).to("m").value,
inclination=inclination,
phiRef=phic,
longAscNodes=0.0,
eccentricity=0.0,
meanPerAno=0.0,
deltaF=1.0 / sample_rate,
f_min=10.0,
f_ref=f_ref,
f_max=100,
approximant=lalsimulation.TaylorF2,
LALpars=lal.CreateDict(),
)
hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params)
lal_freqs = np.array(
[hp_lal.f0 + ii * hp_lal.deltaF for ii in range(len(hp_lal.data.data))]
)

torch_freqs = torch.arange(
params["f_min"], params["f_max"], params["deltaF"]
)
_params = torch.tensor(
[mass_1, mass_2, distance, phic, inclination]
).repeat(
10, 1
) # repeat along batch dim for testing
hp_torch, hc_torch = waveforms.TaylorF2(torch_freqs, _params, f_ref)

assert hp_torch.shape[0] == 10 # entire batch is returned

# select only first element of the batch for further testing since
# all are repeated
hp_torch = hp_torch[0]
hc_torch = hc_torch[0]
# restrict between fmin and fmax
lal_mask = (lal_freqs > params["f_min"]) & (lal_freqs < params["f_max"])
torch_mask = (torch_freqs > params["f_min"]) & (
torch_freqs < params["f_max"]
)

hp_lal_data = hp_lal.data.data[lal_mask]
hc_lal_data = hc_lal.data.data[lal_mask]
hp_torch = hp_torch[torch_mask]
hc_torch = hc_torch[torch_mask]

assert np.allclose(hp_lal_data.real, hp_torch.real)
assert np.allclose(hp_lal_data.imag, hp_torch.imag)
assert np.allclose(hc_lal_data.real, hc_torch.real)
assert np.allclose(hc_lal_data.imag, hc_torch.imag)

0 comments on commit 00b7e99

Please sign in to comment.