-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgrapharm.py
242 lines (202 loc) · 11 KB
/
grapharm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
from tqdm import tqdm
import torch
import wandb
import torch.nn as nn
import logging
from models import DiffusionOrderingNetwork, DenoisingNetwork
from utils import NodeMasking
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class GraphARM(nn.Module):
'''
Class to encapsulate DiffusionOrderingNetwork and DenoisingNetwork, as well as the training loop
for both with diffusion and denoising steps.
'''
def __init__(self,
dataset,
denoising_network,
diffusion_ordering_network,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
super(GraphARM, self).__init__()
self.device = device
self.diffusion_ordering_network = diffusion_ordering_network.to(device)
self.denoising_network = denoising_network.to(device)
self.masker = NodeMasking(dataset)
self.denoising_optimizer = torch.optim.Adam(self.denoising_network.parameters(), lr=1e-5, betas=(0.9, 0.999))
self.ordering_optimizer = torch.optim.Adam(self.diffusion_ordering_network.parameters(), lr=5e-5, betas=(0.9, 0.999))
def node_decay_ordering(self, datapoint):
'''
Returns node order for a given graph, using the diffusion ordering network.
'''
p = datapoint.clone().to(self.device)
node_order = []
sigma_t_dist_list = []
for i in range(p.x.shape[0]):
# use diffusion ordering network to get probabilities
sigma_t_dist = self.diffusion_ordering_network(p, node_order)
# sample (only unmasked nodes) from categorical distribution to get node to mask
unmasked = torch.tensor([i not in node_order for i in range(p.x.shape[0])]).to(self.device)
sigma_t_dist_list.append(sigma_t_dist.flatten())
sigma_t = torch.distributions.Categorical(probs=sigma_t_dist[unmasked].flatten()).sample()
# get node index
sigma_t = torch.where(unmasked.flatten())[0][sigma_t.long()]
node_order.append(sigma_t)
return node_order, sigma_t_dist_list
def uniform_node_decay_ordering(self, datapoint):
'''
Samples next node from uniform distribution
'''
p = datapoint.clone()
return torch.randperm(p.x.shape[0]).tolist()
def generate_diffusion_trajectories(self, graph, M):
'''
Generates M diffusion trajectories for a given graph,
using the node decay ordering mechanism.
'''
original_data = graph.clone().to(self.device)
diffusion_trajectories = []
for m in range(M):
node_order, sigma_t_dist = self.node_decay_ordering(graph)
node_order_invariate = node_order
# create diffusion trajectory
diffusion_trajectory = [original_data]
masked_data = graph.clone()
for i in range(len(node_order)):
node = node_order[i]
masked_data = masked_data.clone().to(self.device)
masked_data = self.masker.mask_node(masked_data, node)
diffusion_trajectory.append(masked_data)
if i < len(node_order) - 1:
masked_data = self.masker.remove_node(masked_data, node)
node_order = [n - 1 if n > node else n for n in node_order] # update node order to account for removed node
diffusion_trajectories.append([diffusion_trajectory, node_order_invariate, sigma_t_dist])
return diffusion_trajectories
def preprocess(self, graph):
'''
Preprocesses graph to be used by the denoising network.
'''
graph = graph.clone()
graph = self.masker.idxify(graph)
graph = self.masker.fully_connect(graph)
return graph
def compute_nll_node(self, node_type_probs, correct_node_type, sigma_t_dist):
'''
Computes the negative log-likelihood for node types.
'''
# Compute NLL for edge type
node_probs = node_type_probs * sigma_t_dist.view(-1, 1).clone()
# get original edge index for each node being unmasked
nll_node = -torch.log(node_probs[:, correct_node_type].sum() + 1e-8)
return nll_node.mean()
def compute_nll_edge(self, edge_type_probs, correct_edge_type):
'''
Computes the negative log-likelihood for edge types.
- get probability of choosing edge type for each edge
- compose edge_type_probs with sigma_t_dist to get probability of choosing edge type for each edge
'''
edge_probs = edge_type_probs.view(-1, edge_type_probs.shape[-1])
edge_probs = torch.gather(edge_probs, 1, correct_edge_type.view(-1, 1))
nll_edge = -torch.log(edge_probs + 1e-8).sum()
return nll_edge.mean()
def compute_denoising_loss(self, diffusion_trajectory, node_order_invariate, sigma_t_dist_list):
'''
Computes the loss for the denoising network based on negative log-likelihood (NLL).
'''
loss = 0
T = len(diffusion_trajectory) - 1 # Total number of time steps
sigma_t = torch.stack(sigma_t_dist_list, dim=0)
G_0 = diffusion_trajectory[0] # Original graph
for t in range(0, T):
graph_t_next = diffusion_trajectory[t + 1] # G_{t+1}
node_type_probs, edge_type_probs = self.denoising_network(graph_t_next.x, graph_t_next.edge_index, graph_t_next.edge_attr)
# Compute NLL for node type
# compute for all nodes, weight them by the sigma_t_dist at the original node order
sigma_t_dist = sigma_t[t]
sigma_t_dist = sigma_t_dist[sigma_t_dist != 0]
original_node_type = G_0.x[node_order_invariate[t]]
nll_node = self.compute_nll_node(node_type_probs, original_node_type, sigma_t_dist)
# get original edge type for each edge in G_0
original_edge_types = G_0.edge_attr[(G_0.edge_index[0] == node_order_invariate[t]) &
(torch.tensor([G_0.edge_index[1][i] in node_order_invariate[t:]
for i in range(G_0.edge_index.shape[1])]))]
nll_edge = self.compute_nll_edge(edge_type_probs, original_edge_types)
loss += nll_node + nll_edge
return loss / T
def compute_ordering_loss(self, diffusion_trajectories, M):
'''
Computes the loss for the diffusion ordering network using the REINFORCE algorithm.
'''
ordering_loss = 0
for trajectory, node_order, sigma_t_dist_list in diffusion_trajectories:
# Compute the reward as the negative denoising loss
reward = -self.compute_denoising_loss(trajectory, node_order, sigma_t_dist_list)
wandb.log({"reward": reward.item()})
# REINFORCE update (policy gradient)
# Calculate probability of trajectory using sigma_t_dist_list
log_prob = torch.tensor(0.0, device=self.device)
for t in range(len(sigma_t_dist_list)):
log_prob += torch.log(sigma_t_dist_list[t][node_order[t]])
wandb.log({"log_prob_sigma_t": log_prob.item()})
ordering_loss += reward * log_prob
return ordering_loss / M
def train_step(self, train_batch, val_batch, M):
'''
Performs one training step for both the denoising and diffusion ordering networks.
'''
self.denoising_optimizer.zero_grad()
self.ordering_optimizer.zero_grad()
# Generate diffusion trajectories for each graph in the batch
total_denoising_loss = 0
total_ordering_loss = 0
for graph in train_batch:
graph = self.preprocess(graph)
diffusion_trajectories = self.generate_diffusion_trajectories(graph, M)
# Compute denoising loss
denoising_loss = sum([self.compute_denoising_loss(traj[0], traj[1], traj[2]) for traj in diffusion_trajectories])
total_denoising_loss += denoising_loss
# Backpropagation
total_denoising_loss.backward()
self.denoising_optimizer.step()
wandb.log({"denoising_loss": total_denoising_loss.item()})
for graph in val_batch:
graph = self.preprocess(graph)
diffusion_trajectories = self.generate_diffusion_trajectories(graph, M)
# Compute ordering loss for REINFORCE
ordering_loss = self.compute_ordering_loss(diffusion_trajectories, M)
total_ordering_loss += ordering_loss
total_ordering_loss.backward()
self.ordering_optimizer.step()
wandb.log({"ordering_loss": total_ordering_loss.item()})
return total_denoising_loss.item(), total_ordering_loss.item()
def predict_new_node(self, graph, sampling_method="sample", preprocess=True):
'''
Predicts the value of a new node for graph as well as its connection to all previously denoised nodes.
sampling_method: "argmax" or "sample"
- argmax: select node and edge type with highest probability
- sample: sample node and edge type from multinomial distribution
'''
assert sampling_method in ["argmax", "sample"], "sampling_method must be either 'argmax' or 'sample'"
with torch.no_grad():
if preprocess:
graph = self.preprocess(graph)
# predict node type
node_type_probs, edge_type_probs = self.denoising_network(graph.x, graph.edge_index, graph.edge_attr)
node_type_probs = node_type_probs[-1] # only predict for last node
# sample node type
if sampling_method == "sample":
node_type = torch.distributions.Categorical(probs=node_type_probs.squeeze()).sample()
elif sampling_method == "argmax":
node_type = torch.argmax(node_type_probs.squeeze(), dim=-1).reshape(-1, 1)
# sample edge type
if sampling_method == "sample":
new_connections = torch.multinomial(edge_type_probs.squeeze(), num_samples=1, replacement=True)
elif sampling_method == "argmax":
new_connections = torch.argmax(edge_type_probs.squeeze(), dim=-1).reshape(-1, 1)
# no need to filter connection to previously denoised nodes, assuming only one new node is added at a time
return node_type, new_connections
def save_model(self, denoising_network_path="denoising_network.pt", diffusion_ordering_network_path="diffusion_ordering_network.pt"):
torch.save(self.denoising_network.state_dict(), denoising_network_path)
torch.save(self.diffusion_ordering_network.state_dict(), diffusion_ordering_network_path)
def load_model(self, denoising_network_path="denoising_network.pt", diffusion_ordering_network_path="diffusion_ordering_network.pt"):
self.denoising_network.load_state_dict(torch.load(denoising_network_path, map_location=self.device))
self.diffusion_ordering_network.load_state_dict(torch.load(diffusion_ordering_network_path, map_location=self.device))