Skip to content

Commit

Permalink
IMRPhenomPv2 fix (#176)
Browse files Browse the repository at this point in the history
* bug fix for xy spins

* run pre-commit

* pass fRD and fDM through to phenom_d phase and amp functions

* switch to relative imports

* use chirp mass to mass ratio conversion function

* create tests for phenom_p

* lower tolerance for phenom_p

* comments and additional checks for fRD and fDM
  • Loading branch information
ravioli1369 authored and EthanMarx committed Jan 4, 2025
1 parent 608d4a5 commit 4ceecbd
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 80 deletions.
86 changes: 71 additions & 15 deletions ml4gw/waveforms/cbc/phenom_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,30 @@ def phenom_d_amp(
chi22,
xi,
distance,
fRD=None, # used for passing ringdown frequency from phenom_p
fDM=None, # used for passing damping frequency from phenom_p
):
ins_amp, ins_Damp = self.phenom_d_inspiral_amp(
Mf, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
)
int_amp, int_Damp = self.phenom_d_int_amp(
Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi
Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, fRD, fDM
)
mrd_amp, mrd_Damp = self.phenom_d_mrd_amp(
Mf, eta, eta2, chi1, chi2, xi
Mf, eta, eta2, chi1, chi2, xi, fRD, fDM
)

gamma2 = self.gamma2_fun(eta, eta2, xi)
gamma3 = self.gamma3_fun(eta, eta2, xi)
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

# merger ringdown
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

Mf_peak = self.fmaxCalc(fRD, fDM, gamma2, gamma3)
# Geometric peak and joining frequencies
Mf_peak = (torch.ones_like(Mf).mT * Mf_peak).mT
Expand Down Expand Up @@ -201,10 +211,27 @@ def phenom_d_amp(
return amp, Damp

def phenom_d_int_amp(
self, Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi
self,
Mf,
eta,
eta2,
Seta,
chi1,
chi2,
chi12,
chi22,
xi,
fRD=None, # used for passing ringdown frequency from phenom_p
fDM=None, # used for passing damping frequency from phenom_p
):
# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

# Geometric frequency definition from PhenomD header file
AMP_fJoin_INS = 0.014

Expand All @@ -220,7 +247,9 @@ def phenom_d_int_amp(
v1, d1 = self.phenom_d_inspiral_amp(
Mf1, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
)
v3, d2 = self.phenom_d_mrd_amp(Mf3, eta, eta2, chi1, chi2, xi)
v3, d2 = self.phenom_d_mrd_amp(
Mf3, eta, eta2, chi1, chi2, xi, fRD, fDM
)
v2 = (
torch.ones_like(Mf).mT * self.AmpIntColFitCoeff(eta, eta2, xi)
).mT
Expand All @@ -239,9 +268,18 @@ def phenom_d_int_amp(
)
return amp, Damp

def phenom_d_mrd_amp(self, Mf, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_mrd_amp(
self, Mf, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

gamma1 = self.gamma1_fun(eta, eta2, xi)
gamma2 = self.gamma2_fun(eta, eta2, xi)
Expand Down Expand Up @@ -384,17 +422,26 @@ def phenom_d_inspiral_amp(

return amp, Damp

def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_phase(
self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
ins_phase, ins_Dphase = self.phenom_d_inspiral_phase(
Mf, mass_1, mass_2, eta, eta2, xi, chi1, chi2
)
int_phase, int_Dphase = self.phenom_d_int_phase(Mf, eta, eta2, xi)
mrd_phase, mrd_Dphase = self.phenom_d_mrd_phase(
Mf, eta, eta2, chi1, chi2, xi
Mf, eta, eta2, chi1, chi2, xi, fRD, fDM
)

# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
# definitions in Eq. (35) of arXiv:1508.07253
# PHI_fJoin_INS in header LALSimIMRPhenomD.h
# C1 continuity at intermediate region i.e. f_1
Expand All @@ -415,7 +462,7 @@ def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):
fRDJoin, eta, eta2, xi
)
mrd_phase_rd, mrd_Dphase_rd = self.phenom_d_mrd_phase(
fRDJoin, eta, eta2, chi1, chi2, xi
fRDJoin, eta, eta2, chi1, chi2, xi, fRD, fDM
)
PhiIntTempVal = (int_phase_rd.mT / eta).mT + C1Int + C2Int * fRDJoin
# C2MRD = int_Dphase_rd - mrd_Dphase_rd
Expand Down Expand Up @@ -454,17 +501,26 @@ def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):

return phasing, Dphasing

def phenom_d_mrd_phase(self, Mf, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_mrd_phase(
self, Mf, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
alpha1 = self.alpha1Fit(eta, eta2, xi)
alpha2 = self.alpha2Fit(eta, eta2, xi)
alpha3 = self.alpha3Fit(eta, eta2, xi)
alpha4 = self.alpha4Fit(eta, eta2, xi)
alpha5 = self.alpha5Fit(eta, eta2, xi)

# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
f_minus_alpha5_fRD = (Mf.t() - alpha5 * fRD).t()
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

f_minus_alpha5_fRD = (Mf.t() - alpha5 * fRD).t()
# Leading 1/eta is not multiplied at this stage
mrd_phasing = (Mf.t() * alpha1).t()
mrd_phasing -= (1 / Mf.t() * alpha2).t()
Expand Down
77 changes: 52 additions & 25 deletions ml4gw/waveforms/cbc/phenom_p.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""
Based on the JAX implementation of IMRPhenomPv2 from
https://github.com/tedwards2412/ripple/blob/main/src/ripplegw/waveforms/IMRPhenomPv2.py
"""

from typing import Dict, Optional, Tuple

import torch
from jaxtyping import Float
from torch import Tensor

from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
from ml4gw.types import BatchTensor, FrequencySeries1d
from ml4gw.waveforms.conversion import rotate_y, rotate_z

from ...constants import MPC_SEC, MTSUN_SI, PI
from ...types import BatchTensor, FrequencySeries1d
from ..conversion import (
chirp_mass_and_mass_ratio_to_components,
rotate_y,
rotate_z,
)
from .phenom_d import IMRPhenomD


Expand Down Expand Up @@ -75,15 +83,15 @@ def forward(
if tc is None:
tc = torch.zeros_like(chirp_mass)

m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
m1 = m2 * mass_ratio
m1, m2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

# # flip m1 m2. For some reason LAL uses this convention for PhenomPv2
m1, m2 = m2, m1
s1x, s2x = s2x, s1x
s1y, s2y = s2y, s1y
s1z, s2z = s2z, s1z

(
chi1_l,
chi2_l,
Expand Down Expand Up @@ -320,42 +328,61 @@ def PhenomPOneFrequency(
phic: Orbital phase at peak of the underlying non precessing model
M: Total mass (Solar masses)
"""

M_s = M * MTSUN_SI
Mf = torch.outer(M_s, fs)
fRD, _ = self.phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip)
fRD, fDM = self.phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip)
# pass M_s * ringdown and M_s * damping frequency to PhenomD functions
MfRD, MfDM = M_s * fRD, M_s * fDM

phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi)
phase, _ = self.phenom_d_phase(
Mf, m1, m2, eta, eta2, chi1, chi2, xi, MfRD, MfDM
)
phase = (phase.mT - (phic + PI / 4.0)).mT
# why are they subtracting 2*phic?
# https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimIMRPhenomP.c#L1316

Amp = self.phenom_d_amp(
Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, distance
Mf,
m1,
m2,
eta,
eta2,
Seta,
chi1,
chi2,
chi12,
chi22,
xi,
distance,
MfRD,
MfDM,
)[0]
Amp0 = self.get_Amp0(Mf, eta)
dist_s = distance * MPC_SEC
Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT
# phase -= 2. * phic; # line 1316 ???

hPhenom = Amp * (torch.exp(-1j * phase))

fRDs = torch.outer(
fRD, torch.linspace(0.5, 1.5, 101, device=fRD.device)
)
delta_fRds = torch.median(torch.diff(fRDs, axis=1), axis=1)[0]
# calculating derivative of phase with frequency following
# https://git.ligo.org/lscsoft/lalsuite/-/blame/master/lalsimulation/lib/LALSimIMRPhenomP.c?page=2#L1057 # noqa: E501
n_fixed = 1000
x = torch.linspace(0.8, 1.2, n_fixed, device=fRD.device)
fRDs = torch.outer(fRD, x)
delta_fRds = (1.2 * fRD - 0.8 * fRD) / (n_fixed - 1)
MfRDs = torch.zeros_like(fRDs)
for i in range(fRD.shape[0]):
MfRDs[i, :] = torch.outer(M_s, fRDs[i, :])[i, :]
RD_phase = self.phenom_d_phase(
MfRDs, m1, m2, eta, eta2, chi1, chi2, xi
MfRDs, m1, m2, eta, eta2, chi1, chi2, xi, MfRD, MfDM
)[0]
diff = torch.diff(RD_phase, axis=1)
diffRDphase = (diff[:, 1:] + diff[:, :-1]) / (
2 * delta_fRds.unsqueeze(1)
)
diffRDphase = -diffRDphase[:, 50]
# MfRD = torch.outer(M_s, fRD)
# Dphase = torch.diag(
# -self.phenom_d_phase(
# MfRD, m1, m2, eta, eta2, chi1, chi2, xi)[1] * M_s
# ).view(-1, 1)
# interpolate at x = 1, as thats the same as f = fRD
diffRDphase = -self.interpolate(
torch.tensor([1]), x[1:-1], diffRDphase
)
return hPhenom, diffRDphase

# Utility functions
Expand Down Expand Up @@ -752,9 +779,9 @@ def FinalSpin_inplane(
eta2 = eta * eta
# m1 > m2, the convention used in phenomD
# (not the convention of internal phenomP)
mass_ratio = m1 / m2
q_factor = m1 / M
af_parallel = self.FinalSpin0815(eta, eta2, chi1_l, chi2_l)
Sperp = chip * mass_ratio * mass_ratio
Sperp = chip * q_factor * q_factor
af = torch.copysign(
torch.ones_like(af_parallel), af_parallel
) * torch.sqrt(Sperp * Sperp + af_parallel * af_parallel)
Expand Down
Loading

0 comments on commit 4ceecbd

Please sign in to comment.