-
Notifications
You must be signed in to change notification settings - Fork 147
/
Mg2vec.py
120 lines (104 loc) · 6.13 KB
/
Mg2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import numpy as np
from . import BaseModel, register_model
@register_model("Mg2vec")
class Mg2vec(BaseModel):
r"""
This is a model mg2vec from `mg2vec: Learning Relationship-Preserving Heterogeneous Graph Representations via
Metagraph Embedding<https://ieeexplore.ieee.org/document/9089251>`__
It contains following parts:
Achieve the metagraph and metagraph instances by mining the raw graph. Please go to
`DataMaker-For-Mg2vec<https://github.com/null-xyj/DataMaker-For-Mg2vec>`__ for more details.
Initialize the embedding for every node and metagraph and adopt an unsupervised method to train the node embeddings
and metagraph embeddings. In detail, for every node, we keep its embedding close to the metagraph it belongs to and
far away from the metagraph we get by negative sampling.
Every node and meta-graph can be represented as an n-dim vector.We define the first-order loss and second-order
loss.
First-Order Loss is for single core node in every meta-graph. We compute the dot product of the node embedding and
the positive meta-graph embedding as the true logit. Then We compute the dot product of the node embedding and
the sampled negative meta-graph embedding as the neg logit. We use the binary_cross_entropy_with_logits function to
compute the first-order loss.
Second-Order Loss consider two core nodes in every meta-graph. First, we cancat the two node's embedding, what is a
2n-dim vector. Then we use a 2n*n matrix and an n-dim vector to map the 2n-dim vector to an n-dim vector. The map
function is showed below:
.. math::
f(u,v) = RELU([u||v]W + b)
u and v means the origin embedding of the two nodes, || is the concatenation operator. W is the 2n*n matrix and b is
the n-dim vector. RELU is the an activation function. f(u,v) means the n-dim vector after transforming.
Then, the computation of second-order loss is the same as the first-order loss.
Finally, we use a parameter alpha to balance the first-order loss and second-order loss.
.. math::
L=(1-alpha)*L_1 + alpha*L_2
After we train the node embeddings, we use the embeddings to complete the relation prediction task.
The relation prediction task is achieved by edge classification task. If two nodes are connected with a relation, we
see the relation as an edge. Then we can adopt the edge classification to complete relation prediction task.
Parameters
----------
node_num: int
the number of core-nodes
mg_num: int
the number of meta-graphs
emb_dimension: int
the embedding dimension of nodes and meta-graphs
unigram: float
the frequency of every meta-graph, for negative sampling
sample_num: int
the number of sampled negative meta-graph
"""
@classmethod
def build_model_from_args(cls, args, hg):
return cls(args.node_num, args.mg_num, args.emb_dimension, args.unigram, args.sample_num)
def __init__(self, node_num, mg_num, emb_dimension, unigram, sample_num):
super(Mg2vec, self).__init__()
self.node_num = node_num
self.mg_num = mg_num
self.mg_unigrams = torch.tensor(unigram, dtype=torch.float64)
self.sample_num = sample_num
self.emb_dimension = emb_dimension
self.n_embedding = nn.Embedding(node_num, emb_dimension, sparse=False)
self.m_embedding = nn.Embedding(mg_num, emb_dimension, sparse=False)
self.n_w_t = nn.Parameter(torch.empty([emb_dimension * 2, emb_dimension]), requires_grad=True)
self.n_b = nn.Parameter(torch.empty(emb_dimension), requires_grad=True)
init.xavier_normal_(self.n_embedding.weight.data)
init.xavier_normal_(self.m_embedding.weight.data)
init.xavier_normal_(self.n_w_t)
init.constant_(self.n_b, 0)
def forward(self, train_a, train_b, train_labels, train_freq, train_weight, device):
batch_size = len(train_a)
n_embed_a = self.n_embedding(train_a)
n_embed_b = self.n_embedding(train_b)
n_embed_con = torch.cat([n_embed_a, n_embed_b], dim=1)
mask_o1 = torch.eq(train_a, train_b).type(torch.FloatTensor).reshape(batch_size, 1).to(device)
mask_o2 = torch.not_equal(train_a, train_b).type(torch.FloatTensor).reshape(batch_size, 1).to(device)
m_embed_pos = self.m_embedding(train_labels)
neg_sample_id = torch.multinomial(self.mg_unigrams, min(self.mg_num, self.sample_num)).to(device)
neg_m_embed = self.m_embedding(neg_sample_id)
n_embed_o1 = n_embed_a * mask_o1
n_embed_o2 = F.relu(torch.mm(n_embed_con, self.n_w_t) + self.n_b) * mask_o2
n_embed = torch.add(n_embed_o1, n_embed_o2)
true_logit = torch.sum((n_embed * m_embed_pos), dim=1, keepdim=True)
neg_logit = torch.mm(n_embed, neg_m_embed.T)
logit = torch.cat([true_logit, neg_logit], dim=1)
labels = torch.cat([torch.ones_like(true_logit), torch.zeros_like(neg_logit)], dim=1)
xent = torch.sum(F.binary_cross_entropy_with_logits(logit, labels, reduction='none'), dim=1, keepdim=True)
unsupervised_loss = torch.mean(train_weight * (train_freq * xent))
return unsupervised_loss
def normalize_embedding(self):
norm = torch.sqrt_(torch.sum(torch.square(self.n_embedding.weight.data), dim=1, keepdim=True))
self.n_embedding.weight.data = self.n_embedding.weight.data / norm
m_norm = torch.sqrt_(torch.sum(torch.square(self.m_embedding.weight.data), dim=1, keepdim=True))
self.m_embedding.weight.data = self.m_embedding.weight.data / m_norm
def save_embedding(self, id2node, file_name):
self.normalize_embedding()
embedding = self.n_embedding.weight.cpu().data.numpy()
with open(file_name, 'w') as f:
for nId, node in id2node.items():
to_write = str(node) + ' ' + ' '.join(map(lambda x: str(x), embedding[nId])) + '\n'
f.write(to_write)
def save_embedding_np(self, file_name):
self.normalize_embedding()
embedding = self.n_embedding.weight.cpu().data.numpy()
np.save(file_name, embedding)