-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinv_conv.py
36 lines (28 loc) · 1.13 KB
/
inv_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class InvConv(nn.Module):
"""Invertible 1x1 Convolution for 2D inputs. Originally described in Glow
(https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version.
Args:
num_channels (int): Number of channels in the input and output.
"""
def __init__(self, num_channels):
super(InvConv, self).__init__()
self.num_channels = num_channels
# Initialize with a random orthogonal matrix
w_init = np.random.randn(num_channels, num_channels)
w_init = np.linalg.qr(w_init)[0].astype(np.float32)
self.weight = nn.Parameter(torch.from_numpy(w_init))
def forward(self, x, sldj, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
sldj = sldj - ldj
else:
weight = self.weight
sldj = sldj + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, sldj