-
Notifications
You must be signed in to change notification settings - Fork 35
/
diff_operators.py
63 lines (45 loc) · 1.85 KB
/
diff_operators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch.autograd import grad
def hessian(y, x):
''' hessian of y wrt x
y: shape (meta_batch_size, num_observations, channels)
x: shape (meta_batch_size, num_observations, 2)
'''
meta_batch_size, num_observations = y.shape[:2]
grad_y = torch.ones_like(y[..., 0]).to(y.device)
h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
for i in range(y.shape[-1]):
# calculate dydx over batches for each feature value of y
dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]
# calculate hessian on y for each x value
for j in range(x.shape[-1]):
h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :]
status = 0
if torch.any(torch.isnan(h)):
status = -1
return h, status
def laplace(y, x):
grad = gradient(y, x)
return divergence(grad, x)
def divergence(y, x):
div = 0.
for i in range(y.shape[-1]):
div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
return div
def gradient(y, x, grad_outputs=None):
if grad_outputs is None:
grad_outputs = torch.ones_like(y)
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
return grad
def jacobian(y, x):
''' jacobian of y wrt x '''
meta_batch_size, num_observations = y.shape[:2]
jac = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1]).to(y.device) # (meta_batch_size*num_points, 2, 2)
for i in range(y.shape[-1]):
# calculate dydx over batches for each feature value of y
y_flat = y[...,i].view(-1, 1)
jac[:, :, i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]
status = 0
if torch.any(torch.isnan(jac)):
status = -1
return jac, status