diff --git a/ml4gw/waveforms/generator.py b/ml4gw/waveforms/generator.py index fef52c4..5f94da7 100644 --- a/ml4gw/waveforms/generator.py +++ b/ml4gw/waveforms/generator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Optional, Tuple import torch from jaxtyping import Float @@ -14,15 +14,41 @@ class ParameterSampler(torch.nn.Module): - def __init__(self, **parameters: Callable) -> None: + def __init__( + self, + conversion_function: Optional[Callable] = None, + **params: Callable, + ): + """ + A class for sampling parameters from a prior distribution + + Args: + conversion_function: + A callable that takes a dictionary of sampled parameters + and returns a dictionary of transformed parameters + **params: + A dictionary of parameter samplers that take an integer N + and return a tensor of shape (N, ...) representing + samples from the prior distribution + """ super().__init__() - self.parameters = parameters + self.params = params + self.conversion_function = conversion_function or (lambda x: x) def forward( self, N: int, - ) -> Dict[str, Float[Tensor, " {N}"]]: - return {k: v.sample((N,)) for k, v in self.parameters.items()} + device: str = "cpu", + ): + # sample parameters from prior + parameters = { + k: v.sample((N,)).to(device) for k, v in self.params.items() + } + # perform any necessary conversions + # to from sampled parameters to + # waveform generation parameters + parameters = self.conversion_function(parameters) + return parameters class TimeDomainCBCWaveformGenerator(torch.nn.Module):