-
Notifications
You must be signed in to change notification settings - Fork 5
/
evaluate_models_utils.py
389 lines (327 loc) · 24.8 KB
/
evaluate_models_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import logging
import time
import argparse
import os
import json
from models.EdgeBank import edge_bank_link_prediction
from utils.metrics import get_link_prediction_metrics, get_node_classification_metrics
from utils.utils import set_random_seed
from utils.utils import NegativeEdgeSampler_local, NeighborSampler
from utils.DataLoader import Data
def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_neg_edge_sampler: NegativeEdgeSampler_local, evaluate_data: Data, loss_func: nn.Module,
num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the link prediction task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_neg_edge_sampler: NegativeEdgeSampler, evaluate negative edge sampler
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert evaluate_neg_edge_sampler.seed is not None
evaluate_neg_edge_sampler.reset_random_state()
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses and metrics
evaluate_losses, evaluate_metrics = [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices]
if evaluate_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = evaluate_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
# we need to compute for positive and negative edges respectively, because the new sampling strategy (for evaluation) allows the negative source nodes to be
# different from the source nodes, this is different from previous works that just replace destination nodes with negative destination nodes
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# note that negative nodes do not change the memories while the positive nodes change the memories,
# we need to first compute the embeddings of negative nodes for memory-based models
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=None,
edges_are_positive=False,
num_neighbors=num_neighbors)
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
# get temporal embedding of negative source and negative destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_neg_src_node_ids,
dst_node_ids=batch_neg_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get positive and negative probabilities, shape (batch_size, )
positive_probabilities = model[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid()
negative_probabilities = model[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid()
predicts = torch.cat([positive_probabilities, negative_probabilities], dim=0)
labels = torch.cat([torch.ones_like(positive_probabilities), torch.zeros_like(negative_probabilities)], dim=0)
loss = loss_func(input=predicts, target=labels)
evaluate_losses.append(loss.item())
evaluate_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
return evaluate_losses, evaluate_metrics
def evaluate_model_node_classification(model_name: str, model: nn.Module, neighbor_sampler: NeighborSampler, evaluate_idx_data_loader: DataLoader,
evaluate_data: Data, loss_func: nn.Module, num_neighbors: int = 20, time_gap: int = 2000):
"""
evaluate models on the node classification task
:param model_name: str, name of the model
:param model: nn.Module, the model to be evaluated
:param neighbor_sampler: NeighborSampler, neighbor sampler
:param evaluate_idx_data_loader: DataLoader, evaluate index data loader
:param evaluate_data: Data, data to be evaluated
:param loss_func: nn.Module, loss function
:param num_neighbors: int, number of neighbors to sample for each node
:param time_gap: int, time gap for neighbors to compute node features
:return:
"""
if model_name in ['DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']:
# evaluation phase use all the graph information
model[0].set_neighbor_sampler(neighbor_sampler)
model.eval()
with torch.no_grad():
# store evaluate losses, trues and predicts
evaluate_total_loss, evaluate_y_trues, evaluate_y_predicts = 0.0, [], []
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices], evaluate_data.labels[evaluate_data_indices]
if model_name in ['TGAT', 'CAWN', 'TCL']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors)
elif model_name in ['JODIE', 'DyRep', 'TGN']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
edge_ids=batch_edge_ids,
edges_are_positive=True,
num_neighbors=num_neighbors)
elif model_name in ['GraphMixer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times,
num_neighbors=num_neighbors,
time_gap=time_gap)
elif model_name in ['DyGFormer']:
# get temporal embedding of source and destination nodes
# two Tensors, with shape (batch_size, node_feat_dim)
batch_src_node_embeddings, batch_dst_node_embeddings = \
model[0].compute_src_dst_node_temporal_embeddings(src_node_ids=batch_src_node_ids,
dst_node_ids=batch_dst_node_ids,
node_interact_times=batch_node_interact_times)
else:
raise ValueError(f"Wrong value for model_name {model_name}!")
# get predicted probabilities, shape (batch_size, )
predicts = model[1](x=batch_src_node_embeddings).squeeze(dim=-1).sigmoid()
labels = torch.from_numpy(batch_labels).float().to(predicts.device)
loss = loss_func(input=predicts, target=labels)
evaluate_total_loss += loss.item()
evaluate_y_trues.append(labels)
evaluate_y_predicts.append(predicts)
evaluate_idx_data_loader_tqdm.set_description(f'evaluate for the {batch_idx + 1}-th batch, evaluate loss: {loss.item()}')
evaluate_total_loss /= (batch_idx + 1)
evaluate_y_trues = torch.cat(evaluate_y_trues, dim=0)
evaluate_y_predicts = torch.cat(evaluate_y_predicts, dim=0)
evaluate_metrics = get_node_classification_metrics(predicts=evaluate_y_predicts, labels=evaluate_y_trues)
return evaluate_total_loss, evaluate_metrics
def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Data, val_data: Data, test_idx_data_loader: DataLoader,
test_neg_edge_sampler: NegativeEdgeSampler_local, test_data: Data):
"""
evaluate the EdgeBank model for link prediction
:param args: argparse.Namespace, configuration
:param train_data: Data, train data
:param val_data: Data, validation data
:param test_idx_data_loader: DataLoader, test index data loader
:param test_neg_edge_sampler: NegativeEdgeSampler, test negative edge sampler
:param test_data: Data, test data
:return:
"""
# generate the train_validation split of the data: needed for constructing the memory for EdgeBank
train_val_data = Data(src_node_ids=np.concatenate([train_data.src_node_ids, val_data.src_node_ids]),
dst_node_ids=np.concatenate([train_data.dst_node_ids, val_data.dst_node_ids]),
node_interact_times=np.concatenate([train_data.node_interact_times, val_data.node_interact_times]),
edge_ids=np.concatenate([train_data.edge_ids, val_data.edge_ids]),
labels=np.concatenate([train_data.labels, val_data.labels]))
test_metric_all_runs = []
for run in range(args.num_runs):
set_random_seed(seed=run)
args.seed = run
args.save_result_name = f'{args.negative_sample_strategy}_negative_sampling_{args.model_name}_seed{args.seed}'
# set up logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
os.makedirs(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/", exist_ok=True)
# create file handler that logs debug and higher level messages
fh = logging.FileHandler(f"./logs/{args.model_name}/{args.dataset_name}/{args.save_result_name}/{str(time.time())}.log")
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.WARNING)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(fh)
logger.addHandler(ch)
run_start_time = time.time()
logger.info(f"********** Run {run + 1} starts. **********")
logger.info(f'configuration is {args}')
loss_func = nn.BCELoss()
# evaluate EdgeBank
logger.info(f'get final performance on dataset {args.dataset_name}...')
# Ensures the random sampler uses a fixed seed for evaluation (i.e. we always sample the same negatives for validation / test set)
assert test_neg_edge_sampler.seed is not None
test_neg_edge_sampler.reset_random_state()
test_losses, test_metrics = [], []
test_idx_data_loader_tqdm = tqdm(test_idx_data_loader, ncols=120)
for batch_idx, test_data_indices in enumerate(test_idx_data_loader_tqdm):
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times = \
test_data.src_node_ids[test_data_indices], test_data.dst_node_ids[test_data_indices], \
test_data.node_interact_times[test_data_indices]
if test_neg_edge_sampler.negative_sample_strategy != 'random':
batch_neg_src_node_ids, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids),
batch_src_node_ids=batch_src_node_ids,
batch_dst_node_ids=batch_dst_node_ids,
current_batch_start_time=batch_node_interact_times[0],
current_batch_end_time=batch_node_interact_times[-1])
else:
_, batch_neg_dst_node_ids = test_neg_edge_sampler.sample(size=len(batch_src_node_ids))
batch_neg_src_node_ids = batch_src_node_ids
positive_edges = (batch_src_node_ids, batch_dst_node_ids)
negative_edges = (batch_neg_src_node_ids, batch_neg_dst_node_ids)
# incorporate the testing data before the current batch to history_data, which is similar to memory-based models
history_data = Data(src_node_ids=np.concatenate([train_val_data.src_node_ids, test_data.src_node_ids[: test_data_indices[0]]]),
dst_node_ids=np.concatenate([train_val_data.dst_node_ids, test_data.dst_node_ids[: test_data_indices[0]]]),
node_interact_times=np.concatenate([train_val_data.node_interact_times, test_data.node_interact_times[: test_data_indices[0]]]),
edge_ids=np.concatenate([train_val_data.edge_ids, test_data.edge_ids[: test_data_indices[0]]]),
labels=np.concatenate([train_val_data.labels, test_data.labels[: test_data_indices[0]]]))
# perform link prediction for EdgeBank
positive_probabilities, negative_probabilities = edge_bank_link_prediction(history_data=history_data,
positive_edges=positive_edges,
negative_edges=negative_edges,
edge_bank_memory_mode=args.edge_bank_memory_mode,
time_window_mode=args.time_window_mode,
time_window_proportion=args.test_ratio)
predicts = torch.from_numpy(np.concatenate([positive_probabilities, negative_probabilities])).float()
labels = torch.cat([torch.ones(len(positive_probabilities)), torch.zeros(len(negative_probabilities))], dim=0)
loss = loss_func(input=predicts, target=labels)
test_losses.append(loss.item())
test_metrics.append(get_link_prediction_metrics(predicts=predicts, labels=labels))
test_idx_data_loader_tqdm.set_description(f'test for the {batch_idx + 1}-th batch, test loss: {loss.item()}')
# store the evaluation metrics at the current run
test_metric_dict = {}
logger.info(f'test loss: {np.mean(test_losses):.4f}')
for metric_name in test_metrics[0].keys():
average_test_metric = np.mean([test_metric[metric_name] for test_metric in test_metrics])
logger.info(f'test {metric_name}, {average_test_metric:.4f}')
test_metric_dict[metric_name] = average_test_metric
single_run_time = time.time() - run_start_time
logger.info(f'Run {run + 1} cost {single_run_time:.2f} seconds.')
test_metric_all_runs.append(test_metric_dict)
# avoid the overlap of logs
if run < args.num_runs - 1:
logger.removeHandler(fh)
logger.removeHandler(ch)
# save model result
result_json = {
"test metrics": {metric_name: f'{test_metric_dict[metric_name]:.4f}'for metric_name in test_metric_dict}
}
result_json = json.dumps(result_json, indent=4)
save_result_folder = f"./saved_results/{args.model_name}/{args.dataset_name}"
os.makedirs(save_result_folder, exist_ok=True)
save_result_path = os.path.join(save_result_folder, f"{args.save_result_name}.json")
with open(save_result_path, 'w') as file:
file.write(result_json)
logger.info(f'save negative sampling results at {save_result_path}')
# store the average metrics at the log of the last run
logger.info(f'metrics over {args.num_runs} runs:')
for metric_name in test_metric_all_runs[0].keys():
logger.info(f'test {metric_name}, {[test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]}')
logger.info(f'average test {metric_name}, {np.mean([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs]):.4f} '
f'± {np.std([test_metric_single_run[metric_name] for test_metric_single_run in test_metric_all_runs], ddof=1):.4f}')