From fdebbb5916ce81fb1259ac5eff0a05d1727ae1c0 Mon Sep 17 00:00:00 2001 From: Ethan Marx <61295922+EthanMarx@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:42:24 -0500 Subject: [PATCH] Merge `dev` into `main` in preparation for `v0.6.0` (#170) * Added spline interpolation code * Integrate spline interp into qtransform * Better case handling * Did time interpolation more efficiently * Make interpolation method an argument * Correct location of qtile stacking * Changed spline API and added data shape check * Add spline to qtransform tests * Separated data validation and added initial spline testing * Added documentation * Added error tests * Correct tolerance * Update Qscan * Fix tests * Specify output size for irfft * add option to regularize scaler std (#166) This is handy when scaling priors which have some parameters fixed, i.e. delta functions * Updates to `IMRPhenomP` api (#167) * consolidate phenomp and phenomd apis * use chirp mass for initializing tensor size * fix ordering of parameters in tests * add precessing spin conversion (#168) * add precessing spin conversion * add conversion file * restructure waveforms module * update tests * add back generator * add more robust fref check * version bump to 0.6.0 (#169) --------- Co-authored-by: William Benoit Co-authored-by: William Benoit Co-authored-by: William Benoit Co-authored-by: wbenoit26 <90333821+wbenoit26@users.noreply.github.com> Co-authored-by: Deep Chatterjee --- ml4gw/constants.py | 29 +- ml4gw/spectral.py | 2 +- ml4gw/transforms/__init__.py | 1 + ml4gw/transforms/qtransform.py | 176 +++++++-- ml4gw/transforms/scaler.py | 6 +- ml4gw/transforms/spline_interpolation.py | 370 ++++++++++++++++++ ml4gw/waveforms/__init__.py | 7 +- ml4gw/waveforms/adhoc/__init__.py | 2 + ml4gw/waveforms/{ => adhoc}/ringdown.py | 0 ml4gw/waveforms/{ => adhoc}/sine_gaussian.py | 0 ml4gw/waveforms/cbc/__init__.py | 3 + ml4gw/waveforms/{ => cbc}/phenom_d.py | 0 ml4gw/waveforms/{ => cbc}/phenom_d_data.py | 0 ml4gw/waveforms/{ => cbc}/phenom_p.py | 78 ++-- ml4gw/waveforms/{ => cbc}/taylorf2.py | 0 ml4gw/waveforms/conversion.py | 187 +++++++++ poetry.lock | 2 +- pyproject.toml | 2 +- tests/transforms/test_qtransform.py | 48 ++- tests/transforms/test_scaler.py | 10 + tests/transforms/test_spline_interpolation.py | 101 +++++ .../{ => adhoc}/test_sine_gaussian.py | 0 .../waveforms/{ => cbc}/test_cbc_waveforms.py | 2 +- tests/waveforms/test_conversion.py | 66 ++++ 24 files changed, 973 insertions(+), 119 deletions(-) create mode 100644 ml4gw/transforms/spline_interpolation.py create mode 100644 ml4gw/waveforms/adhoc/__init__.py rename ml4gw/waveforms/{ => adhoc}/ringdown.py (100%) rename ml4gw/waveforms/{ => adhoc}/sine_gaussian.py (100%) create mode 100644 ml4gw/waveforms/cbc/__init__.py rename ml4gw/waveforms/{ => cbc}/phenom_d.py (100%) rename ml4gw/waveforms/{ => cbc}/phenom_d_data.py (100%) rename ml4gw/waveforms/{ => cbc}/phenom_p.py (92%) rename ml4gw/waveforms/{ => cbc}/taylorf2.py (100%) create mode 100644 ml4gw/waveforms/conversion.py create mode 100644 tests/transforms/test_spline_interpolation.py rename tests/waveforms/{ => adhoc}/test_sine_gaussian.py (100%) rename tests/waveforms/{ => cbc}/test_cbc_waveforms.py (100%) create mode 100644 tests/waveforms/test_conversion.py diff --git a/ml4gw/constants.py b/ml4gw/constants.py index 00fe12d..c30eb65 100644 --- a/ml4gw/constants.py +++ b/ml4gw/constants.py @@ -4,42 +4,33 @@ EulerGamma = 0.577215664901532860606512090082402431 +# solar mass MSUN = 1.988409902147041637325262574352366540e30 # kg -"""Solar mass""" +# Geometrized nominal solar mass, m MRSUN = 1.476625038050124729627979840144936351e3 -"""Geometrized nominal solar mass, m""" +# Newton's gravitational constant G = 6.67430e-11 # m^3 / kg / s^2 -"""Newton's gravitational constant""" +# Speed of light C = 299792458.0 # m / s -"""Speed of light""" -"""Pi""" +# pi and 2pi PI = 3.141592653589793238462643383279502884 - TWO_PI = 6.283185307179586476925286766559005768 +# G MSUN / C^3 in seconds gt = G * MSUN / (C**3.0) -""" -G MSUN / C^3 in seconds -""" +# 1 solar mass in seconds. Same value as lal.MTSUN_SI MTSUN_SI = 4.925490947641266978197229498498379006e-6 -"""1 solar mass in seconds. Same value as lal.MTSUN_SI""" +# Meters per Mpc. m_per_Mpc = 3.085677581491367278913937957796471611e22 -""" -Meters per Mpc. -""" +# 1 Mpc in seconds. MPC_SEC = m_per_Mpc / C -""" -1 Mpc in seconds. -""" +# Speed of light in vacuum (:math:`c`), in gigaparsecs per second clightGpc = C / 3.0856778570831e22 -""" -Speed of light in vacuum (:math:`c`), in gigaparsecs per second -""" diff --git a/ml4gw/spectral.py b/ml4gw/spectral.py index 6317efd..934d666 100644 --- a/ml4gw/spectral.py +++ b/ml4gw/spectral.py @@ -441,7 +441,7 @@ def normalize_by_psd( # convert back to the time domain and normalize # TODO: what's this normalization factor? - X = torch.fft.irfft(X_tilde, norm="forward", dim=-1) + X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1) X = X.float() / sample_rate**0.5 # slice off corrupted data at edges of kernel diff --git a/ml4gw/transforms/__init__.py b/ml4gw/transforms/__init__.py index 5d381db..74a7eb8 100644 --- a/ml4gw/transforms/__init__.py +++ b/ml4gw/transforms/__init__.py @@ -4,5 +4,6 @@ from .snr_rescaler import SnrRescaler from .spectral import SpectralDensity from .spectrogram import MultiResolutionSpectrogram +from .spline_interpolation import SplineInterpolate from .waveforms import WaveformProjector, WaveformSampler from .whitening import FixedWhiten, Whiten diff --git a/ml4gw/transforms/qtransform.py b/ml4gw/transforms/qtransform.py index 18b7f6a..2898ce7 100644 --- a/ml4gw/transforms/qtransform.py +++ b/ml4gw/transforms/qtransform.py @@ -1,11 +1,13 @@ import math -from typing import List, Optional, Tuple +import warnings +from typing import List, Tuple import torch import torch.nn.functional as F from jaxtyping import Float, Int from torch import Tensor +from ml4gw.transforms.spline_interpolation import SplineInterpolate from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d """ @@ -38,7 +40,6 @@ class QTile(torch.nn.Module): mismatch: The maximum fractional mismatch between neighboring tiles - """ def __init__( @@ -100,7 +101,9 @@ def get_data_indices(self) -> Int[Tensor, " windowsize"]: ).type(torch.long) def forward( - self, fseries: FrequencySeries1to3d, norm: str = "median" + self, + fseries: FrequencySeries1to3d, + norm: str = "median", ) -> TimeSeries1to3d: """ Compute the transform for this row @@ -144,7 +147,7 @@ def forward( energy /= means else: raise ValueError("Invalid normalisation %r" % norm) - return energy.type(torch.float32) + energy = energy.type(torch.float32) return energy @@ -172,6 +175,19 @@ class SingleQTransform(torch.nn.Module): be chosen based on q, sample_rate, and duration mismatch: The maximum fractional mismatch between neighboring tiles + interpolation_method: + The method by which to interpolate each `QTile` to the specified + number of time and frequency bins. The acceptable values are + "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic" + options will use PyTorch's built-in interpolation modes, while + "spline" will use the custom Torch-based implementation in + `ml4gw`, as PyTorch does not have spline-based intertpolation. + The "spline" mode is most similar to the results of GWpy's + Q-transform, which uses `scipy` to do spline interpolation. + However, it is also the slowest and most memory intensive due to + the matrix equation solving steps. Therefore, the default method + is "bicubic" as it produces the most similar results while + optimizing for computing performance. """ def __init__( @@ -182,6 +198,7 @@ def __init__( q: float = 12, frange: List[float] = [0, torch.inf], mismatch: float = 0.2, + interpolation_method: str = "bicubic", ) -> None: super().__init__() self.q = q @@ -190,20 +207,87 @@ def __init__( self.duration = duration self.mismatch = mismatch + # If q is too large, the minimum of the frange computed + # below will be larger than the maximum + max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5) + if q >= max_q: + raise ValueError( + "The given q value is too large for the given duration and " + f"sample rate. The maximum allowable value is {max_q}" + ) + + if interpolation_method not in ["bilinear", "bicubic", "spline"]: + raise ValueError( + "Interpolation method must be either 'bilinear', 'bicubic', " + f"or 'spline'; got {interpolation_method}" + ) + self.interpolation_method = interpolation_method + qprime = self.q / 11 ** (1 / 2.0) if self.frange[0] <= 0: # set non-zero lower frequency self.frange[0] = 50 * self.q / (2 * torch.pi * duration) if math.isinf(self.frange[1]): # set non-infinite upper frequency self.frange[1] = sample_rate / 2 / (1 + 1 / qprime) + self.freqs = self.get_freqs() self.qtile_transforms = torch.nn.ModuleList( [ - QTile(self.q, freq, self.duration, sample_rate, self.mismatch) + QTile( + q=self.q, + frequency=freq, + duration=self.duration, + sample_rate=sample_rate, + mismatch=self.mismatch, + ) for freq in self.freqs ] ) self.qtiles = None + if self.interpolation_method == "spline": + self._set_up_spline_interp() + + def _set_up_spline_interp(self): + ntiles = [qtile.ntiles() for qtile in self.qtile_transforms] + # For efficiency, we'll stack all qtiles of the same length before + # interpolating, so we need to figure out which those are + unique_ntiles = sorted(list(set(ntiles))) + idx = torch.arange(len(ntiles)) + self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles] + + t_out = torch.arange( + 0, self.duration, self.duration / self.spectrogram_shape[1] + ) + self.qtile_interpolators = torch.nn.ModuleList( + [ + SplineInterpolate( + kx=3, + x_in=torch.arange(0, self.duration, self.duration / tiles), + y_in=torch.arange(len(idx)), + x_out=t_out, + y_out=torch.arange(len(idx)), + ) + for tiles, idx in zip(unique_ntiles, self.stack_idx) + ] + ) + + t_in = t_out + f_in = self.freqs + f_out = torch.logspace( + math.log10(self.frange[0]), + math.log10(self.frange[-1]), + self.spectrogram_shape[0], + ) + + self.interpolator = SplineInterpolate( + kx=3, + ky=3, + x_in=t_in, + y_in=f_in, + x_out=t_out, + y_out=f_out, + ) + def get_freqs(self) -> Float[Tensor, " nfreq"]: """ Calculate the frequencies that will be used in this transform. @@ -220,7 +304,8 @@ def get_freqs(self) -> Float[Tensor, " nfreq"]: freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep) freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)]) - freqs = (minf * freqs // fstepmin) * fstepmin + # Cast freqs to float64 to avoid off-by-ones from rounding + freqs = (minf * freqs.double() // fstepmin) * fstepmin return torch.unique(freqs) def get_max_energy( @@ -268,7 +353,11 @@ def get_max_energy( if dimension == "batch": return torch.max(max_across_ft, dim=-1).values - def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None: + def compute_qtiles( + self, + X: TimeSeries1to3d, + norm: str = "median", + ) -> None: """ Take the FFT of the input timeseries and calculate the transform for each `QTile` @@ -278,28 +367,40 @@ def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None: X[..., 1:] *= 2 self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms] - def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d: - """ - Interpolate each `QTile` to the specified number of time and - frequency bins. Note that PyTorch does not have the same - interpolation methods that GWpy uses, and so the interpolated - spectrograms will be different even though the uninterpolated - values match. The `bicubic` interpolation method is used as - it seems to match GWpy most closely. - """ + def interpolate(self) -> TimeSeries3d: if self.qtiles is None: raise RuntimeError( "Q-tiles must first be computed with .compute_qtiles()" ) + if self.interpolation_method == "spline": + qtiles = [ + torch.stack([self.qtiles[i] for i in idx], dim=-2) + for idx in self.stack_idx + ] + time_interped = torch.cat( + [ + interpolator(qtile) + for qtile, interpolator in zip( + qtiles, self.qtile_interpolators + ) + ], + dim=-2, + ) + return self.interpolator(time_interped) + num_f_bins, num_t_bins = self.spectrogram_shape resampled = [ F.interpolate( - qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic" + qtile[None], + (qtile.shape[-2], num_t_bins), + mode=self.interpolation_method, ) for qtile in self.qtiles ] resampled = torch.stack(resampled, dim=-2) resampled = F.interpolate( - resampled[0], (num_f_bins, num_t_bins), mode="bicubic" + resampled[0], + (num_f_bins, num_t_bins), + mode=self.interpolation_method, ) return torch.squeeze(resampled) @@ -307,7 +408,6 @@ def forward( self, X: TimeSeries1to3d, norm: str = "median", - spectrogram_shape: Optional[Tuple[int, int]] = None, ): """ Compute the Q-tiles and interpolate @@ -321,24 +421,15 @@ def forward( three-dimensional, axes will be added during Q-tile computation. norm: - The method of interpolation used by each QTile - spectrogram_shape: - The shape of the interpolated spectrogram, specified as - `(num_f_bins, num_t_bins)`. Because the - frequency spacing of the Q-tiles is in log-space, the frequency - interpolation is log-spaced as well. If not given, the shape - used to initialize the transform will be used. + The method of normalization used by each QTile Returns: The interpolated Q-transform for the batch of data. Output will have one more dimension than the input """ - if spectrogram_shape is None: - spectrogram_shape = self.spectrogram_shape - num_f_bins, num_t_bins = spectrogram_shape self.compute_qtiles(X, norm) - return self.interpolate(num_f_bins, num_t_bins) + return self.interpolate() class QScan(torch.nn.Module): @@ -376,14 +467,22 @@ def __init__( spectrogram_shape: Tuple[int, int], qrange: List[float] = [4, 64], frange: List[float] = [0, torch.inf], + interpolation_method="bicubic", mismatch: float = 0.2, ) -> None: super().__init__() self.qrange = qrange self.mismatch = mismatch - self.qs = self.get_qs() self.frange = frange self.spectrogram_shape = spectrogram_shape + max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5) + self.qs = self.get_qs() + if self.qs[-1] >= max_q: + warnings.warn( + "Some Q values exceed the maximum allowable Q value of " + f"{max_q}. The list of Q values to be tested in this " + "scan will be truncated to avoid those values." + ) # Deliberately doing something different from GWpy here. # Their final frange is the intersection of the frange @@ -397,9 +496,11 @@ def __init__( spectrogram_shape=spectrogram_shape, q=q, frange=self.frange.copy(), + interpolation_method=interpolation_method, mismatch=self.mismatch, ) for q in self.qs + if q < max_q ] ) @@ -415,6 +516,7 @@ def get_qs(self) -> List[float]: self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5)) for i in range(nplanes) ] + return qs def forward( @@ -422,7 +524,6 @@ def forward( X: TimeSeries1to3d, fsearch_range: List[float] = None, norm: str = "median", - spectrogram_shape: Optional[Tuple[int, int]] = None, ): """ Compute the set of QTiles for each Q transform and determine which @@ -442,12 +543,6 @@ def forward( for the maximum energy norm: The method of interpolation used by each QTile - spectrogram_shape: - The shape of the interpolated spectrogram, specified as - `(num_f_bins, num_t_bins)`. Because the - frequency spacing of the Q-tiles is in log-space, the frequency - interpolation is log-spaced as well. If not given, the shape - used to initialize the transform will be used. Returns: An interpolated Q-transform for the batch of data. Output will @@ -463,7 +558,4 @@ def forward( ] ) ) - if spectrogram_shape is None: - spectrogram_shape = self.spectrogram_shape - num_f_bins, num_t_bins = spectrogram_shape - return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins) + return self.q_transforms[idx].interpolate() diff --git a/ml4gw/transforms/scaler.py b/ml4gw/transforms/scaler.py index e18d730..d868511 100644 --- a/ml4gw/transforms/scaler.py +++ b/ml4gw/transforms/scaler.py @@ -36,7 +36,9 @@ def __init__(self, num_channels: Optional[int] = None) -> None: self.register_buffer("mean", mean) self.register_buffer("std", std) - def fit(self, X: Float[Tensor, "... time"]) -> None: + def fit( + self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0 + ) -> None: """Fit the scaling parameters to a timeseries Computes the channel-wise mean and standard deviation @@ -59,7 +61,7 @@ def fit(self, X: Float[Tensor, "... time"]) -> None: "Can't fit channel wise mean and standard deviation " "from tensor of shape {}".format(X.shape) ) - + std += std_reg * torch.ones_like(std) super().build(mean=mean, std=std) def forward( diff --git a/ml4gw/transforms/spline_interpolation.py b/ml4gw/transforms/spline_interpolation.py new file mode 100644 index 0000000..6972560 --- /dev/null +++ b/ml4gw/transforms/spline_interpolation.py @@ -0,0 +1,370 @@ +""" +Adaptation of code from https://github.com/dottormale/Qtransform +""" + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + + +class SplineInterpolate(torch.nn.Module): + """ + Perform 1D or 2D spline interpolation based on De Boor's method. + Supports batched, multi-channel inputs, so acceptable data + shapes are `(width)`, `(height, width)`, `(batch, width)`, + `(batch, height, width)`, `(batch, channel, width)`, and + `(batch, channel, height, width)`. + + During initialization of this Module, both the desired input + and output coordinate Tensors can be specified to allow + pre-computation of the B-spline basis matrices, though the only + mandatory argument is the coordinates of the data along the + `width` dimension. If no argument is given for coordinates along + the `height` dimension, it is assumed that 1D interpolation is + desired. + + Unlike scipy's implementation of spline interpolation, the data + to be interpolated is not passed until actually calling the + object. This is useful for cases where the input and output + coordinates are known in advance, but the data is not, so that + the interpolator can be set up ahead of time. + + WARNING: compared to scipy's spline interpolation, this method + produces edge artifacts when the output coordinates are near + the boundaries of the input coordinates. Therefore, it is + recommended to interpolate only to coordinates that are well + within the input coordinate range. Unfortunately, the specific + definition of "well within" changes based on the size of the + data, so some testing may be required to get good results. + + Args: + x_in: + Coordinates of the width dimension of the data + y_in: + Coordinates of the height dimension of the data. If not + specified, it is assumed the 1D interpolation is desired, + and so the default value is a Tensor of length 1 + kx: + Degree of spline interpolation along the width dimension. + Default is cubic. + ky: + Degree of spline interpolation along the height dimension. + Default is cubic. + sx: + Regularization factor to avoid singularities during matrix + inversion for interpolation along the width dimension. Not + to be confused with the `s` parameter in scipy's spline + methods, which controls the number of knots. + sy: + Regularization factor to avoid singularities during matrix + inversion for interpolation along the height dimension. + x_out: + Coordinates for the data to be interpolated to along the + width dimension. If not specified during initialization, + this must be specified during the object call. + y_out: + Coordinates for the data to be interpolated to along the + height dimension. If not specified during initialization, + this must be specified during the object call. + + """ + + def __init__( + self, + x_in: Tensor, + y_in: Tensor = Tensor([1]), + kx: int = 3, + ky: int = 3, + sx: float = 0.001, + sy: float = 0.001, + x_out: Optional[Tensor] = None, + y_out: Optional[Tensor] = None, + ): + super().__init__() + self.kx = kx + self.ky = ky + self.sx = sx + self.sy = sy + self.register_buffer("x_in", x_in) + self.register_buffer("y_in", y_in) + self.register_buffer("x_out", x_out) + self.register_buffer("y_out", y_out) + + tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx) + self.register_buffer("tx", tx) + self.register_buffer("Bx", Bx) + self.register_buffer("BxT_Bx", BxT_Bx) + + ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy) + self.register_buffer("ty", ty) + self.register_buffer("By", By) + self.register_buffer("ByT_By", ByT_By) + + if self.x_out is not None: + Bx_out = self.bspline_basis_natural(x_out, kx, self.tx) + self.register_buffer("Bx_out", Bx_out) + if self.y_out is not None: + By_out = self.bspline_basis_natural(y_out, ky, self.ty) + self.register_buffer("By_out", By_out) + + def _compute_knots_and_basis_matrices(self, x, k, s): + knots = self.generate_natural_knots(x, k) + basis_matrix = self.bspline_basis_natural(x, k, knots) + identity = torch.eye(basis_matrix.shape[-1]) + B_T_B = basis_matrix.T @ basis_matrix + s * identity + return knots, basis_matrix, B_T_B + + def generate_natural_knots(self, x: Tensor, k: int) -> Tensor: + """ + Generates a natural knot sequence for B-spline interpolation. + Natural knot sequence means that 2*k knots are added to the beginning + and end of datapoints as replicas of first and last datapoint + respectively in order to enforce natural boundary conditions, + i.e. second derivative = 0. + The other n nodes are placed in correspondece of the data points. + + Args: + x: Tensor of data point positions. + k: Degree of the spline. + + Returns: + Tensor of knot positions. + """ + return F.pad(x[None], (k, k), mode="replicate")[0] + + def compute_L_R( + self, + x: Tensor, + t: Tensor, + d: int, + m: int, + ) -> Tuple[Tensor, Tensor]: + + """ + Compute the L and R values for B-spline basis functions. + L and R are respectively the first and second coefficient multiplying + B_{i,p-1}(x) and B_{i+1,p-1}(x) in De Boor's recursive formula for + Bspline basis funciton computation + See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for details + + Args: + x: + Tensor of data point positions. + t: + Tensor of knot positions. + d: + Current degree of the basis function. + m: + Number of intervals (n - k - 1, where n is the number of knots + and k is the degree). + + Returns: + L: Tensor containing left values for the B-spline basis functions. + R: Tensor containing right values for the B-spline basis functions. + """ + left_num = x.unsqueeze(1) - t[:m].unsqueeze(0) + left_den = t[d : m + d] - t[:m] + L = left_num / left_den.unsqueeze(0) + L = torch.nan_to_num_(L, nan=0.0, posinf=0.0, neginf=0.0) + + right_num = t[d + 1 : m + d + 1] - x.unsqueeze(1) + right_den = t[d + 1 : m + d + 1] - t[1 : m + 1] + R = right_num / right_den.unsqueeze(0) + R = torch.nan_to_num_(R, nan=0.0, posinf=0.0, neginf=0.0) + + return L, R + + def zeroth_order( + self, + x: Tensor, + k: int, + t: Tensor, + n: int, + m: int, + ) -> Tensor: + + """ + Compute the zeroth-order B-spline basis functions + according to de Boors recursive formula. + See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference + + Args: + x: + Tensor of data point positions. + k: + Degree of the spline. + t: + Tensor of knot positions. + n: + Number of data points. + m: + Number of intervals (n - k - 1, where n is the number of knots + and k is the degree). + + Returns: + b: Tensor containing the zeroth-order B-spline basis functions. + """ + b = torch.zeros((n, m, k + 1)) + + mask_lower = t[: m + 1].unsqueeze(0)[:, :-1] <= x.unsqueeze(1) + mask_upper = x.unsqueeze(1) < t[: m + 1].unsqueeze(0)[:, 1:] + + b[:, :, 0] = mask_lower & mask_upper + b[:, 0, 0] = torch.where(x < t[1], torch.ones_like(x), b[:, 0, 0]) + b[:, -1, 0] = torch.where(x >= t[-2], torch.ones_like(x), b[:, -1, 0]) + return b + + def bspline_basis_natural( + self, + x: Tensor, + k: int, + t: Tensor, + ) -> Tensor: + """ + Compute bspline basis function using de Boor's recursive formula + (See https://en.wikipedia.org/wiki/De_Boor%27s_algorithm for reference) + Args: + x: Tensor of data point positions. + k: Degree of the spline. + t: Tensor of knot positions. + + Returns: + Tensor containing the kth-order B-spline basis functions + """ + + if len(x) == 1: + return torch.eye(1) + n = x.shape[0] + m = t.shape[0] - k - 1 + + # calculate zeroth order basis funciton + b = self.zeroth_order(x, k, t, n, m) + + zeros_tensor = torch.zeros(b.shape[0], 1) + # recursive de Boors formula for bspline basis functions + for d in range(1, k + 1): + L, R = self.compute_L_R(x, t, d, m) + left = L * b[:, :, d - 1] + + temp_b = torch.cat([b[:, 1:, d - 1], zeros_tensor], dim=1) + + right = R * temp_b + b[:, :, d] = left + right + + return b[:, :, -1] + + def bivariate_spline_fit_natural(self, Z): + + if len(Z.shape) == 3: + Z_Bx = torch.matmul(Z, self.Bx) + # ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1 + return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT + + # Adding batch/channel dimension handling + # ByT @ Z @ Bx + ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx) + # (ByT @ By)^-1 @ (ByT @ Z @ Bx) = By^-1 @ Z @ Bx + E = torch.linalg.solve(self.ByT_By, ByT_Z_Bx) + # ((BxT @ Bx)^-1 @ (By^-1 @ Z @ Bx)T)T = By^-1 @ Z @ BxT^-1 + return torch.linalg.solve(self.BxT_Bx, E.mT).mT + + def evaluate_bivariate_spline(self, C: Tensor): + """ + Evaluate a bivariate spline on a grid of x and y points. + + Args: + C: Coefficient tensor of shape (batch_size, mx, my). + + Returns: + Z_interp: Interpolated values at the grid points. + """ + # Perform matrix multiplication using einsum to get Z_interp + if len(C.shape) == 3: + return torch.matmul(C, self.Bx_out.mT) + return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT) + + def _validate_inputs(self, Z, x_out, y_out): + if x_out is None and self.x_out is None: + raise ValueError( + "Output x-coordinates were not specified in either object " + "creation or in forward call" + ) + + if y_out is None and self.y_out is None: + y_out = self.y_in + + dims = len(Z.shape) + if dims > 4: + raise ValueError("Input data has more than 4 dimensions") + + if len(self.y_in) > 1 and dims == 1: + raise ValueError( + "An input y-coordinate array with length greater than 1 " + "was given, but the input data is 1-dimensional. Expected " + "input data to be at least 2-dimensional" + ) + + # Expand Z to have 4 dimensions + # There are 6 valid input shapes: (w), (b, w), (b, c, w), + # (h, w), (b, h, w), and (b, c, h, w). + + # If the input y coordinate array has length 1, + # assume the first dimension(s) are batch dimensions + # and that no height dimension is included in Z + idx = -2 if len(self.y_in) == 1 else -3 + while len(Z.shape) < 4: + Z = Z.unsqueeze(idx) + + if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]): + raise ValueError( + "The spatial dimensions of the data tensor do not match " + "the given input dimensions. " + f"Expected [{len(self.y_in)}, {len(self.x_in)}], but got " + f"[{Z.shape[-2]}, {Z.shape[-1]}]" + ) + + return Z, y_out + + def forward( + self, + Z: Tensor, + x_out: Optional[Tensor] = None, + y_out: Optional[Tensor] = None, + ) -> Tensor: + """ + Compute the interpolated data + + Args: + Z: + Tensor of data to be interpolated. Must be between 1 and 4 + dimensions. The shape of the tensor must agree with the + input coordinates given on initialization. If `y_in` was + not specified during initialization, it is assumed that + Z does not have a height dimension. + x_out: + Coordinates to interpolate the data to along the width + dimension. Overrides any value that was set during + initialization. + y_out: + Coordinates to interpolate the data to along the height + dimension. Overrides any value that was set during + initialization. + + Returns: + A 4D tensor with shape `(batch, channel, height, width)`. + Depending on the input data shape, many of these dimensions + may have length 1. + """ + + Z, y_out = self._validate_inputs(Z, x_out, y_out) + + if x_out is not None: + self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx) + if y_out is not None: + self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty) + + coef = self.bivariate_spline_fit_natural(Z) + Z_interp = self.evaluate_bivariate_spline(coef) + return Z_interp diff --git a/ml4gw/waveforms/__init__.py b/ml4gw/waveforms/__init__.py index 9d3e5a7..9192f68 100644 --- a/ml4gw/waveforms/__init__.py +++ b/ml4gw/waveforms/__init__.py @@ -1,5 +1,2 @@ -from .phenom_d import IMRPhenomD -from .phenom_p import IMRPhenomPv2 -from .ringdown import Ringdown -from .sine_gaussian import SineGaussian -from .taylorf2 import TaylorF2 +from .adhoc import * +from .cbc import * diff --git a/ml4gw/waveforms/adhoc/__init__.py b/ml4gw/waveforms/adhoc/__init__.py new file mode 100644 index 0000000..da5814e --- /dev/null +++ b/ml4gw/waveforms/adhoc/__init__.py @@ -0,0 +1,2 @@ +from .ringdown import Ringdown +from .sine_gaussian import SineGaussian diff --git a/ml4gw/waveforms/ringdown.py b/ml4gw/waveforms/adhoc/ringdown.py similarity index 100% rename from ml4gw/waveforms/ringdown.py rename to ml4gw/waveforms/adhoc/ringdown.py diff --git a/ml4gw/waveforms/sine_gaussian.py b/ml4gw/waveforms/adhoc/sine_gaussian.py similarity index 100% rename from ml4gw/waveforms/sine_gaussian.py rename to ml4gw/waveforms/adhoc/sine_gaussian.py diff --git a/ml4gw/waveforms/cbc/__init__.py b/ml4gw/waveforms/cbc/__init__.py new file mode 100644 index 0000000..cb3eb9d --- /dev/null +++ b/ml4gw/waveforms/cbc/__init__.py @@ -0,0 +1,3 @@ +from .phenom_d import IMRPhenomD +from .phenom_p import IMRPhenomPv2 +from .taylorf2 import TaylorF2 diff --git a/ml4gw/waveforms/phenom_d.py b/ml4gw/waveforms/cbc/phenom_d.py similarity index 100% rename from ml4gw/waveforms/phenom_d.py rename to ml4gw/waveforms/cbc/phenom_d.py diff --git a/ml4gw/waveforms/phenom_d_data.py b/ml4gw/waveforms/cbc/phenom_d_data.py similarity index 100% rename from ml4gw/waveforms/phenom_d_data.py rename to ml4gw/waveforms/cbc/phenom_d_data.py diff --git a/ml4gw/waveforms/phenom_p.py b/ml4gw/waveforms/cbc/phenom_p.py similarity index 92% rename from ml4gw/waveforms/phenom_p.py rename to ml4gw/waveforms/cbc/phenom_p.py index 17528fc..df91ef2 100644 --- a/ml4gw/waveforms/phenom_p.py +++ b/ml4gw/waveforms/cbc/phenom_p.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch from jaxtyping import Float @@ -6,6 +6,7 @@ from ml4gw.constants import MPC_SEC, MTSUN_SI, PI from ml4gw.types import BatchTensor, FrequencySeries1d +from ml4gw.waveforms.conversion import rotate_y, rotate_z from .phenom_d import IMRPhenomD @@ -25,11 +26,11 @@ def forward( s2x: BatchTensor, s2y: BatchTensor, s2z: BatchTensor, - dist_mpc: BatchTensor, - tc: BatchTensor, - phiRef: BatchTensor, - incl: BatchTensor, + distance: BatchTensor, + phic: BatchTensor, + inclination: BatchTensor, f_ref: float, + tc: Optional[BatchTensor] = None, ): """ IMRPhenomPv2 waveform @@ -53,13 +54,13 @@ def forward( Spin component y of the second BH. s2z : Spin component z of the second BH. - dist_mpc : + distance : Luminosity distance in Mpc. tc : Coalescence time. - phiRef : + phic : Reference phase. - incl : + inclination : Inclination angle. f_ref : Reference frequency in Hz. @@ -71,6 +72,9 @@ def forward( Note: m1 must be larger than m2. """ + if tc is None: + tc = torch.zeros_like(chirp_mass) + m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6 m1 = m2 * mass_ratio @@ -89,7 +93,7 @@ def forward( phi_aligned, zeta_polariz, ) = self.convert_spins( - m1, m2, f_ref, phiRef, incl, s1x, s1y, s1z, s2x, s2y, s2z + m1, m2, f_ref, phic, inclination, s1x, s1y, s1z, s2x, s2y, s2z ) phic = 2 * phi_aligned @@ -152,7 +156,7 @@ def forward( phic, M, xi, - dist_mpc, + distance, ) hp, hc = self.PhenomPCoreTwistUp( @@ -309,7 +313,7 @@ def PhenomPOneFrequency( phic, M, xi, - dist_mpc, + distance, ): """ m1, m2: in solar masses @@ -324,10 +328,10 @@ def PhenomPOneFrequency( phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi) phase = (phase.mT - (phic + PI / 4.0)).mT Amp = self.phenom_d_amp( - Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, dist_mpc + Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, distance )[0] Amp0 = self.get_Amp0(Mf, eta) - dist_s = dist_mpc * MPC_SEC + dist_s = distance * MPC_SEC Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT # phase -= 2. * phic; # line 1316 ??? hPhenom = Amp * (torch.exp(-1j * phase)) @@ -391,16 +395,6 @@ def interpolate( return interpolated.reshape(original_shape) - def ROTATEZ(self, angle: BatchTensor, x, y, z): - tmp_x = x * torch.cos(angle) - y * torch.sin(angle) - tmp_y = x * torch.sin(angle) + y * torch.cos(angle) - return tmp_x, tmp_y, z - - def ROTATEY(self, angle, x, y, z): - tmp_x = x * torch.cos(angle) + z * torch.sin(angle) - tmp_z = -x * torch.sin(angle) + z * torch.cos(angle) - return tmp_x, y, tmp_z - def L2PNR( self, v: BatchTensor, @@ -425,8 +419,8 @@ def convert_spins( m1: BatchTensor, m2: BatchTensor, f_ref: float, - phiRef: BatchTensor, - incl: BatchTensor, + phic: BatchTensor, + inclination: BatchTensor, s1x: BatchTensor, s1y: BatchTensor, s1z: BatchTensor, @@ -486,32 +480,32 @@ def convert_spins( # First we determine kappa # in the source frame, the components of N are given in # Eq (35c) of T1500606-v6 - Nx_sf = torch.sin(incl) * torch.cos(PI / 2.0 - phiRef) - Ny_sf = torch.sin(incl) * torch.sin(PI / 2.0 - phiRef) - Nz_sf = torch.cos(incl) + Nx_sf = torch.sin(inclination) * torch.cos(PI / 2.0 - phic) + Ny_sf = torch.sin(inclination) * torch.sin(PI / 2.0 - phic) + Nz_sf = torch.cos(inclination) tmp_x = Nx_sf tmp_y = Ny_sf tmp_z = Nz_sf - tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z) kappa = -torch.arctan2(tmp_y, tmp_x) # Then we determine alpha0, by rotating LN tmp_x, tmp_y, tmp_z = 0, 0, 1 - tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z) alpha0 = torch.arctan2(tmp_y, tmp_x) # Finally we determine thetaJ, by rotating N tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf - tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z) Nx_Jf, Nz_Jf = tmp_x, tmp_z thetaJN = torch.arccos(Nz_Jf) @@ -528,13 +522,13 @@ def convert_spins( # Both triads differ from each other by a rotation around N by an angle # \zeta and we need to rotate the polarizations accordingly by 2\zeta - Xx_sf = -torch.cos(incl) * torch.sin(phiRef) - Xy_sf = -torch.cos(incl) * torch.cos(phiRef) - Xz_sf = torch.sin(incl) + Xx_sf = -torch.cos(inclination) * torch.sin(phic) + Xy_sf = -torch.cos(inclination) * torch.cos(phic) + Xz_sf = torch.sin(inclination) tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf - tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z) - tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z) + tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z) # Now the tmp_a are the components of X in the J frame # We need the polar angle of that vector in the P,Q basis of Arun et al diff --git a/ml4gw/waveforms/taylorf2.py b/ml4gw/waveforms/cbc/taylorf2.py similarity index 100% rename from ml4gw/waveforms/taylorf2.py rename to ml4gw/waveforms/cbc/taylorf2.py diff --git a/ml4gw/waveforms/conversion.py b/ml4gw/waveforms/conversion.py new file mode 100644 index 0000000..7c36234 --- /dev/null +++ b/ml4gw/waveforms/conversion.py @@ -0,0 +1,187 @@ +import torch + +from ml4gw.constants import MTSUN_SI, PI +from ml4gw.types import BatchTensor + + +def rotate_z(angle: BatchTensor, x, y, z): + x_tmp = x * torch.cos(angle) - y * torch.sin(angle) + y_tmp = x * torch.sin(angle) + y * torch.cos(angle) + return x_tmp, y_tmp, z + + +def rotate_y(angle, x, y, z): + x_tmp = x * torch.cos(angle) + z * torch.sin(angle) + z_tmp = -x * torch.sin(angle) + z * torch.cos(angle) + return x_tmp, y, z_tmp + + +def XLALSimInspiralLN( + total_mass: BatchTensor, eta: BatchTensor, v: BatchTensor +): + """ + See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L2173 # noqa + """ + return total_mass**2 * eta / v + + +def XLALSimInspiralL_2PN(eta: BatchTensor): + """ + See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiralPNCoefficients.c#L2181 # noqa + """ + return 1.5 + eta / 6.0 + + +def bilby_spins_to_lalsim( + theta_jn: BatchTensor, + phi_jl: BatchTensor, + tilt_1: BatchTensor, + tilt_2: BatchTensor, + phi_12: BatchTensor, + a_1: BatchTensor, + a_2: BatchTensor, + mass_1: BatchTensor, + mass_2: BatchTensor, + f_ref: float, + phi_ref: BatchTensor, +): + """ + Converts between bilby spin and lalsimulation spin conventions. + + See https://github.com/bilby-dev/bilby/blob/cccdf891e82d46319e69dbfdf48c4970b4e9a727/bilby/gw/conversion.py#L105 # noqa + and https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L3594 # noqa + + Args: + theta_jn: BatchTensor, + phi_jl: BatchTensor, + tilt_1: BatchTensor, + tilt_2: BatchTensor, + phi_12: BatchTensor, + a_1: BatchTensor, + a_2: BatchTensor, + mass_1: BatchTensor, + mass_2: BatchTensor, + f_ref: float, + phi_ref: BatchTensor, + """ + + # check if f_ref is valid + if f_ref <= 0.0: + raise ValueError( + "f_ref <= 0 is invalid. " + "Please pass in the starting GW frequency instead." + ) + + # starting frame: LNhat is along the z-axis and the unit + # spin vectors are defined from the angles relative to LNhat. + # Note that we put s1hat in the x-z plane, and phi12 + # sets the azimuthal angle of s2hat measured from the x-axis. + lnh_x = 0 + lnh_y = 0 + lnh_z = 1 + # Spins are given wrt to L, + # but still we cannot fill the spin as we do not know + # what will be the relative orientation of L and N. + # Note that these spin components are NOT wrt to binary + # separation vector, but wrt to binary separation vector + # at phiref=0. + + s1hatx = torch.sin(tilt_1) * torch.cos(phi_ref) + s1haty = torch.sin(tilt_1) * torch.sin(phi_ref) + s1hatz = torch.cos(tilt_1) + s2hatx = torch.sin(tilt_2) * torch.cos(phi_12 + phi_ref) + s2haty = torch.sin(tilt_2) * torch.sin(phi_12 + phi_ref) + s2hatz = torch.cos(tilt_2) + + total_mass = mass_1 + mass_2 + + eta = mass_1 * mass_2 / (mass_1 + mass_2) / (mass_1 + mass_2) + + # v parameter at reference point + v0 = ((mass_1 + mass_2) * MTSUN_SI * PI * f_ref) ** (1 / 3) + + # Define S1, S2, J with proper magnitudes */ + + l_mag = XLALSimInspiralLN(total_mass, eta, v0) * ( + 1.0 + v0 * v0 * XLALSimInspiralL_2PN(eta) + ) + s1x = mass_1 * mass_1 * a_1 * s1hatx + s1y = mass_1 * mass_1 * a_1 * s1haty + s1z = mass_1 * mass_1 * a_1 * s1hatz + s2x = mass_2 * mass_2 * a_2 * s2hatx + s2y = mass_2 * mass_2 * a_2 * s2haty + s2z = mass_2 * mass_2 * a_2 * s2hatz + Jx = s1x + s2x + Jy = s1y + s2y + Jz = l_mag + s1z + s2z + + # Normalize J to Jhat, find its angles in starting frame */ + Jnorm = torch.sqrt(Jx * Jx + Jy * Jy + Jz * Jz) + Jhatx = Jx / Jnorm + Jhaty = Jy / Jnorm + Jhatz = Jz / Jnorm + theta0 = torch.acos(Jhatz) + phi0 = torch.atan2(Jhaty, Jhatx) + + # Rotation 1: Rotate about z-axis by -phi0 to put Jhat in x-z plane + s1hatx, s1haty, s1hatz = rotate_z(-phi0, s1hatx, s1haty, s1hatz) + s2hatx, s2haty, s2hatz = rotate_z(-phi0, s2hatx, s2haty, s2hatz) + + # Rotation 2: Rotate about new y-axis by -theta0 + # to put Jhat along z-axis + + lnh_x, lnh_y, lnh_z = rotate_y(-theta0, lnh_x, lnh_y, lnh_z) + s1hatx, s1haty, s1hatz = rotate_y(-theta0, s1hatx, s1haty, s1hatz) + s2hatx, s2haty, s2hatz = rotate_y(-theta0, s2hatx, s2haty, s2hatz) + + # Rotation 3: Rotate about new z-axis by phiJL to put L at desired + # azimuth about J. Note that is currently in x-z plane towards -x + # (i.e. azimuth=pi). Hence we rotate about z by phiJL - LAL_PI + lnh_x, lnh_y, lnh_z = rotate_z(phi_jl - PI, lnh_x, lnh_y, lnh_z) + s1hatx, s1haty, s1hatz = rotate_z(phi_jl - PI, s1hatx, s1haty, s1hatz) + s2hatx, s2haty, s2hatz = rotate_z(phi_jl - PI, s2hatx, s2haty, s2hatz) + + # The cosinus of the angle between L and N is the scalar + # product of the two vectors. + # We do not need to perform additional rotation to compute it. + Nx = 0.0 + Ny = torch.sin(theta_jn) + Nz = torch.cos(theta_jn) + incl = torch.acos(Nx * lnh_x + Ny * lnh_y + Nz * lnh_z) + + # Rotation 4-5: Now J is along z and N in y-z plane, inclined from J + # by thetaJN and with >ve component along y. + # Now we bring L into the z axis to get spin components. + thetalj = torch.acos(lnh_z) + phil = torch.atan2(lnh_y, lnh_x) + + s1hatx, s1haty, s1hatz = rotate_z(-phil, s1hatx, s1haty, s1hatz) + s2hatx, s2haty, s2hatz = rotate_z(-phil, s2hatx, s2haty, s2hatz) + Nx, Ny, Nz = rotate_z(-phil, Nx, Ny, Nz) + + s1hatx, s1haty, s1hatz = rotate_y(-thetalj, s1hatx, s1haty, s1hatz) + s2hatx, s2haty, s2hatz = rotate_y(-thetalj, s2hatx, s2haty, s2hatz) + Nx, Ny, Nz = rotate_y(-thetalj, Nx, Ny, Nz) + + # Rotation 6: Now L is along z and we have to bring N + # in the y-z plane with >ve y components. + + phiN = torch.atan2(Ny, Nx) + # Note the extra -phiRef here: + # output spins must be given wrt to two body separations + # which are rigidly rotated with spins + s1hatx, s1haty, s1hatz = rotate_z( + PI / 2.0 - phiN - phi_ref, s1hatx, s1haty, s1hatz + ) + s2hatx, s2haty, s2hatz = rotate_z( + PI / 2.0 - phiN - phi_ref, s2hatx, s2haty, s2hatz + ) + + s1x = s1hatx * a_1 + s1y = s1haty * a_1 + s1z = s1hatz * a_1 + s2x = s2hatx * a_2 + s2y = s2haty * a_2 + s2z = s2hatz * a_2 + + return incl, s1x, s1y, s1z, s2x, s2y, s2z diff --git a/poetry.lock b/poetry.lock index 21078f6..535748e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2759,8 +2759,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" diff --git a/pyproject.toml b/pyproject.toml index b12507e..8c5a379 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ml4gw" -version = "0.5.1" +version = "0.6.0" description = "Tools for training torch models on gravitational wave data" readme = "README.md" authors = [ diff --git a/tests/transforms/test_qtransform.py b/tests/transforms/test_qtransform.py index 4b43e0a..345f728 100644 --- a/tests/transforms/test_qtransform.py +++ b/tests/transforms/test_qtransform.py @@ -14,7 +14,7 @@ def duration(request): return request.param -@pytest.fixture(params=[2048, 4096]) +@pytest.fixture(params=[1024, 2048]) def sample_rate(request): return request.param @@ -29,7 +29,7 @@ def spectrogram_shape(request): return request.param -@pytest.fixture(params=[12, 100]) +@pytest.fixture(params=[12, 50]) def q(request): return request.param @@ -39,11 +39,16 @@ def mismatch(request): return request.param -@pytest.fixture(params=[128, 512]) +@pytest.fixture(params=[128, 256]) def frequency(request): return request.param +@pytest.fixture(params=["bilinear", "bicubic", "spline"]) +def interpolation_method(request): + return request.param + + def test_qtile( q, frequency, @@ -85,11 +90,34 @@ def test_singleqtransform( mismatch, norm, spectrogram_shape, + interpolation_method, ): X = torch.randn(int(duration * sample_rate)) fseries = torch.fft.rfft(X, norm="forward") fseries[..., 1:] *= 2 + with pytest.raises(ValueError): + qtransform = SingleQTransform( + duration, + sample_rate, + spectrogram_shape, + q, + frange=[0, torch.inf], + mismatch=mismatch, + interpolation_method="nonsense", + ) + + with pytest.raises(ValueError): + qtransform = SingleQTransform( + duration, + sample_rate, + spectrogram_shape, + q=1000, + frange=[0, torch.inf], + mismatch=mismatch, + interpolation_method="nonsense", + ) + qtransform = SingleQTransform( duration, sample_rate, @@ -97,13 +125,14 @@ def test_singleqtransform( q, frange=[0, torch.inf], mismatch=mismatch, + interpolation_method=interpolation_method, ) with pytest.raises(RuntimeError): qtransform.get_max_energy() with pytest.raises(RuntimeError): - qtransform.interpolate(*spectrogram_shape) + qtransform.interpolate() qplane = QPlane( q, @@ -138,10 +167,19 @@ def test_get_qs( qrange = [1, 1000] qscan = QScan( - duration, sample_rate, spectrogram_shape, qrange, frange, mismatch + duration, + sample_rate, + spectrogram_shape, + qrange, + frange, + mismatch=mismatch, ) qtiling = QTiling( duration, sample_rate, qrange, frange=[0, np.inf], mismatch=mismatch ) assert np.allclose(qscan.get_qs(), qtiling.qs) + + # Just check that the QScan runs + data = torch.randn(int(sample_rate * duration)) + _ = qscan(data) diff --git a/tests/transforms/test_scaler.py b/tests/transforms/test_scaler.py index 6a89bc6..a28e922 100644 --- a/tests/transforms/test_scaler.py +++ b/tests/transforms/test_scaler.py @@ -37,6 +37,16 @@ def test_scaler_1d(): assert torch.isclose(x, y, rtol=1e-6).all().item() +def test_scaler_regularization(): + scaler_non_reg = ChannelWiseScaler(num_channels=2) + scaler_reg = ChannelWiseScaler(num_channels=2) + X = torch.ones(2, 100).type(torch.float32) + scaler_non_reg.fit(X) + scaler_reg.fit(X, std_reg=1e-8) + assert torch.all(scaler_non_reg.std == 0) + assert torch.all(scaler_reg.std > 0) + + def test_scaler_2d(): num_channels = 4 scaler = ChannelWiseScaler(num_channels) diff --git a/tests/transforms/test_spline_interpolation.py b/tests/transforms/test_spline_interpolation.py new file mode 100644 index 0000000..5b99766 --- /dev/null +++ b/tests/transforms/test_spline_interpolation.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest +import torch +from scipy.interpolate import RectBivariateSpline, UnivariateSpline +from torch import Tensor + +from ml4gw.transforms import SplineInterpolate + + +class TestSplineInterpolate: + @pytest.fixture(params=[50, 100, 200]) + def x_out_len(self, request): + return request.param + + @pytest.fixture(params=[25, 200, 1000]) + def y_out_len(self, request): + return request.param + + def test_1d_interpolation(self, x_out_len): + x_in = np.linspace(0, 10, 100) + data = np.sin(x_in) + # There are edge effects in the torch transform that + # aren't present in scipy. Would be great to solve that, + # but a workaround is to interpolate well within the + # boundaries of the input coordinates. Unfortunately, + # what specifically that means depends on the size of + # the input array. + pad = len(x_in) // 10 + x_out = np.linspace(x_in[pad], x_in[-pad], x_out_len) + + scipy_spline = UnivariateSpline(x_in, data, k=3, s=0) + expected = scipy_spline(x_out) + + torch_spline = SplineInterpolate( + x_in=Tensor(x_in), + x_out=Tensor(x_out), + kx=3, + ) + actual = torch_spline(Tensor(data)).squeeze().numpy() + + # The "steady-state" ratio between the torch and scipy + # interpolations is about 0.9990, with some minor fluctuations. + # Would be nice to know why the torch interpolation is + # consistently smaller + assert np.allclose(actual, expected, rtol=5e-3) + + def test_2d_interpolation(self, x_out_len, y_out_len): + x_in = np.linspace(0, 10, 100) + y_in = np.linspace(0, 5, 200) + x_grid, y_grid = np.meshgrid(x_in, y_in) + data = np.sin(x_grid) * np.cos(y_grid) + + pad = len(x_in) // 10 + x_out = np.linspace(x_in[pad], x_in[-pad], x_out_len) + pad = len(y_in) // 10 + y_out = np.linspace(y_in[pad], y_in[-pad], y_out_len) + + scipy_spline = RectBivariateSpline(x_in, y_in, data.T, kx=3, ky=3, s=0) + expected = scipy_spline(x_out, y_out).T + + torch_spline = SplineInterpolate( + x_in=Tensor(x_in), + x_out=Tensor(x_out), + y_in=Tensor(y_in), + y_out=Tensor(y_out), + kx=3, + ky=3, + ) + actual = torch_spline(Tensor(data)).squeeze().numpy() + + # The "steady-state" ratio between the torch and scipy + # interpolations is about 0.999, with some minor fluctuations. + # Would be nice to know why the torch interpolation is + # consistently smaller + assert np.allclose(actual, expected, rtol=5e-3) + + def test_errors(self): + x_in = torch.arange(10) + x_out = x_in + torch_spline = SplineInterpolate(x_in) + data = torch.randn(len(x_in)) + with pytest.raises(ValueError) as exc: + torch_spline(data) + assert str(exc.value).startswith("Output x-coordinates were not") + + data = torch.randn((1, 2, 3, 4, 5)) + with pytest.raises(ValueError) as exc: + torch_spline(data, x_out=x_out) + assert str(exc.value).startswith("Input data has more than 4") + + y_in = torch.arange(10) + torch_spline = SplineInterpolate(x_in=x_in, y_in=y_in) + data = torch.randn(len(x_in)) + with pytest.raises(ValueError) as exc: + torch_spline(data, x_out=x_out) + assert str(exc.value).startswith("An input y-coordinate array") + + data = torch.randn((len(y_in) - 1, len(x_in) - 1)) + with pytest.raises(ValueError) as exc: + torch_spline(data, x_out=x_out) + assert str(exc.value).startswith("The spatial dimensions of the data") diff --git a/tests/waveforms/test_sine_gaussian.py b/tests/waveforms/adhoc/test_sine_gaussian.py similarity index 100% rename from tests/waveforms/test_sine_gaussian.py rename to tests/waveforms/adhoc/test_sine_gaussian.py diff --git a/tests/waveforms/test_cbc_waveforms.py b/tests/waveforms/cbc/test_cbc_waveforms.py similarity index 100% rename from tests/waveforms/test_cbc_waveforms.py rename to tests/waveforms/cbc/test_cbc_waveforms.py index a888108..2d28b04 100644 --- a/tests/waveforms/test_cbc_waveforms.py +++ b/tests/waveforms/cbc/test_cbc_waveforms.py @@ -318,10 +318,10 @@ def test_phenom_p(chirp_mass, mass_ratio, chi1z, chi2z, distance, sample_rate): batched_chi2y, batched_chi2z, batched_distance, - batched_tc, batched_phic, batched_inclination, f_ref, + batched_tc, ) assert hp_torch.shape[0] == 10 # entire batch is returned diff --git a/tests/waveforms/test_conversion.py b/tests/waveforms/test_conversion.py new file mode 100644 index 0000000..f4fc528 --- /dev/null +++ b/tests/waveforms/test_conversion.py @@ -0,0 +1,66 @@ +import numpy as np +import torch +from lalsimulation import SimInspiralTransformPrecessingNewInitialConditions +from torch.distributions import Uniform + +from ml4gw.constants import MSUN +from ml4gw.waveforms.conversion import bilby_spins_to_lalsim + + +def test_bilby_to_lalsim_spins(): + theta_jn = Uniform(0, torch.pi).sample((100,)) + phi_jl = Uniform(0, 2 * torch.pi).sample((100,)) + tilt_1 = Uniform(0, torch.pi).sample((100,)) + tilt_2 = Uniform(0, torch.pi).sample((100,)) + phi_12 = Uniform(0, 2 * torch.pi).sample((100,)) + a_1 = Uniform(0, 0.99).sample((100,)) + a_2 = Uniform(0, 0.99).sample((100,)) + mass_1 = Uniform(3, 100).sample((100,)) + mass_2 = Uniform(3, 100).sample((100,)) + f_ref = 40.0 + phi_ref = Uniform(0, torch.pi).sample((100,)) + incl, s1x, s1y, s1z, s2x, s2y, s2z = bilby_spins_to_lalsim( + theta_jn, + phi_jl, + tilt_1, + tilt_2, + phi_12, + a_1, + a_2, + mass_1, + mass_2, + f_ref, + phi_ref, + ) + for i in range(2): + + ( + lal_incl, + lal_s1x, + lal_s1y, + lal_s1z, + lal_s2x, + lal_s2y, + lal_s2z, + ) = SimInspiralTransformPrecessingNewInitialConditions( + theta_jn[i].item(), + phi_jl[i].item(), + tilt_1[i].item(), + tilt_2[i].item(), + phi_12[i].item(), + a_1[i].item(), + a_2[i].item(), + mass_1[i].item() * MSUN, + mass_2[i].item() * MSUN, + f_ref, + phi_ref[i].item(), + ) + + # check if the values are close up to 4 decimal places + assert np.isclose(incl[i].item(), lal_incl, atol=1e-4) + assert np.isclose(s1x[i].item(), lal_s1x, atol=1e-4) + assert np.isclose(s1y[i].item(), lal_s1y, atol=1e-4) + assert np.isclose(s1z[i].item(), lal_s1z, atol=1e-4) + assert np.isclose(s2x[i].item(), lal_s2x, atol=1e-4) + assert np.isclose(s2y[i].item(), lal_s2y, atol=1e-4) + assert np.isclose(s2z[i].item(), lal_s2z, atol=1e-4)