Skip to content

Commit

Permalink
make tests more robust; fix bug in phenomd implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Dec 16, 2024
1 parent fdebbb5 commit 962424d
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 327 deletions.
45 changes: 25 additions & 20 deletions ml4gw/waveforms/cbc/phenom_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def phenom_d_htilde(
total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
mass_1 = total_mass / (1 + mass_ratio)
mass_2 = mass_1 * mass_ratio

eta = (chirp_mass / total_mass) ** (5 / 3)
eta2 = eta * eta
Seta = torch.sqrt(1.0 - 4.0 * eta)
Expand Down Expand Up @@ -122,8 +123,6 @@ def phenom_d_htilde(

amp, _ = self.phenom_d_amp(
Mf,
mass_1,
mass_2,
eta,
eta2,
Seta,
Expand All @@ -132,7 +131,6 @@ def phenom_d_htilde(
chi12,
chi22,
xi,
distance,
)

amp_0 = self.taylorf2_amplitude(
Expand All @@ -146,8 +144,6 @@ def phenom_d_htilde(
def phenom_d_amp(
self,
Mf,
mass_1,
mass_2,
eta,
eta2,
Seta,
Expand All @@ -156,7 +152,6 @@ def phenom_d_amp(
chi12,
chi22,
xi,
distance,
):
ins_amp, ins_Damp = self.phenom_d_inspiral_amp(
Mf, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
Expand Down Expand Up @@ -213,6 +208,7 @@ def phenom_d_int_amp(
gamma3 = self.gamma3_fun(eta, eta2, xi)

fpeak = self.fmaxCalc(fRD, fDM, gamma2, gamma3)

Mf3 = (torch.ones_like(Mf).mT * fpeak).mT
dfx = 0.5 * (Mf3 - Mf1)
Mf2 = Mf1 + dfx
Expand All @@ -221,6 +217,7 @@ def phenom_d_int_amp(
Mf1, eta, eta2, Seta, xi, chi1, chi2, chi12, chi22
)
v3, d2 = self.phenom_d_mrd_amp(Mf3, eta, eta2, chi1, chi2, xi)

v2 = (
torch.ones_like(Mf).mT * self.AmpIntColFitCoeff(eta, eta2, xi)
).mT
Expand Down Expand Up @@ -248,7 +245,9 @@ def phenom_d_mrd_amp(self, Mf, eta, eta2, chi1, chi2, xi):
gamma3 = self.gamma3_fun(eta, eta2, xi)
fDMgamma3 = fDM * gamma3
pow2_fDMgamma3 = (torch.ones_like(Mf).mT * fDMgamma3 * fDMgamma3).mT

fminfRD = Mf - (torch.ones_like(Mf).mT * fRD).mT

exp_times_lorentzian = torch.exp(fminfRD.mT * gamma2 / fDMgamma3).mT
exp_times_lorentzian *= fminfRD**2 + pow2_fDMgamma3

Expand All @@ -257,6 +256,7 @@ def phenom_d_mrd_amp(self, Mf, eta, eta2, chi1, chi2, xi):
fminfRD * fminfRD + pow2_fDMgamma3
).mT - (gamma2 * gamma1)
Damp = Damp.mT / exp_times_lorentzian

return amp, Damp

def phenom_d_inspiral_amp(
Expand Down Expand Up @@ -580,7 +580,7 @@ def phenom_d_inspiral_phase(
return ins_phasing, ins_Dphasing

def fring_fdamp(self, eta, eta2, chi1, chi2):
finspin = self.FinalSpin0815(eta, eta2, chi1, chi2)
finspin = self.FinalSpin0815(eta, chi1, chi2)
Erad = self.PhenomInternal_EradRational0815(eta, eta2, chi1, chi2)

fRD, fDM = self._linear_interp_finspin(finspin)
Expand All @@ -590,18 +590,19 @@ def fring_fdamp(self, eta, eta2, chi1, chi2):
return fRD, fDM

def fmaxCalc(self, fRD, fDM, gamma2, gamma3):
res = torch.zeros_like(gamma2)
res = torch.abs(fRD + (-fDM * gamma3) / gamma2) * (gamma2 > 1).to(
torch.int
) + torch.abs(
fRD
+ (fDM * (-1 + torch.sqrt(1 - gamma2 * gamma2)) * gamma3) / gamma2
) * (
gamma2 <= 1
).to(
torch.int
)
return res
mask = gamma2 <= 1
# calculate result for gamma2 <= 1 case
sqrt_term = torch.sqrt(1 - gamma2.pow(2))
result_case1 = fRD + (fDM * (-1 + sqrt_term) * gamma3) / gamma2

# calculate result for gamma2 > 1 case
# i.e. don't add sqrt term
result_case2 = fRD + (-fDM * gamma3) / gamma2

# combine results using mask
result = torch.where(mask, result_case1, result_case2)

return torch.abs(result)

def _linear_interp_finspin(self, finspin):
# chi is a batch of final spins i.e. torch.Size([n])
Expand Down Expand Up @@ -1083,14 +1084,18 @@ def rho3_fun(self, eta, eta2, xi):
* xi
)

def FinalSpin0815(self, eta, eta2, chi1, chi2):
def FinalSpin0815(self, eta, chi1, chi2):
Seta = torch.sqrt(1.0 - 4.0 * eta)
Seta = torch.nan_to_num(Seta) # avoid nan around eta = 0.25
m1 = 0.5 * (1.0 + Seta)
m2 = 0.5 * (1.0 - Seta)
m1s = m1 * m1
m2s = m2 * m2
s = m1s * chi1 + m2s * chi2
return self.FinalSpin0815_s(eta, s)

def FinalSpin0815_s(self, eta, s):
eta2 = eta * eta
eta3 = eta2 * eta
s2 = s * s
s3 = s2 * s
Expand Down
9 changes: 6 additions & 3 deletions ml4gw/waveforms/cbc/taylorf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
from ml4gw.constants import EulerGamma as GAMMA
from ml4gw.types import BatchTensor, FrequencySeries1d
from ml4gw.waveforms.conversion import chirp_mass_and_mass_ratio_to_components


class TaylorF2(torch.nn.Module):
Expand Down Expand Up @@ -31,7 +32,7 @@ def forward(
chirp_mass:
Chirp mass in solar masses
mass_ratio:
Mass ratio m1/m2
Mass ratio m1 / m2, which by convention is <= 1.
chi1:
Spin of m1
chi2:
Expand Down Expand Up @@ -60,8 +61,10 @@ def forward(
or phic.shape[0] != inclination.shape[0]
):
raise RuntimeError("Tensors should have same batch size")
mass2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
mass1 = mass_ratio * mass2

mass1, mass2 = chirp_mass_and_mass_ratio_to_components(
chirp_mass, mass_ratio
)
cfac = torch.cos(inclination)
pfac = 0.5 * (1.0 + cfac * cfac)

Expand Down
16 changes: 16 additions & 0 deletions ml4gw/waveforms/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def XLALSimInspiralL_2PN(eta: BatchTensor):
return 1.5 + eta / 6.0


def chirp_mass_and_mass_ratio_to_components(
chirp_mass: BatchTensor, mass_ratio: BatchTensor
):
"""
Compute component masses from chirp mass and mass ratio.
Args:
chirp_mass: Tensor of chirp mass values
mass_ratio: Tensor of mass ratio values, `m1 / m2`, where m1 >= m2
"""
total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
mass_1 = total_mass / (1 + mass_ratio)
mass_2 = mass_1 * mass_ratio
return mass_1, mass_2


def bilby_spins_to_lalsim(
theta_jn: BatchTensor,
phi_jl: BatchTensor,
Expand Down
Loading

0 comments on commit 962424d

Please sign in to comment.