Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #178

Merged
merged 13 commits into from
Jan 9, 2025
2 changes: 1 addition & 1 deletion ml4gw/transforms/spline_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def bivariate_spline_fit_natural(self, Z):
return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT

# Adding batch/channel dimension handling
# ByT @ Z @ Bx
# ByT @ Z @ BxW
ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
# (ByT @ By)^-1 @ (ByT @ Z @ Bx) = By^-1 @ Z @ Bx
E = torch.linalg.solve(self.ByT_By, ByT_Z_Bx)
Expand Down
86 changes: 71 additions & 15 deletions ml4gw/waveforms/cbc/phenom_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,30 @@ def phenom_d_amp(
chi22,
xi,
distance,
fRD=None, # used for passing ringdown frequency from phenom_p
fDM=None, # used for passing damping frequency from phenom_p
):
ins_amp, ins_Damp = self.phenom_d_inspiral_amp(
Mf, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
)
int_amp, int_Damp = self.phenom_d_int_amp(
Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi
Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, fRD, fDM
)
mrd_amp, mrd_Damp = self.phenom_d_mrd_amp(
Mf, eta, eta2, chi1, chi2, xi
Mf, eta, eta2, chi1, chi2, xi, fRD, fDM
)

gamma2 = self.gamma2_fun(eta, eta2, xi)
gamma3 = self.gamma3_fun(eta, eta2, xi)
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

# merger ringdown
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

Mf_peak = self.fmaxCalc(fRD, fDM, gamma2, gamma3)
# Geometric peak and joining frequencies
Mf_peak = (torch.ones_like(Mf).mT * Mf_peak).mT
Expand Down Expand Up @@ -201,10 +211,27 @@ def phenom_d_amp(
return amp, Damp

def phenom_d_int_amp(
self, Mf, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi
self,
Mf,
eta,
eta2,
Seta,
chi1,
chi2,
chi12,
chi22,
xi,
fRD=None, # used for passing ringdown frequency from phenom_p
fDM=None, # used for passing damping frequency from phenom_p
):
# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

# Geometric frequency definition from PhenomD header file
AMP_fJoin_INS = 0.014

Expand All @@ -220,7 +247,9 @@ def phenom_d_int_amp(
v1, d1 = self.phenom_d_inspiral_amp(
Mf1, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
)
v3, d2 = self.phenom_d_mrd_amp(Mf3, eta, eta2, chi1, chi2, xi)
v3, d2 = self.phenom_d_mrd_amp(
Mf3, eta, eta2, chi1, chi2, xi, fRD, fDM
)
v2 = (
torch.ones_like(Mf).mT * self.AmpIntColFitCoeff(eta, eta2, xi)
).mT
Expand All @@ -239,9 +268,18 @@ def phenom_d_int_amp(
)
return amp, Damp

def phenom_d_mrd_amp(self, Mf, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_mrd_amp(
self, Mf, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

gamma1 = self.gamma1_fun(eta, eta2, xi)
gamma2 = self.gamma2_fun(eta, eta2, xi)
Expand Down Expand Up @@ -384,17 +422,26 @@ def phenom_d_inspiral_amp(

return amp, Damp

def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_phase(
self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
ins_phase, ins_Dphase = self.phenom_d_inspiral_phase(
Mf, mass_1, mass_2, eta, eta2, xi, chi1, chi2
)
int_phase, int_Dphase = self.phenom_d_int_phase(Mf, eta, eta2, xi)
mrd_phase, mrd_Dphase = self.phenom_d_mrd_phase(
Mf, eta, eta2, chi1, chi2, xi
Mf, eta, eta2, chi1, chi2, xi, fRD, fDM
)

# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
# definitions in Eq. (35) of arXiv:1508.07253
# PHI_fJoin_INS in header LALSimIMRPhenomD.h
# C1 continuity at intermediate region i.e. f_1
Expand All @@ -415,7 +462,7 @@ def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):
fRDJoin, eta, eta2, xi
)
mrd_phase_rd, mrd_Dphase_rd = self.phenom_d_mrd_phase(
fRDJoin, eta, eta2, chi1, chi2, xi
fRDJoin, eta, eta2, chi1, chi2, xi, fRD, fDM
)
PhiIntTempVal = (int_phase_rd.mT / eta).mT + C1Int + C2Int * fRDJoin
# C2MRD = int_Dphase_rd - mrd_Dphase_rd
Expand Down Expand Up @@ -454,17 +501,26 @@ def phenom_d_phase(self, Mf, mass_1, mass_2, eta, eta2, chi1, chi2, xi):

return phasing, Dphasing

def phenom_d_mrd_phase(self, Mf, eta, eta2, chi1, chi2, xi):
# fRD and fDM are to be passed for generating phenom_p waveforms
# and remain None for phenom_d
def phenom_d_mrd_phase(
self, Mf, eta, eta2, chi1, chi2, xi, fRD=None, fDM=None
):
alpha1 = self.alpha1Fit(eta, eta2, xi)
alpha2 = self.alpha2Fit(eta, eta2, xi)
alpha3 = self.alpha3Fit(eta, eta2, xi)
alpha4 = self.alpha4Fit(eta, eta2, xi)
alpha5 = self.alpha5Fit(eta, eta2, xi)

# merger ringdown
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)
f_minus_alpha5_fRD = (Mf.t() - alpha5 * fRD).t()
if (fRD is None) != (fDM is None):
raise ValueError(
"Both fRD and fDM must either be provided or both be None"
)
if (fRD is None) and (fDM is None):
fRD, fDM = self.fring_fdamp(eta, eta2, chi1, chi2)

f_minus_alpha5_fRD = (Mf.t() - alpha5 * fRD).t()
# Leading 1/eta is not multiplied at this stage
mrd_phasing = (Mf.t() * alpha1).t()
mrd_phasing -= (1 / Mf.t() * alpha2).t()
Expand Down
77 changes: 52 additions & 25 deletions ml4gw/waveforms/cbc/phenom_p.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
"""
Based on the JAX implementation of IMRPhenomPv2 from
https://github.com/tedwards2412/ripple/blob/main/src/ripplegw/waveforms/IMRPhenomPv2.py
"""

from typing import Dict, Optional, Tuple

import torch
from jaxtyping import Float
from torch import Tensor

from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
from ml4gw.types import BatchTensor, FrequencySeries1d
from ml4gw.waveforms.conversion import rotate_y, rotate_z

from ...constants import MPC_SEC, MTSUN_SI, PI
from ...types import BatchTensor, FrequencySeries1d
from ..conversion import (
chirp_mass_and_mass_ratio_to_components,
rotate_y,
rotate_z,
)
from .phenom_d import IMRPhenomD


Expand Down Expand Up @@ -75,15 +83,15 @@ def forward(
if tc is None:
tc = torch.zeros_like(chirp_mass)

m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
m1 = m2 * mass_ratio
m1, m2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)

# # flip m1 m2. For some reason LAL uses this convention for PhenomPv2
m1, m2 = m2, m1
s1x, s2x = s2x, s1x
s1y, s2y = s2y, s1y
s1z, s2z = s2z, s1z

(
chi1_l,
chi2_l,
Expand Down Expand Up @@ -320,42 +328,61 @@ def PhenomPOneFrequency(
phic: Orbital phase at peak of the underlying non precessing model
M: Total mass (Solar masses)
"""

M_s = M * MTSUN_SI
Mf = torch.outer(M_s, fs)
fRD, _ = self.phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip)
fRD, fDM = self.phP_get_fRD_fdamp(m1, m2, chi1, chi2, chip)
# pass M_s * ringdown and M_s * damping frequency to PhenomD functions
MfRD, MfDM = M_s * fRD, M_s * fDM

phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi)
phase, _ = self.phenom_d_phase(
Mf, m1, m2, eta, eta2, chi1, chi2, xi, MfRD, MfDM
)
phase = (phase.mT - (phic + PI / 4.0)).mT
# why are they subtracting 2*phic?
# https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimIMRPhenomP.c#L1316

Amp = self.phenom_d_amp(
Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, distance
Mf,
m1,
m2,
eta,
eta2,
Seta,
chi1,
chi2,
chi12,
chi22,
xi,
distance,
MfRD,
MfDM,
)[0]
Amp0 = self.get_Amp0(Mf, eta)
dist_s = distance * MPC_SEC
Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT
# phase -= 2. * phic; # line 1316 ???

hPhenom = Amp * (torch.exp(-1j * phase))

fRDs = torch.outer(
fRD, torch.linspace(0.5, 1.5, 101, device=fRD.device)
)
delta_fRds = torch.median(torch.diff(fRDs, axis=1), axis=1)[0]
# calculating derivative of phase with frequency following
# https://git.ligo.org/lscsoft/lalsuite/-/blame/master/lalsimulation/lib/LALSimIMRPhenomP.c?page=2#L1057 # noqa: E501
n_fixed = 1000
x = torch.linspace(0.8, 1.2, n_fixed, device=fRD.device)
fRDs = torch.outer(fRD, x)
delta_fRds = (1.2 * fRD - 0.8 * fRD) / (n_fixed - 1)
MfRDs = torch.zeros_like(fRDs)
for i in range(fRD.shape[0]):
MfRDs[i, :] = torch.outer(M_s, fRDs[i, :])[i, :]
RD_phase = self.phenom_d_phase(
MfRDs, m1, m2, eta, eta2, chi1, chi2, xi
MfRDs, m1, m2, eta, eta2, chi1, chi2, xi, MfRD, MfDM
)[0]
diff = torch.diff(RD_phase, axis=1)
diffRDphase = (diff[:, 1:] + diff[:, :-1]) / (
2 * delta_fRds.unsqueeze(1)
)
diffRDphase = -diffRDphase[:, 50]
# MfRD = torch.outer(M_s, fRD)
# Dphase = torch.diag(
# -self.phenom_d_phase(
# MfRD, m1, m2, eta, eta2, chi1, chi2, xi)[1] * M_s
# ).view(-1, 1)
# reshape x to have same shape as diffRDphase
x = x[1:-1].unsqueeze(0).expand(diffRDphase.shape)
# interpolate at x = 1, as thats the same as f = fRD
diffRDphase = -self.interpolate(torch.tensor([1]), x, diffRDphase)
return hPhenom, diffRDphase

# Utility functions
Expand Down Expand Up @@ -752,9 +779,9 @@ def FinalSpin_inplane(
eta2 = eta * eta
# m1 > m2, the convention used in phenomD
# (not the convention of internal phenomP)
mass_ratio = m1 / m2
q_factor = m1 / M
af_parallel = self.FinalSpin0815(eta, eta2, chi1_l, chi2_l)
Sperp = chip * mass_ratio * mass_ratio
Sperp = chip * q_factor * q_factor
af = torch.copysign(
torch.ones_like(af_parallel), af_parallel
) * torch.sqrt(Sperp * Sperp + af_parallel * af_parallel)
Expand Down
Loading
Loading