-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcheby_poly_layer.py
168 lines (142 loc) · 7.01 KB
/
cheby_poly_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Use Chebyshev polynomials for fast computation of the heat/wave equation solutions.
This file contains the implementation of the ChebyLayer class.
Xingzhi Sun
April 2023
for the chebyshev polynomials, refer to:
https://en.wikipedia.org/wiki/Chebyshev_polynomials
written as a pytorch_geometric layer.
"""
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import get_laplacian
class ChebyPolyLayer(MessagePassing):
"""
Chebyshev polynomials as a custom message passing layer.
Pass edge_index and edge_weight to the forward method.
"""
def __init__(self, coefs=None):
"""
the coefs is fixed for the layer (which is determined by the equation as well as the sampling points)
can make more flexible by passing time points to forward method. computing the coefficients should'nt be slow.
Args:
coefs (tensor): shape (T, k).
k is the degree of the chebyshev polynomials.
T is the number of time points. t is the sample times.
"""
super().__init__(aggr="add", node_dim=-2) # "Add" aggregation.
self.coefs = coefs
def forward(self, edge_index, edge_weight, x, coefs=None):
"""Evaluate the chebyshev polynomials through recursive message passing.
coefs is 2 dimentional with the first dimension to be time points.
Args:
edge_index (tensor): shape (2, E). E is number of edges.
edge_weight (tensor, optional): shape (E, ).
x (tensor): shape (n, m). n is number of nodes. m is number of features.
coefs (tensor): shape (T, k). defaults to None (use the coef specified at input).
k is the degree of the chebyshev polynomials.
T is the number of time points. t is the sample times.
Returns:
tensor: shape (T, n, m).
"""
if coefs is None: coefs = self.coefs
assert coefs is not None
k = coefs.size(1)
assert k > 2
## using symmetrically normalized laplacian so that the eigenvalues are within [0, 2]
## see https://math.stackexchange.com/questions/2511544/largest-eigenvalue-of-a-normalized-graph-laplacian
laplacian_edge_index, laplacian_edge_weight = get_laplacian(
edge_index, edge_weight, normalization='sym')
T0 = x
out = coefs[:, 0].view(-1, 1, 1) * T0
T1 = self.propagate(edge_index=laplacian_edge_index, x=x, edge_weight=laplacian_edge_weight)
out += coefs[:, 1].view(-1, 1, 1) * T1
for i in range(2, k):
T2 = 2 * self.propagate(edge_index=laplacian_edge_index, x=T1, edge_weight=laplacian_edge_weight) - T0
out += coefs[:, i].view(-1, 1, 1) * T2
T0, T1 = T1, T2
return out
# return self.propagate(edge_index=laplacian_edge_index, x=x, edge_weight=laplacian_edge_weight)
def message(self, x_j, edge_weight):
"""
edge_weight is the edge weight of the graph laplacian.
"""
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return edge_weight.view(-1, 1) * x_j
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]
# Step 6: Return new node embeddings.
return aggr_out - x ## reparameterize lambda to range (-1, 1) (L_tilde = L - I, so L_tilde x = L x - x)
class ChebyLayer(MessagePassing):
"""
[DEPRECATED.]
Chebyshev polynomials as a custom message passing layer.
"""
def __init__(self, edge_index, edge_weight):
"""computes the graph laplacian.
Args:
edge_index (tensor): shape (2, E). E is number of edges.
edge_weight (tensor, optional): shape (E, ).
"""
super().__init__(aggr="add", node_dim=-3) # "Add" aggregation.
self.edge_index = edge_index
self.edge_weight = edge_weight
## using symmetrically normalized laplacian so that the eigenvalues are within [0, 2]
## see https://math.stackexchange.com/questions/2511544/largest-eigenvalue-of-a-normalized-graph-laplacian
self.laplacian_edge_index, self.laplacian_edge_weight = get_laplacian(
self.edge_index, self.edge_weight, normalization='sym')
def forward(self, x, coefs):
"""Evaluate the chebyshev polynomials through recursive message passing.
coefs is 2 dimentional with the first dimension to be time points.
Args:
x (tensor): shape (n, *, *). n is number of nodes.
coefs (tensor): shape (T, k).
k is the degree of the chebyshev polynomials.
T is the number of time points. t is the sample times.
Returns:
tensor: shape (T, n, *, *).
"""
k = coefs.size(1)
assert k > 2
T0 = x
out = coefs[:, 0].view(-1, 1, 1, 1) * T0
T1 = self.propagate(edge_index=self.laplacian_edge_index, x=x, edge_weight=self.laplacian_edge_weight)
out += coefs[:, 1].view(-1, 1, 1, 1) * T1
for i in range(2, k):
T2 = 2 * self.propagate(edge_index=self.laplacian_edge_index, x=T1, edge_weight=self.laplacian_edge_weight) - T0
out += coefs[:, i].view(-1, 1, 1, 1) * T2
T0, T1 = T1, T2
return out
# return self.propagate(edge_index=self.laplacian_edge_index, x=x, edge_weight=self.laplacian_edge_weight)
def message(self, x_j):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return self.laplacian_edge_weight.view(-1, 1, 1) * x_j
def update(self, aggr_out, x):
# aggr_out has shape [N, out_channels]
# Step 6: Return new node embeddings.
return aggr_out - x ## reparameterize lambda to range (-1, 1) (L_tilde = L - I, so L_tilde x = L x - x)
def get_cheby_coefs(func, ts, degree, device, N=1000):
"""
get the Chebyshev polynomial coefficients of the function.
Args:
func (function): the function to be integrated.
f(t, lam) has two parameters:
t is the time point, and lam is the eigenvalue of the laplacian matrix
ts (tensor): shape (T, ). the sample time points.
degree (int): the degree of the Chebyshev polynomials. (the min power is 0, the max power is degree - 1)
N (int, optional): the number of points to be used in the integration. Defaults to 1000.
Returns:
tensor: shape (T, k). the coefficients of the Chebyshev polynomials.
Note:
see https://en.wikipedia.org/wiki/Chebyshev_polynomials#Orthogonality for the formula for computing the coefficients.
"""
ks = torch.arange(N).to(device)
xks = torch.cos(torch.pi * (ks + 0.5) / (N)) + 1 ## reparameterize from [-1, 1] to [0, 2]
ns = torch.arange(degree).unsqueeze(-1).to(device)
Tn_xks = torch.cos(ns * torch.pi * (ks + 0.5) / (N)) ## shape (degree, N)
func_xks = func(ts, xks).to(device) ## shape (T, N)
coefs = (2 * Tn_xks.unsqueeze(0) * func_xks.unsqueeze(1)).mean(axis=-1) ## shape (T, degree)
coefs[:, 0] /= 2
return coefs