Skip to content

Commit

Permalink
add conversion function to parameter sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Jan 21, 2025
1 parent 2dbac9b commit dfab4d4
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions ml4gw/waveforms/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, Optional, Tuple

import torch
from jaxtyping import Float
Expand All @@ -14,15 +14,41 @@


class ParameterSampler(torch.nn.Module):
def __init__(self, **parameters: Callable) -> None:
def __init__(
self,
conversion_function: Optional[Callable] = None,
**params: Callable,
):
"""
A class for sampling parameters from a prior distribution
Args:
conversion_function:
A callable that takes a dictionary of sampled parameters
and returns a dictionary of transformed parameters
**params:
A dictionary of parameter samplers that take an integer N
and return a tensor of shape (N, ...) representing
samples from the prior distribution
"""
super().__init__()
self.parameters = parameters
self.params = params
self.conversion_function = conversion_function or (lambda x: x)

def forward(
self,
N: int,
) -> Dict[str, Float[Tensor, " {N}"]]:
return {k: v.sample((N,)) for k, v in self.parameters.items()}
device: str = "cpu",
):
# sample parameters from prior
parameters = {
k: v.sample((N,)).to(device) for k, v in self.params.items()
}
# perform any necessary conversions
# to from sampled parameters to
# waveform generation parameters
parameters = self.conversion_function(parameters)
return parameters


class TimeDomainCBCWaveformGenerator(torch.nn.Module):
Expand Down

0 comments on commit dfab4d4

Please sign in to comment.