-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathsaga.py
127 lines (97 loc) · 3.7 KB
/
saga.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
###################################
### SAGA ###
###################################
# Reference:
# A. Defazio, F. Bach, and S. Lacoste-Julien
# "Saga: A fast incremental gradient method with support for non-strongly convex composite objectives."
# Advances in Neural Information Processing Systems. 2014.
# Authors: Aurelien Lucchi and Jonas Kohler, 2017
from datetime import datetime
import numpy as np
def SAGA(w, loss, gradient, X=None, Y=None, opt=None, **kwargs):
print ('--- SAGA ---')
n = X.shape[0]
d = X.shape[1]
n_epochs = opt.get('n_epochs_saga', 100)
eta = opt.get('learning_rate_saga',1e-1)
batch_size = 1
loss_collector=[]
timings_collector=[]
samples_collector=[]
_loss = loss(w, X, Y, **kwargs)
loss_collector.append(_loss)
timings_collector.append(0)
samples_collector.append(0)
start = datetime.now()
timing=0
k=0
# Store past gradients in a table
mem_gradients = {}
nGradients = 0 # no gradient stored in mem_gradients at initialization
avg_mg = np.zeros(d)
# Fill in table
a = 1.0/n
bool_idx = np.zeros(n,dtype=bool)
for i in range(n):
bool_idx[i]=True
_X=np.zeros((batch_size,d))
_X=np.compress(bool_idx,X,axis=0)
_Y=np.compress(bool_idx,Y,axis=0)
grad = gradient(w, _X, _Y,**kwargs)
bool_idx[i]=False
mem_gradients[i] = grad
#avg_mg = avg_mg + (grad*a)
avg_mg = avg_mg + grad
avg_mg = avg_mg/n
nGradients = n
n_samples_per_step = 1
n_steps = int((n_epochs*n)/n_samples_per_step)
n_samples_seen = 0 # number of samples processed so far
k = 0
for i in range(n_steps):
# I: subsampling
#int_idx=np.random.permutation(n)[0:batch_size]
int_idx=np.random.randint(0, high=n, size=1)
bool_idx = np.zeros(n,dtype=bool)
bool_idx[int_idx]=True
idx = int_idx[0]
_X=np.zeros((batch_size,d))
_X=np.compress(bool_idx,X,axis=0)
_Y=np.compress(bool_idx,Y,axis=0)
# II: compute step
grad = gradient(w, _X, _Y,**kwargs)
n_samples_seen += batch_size
if (n_samples_seen >= n*k) == True:
_timing=timing
timing=(datetime.now() - start).total_seconds()
_loss = loss(w, X, Y, **kwargs)
print ('Epoch ' + str(k) + ': loss = ' + str(_loss) + ' norm_grad = ' + str(np.linalg.norm(grad)), 'time=',round(timing-_timing,3))
timings_collector.append(timing)
samples_collector.append((i+1)*batch_size)
loss_collector.append(_loss)
k+=1
# Parameter update
if idx in mem_gradients:
w = w - eta*(grad - mem_gradients[idx] + avg_mg) # SAGA step
else:
w = w - eta*grad # SGD step
# Update average gradient
if idx in mem_gradients:
delta_grad = grad - mem_gradients[idx]
a = 1.0/nGradients
avg_mg = avg_mg + (delta_grad*a)
else:
# Gradient for datapoint idx does not exist yet
nGradients = nGradients + 1 # increment number of gradients
a = 1.0/nGradients
b = 1.0 - a
avg_mg = (avg_mg*b) + (grad*a)
# Sanity check
#a = 1.0/n
#avg_mg_2 = np.zeros(d)
#for i in range(n):
# avg_mg_2 = avg_mg_2 + (mem_gradients[i]*a)
#print('diff = ', np.linalg.norm(avg_mg_2-avg_mg), np.linalg.norm(avg_mg))
# Update memorized gradients
mem_gradients[idx] = grad
return w, timings_collector, loss_collector, samples_collector