From 709ac23daea999794db13552b74d4af0facd293e Mon Sep 17 00:00:00 2001 From: Ethan Jacob Marx Date: Tue, 14 May 2024 11:04:39 -0700 Subject: [PATCH 1/2] fix type hints in dists --- ml4gw/distributions.py | 21 +++++++++++++-------- tests/test_distributions.py | 13 ++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ml4gw/distributions.py b/ml4gw/distributions.py index 4609ddad..d0392fe1 100644 --- a/ml4gw/distributions.py +++ b/ml4gw/distributions.py @@ -9,6 +9,7 @@ import torch import torch.distributions as dist +import math class Cosine(dist.Distribution): @@ -21,14 +22,15 @@ class Cosine(dist.Distribution): def __init__( self, - low: float = torch.as_tensor(-torch.pi / 2), - high: float = torch.as_tensor(torch.pi / 2), + low: float = -math.pi / 2, + high: float = math.pi / 2, validate_args=None, ): batch_shape = torch.Size() super().__init__(batch_shape, validate_args=validate_args) - self.low = low - self.norm = 1 / (torch.sin(high) - torch.sin(low)) + self.low = torch.as_tensor(low) + self.high = torch.as_tensor(high) + self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low)) def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: u = torch.rand(sample_shape, device=self.low.device) @@ -48,13 +50,16 @@ class Sine(dist.TransformedDistribution): def __init__( self, - low: float = torch.as_tensor(0), - high: float = torch.as_tensor(torch.pi), + low: float = 0.0, + high: float = math.pi, validate_args=None, ): + low = torch.as_tensor(low) + high = torch.as_tensor(high) base_dist = Cosine( low - torch.pi / 2, high - torch.pi / 2, validate_args ) + super().__init__( base_dist, [ @@ -153,12 +158,12 @@ class DeltaFunction(dist.Distribution): def __init__( self, - peak: float = torch.as_tensor(0.0), + peak: float = 0.0, validate_args=None, ): batch_shape = torch.Size() super().__init__(batch_shape, validate_args=validate_args) - self.peak = peak + self.peak = torch.as_tensor(peak) def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: return self.peak * torch.ones( diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 6f84e6e7..cd1374df 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -2,8 +2,7 @@ import pytest import torch from scipy import optimize -from torch import pi - +import math from ml4gw import distributions # TODO: for all tests, how to validate that @@ -11,10 +10,10 @@ def test_log_uniform(): - sampler = distributions.LogUniform(torch.e, torch.e**2) + sampler = distributions.LogUniform(math.e, math.e**2) samples = sampler.sample((10,)) assert len(samples) == 10 - assert ((torch.e <= samples) & (torch.e**2 <= 100)).all() + assert ((math.e <= samples) & (math.e**2 <= 100)).all() # check that the mean is roughly correct # (within three standard deviations) @@ -32,9 +31,9 @@ def test_cosine(): sampler = distributions.Cosine() samples = sampler.sample((10,)) assert len(samples) == 10 - assert ((-pi / 2 <= samples) & (samples <= pi / 2)).all() + assert ((-math.pi / 2 <= samples) & (samples <= math.pi / 2)).all() - sampler = distributions.Cosine(torch.as_tensor(-3), torch.as_tensor(5)) + sampler = distributions.Cosine(-3, 5) samples = sampler.sample((100,)) assert len(samples) == 100 assert ((-3 <= samples) & (samples <= 5)).all() @@ -82,6 +81,6 @@ def foo(x, a, b): def test_delta_function(): - sampler = distributions.DeltaFunction(peak=torch.as_tensor(20)) + sampler = distributions.DeltaFunction(peak=20) samples = sampler.sample((10,)) assert (samples == 20).all() From 6e3c5fcdb8b23623c1c8f4562aa69f19aedd3029 Mon Sep 17 00:00:00 2001 From: Ethan Jacob Marx Date: Tue, 14 May 2024 11:05:31 -0700 Subject: [PATCH 2/2] fix type hints in dists --- ml4gw/distributions.py | 2 +- tests/test_distributions.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ml4gw/distributions.py b/ml4gw/distributions.py index d0392fe1..a4588ae9 100644 --- a/ml4gw/distributions.py +++ b/ml4gw/distributions.py @@ -5,11 +5,11 @@ from the corresponding distribution. """ +import math from typing import Optional import torch import torch.distributions as dist -import math class Cosine(dist.Distribution): diff --git a/tests/test_distributions.py b/tests/test_distributions.py index cd1374df..8ebed558 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,8 +1,9 @@ +import math + import numpy as np import pytest -import torch from scipy import optimize -import math + from ml4gw import distributions # TODO: for all tests, how to validate that