-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregularization_functions.py
executable file
·60 lines (41 loc) · 1.29 KB
/
regularization_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*-coding: utf-8 -*-
"""
Created on April 8, 2016
"""
from six import add_metaclass
from zope.interface import implementer, Interface
from mapped_object_registry import MappedObjectsRegistry
class IRegularization(Interface):
def get_weights_reg_grad():
"""
Calculates dR/dw, where R(w, b) - regularization
"""
def get_bias_reg_grad():
"""
Calculates dR/db, where R(w, b) - regularization
"""
class RegularizationRegistry(MappedObjectsRegistry):
mapping = "regularization"
@add_metaclass(RegularizationRegistry)
class BaseRegularization(object):
def __init__(self, lambda_coeff, **kwargs):
self.lambda_coeff = lambda_coeff
@implementer(IRegularization)
class L1Regularization(BaseRegularization):
MAPPING = "l1"
def get_weights_reg_grad(self):
return self.lambda_coeff
@staticmethod
def get_bias_reg_grad():
return 0
@implementer(IRegularization)
class L2Regularization(BaseRegularization):
MAPPING = "l2"
def __init__(self, **kwargs):
super(L2Regularization, self).__init__(**kwargs)
self.weights = kwargs["weights"]
def get_weights_reg_grad(self):
return 2 * self.lambda_coeff * self.weights
@staticmethod
def get_bias_reg_grad():
return 0