From 4edecb5cfde30569b00a06885be962f2d719122f Mon Sep 17 00:00:00 2001 From: Ethan Marx <61295922+EthanMarx@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:19:43 -0500 Subject: [PATCH] Fix bug in `IMRPhenomD` fmax calc (#172) * fix nan issue in fmax calc * update tests to use wider distribution * remove debug save --- ml4gw/waveforms/cbc/phenom_d.py | 25 +- tests/waveforms/cbc/test_cbc_waveforms.py | 677 ++++++++++++---------- 2 files changed, 384 insertions(+), 318 deletions(-) diff --git a/ml4gw/waveforms/cbc/phenom_d.py b/ml4gw/waveforms/cbc/phenom_d.py index 6630e925..9a883825 100644 --- a/ml4gw/waveforms/cbc/phenom_d.py +++ b/ml4gw/waveforms/cbc/phenom_d.py @@ -590,18 +590,19 @@ def fring_fdamp(self, eta, eta2, chi1, chi2): return fRD, fDM def fmaxCalc(self, fRD, fDM, gamma2, gamma3): - res = torch.zeros_like(gamma2) - res = torch.abs(fRD + (-fDM * gamma3) / gamma2) * (gamma2 > 1).to( - torch.int - ) + torch.abs( - fRD - + (fDM * (-1 + torch.sqrt(1 - gamma2 * gamma2)) * gamma3) / gamma2 - ) * ( - gamma2 <= 1 - ).to( - torch.int - ) - return res + mask = gamma2 <= 1 + # calculate result for gamma2 <= 1 case + sqrt_term = torch.sqrt(1 - gamma2.pow(2)) + result_case1 = fRD + (fDM * (-1 + sqrt_term) * gamma3) / gamma2 + + # calculate result for gamma2 > 1 case + # i.e. don't add sqrt term + result_case2 = fRD + (-fDM * gamma3) / gamma2 + + # combine results using mask + result = torch.where(mask, result_case1, result_case2) + + return torch.abs(result) def _linear_interp_finspin(self, finspin): # chi is a batch of final spins i.e. torch.Size([n]) diff --git a/tests/waveforms/cbc/test_cbc_waveforms.py b/tests/waveforms/cbc/test_cbc_waveforms.py index 2d28b04d..b2ed6cad 100644 --- a/tests/waveforms/cbc/test_cbc_waveforms.py +++ b/tests/waveforms/cbc/test_cbc_waveforms.py @@ -4,352 +4,417 @@ import pytest import torch from astropy import units as u +from torch.distributions import Uniform import ml4gw.waveforms as waveforms -@pytest.fixture(params=[128, 256]) +@pytest.fixture(params=[256, 1024, 2048]) def sample_rate(request): return request.param -@pytest.fixture(params=[20.0, 30.0, 40.0]) -def mass_1(request): - return request.param +@pytest.fixture() +def chirp_mass(request): + dist = Uniform(5, 100) + return dist.sample((100,)) -@pytest.fixture(params=[15.0, 25.0, 35.0]) -def mass_2(request): - return request.param +@pytest.fixture() +def mass_ratio(): + dist = Uniform(0.125, 0.99) + return dist.sample((100,)) -@pytest.fixture(params=[15.0, 30.0]) -def chirp_mass(request): - return request.param +@pytest.fixture() +def a_1(request): + dist = Uniform(0, 0.90) + return dist.sample((100,)) -@pytest.fixture(params=[0.99, 0.5]) -def mass_ratio(request): - return request.param +@pytest.fixture() +def a_2(request): + dist = Uniform(0, 0.90) + return dist.sample((100,)) -@pytest.fixture(params=[0.0, 0.5]) -def chi1z(request): - return request.param +@pytest.fixture() +def tilt_1(request): + dist = Uniform(0, torch.pi) + return dist.sample((100,)) -@pytest.fixture(params=[-0.1, 0.1]) -def chi2z(request): - return request.param +@pytest.fixture() +def tilt_2(request): + dist = Uniform(0, torch.pi) + return dist.sample((100,)) + + +@pytest.fixture() +def phi_12(request): + dist = Uniform(0, 2 * torch.pi) + return dist.sample((100,)) -@pytest.fixture(params=[100.0, 1000.0]) +@pytest.fixture() +def phi_jl(request): + dist = Uniform(0, 2 * torch.pi) + return dist.sample((100,)) + + +@pytest.fixture() def distance(request): - return request.param + dist = Uniform(100, 3000) + return dist.sample((100,)) -@pytest.fixture(params=[100.0, 1000.0]) -def inclination(request): +@pytest.fixture() +def theta_jn(request): + dist = Uniform(0, torch.pi) + return dist.sample((100,)) + + +@pytest.fixture() +def phase(request): + dist = Uniform(0, 2 * torch.pi) + return dist.sample((100,)) + + +@pytest.fixture() +def chi_1(request): + dist = Uniform(-0.999, 0.999) + return dist.sample((100,)) + + +@pytest.fixture() +def chi_2(request): + dist = Uniform(-0.999, 0.999) + return dist.sample((100,)) + + +@pytest.fixture(params=[20, 40]) +def f_ref(request): return request.param def test_taylor_f2( - chirp_mass, mass_ratio, chi1z, chi2z, distance, inclination, sample_rate + chirp_mass, + mass_ratio, + chi_1, + chi_2, + phase, + distance, + f_ref, + theta_jn, + sample_rate, ): mass_2 = chirp_mass * (1 + mass_ratio) ** 0.2 / mass_ratio**0.6 mass_1 = mass_ratio * mass_2 - # Fix coal. phase, ref, freq. - phic, f_ref = 0.0, 25 - params = dict( - m1=mass_1 * lal.MSUN_SI, - m2=mass_2 * lal.MSUN_SI, - S1x=0, - S1y=0, - S1z=chi1z, - S2x=0, - S2y=0, - S2z=chi2z, - distance=(distance * u.Mpc).to("m").value, - inclination=inclination, - phiRef=phic, - longAscNodes=0.0, - eccentricity=0.0, - meanPerAno=0.0, - deltaF=1.0 / sample_rate, - f_min=10.0, - f_ref=f_ref, - f_max=100, - approximant=lalsimulation.TaylorF2, - LALpars=lal.CreateDict(), - ) - hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) - lal_freqs = np.array( - [hp_lal.f0 + ii * hp_lal.deltaF for ii in range(len(hp_lal.data.data))] - ) - - torch_freqs = torch.arange( - params["f_min"], params["f_max"], params["deltaF"] - ) - _params = torch.tensor( - [chirp_mass, mass_ratio, chi1z, chi2z, distance, phic, inclination] - ).repeat( - 10, 1 - ) # repeat along batch dim for testing - batched_chirp_mass = _params[:, 0] - batched_mass_ratio = _params[:, 1] - batched_chi1 = _params[:, 2] - batched_chi2 = _params[:, 3] - batched_distance = _params[:, 4] - batched_phic = _params[:, 5] - batched_inclination = _params[:, 6] - hc_torch, hp_torch = waveforms.TaylorF2()( - torch_freqs, - batched_chirp_mass, - batched_mass_ratio, - batched_chi1, - batched_chi2, - batched_distance, - batched_phic, - batched_inclination, - f_ref, - ) - - assert hp_torch.shape[0] == 10 # entire batch is returned - - # select only first element of the batch for further testing since - # all are repeated - hp_torch = hp_torch[0] - hc_torch = hc_torch[0] - # restrict between fmin and fmax - lal_mask = (lal_freqs > params["f_min"]) & (lal_freqs < params["f_max"]) - torch_mask = (torch_freqs > params["f_min"]) & ( - torch_freqs < params["f_max"] - ) - - hp_lal_data = hp_lal.data.data[lal_mask] - hc_lal_data = hc_lal.data.data[lal_mask] - hp_torch = hp_torch[torch_mask] - hc_torch = hc_torch[torch_mask] - - assert np.allclose( - 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=1e-3 - ) - assert np.allclose( - 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=1e-3 - ) - assert np.allclose( - 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=1e-3 - ) - assert np.allclose( - 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=1e-3 - ) + + # compare each waveform with lalsimulation + for i in range(len(chirp_mass)): + + # construct lalinference params + params = dict( + m1=mass_1[i].item() * lal.MSUN_SI, + m2=mass_2[i].item() * lal.MSUN_SI, + S1x=0, + S1y=0, + S1z=chi_1[i].item(), + S2x=0, + S2y=0, + S2z=chi_2[i].item(), + distance=(distance[i].item() * u.Mpc).to("m").value, + inclination=theta_jn[i].item(), + phiRef=phase[i].item(), + longAscNodes=0.0, + eccentricity=0.0, + meanPerAno=0.0, + deltaF=1.0 / sample_rate, + f_min=20, + f_ref=f_ref, + f_max=300, + approximant=lalsimulation.TaylorF2, + LALpars=lal.CreateDict(), + ) + hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) + + # reconstruct frequencies generated by + # lal and filter based on fmin and fmax + lal_freqs = np.array( + [ + hp_lal.f0 + ii * hp_lal.deltaF + for ii in range(len(hp_lal.data.data)) + ] + ) + + lal_mask = (lal_freqs > params["f_min"]) & ( + lal_freqs < params["f_max"] + ) + + lal_freqs = lal_freqs[lal_mask] + torch_freqs = torch.tensor(lal_freqs, dtype=torch.float64) + + # generate waveforms using ml4gw + hc_ml4gw, hp_ml4gw = waveforms.TaylorF2()( + torch_freqs, + chirp_mass[i][None], + mass_ratio[i][None], + chi_1[i][None], + chi_2[i][None], + distance[i][None], + phase[i][None], + theta_jn[i][None], + f_ref, + ) + + hc_ml4gw = hc_ml4gw[0] + hp_ml4gw = hp_ml4gw[0] + + hp_lal_data = hp_lal.data.data[lal_mask] + hc_lal_data = hc_lal.data.data[lal_mask] + + # ensure no nans + assert not torch.any(torch.isnan(hc_ml4gw)) + assert not torch.any(torch.isnan(hp_ml4gw)) + + assert np.allclose( + 1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=1e-3 + ) + + # taylor f2 is symmetric w.r.t m1 --> m2 flip. + # so test that the waveforms are the same when m1 and m2 + # (and corresponding chi_1, chi_2 are flipped) + # are flipped this can be done by flipping mass ratio + hc_ml4gw, hp_ml4gw = waveforms.TaylorF2()( + torch_freqs, + chirp_mass[i][None], + 1 / mass_ratio[i][None], + chi_2[i][None], + chi_1[i][None], + distance[i][None], + phase[i][None], + theta_jn[i][None], + f_ref, + ) + + hc_ml4gw = hc_ml4gw[0] + hp_ml4gw = hp_ml4gw[0] + + assert np.allclose( + 1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=1e-3 + ) def test_phenom_d( - chirp_mass, mass_ratio, chi1z, chi2z, distance, inclination, sample_rate + chirp_mass, + mass_ratio, + chi_1, + chi_2, + distance, + phase, + theta_jn, + sample_rate, + f_ref, ): total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6 mass_1 = total_mass / (1 + mass_ratio) mass_2 = mass_1 * mass_ratio - phic, f_ref = 0.0, 25 - - params = dict( - m1=mass_1 * lal.MSUN_SI, - m2=mass_2 * lal.MSUN_SI, - S1x=0, - S1y=0, - S1z=chi1z, - S2x=0, - S2y=0, - S2z=chi2z, - distance=(distance * u.Mpc).to("m").value, - inclination=inclination, - phiRef=phic, - longAscNodes=0.0, - eccentricity=0.0, - meanPerAno=0.0, - deltaF=1.0 / sample_rate, - f_min=10.0, - f_ref=f_ref, - f_max=300, - approximant=lalsimulation.IMRPhenomD, - LALpars=lal.CreateDict(), - ) - hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) - lal_freqs = np.array( - [hp_lal.f0 + ii * hp_lal.deltaF for ii in range(len(hp_lal.data.data))] - ) - - torch_freqs = torch.arange( - params["f_min"], params["f_max"], params["deltaF"] - ) - _params = torch.tensor( - [chirp_mass, mass_ratio, chi1z, chi2z, distance, phic, inclination] - ).repeat( - 10, 1 - ) # repeat along batch dim for testing - batched_chirp_mass = _params[:, 0] - batched_mass_ratio = _params[:, 1] - batched_chi1 = _params[:, 2] - batched_chi2 = _params[:, 3] - batched_distance = _params[:, 4] - batched_phic = _params[:, 5] - batched_inclination = _params[:, 6] - hc_torch, hp_torch = waveforms.IMRPhenomD()( - torch_freqs, - batched_chirp_mass, - batched_mass_ratio, - batched_chi1, - batched_chi2, - batched_distance, - batched_phic, - batched_inclination, - f_ref, - ) - - assert hp_torch.shape[0] == 10 # entire batch is returned - - # select only first element of the batch for further testing since - # all are repeated - hp_torch = hp_torch[0] - hc_torch = hc_torch[0] - # restrict between fmin and fmax - lal_mask = (lal_freqs > params["f_min"]) & (lal_freqs < params["f_max"]) - torch_mask = (torch_freqs > params["f_min"]) & ( - torch_freqs < params["f_max"] - ) - - hp_lal_data = hp_lal.data.data[lal_mask] - hc_lal_data = hc_lal.data.data[lal_mask] - hp_torch = hp_torch[torch_mask] - hc_torch = hc_torch[torch_mask] - - assert np.allclose( - 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-4 - ) - assert np.allclose( - 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-4 - ) - assert np.allclose( - 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-4 - ) - assert np.allclose( - 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-4 - ) - - -def test_phenom_p(chirp_mass, mass_ratio, chi1z, chi2z, distance, sample_rate): + + # compare each waveform with lalsimulation + for i in range(len(chirp_mass)): + + # construct lalinference params + params = dict( + m1=mass_1[i].item() * lal.MSUN_SI, + m2=mass_2[i].item() * lal.MSUN_SI, + S1x=0, + S1y=0, + S1z=chi_1[i].item(), + S2x=0, + S2y=0, + S2z=chi_2[i].item(), + distance=(distance[i].item() * u.Mpc).to("m").value, + inclination=theta_jn[i].item(), + phiRef=phase[i].item(), + longAscNodes=0.0, + eccentricity=0.0, + meanPerAno=0.0, + deltaF=1.0 / sample_rate, + f_min=20, + f_ref=f_ref, + f_max=300, + approximant=lalsimulation.IMRPhenomD, + LALpars=lal.CreateDict(), + ) + hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) + + # reconstruct frequencies generated by + # lal and filter based on fmin and fmax + lal_freqs = np.array( + [ + hp_lal.f0 + ii * hp_lal.deltaF + for ii in range(len(hp_lal.data.data)) + ] + ) + + lal_mask = (lal_freqs > params["f_min"]) & ( + lal_freqs < params["f_max"] + ) + + lal_freqs = lal_freqs[lal_mask] + torch_freqs = torch.tensor(lal_freqs, dtype=torch.float32) + + # generate waveforms using ml4gw + hc_ml4gw, hp_ml4gw = waveforms.IMRPhenomD()( + torch_freqs, + chirp_mass[i][None], + mass_ratio[i][None], + chi_1[i][None], + chi_2[i][None], + distance[i][None], + phase[i][None], + theta_jn[i][None], + f_ref, + ) + + hc_ml4gw = hc_ml4gw[0] + + hp_ml4gw = hp_ml4gw[0] + + hp_lal_data = hp_lal.data.data[lal_mask] + hc_lal_data = hc_lal.data.data[lal_mask] + + assert not torch.any(torch.isnan(hc_ml4gw)) + assert not torch.any(torch.isnan(hp_ml4gw)) + + assert np.allclose( + 1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=1e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=1e-3 + ) + + +def test_phenom_p(): + chirp_mass = torch.tensor([15.0, 30.0]) + mass_ratio = torch.tensor([0.99, 0.5]) + chi1z = torch.tensor([0.0, 0.5]) + chi2z = torch.tensor([-0.1, 0.1]) + distance = torch.tensor([100.0, 1000.0]) + sample_rate = 2048 + mass_2 = chirp_mass * (1 + mass_ratio) ** 0.2 / mass_ratio**0.6 mass_1 = mass_2 * mass_ratio - if mass_2 > mass_1: - mass_1, mass_2 = mass_2, mass_1 - mass_ratio = 1 / mass_ratio + f_ref = 20.0 phic = 0.0 tc = 0.0 inclination = 0.0 - params = dict( - m1=mass_1 * lal.MSUN_SI, - m2=mass_2 * lal.MSUN_SI, - S1x=0, - S1y=0, - S1z=chi1z, - S2x=0, - S2y=0, - S2z=chi2z, - distance=(distance * u.Mpc).to("m").value, - inclination=inclination, - phiRef=phic, - longAscNodes=0.0, - eccentricity=0.0, - meanPerAno=0.0, - deltaF=1.0 / sample_rate, - f_min=10.0, - f_ref=f_ref, - f_max=300, - approximant=lalsimulation.IMRPhenomPv2, - LALpars=lal.CreateDict(), - ) - hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) - lal_freqs = np.array( - [hp_lal.f0 + ii * hp_lal.deltaF for ii in range(len(hp_lal.data.data))] - ) - - torch_freqs = torch.arange( - params["f_min"], params["f_max"], params["deltaF"] - ) - _params = torch.tensor( - [ - chirp_mass, - mass_ratio, - 0, - 0, - chi1z, - 0, - 0, - chi2z, - distance, - tc, - phic, - inclination, - ] - ).repeat(10, 1) - # repeat along batch dim for testing - batched_chirp_mass = _params[:, 0] - batched_mass_ratio = _params[:, 1] - batched_chi1x = _params[:, 2] - batched_chi1y = _params[:, 3] - batched_chi1z = _params[:, 4] - batched_chi2x = _params[:, 5] - batched_chi2y = _params[:, 6] - batched_chi2z = _params[:, 7] - batched_distance = _params[:, 8] - batched_tc = _params[:, 9] - batched_phic = _params[:, 10] - batched_inclination = _params[:, 11] - hc_torch, hp_torch = waveforms.IMRPhenomPv2()( - torch_freqs, - batched_chirp_mass, - batched_mass_ratio, - batched_chi1x, - batched_chi1y, - batched_chi1z, - batched_chi2x, - batched_chi2y, - batched_chi2z, - batched_distance, - batched_phic, - batched_inclination, - f_ref, - batched_tc, - ) - - assert hp_torch.shape[0] == 10 # entire batch is returned - - # select only first element of the batch for further testing since - # all are repeated - hp_torch = hp_torch[0] - hc_torch = hc_torch[0] - # restrict between fmin and fmax - lal_mask = (lal_freqs > params["f_min"]) & (lal_freqs < params["f_max"]) - torch_mask = (torch_freqs > params["f_min"]) & ( - torch_freqs < params["f_max"] - ) - - hp_lal_data = hp_lal.data.data[lal_mask] - hc_lal_data = hc_lal.data.data[lal_mask] - hp_torch = hp_torch[torch_mask] - hc_torch = hc_torch[torch_mask] - - assert np.allclose( - 1e21 * hp_lal_data.real, 1e21 * hp_torch.real.numpy(), atol=2e-3 - ) - assert np.allclose( - 1e21 * hp_lal_data.imag, 1e21 * hp_torch.imag.numpy(), atol=2e-3 - ) - assert np.allclose( - 1e21 * hc_lal_data.real, 1e21 * hc_torch.real.numpy(), atol=2e-3 - ) - assert np.allclose( - 1e21 * hc_lal_data.imag, 1e21 * hc_torch.imag.numpy(), atol=2e-3 - ) + + for i in range(chirp_mass.shape[0]): + m1, m2 = mass_1[i], mass_2[i] + mr = mass_ratio[i] + if m2 > m1: + m1, m2 = m2, m1 + mr = 1 / mr + + params = dict( + m1=m1.item() * lal.MSUN_SI, + m2=m2.item() * lal.MSUN_SI, + S1x=0, + S1y=0, + S1z=chi1z[i].item(), + S2x=0, + S2y=0, + S2z=chi2z[i].item(), + distance=(distance[i].item() * u.Mpc).to("m").value, + inclination=inclination, + phiRef=phic, + longAscNodes=0.0, + eccentricity=0.0, + meanPerAno=0.0, + deltaF=1.0 / sample_rate, + f_min=10.0, + f_ref=f_ref, + f_max=300, + approximant=lalsimulation.IMRPhenomPv2, + LALpars=lal.CreateDict(), + ) + hp_lal, hc_lal = lalsimulation.SimInspiralChooseFDWaveform(**params) + + # reconstruct frequencies generated by + # lal and filter based on fmin and fmax + lal_freqs = np.array( + [ + hp_lal.f0 + ii * hp_lal.deltaF + for ii in range(len(hp_lal.data.data)) + ] + ) + + lal_mask = (lal_freqs > params["f_min"]) & ( + lal_freqs < params["f_max"] + ) + + lal_freqs = lal_freqs[lal_mask] + torch_freqs = torch.tensor(lal_freqs, dtype=torch.float32) + + hc_ml4gw, hp_ml4gw = waveforms.IMRPhenomPv2()( + torch_freqs, + chirp_mass[i][None], + torch.tensor([mr]), + torch.tensor([0.0]), + torch.tensor([0.0]), + chi1z[i][None], + torch.tensor([0.0]), + torch.tensor([0.0]), + chi2z[i][None], + distance[i][None], + torch.tensor([phic]), + torch.tensor([inclination]), + f_ref, + torch.tensor([tc]), + ) + + hp_ml4gw = hp_ml4gw[0] + hc_ml4gw = hc_ml4gw[0] + + hp_lal_data = hp_lal.data.data[lal_mask] + hc_lal_data = hc_lal.data.data[lal_mask] + + assert np.allclose( + 1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=2e-3 + ) + assert np.allclose( + 1e21 * hp_lal_data.imag, 1e21 * hp_ml4gw.imag.numpy(), atol=2e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.real, 1e21 * hc_ml4gw.real.numpy(), atol=2e-3 + ) + assert np.allclose( + 1e21 * hc_lal_data.imag, 1e21 * hc_ml4gw.imag.numpy(), atol=2e-3 + )