From f25891830dd4917c57ff62595ce1d5598764bcbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kaan=20B=C4=B1=C3=A7akc=C4=B1?= Date: Fri, 17 Feb 2023 01:09:37 +0000 Subject: [PATCH] Add get_config and from_config methods --- .../python/layers/dense_variational_v2.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tensorflow_probability/python/layers/dense_variational_v2.py b/tensorflow_probability/python/layers/dense_variational_v2.py index 9f8dd3ebcd..61ceaac3d6 100644 --- a/tensorflow_probability/python/layers/dense_variational_v2.py +++ b/tensorflow_probability/python/layers/dense_variational_v2.py @@ -104,6 +104,7 @@ def build(self, input_shape): last_dim * self.units, self.units if self.use_bias else 0, dtype) + with tf.name_scope('prior'): self._prior = self._make_prior_fn( last_dim * self.units, @@ -157,10 +158,25 @@ def compute_output_shape(self, input_shape): input_shape = input_shape.with_rank_at_least(2) if input_shape[-1] is None: raise ValueError( - f'The innermost dimension of input_shape must be defined, but saw: {input_shape}' + f'The innermost dimension of input_shape must be defined, ' + f'but saw: {input_shape}' ) return input_shape[:-1].concatenate(self.units) + def get_config(self): + base_config = super(DenseVariational, self).get_config() + config = { + 'units': self.units, + 'make_posterior_fn': self._make_posterior_fn, + 'make_prior_fn': self._make_prior_fn, + 'activation': self.activation, + 'use_bias': self.use_bias, + } + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + return cls(**config) def _make_kl_divergence_penalty( use_exact_kl=False,