Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 553645627
  • Loading branch information
DistraxDev authored and DistraxDev committed Aug 4, 2023
1 parent 18c10e5 commit f06f16e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
1 change: 1 addition & 0 deletions distrax/_src/bijectors/scalar_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def log_scale(self) -> Numeric:
@property
def scale(self) -> Numeric:
"""The bijector's scale."""
assert self._scale is not None # By construction.
return self._scale

def forward(self, x: Array) -> Array:
Expand Down
5 changes: 4 additions & 1 deletion distrax/_src/distributions/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,13 @@ def _infer_shapes_and_dtype(self):

self._dtype = shape_dtype.dtype

# pylint:disable=invalid-unary-operand-type
if self.bijector.event_ndims_out == 0:
self._event_shape = ()
self._batch_shape = shape_dtype.shape
else:
# pylint: disable-next=invalid-unary-operand-type
self._event_shape = shape_dtype.shape[-self.bijector.event_ndims_out:]
# pylint: disable-next=invalid-unary-operand-type
self._batch_shape = shape_dtype.shape[:-self.bijector.event_ndims_out]

@property
Expand All @@ -145,13 +146,15 @@ def event_shape(self) -> Tuple[int, ...]:
"""See `Distribution.event_shape`."""
if self._event_shape is None:
self._infer_shapes_and_dtype()
assert self._event_shape is not None # By _infer_shapes_and_dtype()
return self._event_shape

@property
def batch_shape(self) -> Tuple[int, ...]:
"""See `Distribution.batch_shape`."""
if self._batch_shape is None:
self._infer_shapes_and_dtype()
assert self._batch_shape is not None # By _infer_shapes_and_dtype()
return self._batch_shape

def log_prob(self, value: EventT) -> Array:
Expand Down
15 changes: 10 additions & 5 deletions distrax/_src/utils/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def f(x, y):

return f

def _get_tfp_cls(self) -> type(tfd.Distribution):
if self.tfp_cls is None:
raise ValueError('TFP class undefined. Run _init_distr_cls() first.')
return self.tfp_cls

def _test_attribute(
self,
attribute_string: str,
Expand Down Expand Up @@ -135,7 +140,7 @@ def _test_attribute(
tfp_dist_kwargs = dist_kwargs

dist = self.distrax_cls(*dist_args, **dist_kwargs)
tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs)
tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs)

if callable(getattr(dist, attribute_string)):
distrax_fn = getattr(dist, attribute_string)
Expand Down Expand Up @@ -202,7 +207,7 @@ def sample_fn(key, sample_shape=sample_shape):
sample_fn = self.variant(sample_fn)
samples = sample_fn(self.key)

tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs)
tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs)
tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
seed=self.key)
chex.assert_equal_shape([samples, tfp_samples])
Expand Down Expand Up @@ -231,7 +236,7 @@ def sample_and_log_prob_fn(key):
log_prob_fn = self.variant(dist.log_prob)
samples, log_prob = sample_and_log_prob_fn(self.key)

tfp_dist = self.tfp_cls(*tfp_dist_args, **tfp_dist_kwargs)
tfp_dist = self._get_tfp_cls()(*tfp_dist_args, **tfp_dist_kwargs)
tfp_samples = tfp_dist.sample(sample_shape=sample_shape,
seed=self.key)
tfp_log_prob = tfp_dist.log_prob(samples)
Expand Down Expand Up @@ -305,9 +310,9 @@ def _test_with_two_distributions(
tfp_dist2_kwargs = dist2_kwargs

dist1 = self.distrax_cls(*dist1_args, **dist1_kwargs)
tfp_dist1 = self.tfp_cls(*tfp_dist1_args, **tfp_dist1_kwargs)
tfp_dist1 = self._get_tfp_cls()(*tfp_dist1_args, **tfp_dist1_kwargs)
dist2 = self.distrax_cls(*dist2_args, **dist2_kwargs)
tfp_dist2 = self.tfp_cls(*tfp_dist2_args, **tfp_dist2_kwargs)
tfp_dist2 = self._get_tfp_cls()(*tfp_dist2_args, **tfp_dist2_kwargs)

tfp_comp_dist1_dist2 = getattr(tfp_dist1, attribute_string)(tfp_dist2)
tfp_comp_dist2_dist1 = getattr(tfp_dist2, attribute_string)(tfp_dist1)
Expand Down

0 comments on commit f06f16e

Please sign in to comment.