-
Notifications
You must be signed in to change notification settings - Fork 147
/
DHNE.py
85 lines (71 loc) · 2.74 KB
/
DHNE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch.nn as nn
from openhgnn.models import BaseModel, register_model
import torch
@register_model("DHNE")
class DHNE(BaseModel):
r"""
**Title:** Structural Deep Embedding for Hyper-Networks
**Authors:** Ke Tu, Peng Cui, Xiao Wang, Fei Wang, Wenwu Zhu
DHNE was introduced in `[paper] <https://arxiv.org/abs/1711.10146>`_
and parameters are defined as follows:
Parameters
----------
nums_type : list
the type of nodes
dim_features : array
the embedding dimension of nodes
embedding_sizes : int
the embedding dimension size
hidden_size : int
The hidden full connected layer size
device : int
the device DHNE working on
"""
@classmethod
def build_model_from_args(cls, args):
return cls(dim_features=args.dim_features,
embedding_sizes=args.embedding_sizes,
hidden_size=args.hidden_size,
nums_type=args.nums_type,
device = args.device
)
def __init__(self, nums_type, dim_features, embedding_sizes, hidden_size, device):
super().__init__()
self.dim_features = dim_features
self.embedding_sizes = embedding_sizes
self.hidden_size = hidden_size
self.nums_type = nums_type
self.device = device
# auto-encoder
self.encodeds = [
nn.Linear(sum([self.nums_type[j] for j in range(3) if j != i]), self.embedding_sizes[i]) for i in range(3)]
self.decodeds = [
nn.Linear(self.embedding_sizes[i], sum([self.nums_type[j] for j in range(3) if j != i])) for i in range(3)]
self.hidden_layer = nn.Linear(
sum(self.embedding_sizes), self.hidden_size)
self.ouput_layer = nn.Linear(self.hidden_size, 1)
def forward(self, input_ids):
"""
The forward part of the DHNE.
Parameters
----------
input_ids :
the input block of this batch
Returns
-------
tensor
The logits after DHNE training.
"""
encodeds = []
decodeds = []
for i in range(3):
encoded = torch.tanh(self.encodeds[i].to(self.device)(input_ids[i].to(self.device)))
encodeds.append(encoded)
decodeds.append(torch.sigmoid(self.decodeds[i].to(self.device)(encoded)))
merged = torch.concat(encodeds, axis=1)
hidden = self.hidden_layer(merged)
hidden = torch.tanh(hidden)
output = self.ouput_layer(hidden)
return decodeds+[output]
def embedding_lookup(self, index, sparse_input=False):
return [self.embeddings[i][index[:, i], :] for i in range(3)]