Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
deepchatterjeeligo committed Dec 7, 2023
1 parent 447a20a commit d31a4a3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 24 deletions.
52 changes: 29 additions & 23 deletions ml4gw/waveforms/taylorf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,16 @@ def taylorf2_phase(
# Phase coeffeciencts from https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c # noqa E501
pfaN = 3.0 / (128.0 * eta)
pfa_v0 = 1.0
pfa_v0 *= pfaN
pfa_v1 = 0.0
pfa_v1 *= pfaN
pfa_v2 = 5.0 * (74.3 / 8.4 + 11.0 * eta) / 9.0
pfa_v2 *= pfaN
pfa_v3 = -16.0 * PI
pfa_v3 *= pfaN
pfa_v4 = (
5.0
* (3058.673 / 7.056 + 5429.0 / 7.0 * eta + 617.0 * eta * eta)
/ 72.0
)
pfa_v4 *= pfaN
pfa_v5logv = 5.0 / 3.0 * (772.9 / 8.4 - 13.0 * eta) * PI
pfa_v5logv *= pfaN
pfa_v5 = 5.0 / 9.0 * (772.9 / 8.4 - 13.0 * eta) * PI
pfa_v5 *= pfaN
pfa_v6logv = -684.8 / 2.1
pfa_v6 = (
11583.231236531 / 4.694215680
Expand All @@ -76,12 +69,9 @@ def taylorf2_phase(
- eta * eta * eta * 127.825 / 1.296
+ pfa_v6logv * torch.log(torch.tensor(4.0))
)
pfa_v6logv *= pfaN
pfa_v6 *= pfaN
pfa_v7 = PI * (
770.96675 / 2.54016 + 378.515 / 1.512 * eta - 740.45 / 7.56 * eta * eta
)
pfa_v7 *= pfaN
# construct power series
phasing = (v7.T * pfa_v7).T
phasing += (v6.T * pfa_v6 + v6_logv.T * pfa_v6logv).T
Expand All @@ -93,6 +83,8 @@ def taylorf2_phase(
phasing += (v0.T * pfa_v0).T
# Divide by 0PN v-dependence
phasing /= v5
# Multiply by 0PN coefficient
phasing = (phasing.T * pfaN).T

return phasing

Expand Down Expand Up @@ -121,13 +113,15 @@ def taylorf2_amplitude(f: TensorType, mass1, mass2, distance) -> TensorType:
return amp


def taylorf2_htilde(f: TensorType, params: TensorType, f_ref: float):
mass1 = params[:, 0]
mass2 = params[:, 1]
distance = params[:, 2]
phic = params[:, 3]

# repeat freq across batch size
def taylorf2_htilde(
f: TensorType,
mass1: TensorType,
mass2: TensorType,
distance: TensorType,
phic: TensorType,
f_ref: float,
):
# frequency array is repeated along batch
f = f.repeat([mass1.shape[0], 1])
f_ref = torch.tensor(f_ref).repeat([mass1.shape[0], 1])

Expand All @@ -142,22 +136,34 @@ def taylorf2_htilde(f: TensorType, params: TensorType, f_ref: float):
return h0


def TaylorF2(f: TensorType, params: TensorType, f_ref: float):
def TaylorF2(
f: TensorType,
mass1: TensorType,
mass2: TensorType,
distance: TensorType,
phic: TensorType,
inclination: TensorType,
f_ref: float,
):
"""
TaylorF2 up to 3.5 PN in phase. SPA amplitude.
params = [mass1, mass2, D, phic, inclination]
TaylorF2 up to 3.5 PN in phase. Newtonian SPA amplitude.
Returns:
--------
hp, hc
"""
# shape assumed (n_batch, params)
# frequency array is repeated along batch
inclination = params[:, 4]
if (
mass1.shape[0] != mass2.shape[0]
or mass2.shape[0] != distance.shape[0]
or distance.shape[0] != phic.shape[0]
or phic.shape[0] != inclination.shape[0]
):
raise RuntimeError("Tensors should have same batch size")
cfac = torch.cos(inclination)
pfac = 0.5 * (1.0 + cfac * cfac)

htilde = taylorf2_htilde(f, params, f_ref)
htilde = taylorf2_htilde(f, mass1, mass2, distance, phic, f_ref)

hp = (htilde.T * pfac).T
hc = -1j * (htilde.T * cfac).T
Expand Down
15 changes: 14 additions & 1 deletion tests/waveforms/test_taylorf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,20 @@ def test_taylor_f2(mass_1, mass_2, distance, inclination, sample_rate):
).repeat(
10, 1
) # repeat along batch dim for testing
hp_torch, hc_torch = waveforms.TaylorF2(torch_freqs, _params, f_ref)
batched_mass1 = _params[:, 0]
batched_mass2 = _params[:, 1]
batched_distance = _params[:, 2]
batched_phic = _params[:, 3]
batched_inclination = _params[:, 4]
hp_torch, hc_torch = waveforms.TaylorF2(
torch_freqs,
batched_mass1,
batched_mass2,
batched_distance,
batched_phic,
batched_inclination,
f_ref,
)

assert hp_torch.shape[0] == 10 # entire batch is returned

Expand Down

0 comments on commit d31a4a3

Please sign in to comment.