forked from Wang-Shuo/GraphRec_PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
92 lines (75 loc) · 3.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import random
truncate_len = 30
"""
Ciao dataset info:
Avg number of items rated per user: 38.3
Avg number of users interacted per user: 2.7
Avg number of users connected per item: 16.4
"""
def collate_fn(batch_data):
"""This function will be used to pad the graph to max length in the batch
It will be used in the Dataloader
"""
uids, iids, labels = [], [], []
u_items, u_users, u_users_items, i_users = [], [], [], []
u_items_len, u_users_len, i_users_len = [], [], []
for data, u_items_u, u_users_u, u_users_items_u, i_users_i in batch_data:
(uid, iid, label) = data
uids.append(uid)
iids.append(iid)
labels.append(label)
# user-items
if len(u_items_u) <= truncate_len:
u_items.append(u_items_u)
else:
u_items.append(random.sample(u_items_u, truncate_len))
u_items_len.append(min(len(u_items_u), truncate_len))
# user-users and user-users-items
if len(u_users_u) <= truncate_len:
u_users.append(u_users_u)
u_u_items = []
for uui in u_users_items_u:
if len(uui) < truncate_len:
u_u_items.append(uui)
else:
u_u_items.append(random.sample(uui, truncate_len))
u_users_items.append(u_u_items)
else:
sample_index = random.sample(list(range(len(u_users_u))), truncate_len)
u_users.append([u_users_u[si] for si in sample_index])
u_users_items_u_tr = [u_users_items_u[si] for si in sample_index]
u_u_items = []
for uui in u_users_items_u_tr:
if len(uui) < truncate_len:
u_u_items.append(uui)
else:
u_u_items.append(random.sample(uui, truncate_len))
u_users_items.append(u_u_items)
u_users_len.append(min(len(u_users_u), truncate_len))
# item-users
if len(i_users_i) <= truncate_len:
i_users.append(i_users_i)
else:
i_users.append(random.sample(i_users_i, truncate_len))
i_users_len.append(min(len(i_users_i), truncate_len))
batch_size = len(batch_data)
# padding
u_items_maxlen = max(u_items_len)
u_users_maxlen = max(u_users_len)
i_users_maxlen = max(i_users_len)
u_item_pad = torch.zeros([batch_size, u_items_maxlen, 2], dtype=torch.long)
for i, ui in enumerate(u_items):
u_item_pad[i, :len(ui), :] = torch.LongTensor(ui)
u_user_pad = torch.zeros([batch_size, u_users_maxlen], dtype=torch.long)
for i, uu in enumerate(u_users):
u_user_pad[i, :len(uu)] = torch.LongTensor(uu)
u_user_item_pad = torch.zeros([batch_size, u_users_maxlen, u_items_maxlen, 2], dtype=torch.long)
for i, uu_items in enumerate(u_users_items):
for j, ui in enumerate(uu_items):
u_user_item_pad[i, j, :len(ui), :] = torch.LongTensor(ui)
i_user_pad = torch.zeros([batch_size, i_users_maxlen, 2], dtype=torch.long)
for i, iu in enumerate(i_users):
i_user_pad[i, :len(iu), :] = torch.LongTensor(iu)
return torch.LongTensor(uids), torch.LongTensor(iids), torch.FloatTensor(labels), \
u_item_pad, u_user_pad, u_user_item_pad, i_user_pad