Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial chi-squared veto implementation #79

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 95 additions & 46 deletions ml4gw/gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
"""

from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from torchtyping import TensorType
Expand Down Expand Up @@ -285,6 +285,96 @@ def get_ifo_geometry(
return torch.Tensor(tensors), torch.Tensor(vertices)


def snr_frequency_series(
template: PSDTensor,
df: float,
psd: Optional[PSDTensor] = None,
strain: Optional[PSDTensor] = None,
highpass: Union[float, TensorType["frequency"], None] = None,
scale: float = 1e20,
) -> PSDTensor:
"""
Returns SNR as a function of frequency
"""

if strain is None:
other = template
else:
other = strain

# TODO: handle scale issues by 1) dividing
# each tensor by the ASD separately and 2)
# multiplying by the sqrt of the scale factor.
# This is less optimal than our previous implementation
# because we're now doing 2 element-wise divides by
# the ASD rather than 1, so worth figuring out if
# there's a better, more consistent way to make the
# scale workout. All see TODO below
asd = psd**0.5
template = template * scale**0.5 / asd
other = other.conj() * scale**0.5 / asd
integrand = template * other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this different than doing integrand = template * other.conj() * scale / psd? Do you run into precision issues?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's exactly it, this is the most stable implementation I've found


# convert to real before we divide the scale
# because something about the complex datatype
# is making this go to 0.
# TODO: we should probably be returning the full
# complex integrand and letting downstream functions
# convert to real where necessary, this way this
# can ultimately be used to take an ifft and compute
# an SNR timeseries. But this requires figuring out
# how to get the scaling right for the complex version
integrand = integrand.real / scale
integrand = integrand.type(torch.float32)

# mask out low frequency components if a critical
# frequency or frequency mask was provided
if highpass is not None:
if not isinstance(highpass, torch.Tensor):
freqs = torch.arange(template.size(-1)) * df
highpass = freqs >= highpass
elif len(highpass) != integrand.shape[-1]:
raise ValueError(
"Can't apply highpass filter mask with {} frequecy bins"
"to signal fft with {} frequency bins".format(
len(highpass), integrand.shape[-1]
)
)
integrand *= highpass.to(integrand)
return integrand * 4 * df


def snr_frequency_series_from_timeseries(
template: WaveformTensor,
sample_rate: float,
psd: Optional[PSDTensor] = None,
strain: Optional[WaveformTensor] = None,
highpass: Union[float, TensorType["frequency"], None] = None,
) -> PSDTensor:
# TODO: should this and snr_frequency_series just
# be combined into a single function with an `input_domain`
# argument? Or maybe expose both `df` and `sample` rate
# arguments and then infer input domain from whichever
# is not `None`?
df = sample_rate / template.size(-1)

# TODO: should we do windowing here?
# compute frequency power, upsampling precision so that
# computing absolute value doesn't accidentally zero some
# values out.
htilde = torch.fft.rfft(template, axis=-1).type(torch.complex128)
if strain is not None:
stilde = torch.fft.rfft(strain, axis=-1).type(torch.complex128)
else:
stilde = None
integrand = snr_frequency_series(htilde, df, psd, stilde, highpass)

# factor of sample_rate**2 that should have been applied
# to each FFT separately, but doing it after the fact
# where the orders of magnitude are more manageable
return integrand / sample_rate**2


def compute_ifo_snr(
responses: WaveformTensor,
psd: PSDTensor,
Expand Down Expand Up @@ -339,51 +429,10 @@ def compute_ifo_snr(
Batch of SNRs computed for each interferometer
"""

# TODO: should we do windowing here?
# compute frequency power, upsampling precision so that
# computing absolute value doesn't accidentally zero some
# values out.
fft = torch.fft.rfft(responses, axis=-1).type(torch.complex128)
fft = fft.abs() / sample_rate

# divide by background asd, then go back to FP32 precision
# and square now that values are back in a reasonable range
integrand = fft / (psd**0.5)
integrand = integrand.type(torch.float32) ** 2

# mask out low frequency components if a critical
# frequency or frequency mask was provided
if highpass is not None:
if not isinstance(highpass, torch.Tensor):
freqs = torch.fft.rfftfreq(responses.shape[-1], 1 / sample_rate)
highpass = freqs >= highpass
elif len(highpass) != integrand.shape[-1]:
raise ValueError(
"Can't apply highpass filter mask with {} frequecy bins"
"to signal fft with {} frequency bins".format(
len(highpass), integrand.shape[-1]
)
)
integrand *= highpass.to(integrand.device)

# sum over the desired frequency range and multiply
# by df to turn it into an integration (and get
# our units to drop out)
# TODO: we could in principle do this without requiring
# that the user specify the sample rate by taking the
# fft as-is (without dividing by sample rate) and then
# taking the mean here (or taking the sum and dividing
# by the sum of `highpass` if it's a mask). If we want
# to allow the user to pass a float for highpass, we'll
# need the sample rate to compute the mask, but if we
# replace this with a `mask` argument instead we're in
# the clear
df = sample_rate / responses.shape[-1]
integrated = integrand.sum(axis=-1) * df

# multiply by 4 for mystical reasons
integrated = 4 * integrated # rho-squared
return torch.sqrt(integrated)
integrand = snr_frequency_series_from_timeseries(
responses, sample_rate, psd, highpass=highpass
)
return integrand.sum(-1) ** 0.5


def compute_network_snr(
Expand Down
1 change: 1 addition & 0 deletions ml4gw/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .chisq import ChiSq
from .scaler import ChannelWiseScaler
from .snr_rescaler import SnrRescaler
from .spectral import SpectralDensity
Expand Down
182 changes: 182 additions & 0 deletions ml4gw/transforms/chisq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Literal, Optional

import torch

from ml4gw.gw import snr_frequency_series


class ChiSq(torch.nn.Module):
def __init__(
self,
num_bins: int,
fftlength: float,
sample_rate: float,
highpass: Optional[float] = None,
return_snr: bool = False,
input_domain: Literal["time", "frequncy"] = "time",
) -> None:
super().__init__()
self.sample_rate = sample_rate

# include extra bin so that we have all left and right edges
self.num_bins = num_bins
bins = torch.arange(num_bins + 1) / num_bins
self.register_buffer("bins", bins)

self.fftsize = int(fftlength * sample_rate)
self.num_freqs = int(fftlength * sample_rate // 2 + 1)
freqs = torch.arange(self.num_freqs) / fftlength
self.register_buffer("freqs", freqs)

if highpass is not None:
mask = freqs >= highpass
self.register_buffer("mask", mask)
else:
self.mask = None
self.return_snr = return_snr
self.input_domain = input_domain

def get_cumulative_snr(self, htilde, psd=None, stilde=None):
"""
Compute the cumulative integral of the SNR frequency
series along the frequency dimension.
"""

snr = snr_frequency_series(htilde, self.sample_rate, psd, stilde)
return snr.cumsum(dim=-1)

def make_indices(self, batch_size, num_channels):
"""
Helper function for selecting arbitrary indices
along the last axis of our batches by building
tensors of repeated index selectors for the
batch and channel axes.
"""
idx0 = torch.arange(batch_size)
idx0 = idx0.view(-1, 1, 1).repeat(1, num_channels, self.num_bins)

idx1 = torch.arange(num_channels)
idx1 = idx1.view(1, -1, 1).repeat(batch_size, 1, self.num_bins)
return idx0, idx1

def get_snr_per_bin(self, qtilde, stilde, edges, psd=None):
"""
For a normalized frequency template qtilde and
frequency-domain strain measurement stilde, measure
the SNR in the bins between the specified edges
(whose last dimension should be one greater than the
number of bins).
"""

# calculate how much SNR _actually_ ended up in each bin
cumulative_snr = self.get_cumulative_snr(qtilde, psd, stilde)

# since we have the cumulative SNR, all we need to
# do is grab the value at the left and right bin
# edges and then subtract them to get the sum
# in between them
batch_size, num_channels, _ = cumulative_snr.shape
idx0, idx1 = self.make_indices(batch_size, num_channels)

right = cumulative_snr[idx0, idx1, edges[:, :, 1:]]
left = cumulative_snr[idx0, idx1, edges[:, :, :-1]]

# need the actual total SNR to see how much
# we deviated from the expected breakdown
total_snr = cumulative_snr[:, :, -1:]
return right - left, total_snr

def partition_frequencies(self, htilde, psd=None):
"""
Compute the edges of the frequency bins that would
(roughly) evenly break up the optimal SNR of the
template. Normalize the template by its maximum
SNR as illustrated in TODO: cite
"""
# compute the cumulative SNR of our template
# wrt the background PSD as a function of frequency
cumulative_snr = self.get_cumulative_snr(htilde, psd)

# break the total SNR up into even bins
total_snr = cumulative_snr[:, :, -1:]
bins = self.bins * total_snr

# figure out which indices along the frequency axis
# break up the SNR as closely into these bins as possible
edges = torch.searchsorted(cumulative_snr, bins, side="right")
edges = edges.clamp(0, cumulative_snr.size(-1) - 1)

# normalize by the sqrt of the total SNR
qtilde = htilde / total_snr**0.5
return qtilde, edges

def interpolate_psd(self, psd):
# have to scale the interpolated psd to ensure
# that the integral of the power remains constant
factor = (psd.size(-1) / self.num_freqs) ** 2
psd = torch.nn.functional.interpolate(
psd, size=self.num_freqs, mode="linear"
)
return psd * factor

def _check_time_domain(self, template, strain):
bad_tensors, bad_shapes = [], []
if template.size(-1) != self.fftsize:
bad_tensors.append("template")
bad_shapes.append(template.shape)
if strain.size(-1) != self.fftsize:
bad_tensors.append("strain")
bad_shapes.append(strain.shape)

if bad_tensors:
verb = "has" if len(bad_tensors) == 1 else "have"
bad_tensors = " and ".join(bad_tensors)
raise ValueError(
"Both template and strain timeseries are "
"expected to have time dimension of size {}, "
"but {} {} shape(s) {}".format(
self.fftsize, bad_tensors, verb, ",".join(bad_shapes)
)
)

def forward(
self,
template: torch.Tensor,
strain: torch.Tensor,
psd: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Make PSD optional in case strain has already been whitened
"""
if psd is not None:
if psd.size(-1) != self.num_freqs:
psd = self.interpolate_psd(psd)
if self.mask is not None:
psd = psd[:, :, self.mask]

if self.input_domain == "time":
self._check_time_domain(template, strain)
htilde = torch.fft.rfft(template, dim=-1) / self.sample_rate
stilde = torch.fft.rfft(strain, dim=-1) / self.sample_rate
else:
htilde, stilde = template, strain
if self.mask is not None:
htilde = htilde[:, :, self.mask]
stilde = stilde[:, :, self.mask]

qtilde, edges = self.partition_frequencies(htilde, psd)
snr_per_bin, total_snr = self.get_snr_per_bin(
qtilde, stilde, edges, psd
)

# for each frequency bin, compute the square of the
# deviation from the expected amount of SNR in the bin
# and then sum it over all the bins
chisq_summand = (snr_per_bin - total_snr / self.num_bins) ** 2
chisq = chisq_summand.sum(dim=-1)

# normalize by number of degrees of freedom
chisq *= self.num_bins / (self.num_bins - 1)
if self.return_snr:
return chisq, total_snr
return chisq
26 changes: 26 additions & 0 deletions tests/transforms/test_chisq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch

from ml4gw.transforms import ChiSq, SpectralDensity


def test_chisq():
scale = 10 ** (-19)
background = scale * torch.randn(4, 2, 2048 * 32)
strain = scale * torch.randn(4, 2, 2048 * 4)

t = torch.arange(2048 * 4) / 2048
freq = 10 + t**3
amp = scale * (0.1 + t**3 / 64)
signal = amp * torch.sin(2 * torch.pi * freq * t)
signal = signal.view(1, 1, -1).repeat(4, 2, 1)
injected = strain + signal

spec = SpectralDensity(
sample_rate=2048, fftlength=4, overlap=2, average="median"
)
psd = spec(background)

transform = ChiSq(num_bins=8, fftlength=4, sample_rate=2048, highpass=10)
chisq = transform(signal, injected, psd)
print(chisq)
raise ValueError
Loading