Skip to content

Commit

Permalink
Fixes issue where tensors in waveform generator not getting built on …
Browse files Browse the repository at this point in the history
…correct device (#199)

* fix device  issue in waveform generator

* update version to 0.7.1
  • Loading branch information
EthanMarx authored Feb 7, 2025
1 parent b9895af commit a55d194
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ml4gw/waveforms/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def generate_conditioned_fd_waveform(
k1s = torch.round(f_min / df)

num_freqs = frequencies.size(0)
frequency_indices = torch.arange(num_freqs)
frequency_indices = torch.arange(num_freqs, device=device)
taper_mask = frequency_indices <= k1s[:, None]
taper_mask &= frequency_indices >= k0s[:, None]

Expand Down Expand Up @@ -253,7 +253,7 @@ def generate_conditioned_fd_waveform(
# that will translate the coalescense time such that it is `right_pad`
# seconds from the right edge of the window
tshift = round(self.right_pad * self.sample_rate) / self.sample_rate
kvals = torch.arange(num_freqs)
kvals = torch.arange(num_freqs, device=device)
phase_shift = torch.exp(1j * 2 * torch.pi * df * tshift * kvals)

hc_spectrum *= phase_shift
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ml4gw"
version = "0.7.0"
version = "0.7.1"
description = "Tools for training torch models on gravitational wave data"
readme = "README.md"
authors = [
Expand Down

0 comments on commit a55d194

Please sign in to comment.