From f06f16e3576f60564cb4a35f233b7b9fa34b07a8 Mon Sep 17 00:00:00 2001 From: DistraxDev Date: Thu, 3 Aug 2023 18:07:27 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 553645627 --- distrax/_src/bijectors/scalar_affine.py | 1 + distrax/_src/distributions/transformed.py | 5 ++++- distrax/_src/utils/equivalence.py | 15 ++++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/distrax/_src/bijectors/scalar_affine.py b/distrax/_src/bijectors/scalar_affine.py index 601ac35..d661592 100644 --- a/distrax/_src/bijectors/scalar_affine.py +++ b/distrax/_src/bijectors/scalar_affine.py @@ -92,6 +92,7 @@ def log_scale(self) -> Numeric: @property def scale(self) -> Numeric: """The bijector's scale.""" + assert self._scale is not None # By construction. return self._scale def forward(self, x: Array) -> Array: diff --git a/distrax/_src/distributions/transformed.py b/distrax/_src/distributions/transformed.py index 7349b22..f066a28 100644 --- a/distrax/_src/distributions/transformed.py +++ b/distrax/_src/distributions/transformed.py @@ -125,12 +125,13 @@ def _infer_shapes_and_dtype(self): self._dtype = shape_dtype.dtype - # pylint:disable=invalid-unary-operand-type if self.bijector.event_ndims_out == 0: self._event_shape = () self._batch_shape = shape_dtype.shape else: + # pylint: disable-next=invalid-unary-operand-type self._event_shape = shape_dtype.shape[-self.bijector.event_ndims_out:] + # pylint: disable-next=invalid-unary-operand-type self._batch_shape = shape_dtype.shape[:-self.bijector.event_ndims_out] @property @@ -145,6 +146,7 @@ def event_shape(self) -> Tuple[int, ...]: """See `Distribution.event_shape`.""" if self._event_shape is None: self._infer_shapes_and_dtype() + assert self._event_shape is not None # By _infer_shapes_and_dtype() return self._event_shape @property @@ -152,6 +154,7 @@ def batch_shape(self) -> Tuple[int, ...]: """See `Distribution.batch_shape`.""" if self._batch_shape is None: self._infer_shapes_and_dtype() + assert self._batch_shape is not None # By _infer_shapes_and_dtype() return self._batch_shape def log_prob(self, value: EventT) -> Array: diff --git a/distrax/_src/utils/equivalence.py b/distrax/_src/utils/equivalence.py index ceb4dcd..da5621c 100644 --- a/distrax/_src/utils/equivalence.py +++ b/distrax/_src/utils/equivalence.py @@ -92,6 +92,11 @@ def f(x, y): return f + def _get_tfp_cls(self) -> type(tfd.Distribution): + if self.tfp_cls is None: + raise ValueError('TFP class undefined. Run _init_distr_cls() first.') + return self.tfp_cls + def _test_attribute( self, attribute_string: str, @@ -135,7 +140,7 @@ def _test_attribute( tfp_dist_kwargs = dist_kwargs dist = self.distrax_cls(*dist_args, **dist_kwargs) - tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs) + tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs) if callable(getattr(dist, attribute_string)): distrax_fn = getattr(dist, attribute_string) @@ -202,7 +207,7 @@ def sample_fn(key, sample_shape=sample_shape): sample_fn = self.variant(sample_fn) samples = sample_fn(self.key) - tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs) + tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.key) chex.assert_equal_shape([samples, tfp_samples]) @@ -231,7 +236,7 @@ def sample_and_log_prob_fn(key): log_prob_fn = self.variant(dist.log_prob) samples, log_prob = sample_and_log_prob_fn(self.key) - tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs) + tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs) tfp_samples = tfp_dist.sample(sample_shape=sample_shape, seed=self.key) tfp_log_prob = tfp_dist.log_prob(samples) @@ -305,9 +310,9 @@ def _test_with_two_distributions( tfp_dist2_kwargs = dist2_kwargs dist1 = self.distrax_cls(*dist1_args, **dist1_kwargs) - tfp_dist1 = self.tfp_cls(*tfp_dist1_args, **tfp_dist1_kwargs) + tfp_dist1 = self._get_tfp_cls()(*tfp_dist1_args, **tfp_dist1_kwargs) dist2 = self.distrax_cls(*dist2_args, **dist2_kwargs) - tfp_dist2 = self.tfp_cls(*tfp_dist2_args, **tfp_dist2_kwargs) + tfp_dist2 = self._get_tfp_cls()(*tfp_dist2_args, **tfp_dist2_kwargs) tfp_comp_dist1_dist2 = getattr(tfp_dist1, attribute_string)(tfp_dist2) tfp_comp_dist2_dist1 = getattr(tfp_dist2, attribute_string)(tfp_dist1)