-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistributions.py
214 lines (197 loc) · 8.66 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""A Gaussian distribution with tanh transformation."""
import torch
from torch.distributions import Normal
from torch.distributions.independent import Independent
class TanhNormal(torch.distributions.Distribution):
r"""A distribution induced by applying a tanh transformation to a Gaussian random variable.
Algorithms like SAC and Pearl use this transformed distribution.
It can be thought of as a distribution of X where
:math:`Y ~ \mathcal{N}(\mu, \sigma)`
:math:`X = tanh(Y)`
Args:
loc (torch.Tensor): The mean of this distribution.
scale (torch.Tensor): The stdev of this distribution.
""" # noqa: 501
def __init__(self, loc, scale):
self._normal = Independent(Normal(loc, scale), 1)
super().__init__()
def log_prob(self, value, pre_tanh_value=None, epsilon=1e-6):
"""The log likelihood of a sample on the this Tanh Distribution.
Args:
value (torch.Tensor): The sample whose loglikelihood is being
computed.
pre_tanh_value (torch.Tensor): The value prior to having the tanh
function applied to it but after it has been sampled from the
normal distribution.
epsilon (float): Regularization constant. Making this value larger
makes the computation more stable but less precise.
Note:
when pre_tanh_value is None, an estimate is made of what the
value is. This leads to a worse estimation of the log_prob.
If the value being used is collected from functions like
`sample` and `rsample`, one can instead use functions like
`sample_return_pre_tanh_value` or
`rsample_return_pre_tanh_value`
Returns:
torch.Tensor: The log likelihood of value on the distribution.
"""
# pylint: disable=arguments-differ
if pre_tanh_value is None:
pre_tanh_value = torch.log(
(1 + epsilon + value) / (1 + epsilon - value)) / 2
norm_lp = self._normal.log_prob(pre_tanh_value)
ret = (norm_lp - torch.sum(
torch.log(self._clip_but_pass_gradient((1. - value**2)) + epsilon),
axis=-1))
return ret
def sample(self, sample_shape=torch.Size()):
"""Return a sample, sampled from this TanhNormal Distribution.
Args:
sample_shape (list): Shape of the returned value.
Note:
Gradients `do not` pass through this operation.
Returns:
torch.Tensor: Sample from this TanhNormal distribution.
"""
with torch.no_grad():
return self.rsample(sample_shape=sample_shape)
def rsample(self, sample_shape=torch.Size()):
"""Return a sample, sampled from this TanhNormal Distribution.
Args:
sample_shape (list): Shape of the returned value.
Note:
Gradients pass through this operation.
Returns:
torch.Tensor: Sample from this TanhNormal distribution.
"""
z = self._normal.rsample(sample_shape)
return torch.tanh(z)
def rsample_with_pre_tanh_value(self, sample_shape=torch.Size()):
"""Return a sample, sampled from this TanhNormal distribution.
Returns the sampled value before the tanh transform is applied and the
sampled value with the tanh transform applied to it.
Args:
sample_shape (list): shape of the return.
Note:
Gradients pass through this operation.
Returns:
torch.Tensor: Samples from this distribution.
torch.Tensor: Samples from the underlying
:obj:`torch.distributions.Normal` distribution, prior to being
transformed with `tanh`.
"""
z = self._normal.rsample(sample_shape)
return z, torch.tanh(z)
def cdf(self, value):
"""Returns the CDF at the value.
Returns the cumulative density/mass function evaluated at
`value` on the underlying normal distribution.
Args:
value (torch.Tensor): The element where the cdf is being evaluated
at.
Returns:
torch.Tensor: the result of the cdf being computed.
"""
return self._normal.cdf(value)
def icdf(self, value):
"""Returns the icdf function evaluated at `value`.
Returns the icdf function evaluated at `value` on the underlying
normal distribution.
Args:
value (torch.Tensor): The element where the cdf is being evaluated
at.
Returns:
torch.Tensor: the result of the cdf being computed.
"""
return self._normal.icdf(value)
@classmethod
def _from_distribution(cls, new_normal):
"""Construct a new TanhNormal distribution from a normal distribution.
Args:
new_normal (Independent(Normal)): underlying normal dist for
the new TanhNormal distribution.
Returns:
TanhNormal: A new distribution whose underlying normal dist
is new_normal.
"""
# pylint: disable=protected-access
new = cls(torch.zeros(1), torch.zeros(1))
new._normal = new_normal
return new
def expand(self, batch_shape, _instance=None):
"""Returns a new TanhNormal distribution.
(or populates an existing instance provided by a derived class) with
batch dimensions expanded to `batch_shape`. This method calls
:class:`~torch.Tensor.expand` on the distribution's parameters. As
such, this does not allocate new memory for the expanded distribution
instance. Additionally, this does not repeat any args checking or
parameter broadcasting in `__init__.py`, when an instance is first
created.
Args:
batch_shape (torch.Size): the desired expanded size.
_instance(instance): new instance provided by subclasses that
need to override `.expand`.
Returns:
Instance: New distribution instance with batch dimensions expanded
to `batch_size`.
"""
new_normal = self._normal.expand(batch_shape, _instance)
new = self._from_distribution(new_normal)
return new
def enumerate_support(self, expand=True):
"""Returns tensor containing all values supported by a discrete dist.
The result will enumerate over dimension 0, so the shape
of the result will be `(cardinality,) + batch_shape + event_shape`
(where `event_shape = ()` for univariate distributions).
Note that this enumerates over all batched tensors in lock-step
`[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
along dim 0, but with the remaining batch dimensions being
singleton dimensions, `[[0], [1], ..`.
To iterate over the full Cartesian product use
`itertools.product(m.enumerate_support())`.
Args:
expand (bool): whether to expand the support over the
batch dims to match the distribution's `batch_shape`.
Note:
Calls the enumerate_support function of the underlying normal
distribution.
Returns:
torch.Tensor: Tensor iterating over dimension 0.
"""
return self._normal.enumerate_support(expand)
@property
def mean(self):
"""torch.Tensor: mean of the distribution."""
return torch.tanh(self._normal.mean)
@property
def variance(self):
"""torch.Tensor: variance of the underlying normal distribution."""
return self._normal.variance
def entropy(self):
"""Returns entropy of the underlying normal distribution.
Returns:
torch.Tensor: entropy of the underlying normal distribution.
"""
return self._normal.entropy()
@staticmethod
def _clip_but_pass_gradient(x, lower=0., upper=1.):
"""Clipping function that allows for gradients to flow through.
Args:
x (torch.Tensor): value to be clipped
lower (float): lower bound of clipping
upper (float): upper bound of clipping
Returns:
torch.Tensor: x clipped between lower and upper.
"""
clip_up = (x > upper).float()
clip_low = (x < lower).float()
with torch.no_grad():
clip = ((upper - x) * clip_up + (lower - x) * clip_low)
return x + clip
def __repr__(self):
"""Returns the parameterization of the distribution.
Returns:
str: The parameterization of the distribution and underlying
distribution.
"""
return self.__class__.__name__