From 8b2ec95c627b742e0272588b43a4117a14a2f5d0 Mon Sep 17 00:00:00 2001 From: Alec Gunny Date: Wed, 8 Nov 2023 16:43:38 -0800 Subject: [PATCH 1/4] adding initial chisq implementation --- ml4gw/gw.py | 115 ++++++++++++--------- ml4gw/transforms/__init__.py | 1 + ml4gw/transforms/chisq.py | 176 +++++++++++++++++++++++++++++++++ tests/transforms/test_chisq.py | 29 ++++++ 4 files changed, 275 insertions(+), 46 deletions(-) create mode 100644 ml4gw/transforms/chisq.py create mode 100644 tests/transforms/test_chisq.py diff --git a/ml4gw/gw.py b/ml4gw/gw.py index 3c146d81..0f5d4ef8 100644 --- a/ml4gw/gw.py +++ b/ml4gw/gw.py @@ -10,7 +10,7 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py """ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torchtyping import TensorType @@ -285,6 +285,72 @@ def get_ifo_geometry( return torch.Tensor(tensors), torch.Tensor(vertices) +def snr_from_freqs( + template: PSDTensor, + df: float, + psd: Optional[PSDTensor] = None, + strain: Optional[PSDTensor] = None, + highpass: Union[float, TensorType["frequency"], None] = None, + scale: float = 1e20 +) -> PSDTensor: + """ + Returns SNR as a function of frequency + """ + + if strain is None: + other = template + else: + other = strain + asd = psd**0.5 + template = template * scale**0.5 / asd + other = other.conj() * scale**0.5 / asd + integrand = template * other + integrand = integrand.real / scale + integrand = integrand.type(torch.float32) + + # mask out low frequency components if a critical + # frequency or frequency mask was provided + if highpass is not None: + if not isinstance(highpass, torch.Tensor): + freqs = torch.arange(template.size(-1)) * df + highpass = freqs >= highpass + elif len(highpass) != integrand.shape[-1]: + raise ValueError( + "Can't apply highpass filter mask with {} frequecy bins" + "to signal fft with {} frequency bins".format( + len(highpass), integrand.shape[-1] + ) + ) + integrand *= highpass.to(integrand) + return integrand * 4 * df + + +def snr_integral( + template: WaveformTensor, + sample_rate: float, + psd: Optional[PSDTensor] = None, + strain: Optional[WaveformTensor] = None, + highpass: Union[float, TensorType["frequency"], None] = None, +) -> PSDTensor: + df = sample_rate / template.size(-1) + + # TODO: should we do windowing here? + # compute frequency power, upsampling precision so that + # computing absolute value doesn't accidentally zero some + # values out. + htilde = torch.fft.rfft(template, axis=-1).type(torch.complex128) + if strain is not None: + stilde = torch.fft.rfft(strain, axis=-1).type(torch.complex128) + else: + stilde = None + integrand = snr_from_freqs(htilde, df, psd, stilde, highpass) + + # factor of sample_rate**2 that should have been applied + # to each FFT separately, but doing it after the fact + # where the orders of magnitude are more manageable + return integrand / sample_rate**2 + + def compute_ifo_snr( responses: WaveformTensor, psd: PSDTensor, @@ -339,51 +405,8 @@ def compute_ifo_snr( Batch of SNRs computed for each interferometer """ - # TODO: should we do windowing here? - # compute frequency power, upsampling precision so that - # computing absolute value doesn't accidentally zero some - # values out. - fft = torch.fft.rfft(responses, axis=-1).type(torch.complex128) - fft = fft.abs() / sample_rate - - # divide by background asd, then go back to FP32 precision - # and square now that values are back in a reasonable range - integrand = fft / (psd**0.5) - integrand = integrand.type(torch.float32) ** 2 - - # mask out low frequency components if a critical - # frequency or frequency mask was provided - if highpass is not None: - if not isinstance(highpass, torch.Tensor): - freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate) - highpass = freqs >= highpass - elif len(highpass) != integrand.shape[-1]: - raise ValueError( - "Can't apply highpass filter mask with {} frequecy bins" - "to signal fft with {} frequency bins".format( - len(highpass), integrand.shape[-1] - ) - ) - integrand *= highpass.to(integrand.device) - - # sum over the desired frequency range and multiply - # by df to turn it into an integration (and get - # our units to drop out) - # TODO: we could in principle do this without requiring - # that the user specify the sample rate by taking the - # fft as-is (without dividing by sample rate) and then - # taking the mean here (or taking the sum and dividing - # by the sum of `highpass` if it's a mask). If we want - # to allow the user to pass a float for highpass, we'll - # need the sample rate to compute the mask, but if we - # replace this with a `mask` argument instead we're in - # the clear - df = sample_rate / responses.shape[-1] - integrated = integrand.sum(axis=-1) * df - - # multiply by 4 for mystical reasons - integrated = 4 * integrated # rho-squared - return torch.sqrt(integrated) + integrand = snr_integral(responses, sample_rate, psd, highpass=highpass) + return integrand.sum(-1)**0.5 def compute_network_snr( diff --git a/ml4gw/transforms/__init__.py b/ml4gw/transforms/__init__.py index 8b014d5b..c4fdb507 100644 --- a/ml4gw/transforms/__init__.py +++ b/ml4gw/transforms/__init__.py @@ -1,3 +1,4 @@ +from .chisq import ChiSq from .scaler import ChannelWiseScaler from .snr_rescaler import SnrRescaler from .spectral import SpectralDensity diff --git a/ml4gw/transforms/chisq.py b/ml4gw/transforms/chisq.py new file mode 100644 index 00000000..bf9f2ae4 --- /dev/null +++ b/ml4gw/transforms/chisq.py @@ -0,0 +1,176 @@ +from typing import Literal, Optional + +import torch + +from ml4gw.gw import snr_from_freqs + + +class ChiSq(torch.nn.Module): + def __init__( + self, + num_bins: int, + fftlength: float, + sample_rate: float, + highpass: Optional[float] = None, + return_snr: bool = False, + input_domain: Literal["time", "frequncy"] = "time" + ) -> None: + super().__init__() + self.sample_rate = sample_rate + + # include extra bin so that we have all left and right edges + self.num_bins = num_bins + bins = torch.arange(num_bins + 1) / num_bins + self.register_buffer("bins", bins) + + self.fftsize = int(fftlength * sample_rate) + self.num_freqs = int(fftlength * sample_rate // 2 + 1) + freqs = torch.arange(self.num_freqs) / fftlength + self.register_buffer("freqs", freqs) + + if highpass is not None: + mask = freqs >= highpass + self.register_buffer("mask", mask) + else: + self.mask = None + self.return_snr = return_snr + self.input_domain = input_domain + + def get_cumulative_snr(self, htilde, psd=None, stilde=None): + snr = snr_from_freqs(htilde, self.sample_rate, psd, stilde) + return snr.cumsum(dim=-1) + + def make_indices(self, batch_size, num_channels): + """ + Helper function for selecting arbitrary indices + along the last axis of our batches by building + tensors of repeated index selectors for the + batch and channel axes. + """ + idx0 = torch.arange(batch_size) + idx0 = idx0.view(-1, 1, 1).repeat(1, num_channels, self.num_bins) + + idx1 = torch.arange(num_channels) + idx1 = idx1.view(1, -1, 1).repeat(batch_size, 1, self.num_bins) + return idx0, idx1 + + def get_snr_per_bin(self, qtilde, stilde, edges, psd=None): + """ + For a normalized frequency template qtilde and + frequency-domain strain measurement stilde, measure + the SNR in the bins between the specified edges + (whose last dimension should be one greater than the + number of bins). + """ + + # calculate how much SNR _actually_ ended up in each bin + cumulative_snr = self.get_cumulative_snr(qtilde, psd, stilde) + + # since we have the cumulative SNR, all we need to + # do is grab the value at the left and right bin + # edges and then subtract them to get the sum + # in between them + batch_size, num_channels, _ = cumulative_snr.shape + idx0, idx1 = self.make_indices(batch_size, num_channels) + + right = cumulative_snr[idx0, idx1, edges[:, :, 1:]] + left = cumulative_snr[idx0, idx1, edges[:, :, :-1]] + + # need the actual total SNR to see how much + # we deviated from the expected breakdown + total_snr = cumulative_snr[:, :, -1:] + return right - left, total_snr + + def partition_frequencies(self, htilde, psd=None): + """ + Compute the edges of the frequency bins that would + (roughly) evenly break up the optimal SNR of the + template. Normalize the template by its maximum + SNR as illustrated in TODO: cite + """ + # compute the cumulative SNR of our template + # wrt the background PSD as a function of frequency + snr_integral = self.get_cumulative_snr(htilde, psd) + + # break the total SNR up into even bins + total_snr = snr_integral[:, :, -1:] + bins = self.bins * total_snr + + # figure out which indices along the frequency axis + # break up the SNR as closely into these bins as possible + edges = torch.searchsorted(snr_integral, bins, side="right") + edges = edges.clamp(0, snr_integral.size(-1) - 1) + + # normalize by the sqrt of the total SNR + qtilde = htilde / total_snr**0.5 + return qtilde, edges + + def interpolate_psd(self, psd): + # have to scale the interpolated psd to ensure + # that the integral of the power remains constant + factor = (psd.size(-1) / self.num_freqs)**2 + psd = torch.nn.functional.interpolate( + psd, size=self.num_freqs, mode="linear" + ) + return psd * factor + + def _check_time_domain(self, template, strain): + bad_tensors, bad_shapes = [], [] + if template.size(-1) != self.fftsize: + bad_tensors.append("template") + bad_shapes.append(template.shape) + if strain.size(-1) != self.fftsize: + bad_tensors.append("strain") + bad_shapes.append(strain.shape) + + if bad_tensors: + verb = "has" if len(bad_tensors) == 1 else "have" + bad_tensors = " and ".join(bad_tensors) + raise ValueError( + "Both template and strain timeseries are " + "expected to have time dimension of size {}, " + "but {} {} shape(s) {}".format( + self.fftsize, bad_tensors, verb, ",".join(bad_shapes) + ) + ) + + def forward( + self, + template: torch.Tensor, + strain: torch.Tensor, + psd: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Make PSD optional in case strain has already been whitened + """ + if psd is not None: + if psd.size(-1) != self.num_freqs: + psd = self.interpolate_psd(psd) + if self.mask is not None: + psd = psd[:, :, self.mask] + + if self.input_domain == "time": + self._check_time_domain(template, strain) + htilde = torch.fft.rfft(template, dim=-1) / self.sample_rate + stilde = torch.fft.rfft(strain, dim=-1) / self.sample_rate + else: + htilde, stilde = template, strain + if self.mask is not None: + htilde = htilde[:, :, self.mask] + stilde = stilde[:, :, self.mask] + + qtilde, edges = self.partition_frequencies(htilde, psd) + snr_per_bin, total_snr = self.get_snr_per_bin(qtilde, stilde, edges, psd) + + # for each frequency bin, compute the square of the + # deviation from the expected amount of SNR in the bin + # and then sum it over all the bins + chisq_summand = (snr_per_bin - total_snr / self.num_bins)**2 + chisq = chisq_summand.sum(dim=-1) + + # normalize by number of degrees of freedom + chisq *= self.num_bins / (self.num_bins - 1) + if self.return_snr: + return chisq, total_snr + return chisq + diff --git a/tests/transforms/test_chisq.py b/tests/transforms/test_chisq.py new file mode 100644 index 00000000..2b55176a --- /dev/null +++ b/tests/transforms/test_chisq.py @@ -0,0 +1,29 @@ +import torch + +from ml4gw.transforms import ChiSq, SpectralDensity + + +def test_chisq(): + scale = 10**(-19) + background = scale * torch.randn(4, 2, 2048 * 32) + strain = scale * torch.randn(4, 2, 2048 * 4) + + t = torch.arange(2048 * 4) / 2048 + freq = 10 + t**3 + amp = scale * (0.1 + t**3/64) + signal = amp * torch.sin(2 * torch.pi * freq * t) + signal = signal.view(1, 1, -1).repeat(4, 2, 1) + injected = strain + signal + + spec = SpectralDensity( + sample_rate=2048, + fftlength=4, + overlap=2, + average="median" + ) + psd = spec(background) + + transform = ChiSq(num_bins=8, fftlength=4, sample_rate=2048, highpass=10) + chisq = transform(signal, injected, psd) + print(chisq) + raise ValueError From f546f9c979a1deb8097d29ca63dbb2e4c0d052b3 Mon Sep 17 00:00:00 2001 From: Alec Gunny Date: Wed, 8 Nov 2023 16:44:05 -0800 Subject: [PATCH 2/4] running pre-commmit checks --- ml4gw/gw.py | 4 ++-- ml4gw/transforms/chisq.py | 15 ++++++++------- tests/transforms/test_chisq.py | 9 +++------ 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/ml4gw/gw.py b/ml4gw/gw.py index 0f5d4ef8..76c22ba8 100644 --- a/ml4gw/gw.py +++ b/ml4gw/gw.py @@ -291,7 +291,7 @@ def snr_from_freqs( psd: Optional[PSDTensor] = None, strain: Optional[PSDTensor] = None, highpass: Union[float, TensorType["frequency"], None] = None, - scale: float = 1e20 + scale: float = 1e20, ) -> PSDTensor: """ Returns SNR as a function of frequency @@ -406,7 +406,7 @@ def compute_ifo_snr( """ integrand = snr_integral(responses, sample_rate, psd, highpass=highpass) - return integrand.sum(-1)**0.5 + return integrand.sum(-1) ** 0.5 def compute_network_snr( diff --git a/ml4gw/transforms/chisq.py b/ml4gw/transforms/chisq.py index bf9f2ae4..0491e92e 100644 --- a/ml4gw/transforms/chisq.py +++ b/ml4gw/transforms/chisq.py @@ -13,7 +13,7 @@ def __init__( sample_rate: float, highpass: Optional[float] = None, return_snr: bool = False, - input_domain: Literal["time", "frequncy"] = "time" + input_domain: Literal["time", "frequncy"] = "time", ) -> None: super().__init__() self.sample_rate = sample_rate @@ -108,7 +108,7 @@ def partition_frequencies(self, htilde, psd=None): def interpolate_psd(self, psd): # have to scale the interpolated psd to ensure # that the integral of the power remains constant - factor = (psd.size(-1) / self.num_freqs)**2 + factor = (psd.size(-1) / self.num_freqs) ** 2 psd = torch.nn.functional.interpolate( psd, size=self.num_freqs, mode="linear" ) @@ -135,10 +135,10 @@ def _check_time_domain(self, template, strain): ) def forward( - self, + self, template: torch.Tensor, strain: torch.Tensor, - psd: Optional[torch.Tensor] = None + psd: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Make PSD optional in case strain has already been whitened @@ -160,12 +160,14 @@ def forward( stilde = stilde[:, :, self.mask] qtilde, edges = self.partition_frequencies(htilde, psd) - snr_per_bin, total_snr = self.get_snr_per_bin(qtilde, stilde, edges, psd) + snr_per_bin, total_snr = self.get_snr_per_bin( + qtilde, stilde, edges, psd + ) # for each frequency bin, compute the square of the # deviation from the expected amount of SNR in the bin # and then sum it over all the bins - chisq_summand = (snr_per_bin - total_snr / self.num_bins)**2 + chisq_summand = (snr_per_bin - total_snr / self.num_bins) ** 2 chisq = chisq_summand.sum(dim=-1) # normalize by number of degrees of freedom @@ -173,4 +175,3 @@ def forward( if self.return_snr: return chisq, total_snr return chisq - diff --git a/tests/transforms/test_chisq.py b/tests/transforms/test_chisq.py index 2b55176a..ba814bcb 100644 --- a/tests/transforms/test_chisq.py +++ b/tests/transforms/test_chisq.py @@ -4,22 +4,19 @@ def test_chisq(): - scale = 10**(-19) + scale = 10 ** (-19) background = scale * torch.randn(4, 2, 2048 * 32) strain = scale * torch.randn(4, 2, 2048 * 4) t = torch.arange(2048 * 4) / 2048 freq = 10 + t**3 - amp = scale * (0.1 + t**3/64) + amp = scale * (0.1 + t**3 / 64) signal = amp * torch.sin(2 * torch.pi * freq * t) signal = signal.view(1, 1, -1).repeat(4, 2, 1) injected = strain + signal spec = SpectralDensity( - sample_rate=2048, - fftlength=4, - overlap=2, - average="median" + sample_rate=2048, fftlength=4, overlap=2, average="median" ) psd = spec(background) From 783a669d1b53d6ac07a6dc12246acd25a887b1a1 Mon Sep 17 00:00:00 2001 From: Alec Gunny Date: Thu, 9 Nov 2023 09:19:30 -0800 Subject: [PATCH 3/4] adding TODO explanations on SNR scaling issues --- ml4gw/gw.py | 34 ++++++++++++++++++++++++++++++---- ml4gw/transforms/chisq.py | 8 ++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/ml4gw/gw.py b/ml4gw/gw.py index 76c22ba8..cb1bcd96 100644 --- a/ml4gw/gw.py +++ b/ml4gw/gw.py @@ -285,7 +285,7 @@ def get_ifo_geometry( return torch.Tensor(tensors), torch.Tensor(vertices) -def snr_from_freqs( +def snr_frequency_series( template: PSDTensor, df: float, psd: Optional[PSDTensor] = None, @@ -301,10 +301,29 @@ def snr_from_freqs( other = template else: other = strain + + # TODO: handle scale issues by 1) dividing + # each tensor by the ASD separately and 2) + # multiplying by the sqrt of the scale factor. + # This is less optimal than our previous implementation + # because we're now doing 2 element-wise divides by + # the ASD rather than 1, so worth figuring out if + # there's a better, more consistent way to make the + # scale workout. All see TODO below asd = psd**0.5 template = template * scale**0.5 / asd other = other.conj() * scale**0.5 / asd integrand = template * other + + # convert to real before we divide the scale + # because something about the complex datatype + # is making this go to 0. + # TODO: we should probably be returning the full + # complex integrand and letting downstream functions + # convert to real where necessary, this way this + # can ultimately be used to take an ifft and compute + # an SNR timeseries. But this requires figuring out + # how to get the scaling right for the complex version integrand = integrand.real / scale integrand = integrand.type(torch.float32) @@ -325,13 +344,18 @@ def snr_from_freqs( return integrand * 4 * df -def snr_integral( +def snr_frequency_series_from_timeseries( template: WaveformTensor, sample_rate: float, psd: Optional[PSDTensor] = None, strain: Optional[WaveformTensor] = None, highpass: Union[float, TensorType["frequency"], None] = None, ) -> PSDTensor: + # TODO: should this and snr_frequency_series just + # be combined into a single function with an `input_domain` + # argument? Or maybe expose both `df` and `sample` rate + # arguments and then infer input domain from whichever + # is not `None`? df = sample_rate / template.size(-1) # TODO: should we do windowing here? @@ -343,7 +367,7 @@ def snr_integral( stilde = torch.fft.rfft(strain, axis=-1).type(torch.complex128) else: stilde = None - integrand = snr_from_freqs(htilde, df, psd, stilde, highpass) + integrand = snr_frequency_series(htilde, df, psd, stilde, highpass) # factor of sample_rate**2 that should have been applied # to each FFT separately, but doing it after the fact @@ -405,7 +429,9 @@ def compute_ifo_snr( Batch of SNRs computed for each interferometer """ - integrand = snr_integral(responses, sample_rate, psd, highpass=highpass) + integrand = snr_frequency_series_from_timeseries( + responses, sample_rate, psd, highpass=highpass + ) return integrand.sum(-1) ** 0.5 diff --git a/ml4gw/transforms/chisq.py b/ml4gw/transforms/chisq.py index 0491e92e..01359669 100644 --- a/ml4gw/transforms/chisq.py +++ b/ml4gw/transforms/chisq.py @@ -90,16 +90,16 @@ def partition_frequencies(self, htilde, psd=None): """ # compute the cumulative SNR of our template # wrt the background PSD as a function of frequency - snr_integral = self.get_cumulative_snr(htilde, psd) + cumulative_snr = self.get_cumulative_snr(htilde, psd) # break the total SNR up into even bins - total_snr = snr_integral[:, :, -1:] + total_snr = cumulative_snr[:, :, -1:] bins = self.bins * total_snr # figure out which indices along the frequency axis # break up the SNR as closely into these bins as possible - edges = torch.searchsorted(snr_integral, bins, side="right") - edges = edges.clamp(0, snr_integral.size(-1) - 1) + edges = torch.searchsorted(cumulative_snr, bins, side="right") + edges = edges.clamp(0, cumulative_snr.size(-1) - 1) # normalize by the sqrt of the total SNR qtilde = htilde / total_snr**0.5 From 3e969d73276a55d23af936e443a2dc10055096af Mon Sep 17 00:00:00 2001 From: Alec Gunny Date: Thu, 9 Nov 2023 09:21:25 -0800 Subject: [PATCH 4/4] using new snr function name in chisq --- ml4gw/transforms/chisq.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ml4gw/transforms/chisq.py b/ml4gw/transforms/chisq.py index 01359669..f7b18fc1 100644 --- a/ml4gw/transforms/chisq.py +++ b/ml4gw/transforms/chisq.py @@ -2,7 +2,7 @@ import torch -from ml4gw.gw import snr_from_freqs +from ml4gw.gw import snr_frequency_series class ChiSq(torch.nn.Module): @@ -37,7 +37,12 @@ def __init__( self.input_domain = input_domain def get_cumulative_snr(self, htilde, psd=None, stilde=None): - snr = snr_from_freqs(htilde, self.sample_rate, psd, stilde) + """ + Compute the cumulative integral of the SNR frequency + series along the frequency dimension. + """ + + snr = snr_frequency_series(htilde, self.sample_rate, psd, stilde) return snr.cumsum(dim=-1) def make_indices(self, batch_size, num_channels):