From 13f61f6e5b0006113fb1d5b5d25deb8c166e4f7b Mon Sep 17 00:00:00 2001 From: Ravi Kumar Date: Mon, 6 Jan 2025 21:57:50 +0000 Subject: [PATCH] fix batch inputs bug --- ml4gw/waveforms/cbc/phenom_p.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml4gw/waveforms/cbc/phenom_p.py b/ml4gw/waveforms/cbc/phenom_p.py index 39c4d0c..c1b2b0f 100644 --- a/ml4gw/waveforms/cbc/phenom_p.py +++ b/ml4gw/waveforms/cbc/phenom_p.py @@ -379,10 +379,10 @@ def PhenomPOneFrequency( diffRDphase = (diff[:, 1:] + diff[:, :-1]) / ( 2 * delta_fRds.unsqueeze(1) ) + # reshape x to have same shape as diffRDphase + x = x[1:-1].unsqueeze(0).expand(diffRDphase.shape) # interpolate at x = 1, as thats the same as f = fRD - diffRDphase = -self.interpolate( - torch.tensor([1]), x[1:-1], diffRDphase - ) + diffRDphase = -self.interpolate(torch.tensor([1]), x, diffRDphase) return hPhenom, diffRDphase # Utility functions