-
Notifications
You must be signed in to change notification settings - Fork 1
/
overparametrization.py
67 lines (57 loc) · 2.31 KB
/
overparametrization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
from numpy import sqrt
from numpy.linalg import inv
from numpy import log
from numpy.linalg import norm
import pandas as pd
import matplotlib.pyplot as plt
import os
import tensorflow as tf
import keras
n = 300
d = 100
N = 200
psi_1 = N / d
psi_2 = n / d
# Gaussian dual projection
for i in range(num_steps):
if i % 10 == 0:
print(i)
zeta = zeta_seq[i]
d = int(n * zeta)
for k in range(rep):
X = np.random.randn(n, p)
# diag = sqrt(np.diag(X @ X.T))
# X = X / diag.reshape(n, 1) * sqrt(p)
beta = np.random.randn(p, 1)
beta = beta / sqrt(p)
sigma = 0
epsilon = np.random.randn(n, 1) * sigma
Y = X @ beta + epsilon
beta_ridge = X.T / n @ inv(X @ X.T / n + lbd * np.identity(n)) @ Y
R = np.random.randn(p, d)
# diag = sqrt(np.diag(R @ R.T))
# R = R / diag.reshape(p, 1) * sqrt(p)
# R = generate_haar_matrix(p, d)
beta_dual = R.T @ X.T / n @ inv(X @ R @ R.T @ X.T / n + lbd * zeta / gamma * np.identity(n)) @ Y
X_relu = np.maximum(X @ R, np.zeros((n, d)))
beta_relu = X_relu.T / n @ inv(X_relu @ X_relu.T / n + lbd * zeta / gamma * np.identity(n)) @ Y
# dual_simu_gaus[k, i] = norm(beta_dual - beta) ** 2
X_test = np.random.randn(n, p)
# diag = np.diag(X_test @ X_test.T)
# X_test = X_test / diag.reshape(n, 1) * sqrt(p)
epsilon_test = np.random.randn(n, 1) * sigma
Y_test = X_test @ beta + epsilon_test
dual[k, i] = norm(Y_test - X_test @ R @ beta_dual) ** 2 / n
relu[k, i] = norm(Y_test - np.maximum(X_test @ R, np.zeros((n, d))) @ beta_relu) ** 2 / n
ridge[k, i] = norm(Y_test - X_test @ beta_ridge) ** 2 / n
plt.errorbar(zeta_seq, np.mean(dual, axis=0), np.std(dual, axis=0), capsize=2, lw=2, label='linear')
plt.errorbar(zeta_seq, np.mean(relu, axis=0), np.std(relu, axis=0), capsize=2, lw=2, label='ReLU')
plt.errorbar(zeta_seq, np.mean(ridge, axis=0), np.std(ridge, axis=0), capsize=2, lw=2, label='ridge')
plt.legend()
plt.grid(linestyle='dotted')
plt.title(r"$\gamma={:.2f},\alpha={},\sigma={}$".format(gamma, alpha, sigma))
plt.xlabel(r'$\zeta$')
plt.ylabel("Test error")
# plt.savefig("double_descent_gamma_{:.2f}_alpha_{}_sigma_{}.png".format(gamma, alpha, sigma))
plt.plot(zeta_seq, np.mean(relu, 0))