-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdbnn-linear.py
74 lines (61 loc) · 2.52 KB
/
dbnn-linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
import z_values
class Linear(nn.Module):
"""
This class takes the number of input features and number of output features as input and propagates the weights to the next layer.
These weights are stored in the form of hyperparameters of a distribution (mu and sigma).
The complete weights are not explicitly generated by calculated on-the-fly in the forward function.
"""
__constants__ = ['in_features', 'out_features']
in_features: int
out_features: int
weight: Tensor
zval: Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty((in_features,2),**factory_kwargs))
if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self) -> None:
self.weight[:,0].data.uniform_(-1,1)
self.weight[:,1].data.uniform_(0,1)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, data_inputs: Tensor) -> Tensor:
torch.absolute(self.weight[:,1])
r1 = data_inputs.T
mu1 = self.weight[:,0]
mu1 = mu1.unsqueeze(0)
sigma1 = self.weight[:,1]
sigma1 = sigma1.unsqueeze(0)
l1 = self.in_features
l2 = self.out_features
if torch.cuda.is_available():
zval=[]
zval=torch.tensor(torch.zeros(l2), device=torch.device("cuda"))
else:
zval=torch.tensor(torch.zeros(l2))
for i in range (1,l2+1):
b = round(i/l2,2)
b1 = math.trunc(b*10)/10
r = int(b1 * 10)
c = int((b - b1) * 100)
zval[i-1] = z_snd[r][c]
zval1 = zval[np.newaxis, :]
r1 = torch.matmul((torch.matmul(sigma1.T,zval1) + mu1.T).T,r1).T
return r1
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)