-
Notifications
You must be signed in to change notification settings - Fork 6
/
non_linearities.py
52 lines (40 loc) · 1.39 KB
/
non_linearities.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import theano.tensor as T
def relu(x):
return T.switch(x > 0, x, 0)
class NonLinearity:
RELU = "rectifier"
TANH = "tanh"
SIGMOID = "sigmoid"
SOFTMAX = "softmax"
def softmax(x):
return T.exp(x)/(T.exp(x).sum(1, keepdims=True))
def get_non_linearity_fn(nonlinearity):
if nonlinearity == NonLinearity.SIGMOID:
return T.nnet.sigmoid
elif nonlinearity == NonLinearity.RELU:
return relu
elif nonlinearity == NonLinearity.TANH:
return T.tanh
elif nonlinearity == NonLinearity.SOFTMAX:
return softmax # T.nnet.softmax
elif nonlinearity is None:
return None
def get_non_linearity_str(nonlinearity):
if nonlinearity == T.nnet.sigmoid:
return NonLinearity.SIGMOID
elif nonlinearity == relu:
return NonLinearity.RELU
elif nonlinearity == T.tanh:
return NonLinearity.TANH
elif nonlinearity == T.nnet.softmax:
return None # we do not use any non-linearity.
elif nonlinearity == softmax:
return None # we do not use any non-linearity.
elif nonlinearity is None:
return None
else:
raise ValueError("Unknown non-linearity")
class CostType:
MeanSquared = "MeanSquaredCost"
CrossEntropy = "CrossEntropy"
NegativeLogLikelihood = "NegativelogLikelihood"