-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathflow.py
108 lines (89 loc) · 3 KB
/
flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import List, Optional
import torch.nn as nn
import stribor as st
from torch import Tensor
from torch.nn import Module
class CouplingFlow(Module):
"""
Affine coupling flow
Args:
dim: Data dimension
n_layers: Number of flow layers
hidden_dims: Hidden dimensions of the flow neural network
time_net: Time embedding module
time_hidden_dim: Time embedding hidden dimension
"""
def __init__(
self,
dim: int,
n_layers: int,
hidden_dims: List[int],
time_net: Module,
time_hidden_dim: Optional[int] = None,
**kwargs
):
super().__init__()
transforms = []
for i in range(n_layers):
transforms.append(st.ContinuousAffineCoupling(
latent_net=st.net.MLP(dim + 1, hidden_dims, 2 * dim),
time_net=getattr(st.net, time_net)(2 * dim, hidden_dim=time_hidden_dim),
mask='none' if dim == 1 else f'ordered_{i % 2}'))
self.flow = st.Flow(transforms=transforms)
def forward(
self,
x: Tensor, # Initial conditions, (..., 1, dim)
t: Tensor, # Times to solve at, (..., seq_len, dim)
t0: Optional[Tensor] = None,
) -> Tensor: # Solutions to IVP given x at t, (..., times, dim)
if x.shape[-2] == 1:
x = x.repeat_interleave(t.shape[-2], dim=-2) # (..., 1, dim) -> (..., seq_len, 1)
# If t0 not 0, solve inverse first
if t0 is not None:
x = self.flow.inverse(x, t=t0)[0]
return self.flow(x, t=t)[0]
class ResNetFlow(Module):
"""
ResNet flow
Args:
dim: Data dimension
n_layers: Number of flow layers
hidden_dims: Hidden dimensions of the residual neural network
time_net: Time embedding module
time_hidden_dim: Time embedding hidden dimension
invertible: Whether to make ResNet invertible (necessary for proper flow)
"""
def __init__(
self,
dim: int,
n_layers: int,
hidden_dims: List[int],
time_net: str,
time_hidden_dim: Optional[int] = None,
invertible: Optional[bool] = True,
**kwargs
):
super().__init__()
layers = []
for _ in range(n_layers):
layers.append(st.net.ResNetFlow(
dim,
hidden_dims,
n_layers,
activation='ReLU',
final_activation=None,
time_net=time_net,
time_hidden_dim=time_hidden_dim,
invertible=invertible
))
self.layers = nn.ModuleList(layers)
def forward(
self,
x: Tensor, # Initial conditions, (..., 1, dim)
t: Tensor, # Times to solve at, (..., seq_len, dim)
) -> Tensor: # Solutions to IVP given x at t, (..., times, dim)
if x.shape[-2] == 1:
x = x.repeat_interleave(t.shape[-2], dim=-2)
for layer in self.layers:
x = layer(x, t)
return x