Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_config and from_config methods to DenseVariational_v2 #1695

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tensorflow_probability/python/layers/dense_variational_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def build(self, input_shape):
last_dim * self.units,
self.units if self.use_bias else 0,
dtype)

with tf.name_scope('prior'):
self._prior = self._make_prior_fn(
last_dim * self.units,
Expand Down Expand Up @@ -157,10 +158,25 @@ def compute_output_shape(self, input_shape):
input_shape = input_shape.with_rank_at_least(2)
if input_shape[-1] is None:
raise ValueError(
f'The innermost dimension of input_shape must be defined, but saw: {input_shape}'
f'The innermost dimension of input_shape must be defined, '
f'but saw: {input_shape}'
)
return input_shape[:-1].concatenate(self.units)

def get_config(self):
base_config = super(DenseVariational, self).get_config()
config = {
'units': self.units,
'make_posterior_fn': self._make_posterior_fn,
'make_prior_fn': self._make_prior_fn,
'activation': self.activation,
'use_bias': self.use_bias,
}
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config):
return cls(**config)

def _make_kl_divergence_penalty(
use_exact_kl=False,
Expand Down