-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathVAEAC.py
executable file
·172 lines (156 loc) · 7.63 KB
/
VAEAC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import math
import torch
from torch.distributions import kl_divergence
from torch.nn import Module
from prob_utils import normal_parse_params
class VAEAC(Module):
"""
Variational Autoencoder with Arbitrary Conditioning core model.
It is rather flexible, but have several assumptions:
+ The batch of objects and the mask of unobserved features
have the same shape.
+ The prior and proposal distributions in the latent space
are component-wise independent Gaussians.
The constructor takes
+ Prior and proposal network which take as an input the concatenation
of the batch of objects and the mask of unobserved features
and return the parameters of Gaussians in the latent space.
The range of neural networks outputs should not be restricted.
+ Generative network takes latent representation as an input
and returns the parameters of generative distribution
p_theta(x_b | z, x_{1 - b}, b), where b is the mask
of unobserved features. The information about x_{1 - b} and b
can be transmitted to generative network from prior network
through nn_utils.MemoryLayer. It is guaranteed that for every batch
prior network is always executed before generative network.
+ Reconstruction log probability. rec_log_prob is a callable
which takes (groundtruth, distr_params, mask) as an input
and return vector of differentiable log probabilities
p_theta(x_b | z, x_{1 - b}, b) for each object of the batch.
+ Sigma_mu and sigma_sigma are the coefficient of the regularization
in the hidden space. The default values correspond to a very weak,
almost disappearing regularization, which is suitable for all
experimental setups the model was tested on.
"""
def __init__(self, rec_log_prob, proposal_network, prior_network,
generative_network, sigma_mu=1e4, sigma_sigma=1e-4):
super().__init__()
self.rec_log_prob = rec_log_prob
self.proposal_network = proposal_network
self.prior_network = prior_network
self.generative_network = generative_network
self.sigma_mu = sigma_mu
self.sigma_sigma = sigma_sigma
def make_observed(self, batch, mask):
"""
Copy batch of objects and zero unobserved features.
"""
observed = torch.tensor(batch)
observed[mask.byte()] = 0
return observed
def make_latent_distributions(self, batch, mask, no_proposal=False):
"""
Make latent distributions for the given batch and mask.
No no_proposal is True, return None instead of proposal distribution.
"""
observed = self.make_observed(batch, mask)
if no_proposal:
proposal = None
else:
full_info = torch.cat([batch, mask], 1)
proposal_params = self.proposal_network(full_info)
proposal = normal_parse_params(proposal_params, 1e-3)
prior_params = self.prior_network(torch.cat([observed, mask], 1))
prior = normal_parse_params(prior_params, 1e-3)
return proposal, prior
def prior_regularization(self, prior):
"""
The prior distribution regularization in the latent space.
Though it saves prior distribution parameters from going to infinity,
the model usually doesn't diverge even without this regularization.
It almost doesn't affect learning process near zero with default
regularization parameters which are recommended to be used.
"""
num_objects = prior.mean.shape[0]
mu = prior.mean.view(num_objects, -1)
sigma = prior.scale.view(num_objects, -1)
mu_regularizer = -(mu ** 2).sum(-1) / 2 / (self.sigma_mu ** 2)
sigma_regularizer = (sigma.log() - sigma).sum(-1) * self.sigma_sigma
return mu_regularizer + sigma_regularizer
def batch_vlb(self, batch, mask):
"""
Compute differentiable lower bound for the given batch of objects
and mask.
"""
proposal, prior = self.make_latent_distributions(batch, mask)
prior_regularization = self.prior_regularization(prior)
latent = proposal.rsample()
rec_params = self.generative_network(latent)
rec_loss = self.rec_log_prob(batch, rec_params, mask)
kl = kl_divergence(proposal, prior).view(batch.shape[0], -1).sum(-1)
return rec_loss - kl + prior_regularization
def batch_iwae(self, batch, mask, K):
"""
Compute IWAE log likelihood estimate with K samples per object.
Technically, it is differentiable, but it is recommended to use it
for evaluation purposes inside torch.no_grad in order to save memory.
With torch.no_grad the method almost doesn't require extra memory
for very large K.
The method makes K independent passes through generator network,
so the batch size is the same as for training with batch_vlb.
"""
proposal, prior = self.make_latent_distributions(batch, mask)
estimates = []
for i in range(K):
latent = proposal.rsample()
rec_params = self.generative_network(latent)
rec_loss = self.rec_log_prob(batch, rec_params, mask)
prior_log_prob = prior.log_prob(latent)
prior_log_prob = prior_log_prob.view(batch.shape[0], -1)
prior_log_prob = prior_log_prob.sum(-1)
proposal_log_prob = proposal.log_prob(latent)
proposal_log_prob = proposal_log_prob.view(batch.shape[0], -1)
proposal_log_prob = proposal_log_prob.sum(-1)
estimate = rec_loss + prior_log_prob - proposal_log_prob
estimates.append(estimate[:, None])
return torch.logsumexp(torch.cat(estimates, 1), 1) - math.log(K)
def generate_samples_params(self, batch, mask, K=1):
"""
Generate parameters of generative distributions for samples
from the given batch.
It makes K latent representation for each object from the batch
and generate samples from them.
The second axis is used to index samples for an object, i. e.
if the batch shape is [n x D1 x D2], then the result shape is
[n x K x D1 x D2].
It is better to use it inside torch.no_grad in order to save memory.
With torch.no_grad the method doesn't require extra memory
except the memory for the result.
"""
_, prior = self.make_latent_distributions(batch, mask)
samples_params = []
for i in range(K):
latent = prior.rsample()
sample_params = self.generative_network(latent)
samples_params.append(sample_params.unsqueeze(1))
return torch.cat(samples_params, 1)
def generate_reconstructions_params(self, batch, mask, K=1):
"""
Generate parameters of generative distributions for reconstructions
from the given batch.
It makes K latent representation for each object from the batch
and generate samples from them.
The second axis is used to index samples for an object, i. e.
if the batch shape is [n x D1 x D2], then the result shape is
[n x K x D1 x D2].
It is better to use it inside torch.no_grad in order to save memory.
With torch.no_grad the method doesn't require extra memory
except the memory for the result.
"""
_, prior = self.make_latent_distributions(batch, mask)
reconstructions_params = []
for i in range(K):
latent = prior.rsample()
rec_params = self.generative_network(latent)
reconstructions_params.append(rec_params.unsqueeze(1))
return torch.cat(reconstructions_params, 1)