-
Notifications
You must be signed in to change notification settings - Fork 147
/
base_model.py
63 lines (51 loc) · 1.72 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from abc import ABCMeta
import torch.nn as nn
class BaseModel(nn.Module, metaclass=ABCMeta):
@classmethod
def build_model_from_args(cls, args, hg):
r"""
Build the model instance from args and hg.
So every subclass inheriting it should override the method.
"""
raise NotImplementedError("Models must implement the build_model_from_args method")
def __init__(self):
super(BaseModel, self).__init__()
def forward(self, *args):
r"""
The model plays a role of encoder. So the forward will encoder original features into new features.
Parameters
-----------
hg : dgl.DGlHeteroGraph
the heterogeneous graph
h_dict : dict[str, th.Tensor]
the dict of heterogeneous feature
Return
-------
out_dic : dict[str, th.Tensor]
A dict of encoded feature. In general, it should ouput all nodes embedding.
It is allowed that just output the embedding of target nodes which are participated in loss calculation.
"""
raise NotImplementedError
def extra_loss(self):
r"""
Some model want to use L2Norm which is not applied all parameters.
Returns
-------
th.Tensor
"""
raise NotImplementedError
def h2dict(self, h, hdict):
pre = 0
out_dict = {}
for i, value in hdict.items():
out_dict[i] = h[pre:value.shape[0]+pre]
pre += value.shape[0]
return out_dict
def get_emb(self):
r"""
Return the embedding of a model for further analysis.
Returns
-------
numpy.array
"""
raise NotImplementedError