-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
37 lines (30 loc) · 1.03 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
class Linear(Module):
"""
Simple Linear layer with dropout.
"""
def __init__(self, args, in_features, out_features, dropout, act, use_bias):
super(Linear, self).__init__()
self.dropout = dropout
self.linear = nn.Linear(in_features, out_features, use_bias)
self.act = act
args.eucl_vars.append(self.linear)
def forward(self, x):
hidden = self.linear.forward(x)
hidden = F.dropout(hidden, self.dropout, training=self.training)
out = self.act(hidden)
return out
class FermiDiracDecoder(Module):
"""Fermi Dirac to compute edge probabilities based on distances."""
def __init__(self, r, t):
super(FermiDiracDecoder, self).__init__()
self.r = r
self.t = t
def forward(self, dist):
probs = 1. / (torch.exp((dist - self.r) / self.t) + 1.0)
return probs