diff --git a/ml4gw/distributions.py b/ml4gw/distributions.py index 4609ddad..a4588ae9 100644 --- a/ml4gw/distributions.py +++ b/ml4gw/distributions.py @@ -5,6 +5,7 @@ from the corresponding distribution. """ +import math from typing import Optional import torch @@ -21,14 +22,15 @@ class Cosine(dist.Distribution): def __init__( self, - low: float = torch.as_tensor(-torch.pi / 2), - high: float = torch.as_tensor(torch.pi / 2), + low: float = -math.pi / 2, + high: float = math.pi / 2, validate_args=None, ): batch_shape = torch.Size() super().__init__(batch_shape, validate_args=validate_args) - self.low = low - self.norm = 1 / (torch.sin(high) - torch.sin(low)) + self.low = torch.as_tensor(low) + self.high = torch.as_tensor(high) + self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low)) def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: u = torch.rand(sample_shape, device=self.low.device) @@ -48,13 +50,16 @@ class Sine(dist.TransformedDistribution): def __init__( self, - low: float = torch.as_tensor(0), - high: float = torch.as_tensor(torch.pi), + low: float = 0.0, + high: float = math.pi, validate_args=None, ): + low = torch.as_tensor(low) + high = torch.as_tensor(high) base_dist = Cosine( low - torch.pi / 2, high - torch.pi / 2, validate_args ) + super().__init__( base_dist, [ @@ -153,12 +158,12 @@ class DeltaFunction(dist.Distribution): def __init__( self, - peak: float = torch.as_tensor(0.0), + peak: float = 0.0, validate_args=None, ): batch_shape = torch.Size() super().__init__(batch_shape, validate_args=validate_args) - self.peak = peak + self.peak = torch.as_tensor(peak) def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: return self.peak * torch.ones( diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 6f84e6e7..8ebed558 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,8 +1,8 @@ +import math + import numpy as np import pytest -import torch from scipy import optimize -from torch import pi from ml4gw import distributions @@ -11,10 +11,10 @@ def test_log_uniform(): - sampler = distributions.LogUniform(torch.e, torch.e**2) + sampler = distributions.LogUniform(math.e, math.e**2) samples = sampler.sample((10,)) assert len(samples) == 10 - assert ((torch.e <= samples) & (torch.e**2 <= 100)).all() + assert ((math.e <= samples) & (math.e**2 <= 100)).all() # check that the mean is roughly correct # (within three standard deviations) @@ -32,9 +32,9 @@ def test_cosine(): sampler = distributions.Cosine() samples = sampler.sample((10,)) assert len(samples) == 10 - assert ((-pi / 2 <= samples) & (samples <= pi / 2)).all() + assert ((-math.pi / 2 <= samples) & (samples <= math.pi / 2)).all() - sampler = distributions.Cosine(torch.as_tensor(-3), torch.as_tensor(5)) + sampler = distributions.Cosine(-3, 5) samples = sampler.sample((100,)) assert len(samples) == 100 assert ((-3 <= samples) & (samples <= 5)).all() @@ -82,6 +82,6 @@ def foo(x, a, b): def test_delta_function(): - sampler = distributions.DeltaFunction(peak=torch.as_tensor(20)) + sampler = distributions.DeltaFunction(peak=20) samples = sampler.sample((10,)) assert (samples == 20).all()