diff --git a/ml4gw/gw.py b/ml4gw/gw.py index 3c146d8..cb1bcd9 100644 --- a/ml4gw/gw.py +++ b/ml4gw/gw.py @@ -10,7 +10,7 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py """ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torchtyping import TensorType @@ -285,6 +285,96 @@ def get_ifo_geometry( return torch.Tensor(tensors), torch.Tensor(vertices) +def snr_frequency_series( + template: PSDTensor, + df: float, + psd: Optional[PSDTensor] = None, + strain: Optional[PSDTensor] = None, + highpass: Union[float, TensorType["frequency"], None] = None, + scale: float = 1e20, +) -> PSDTensor: + """ + Returns SNR as a function of frequency + """ + + if strain is None: + other = template + else: + other = strain + + # TODO: handle scale issues by 1) dividing + # each tensor by the ASD separately and 2) + # multiplying by the sqrt of the scale factor. + # This is less optimal than our previous implementation + # because we're now doing 2 element-wise divides by + # the ASD rather than 1, so worth figuring out if + # there's a better, more consistent way to make the + # scale workout. All see TODO below + asd = psd**0.5 + template = template * scale**0.5 / asd + other = other.conj() * scale**0.5 / asd + integrand = template * other + + # convert to real before we divide the scale + # because something about the complex datatype + # is making this go to 0. + # TODO: we should probably be returning the full + # complex integrand and letting downstream functions + # convert to real where necessary, this way this + # can ultimately be used to take an ifft and compute + # an SNR timeseries. But this requires figuring out + # how to get the scaling right for the complex version + integrand = integrand.real / scale + integrand = integrand.type(torch.float32) + + # mask out low frequency components if a critical + # frequency or frequency mask was provided + if highpass is not None: + if not isinstance(highpass, torch.Tensor): + freqs = torch.arange(template.size(-1)) * df + highpass = freqs >= highpass + elif len(highpass) != integrand.shape[-1]: + raise ValueError( + "Can't apply highpass filter mask with {} frequecy bins" + "to signal fft with {} frequency bins".format( + len(highpass), integrand.shape[-1] + ) + ) + integrand *= highpass.to(integrand) + return integrand * 4 * df + + +def snr_frequency_series_from_timeseries( + template: WaveformTensor, + sample_rate: float, + psd: Optional[PSDTensor] = None, + strain: Optional[WaveformTensor] = None, + highpass: Union[float, TensorType["frequency"], None] = None, +) -> PSDTensor: + # TODO: should this and snr_frequency_series just + # be combined into a single function with an `input_domain` + # argument? Or maybe expose both `df` and `sample` rate + # arguments and then infer input domain from whichever + # is not `None`? + df = sample_rate / template.size(-1) + + # TODO: should we do windowing here? + # compute frequency power, upsampling precision so that + # computing absolute value doesn't accidentally zero some + # values out. + htilde = torch.fft.rfft(template, axis=-1).type(torch.complex128) + if strain is not None: + stilde = torch.fft.rfft(strain, axis=-1).type(torch.complex128) + else: + stilde = None + integrand = snr_frequency_series(htilde, df, psd, stilde, highpass) + + # factor of sample_rate**2 that should have been applied + # to each FFT separately, but doing it after the fact + # where the orders of magnitude are more manageable + return integrand / sample_rate**2 + + def compute_ifo_snr( responses: WaveformTensor, psd: PSDTensor, @@ -339,51 +429,10 @@ def compute_ifo_snr( Batch of SNRs computed for each interferometer """ - # TODO: should we do windowing here? - # compute frequency power, upsampling precision so that - # computing absolute value doesn't accidentally zero some - # values out. - fft = torch.fft.rfft(responses, axis=-1).type(torch.complex128) - fft = fft.abs() / sample_rate - - # divide by background asd, then go back to FP32 precision - # and square now that values are back in a reasonable range - integrand = fft / (psd**0.5) - integrand = integrand.type(torch.float32) ** 2 - - # mask out low frequency components if a critical - # frequency or frequency mask was provided - if highpass is not None: - if not isinstance(highpass, torch.Tensor): - freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate) - highpass = freqs >= highpass - elif len(highpass) != integrand.shape[-1]: - raise ValueError( - "Can't apply highpass filter mask with {} frequecy bins" - "to signal fft with {} frequency bins".format( - len(highpass), integrand.shape[-1] - ) - ) - integrand *= highpass.to(integrand.device) - - # sum over the desired frequency range and multiply - # by df to turn it into an integration (and get - # our units to drop out) - # TODO: we could in principle do this without requiring - # that the user specify the sample rate by taking the - # fft as-is (without dividing by sample rate) and then - # taking the mean here (or taking the sum and dividing - # by the sum of `highpass` if it's a mask). If we want - # to allow the user to pass a float for highpass, we'll - # need the sample rate to compute the mask, but if we - # replace this with a `mask` argument instead we're in - # the clear - df = sample_rate / responses.shape[-1] - integrated = integrand.sum(axis=-1) * df - - # multiply by 4 for mystical reasons - integrated = 4 * integrated # rho-squared - return torch.sqrt(integrated) + integrand = snr_frequency_series_from_timeseries( + responses, sample_rate, psd, highpass=highpass + ) + return integrand.sum(-1) ** 0.5 def compute_network_snr( diff --git a/ml4gw/transforms/__init__.py b/ml4gw/transforms/__init__.py index 8b014d5..c4fdb50 100644 --- a/ml4gw/transforms/__init__.py +++ b/ml4gw/transforms/__init__.py @@ -1,3 +1,4 @@ +from .chisq import ChiSq from .scaler import ChannelWiseScaler from .snr_rescaler import SnrRescaler from .spectral import SpectralDensity diff --git a/ml4gw/transforms/chisq.py b/ml4gw/transforms/chisq.py new file mode 100644 index 0000000..f7b18fc --- /dev/null +++ b/ml4gw/transforms/chisq.py @@ -0,0 +1,182 @@ +from typing import Literal, Optional + +import torch + +from ml4gw.gw import snr_frequency_series + + +class ChiSq(torch.nn.Module): + def __init__( + self, + num_bins: int, + fftlength: float, + sample_rate: float, + highpass: Optional[float] = None, + return_snr: bool = False, + input_domain: Literal["time", "frequncy"] = "time", + ) -> None: + super().__init__() + self.sample_rate = sample_rate + + # include extra bin so that we have all left and right edges + self.num_bins = num_bins + bins = torch.arange(num_bins + 1) / num_bins + self.register_buffer("bins", bins) + + self.fftsize = int(fftlength * sample_rate) + self.num_freqs = int(fftlength * sample_rate // 2 + 1) + freqs = torch.arange(self.num_freqs) / fftlength + self.register_buffer("freqs", freqs) + + if highpass is not None: + mask = freqs >= highpass + self.register_buffer("mask", mask) + else: + self.mask = None + self.return_snr = return_snr + self.input_domain = input_domain + + def get_cumulative_snr(self, htilde, psd=None, stilde=None): + """ + Compute the cumulative integral of the SNR frequency + series along the frequency dimension. + """ + + snr = snr_frequency_series(htilde, self.sample_rate, psd, stilde) + return snr.cumsum(dim=-1) + + def make_indices(self, batch_size, num_channels): + """ + Helper function for selecting arbitrary indices + along the last axis of our batches by building + tensors of repeated index selectors for the + batch and channel axes. + """ + idx0 = torch.arange(batch_size) + idx0 = idx0.view(-1, 1, 1).repeat(1, num_channels, self.num_bins) + + idx1 = torch.arange(num_channels) + idx1 = idx1.view(1, -1, 1).repeat(batch_size, 1, self.num_bins) + return idx0, idx1 + + def get_snr_per_bin(self, qtilde, stilde, edges, psd=None): + """ + For a normalized frequency template qtilde and + frequency-domain strain measurement stilde, measure + the SNR in the bins between the specified edges + (whose last dimension should be one greater than the + number of bins). + """ + + # calculate how much SNR _actually_ ended up in each bin + cumulative_snr = self.get_cumulative_snr(qtilde, psd, stilde) + + # since we have the cumulative SNR, all we need to + # do is grab the value at the left and right bin + # edges and then subtract them to get the sum + # in between them + batch_size, num_channels, _ = cumulative_snr.shape + idx0, idx1 = self.make_indices(batch_size, num_channels) + + right = cumulative_snr[idx0, idx1, edges[:, :, 1:]] + left = cumulative_snr[idx0, idx1, edges[:, :, :-1]] + + # need the actual total SNR to see how much + # we deviated from the expected breakdown + total_snr = cumulative_snr[:, :, -1:] + return right - left, total_snr + + def partition_frequencies(self, htilde, psd=None): + """ + Compute the edges of the frequency bins that would + (roughly) evenly break up the optimal SNR of the + template. Normalize the template by its maximum + SNR as illustrated in TODO: cite + """ + # compute the cumulative SNR of our template + # wrt the background PSD as a function of frequency + cumulative_snr = self.get_cumulative_snr(htilde, psd) + + # break the total SNR up into even bins + total_snr = cumulative_snr[:, :, -1:] + bins = self.bins * total_snr + + # figure out which indices along the frequency axis + # break up the SNR as closely into these bins as possible + edges = torch.searchsorted(cumulative_snr, bins, side="right") + edges = edges.clamp(0, cumulative_snr.size(-1) - 1) + + # normalize by the sqrt of the total SNR + qtilde = htilde / total_snr**0.5 + return qtilde, edges + + def interpolate_psd(self, psd): + # have to scale the interpolated psd to ensure + # that the integral of the power remains constant + factor = (psd.size(-1) / self.num_freqs) ** 2 + psd = torch.nn.functional.interpolate( + psd, size=self.num_freqs, mode="linear" + ) + return psd * factor + + def _check_time_domain(self, template, strain): + bad_tensors, bad_shapes = [], [] + if template.size(-1) != self.fftsize: + bad_tensors.append("template") + bad_shapes.append(template.shape) + if strain.size(-1) != self.fftsize: + bad_tensors.append("strain") + bad_shapes.append(strain.shape) + + if bad_tensors: + verb = "has" if len(bad_tensors) == 1 else "have" + bad_tensors = " and ".join(bad_tensors) + raise ValueError( + "Both template and strain timeseries are " + "expected to have time dimension of size {}, " + "but {} {} shape(s) {}".format( + self.fftsize, bad_tensors, verb, ",".join(bad_shapes) + ) + ) + + def forward( + self, + template: torch.Tensor, + strain: torch.Tensor, + psd: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Make PSD optional in case strain has already been whitened + """ + if psd is not None: + if psd.size(-1) != self.num_freqs: + psd = self.interpolate_psd(psd) + if self.mask is not None: + psd = psd[:, :, self.mask] + + if self.input_domain == "time": + self._check_time_domain(template, strain) + htilde = torch.fft.rfft(template, dim=-1) / self.sample_rate + stilde = torch.fft.rfft(strain, dim=-1) / self.sample_rate + else: + htilde, stilde = template, strain + if self.mask is not None: + htilde = htilde[:, :, self.mask] + stilde = stilde[:, :, self.mask] + + qtilde, edges = self.partition_frequencies(htilde, psd) + snr_per_bin, total_snr = self.get_snr_per_bin( + qtilde, stilde, edges, psd + ) + + # for each frequency bin, compute the square of the + # deviation from the expected amount of SNR in the bin + # and then sum it over all the bins + chisq_summand = (snr_per_bin - total_snr / self.num_bins) ** 2 + chisq = chisq_summand.sum(dim=-1) + + # normalize by number of degrees of freedom + chisq *= self.num_bins / (self.num_bins - 1) + if self.return_snr: + return chisq, total_snr + return chisq diff --git a/tests/transforms/test_chisq.py b/tests/transforms/test_chisq.py new file mode 100644 index 0000000..ba814bc --- /dev/null +++ b/tests/transforms/test_chisq.py @@ -0,0 +1,26 @@ +import torch + +from ml4gw.transforms import ChiSq, SpectralDensity + + +def test_chisq(): + scale = 10 ** (-19) + background = scale * torch.randn(4, 2, 2048 * 32) + strain = scale * torch.randn(4, 2, 2048 * 4) + + t = torch.arange(2048 * 4) / 2048 + freq = 10 + t**3 + amp = scale * (0.1 + t**3 / 64) + signal = amp * torch.sin(2 * torch.pi * freq * t) + signal = signal.view(1, 1, -1).repeat(4, 2, 1) + injected = strain + signal + + spec = SpectralDensity( + sample_rate=2048, fftlength=4, overlap=2, average="median" + ) + psd = spec(background) + + transform = ChiSq(num_bins=8, fftlength=4, sample_rate=2048, highpass=10) + chisq = transform(signal, injected, psd) + print(chisq) + raise ValueError