-
Notifications
You must be signed in to change notification settings - Fork 0
/
learning.v
376 lines (340 loc) · 15.6 KB
/
learning.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
Set Implicit Arguments.
Unset Strict Implicit.
Require Import mathcomp.ssreflect.ssreflect.
From mathcomp Require Import all_ssreflect.
From mathcomp Require Import all_algebra.
Import GRing.Theory Num.Def Num.Theory.
Require Import QArith Reals Rpower Ranalysis Fourier Lra.
Require Import bigops numerics expfacts dist chernoff.
Section learning.
Variables A B : finType.
Variable d : A*B -> R.
Variable d_dist : big_sum (enum [finType of A*B]) d = 1.
Variable d_nonneg : forall x, 0 <= d x.
Variable m : nat. (*The number of training samples*)
Variable m_gt0 : (0 < m)%nat.
Notation mR := (INR m).
Definition i0 : 'I_m := Ordinal m_gt0.
(** Training sets *)
Definition training_set : finType := [finType of {ffun 'I_m -> [finType of A*B]}].
Section error_RV.
Variable Hyp : finType.
(** The (hypothesis-indexed) set of random variables being evaluated *)
Variable X : Hyp -> 'I_m -> A*B -> R.
Variable X_range : forall h i x, 0 <= X h i x <= 1.
(** The empirical average of h on T *)
Definition empVal (T : training_set) (h : Hyp) :=
(big_sum (enum 'I_m) (fun i => X h i (T i))) / mR.
(** The expected value in D of X h *)
Definition expVal (h : Hyp) := expValR d (X h i0).
Variable expVal_nontrivial : forall h : Hyp, 0 < expVal h < 1.
Lemma chernoff_bound_h
(h : Hyp)
(Hid : identically_distributed d (X h))
(eps : R) (eps_gt0 : 0 < eps) (Hyp_eps : eps < 1 - expVal h) :
probOfR
(prodR (fun _ : 'I_m => d))
[pred T : training_set | Rle_lt_dec (expVal h + eps) (empVal T h)] <=
exp (-2%R * eps^2 * mR).
Proof.
have ->: expVal h = p_exp d m_gt0 (X h) by [].
set (P := probOfR _ _).
have ->: P =
probOfR
(prodR (T:=prod_finType A B) (fun _ : 'I_m => d))
(fun T => Rle_lt_dec (p_exp (T:=prod_finType A B) d m_gt0 (X h) + eps) (p_hat (X h) T)).
{ rewrite /probOfR; apply: big_sum_ext => //=; apply eq_in_filter => /= T Hin.
have ->: empVal T h = p_hat (X h) T by rewrite /p_hat/empVal Rmult_comm.
by []. }
apply: chernoff => //; by case: (expVal_nontrivial h).
Qed.
Definition eps_Hyp (eps : R) : finType :=
[finType of {h : Hyp | Rlt_le_dec eps (1 - expVal h)}].
Variable identical : forall h : Hyp, identically_distributed d (X h).
Lemma chernoff_bound1 (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| [exists i : 'I_#|eps_Hyp eps|,
let: h := projT1 (enum_val i)
in Rle_lt_dec (expVal h + eps) (empVal T h)]]
<= INR #|eps_Hyp eps| * exp (-2%R * eps^2 * mR).
Proof.
set (P := fun i:'I_#|eps_Hyp eps| =>
[pred T : training_set |
let: h := projT1 (enum_val i)
in Rle_lt_dec (expVal h + eps) (empVal T h)]).
change (probOfR (prodR (fun _ => d))
[pred T:training_set | [exists i : 'I_#|eps_Hyp eps|, P i T]]
<= INR #|eps_Hyp eps| * exp (-2%R * eps^2 * mR)).
apply: Rle_trans; [apply: union_bound|].
{ by apply: prodR_nonneg. }
have Hle:
\big[Rplus/0]_(i in 'I_#|eps_Hyp eps|)
probOfR
(prodR (T:=prod_finType A B) (fun _ : 'I_m => d)) [eta P i]
<= \big[Rplus/0]_(i in 'I_#|eps_Hyp eps|) (exp (-2%R * eps^2 * mR)).
{ rewrite -2!big_sum_sumP; apply big_sum_le => c Hin.
apply chernoff_bound_h => //; case: (enum_val c) => //= => x p.
case: (Rlt_le_dec eps (1 - expVal x)) p => //. }
apply: Rle_trans; first by apply: Hle.
rewrite big_const card_ord; elim: #|eps_Hyp eps|.
{ rewrite !Rmult_0_l; apply: Rle_refl. }
move => n H; rewrite iterS.
have ->:
INR n.+1 * exp (- (2) * eps ^ 2 * mR)
= (exp (- (2) * eps ^ 2 * mR)) + INR n * exp (- (2) * eps ^ 2 * mR).
{ by rewrite S_INR Rmult_assoc Rmult_plus_distr_r Rmult_1_l Rplus_comm. }
apply: Rplus_le_compat_l => //.
Qed.
Lemma chernoff_bound2 (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| [exists h : eps_Hyp eps,
Rle_lt_dec (expVal (sval h) + eps) (empVal T (sval h))]]
<= INR #|eps_Hyp eps| * exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; last by apply: chernoff_bound1.
apply: probOfR_le; first by apply: prodR_nonneg.
move => j /=; case/existsP => h H.
by apply/existsP; exists (enum_rank h); rewrite enum_rankK.
Qed.
Lemma empVal_le1 T h : empVal T h <= 1.
Proof.
rewrite /empVal; set (f := big_sum _ _).
have H: f <= mR.
{ have H1: f <= big_sum (enum 'I_m) (fun _ => 1).
{ rewrite /f; apply: big_sum_le => /= i _.
by case: (X_range h i (T i)). }
apply: Rle_trans; first by apply: H1.
rewrite big_sum_constant Rmult_1_r; apply: le_INR.
rewrite size_enum_ord; apply/leP => //. }
have H1: f / mR <= mR / mR.
{ rewrite /Rdiv; apply: Rmult_le_compat_r => //.
rewrite -[/mR]Rmult_1_l; apply: Rle_mult_inv_pos.
lra.
have ->: 0 = INR 0 by [].
apply: lt_INR; apply/ltP => //. }
apply: Rle_trans; first by apply: H1.
rewrite /Rdiv Rinv_r; first by apply: Rle_refl.
apply: not_0_INR => Heq; move: m_gt0; rewrite Heq //.
Qed.
Lemma chernoff_bound3
(learn : training_set -> Hyp) (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set |
let: h := learn T in
Rlt_le_dec (expVal h + eps) (empVal T h)]
<= INR #|eps_Hyp eps| * exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; last by apply: chernoff_bound2.
apply: probOfR_le; first by apply: prodR_nonneg.
move => T /= H; apply/existsP.
rewrite /eps_Hyp.
have X1: expVal (learn T) + eps < empVal T (learn T).
{ move: H.
case: (Rlt_le_dec (expVal (learn T) + eps) (empVal T (learn T))) => //. }
move: (empVal_le1 T (learn T)) => X2.
have X3: eps < 1 - expVal (learn T) by lra.
have X4: Rlt_le_dec eps (1 - expVal (learn T)).
{ case: (Rlt_le_dec eps (1 - expVal (learn T))) => //.
move => b; lra. }
exists (exist _ (learn T) X4) => /=.
case: (Rle_lt_dec (expVal (learn T) + eps) (empVal T (learn T))) => //.
move => b; lra.
Qed.
Lemma eps_Hyp_card eps : (#|eps_Hyp eps| <= #|Hyp|)%nat.
Proof.
rewrite /eps_Hyp /= card_sig; apply: leq_trans; first by apply: subset_leq_card.
by rewrite cardsT.
Qed.
Lemma chernoff_bound
(learn : training_set -> Hyp) (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set |
let: h := learn T in
Rlt_le_dec (expVal h + eps) (empVal T h)]
<= INR #|Hyp| * exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; first by apply: chernoff_bound3.
apply Rmult_le_compat_r; first by apply: Rlt_le; apply: exp_pos.
apply: le_INR; apply/leP; apply: eps_Hyp_card.
Qed.
Lemma chernoff_bound_holdout
(h : Hyp) (eps : R) (eps_gt0 : 0 < eps) (eps_lt : eps < 1 - expVal h) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set |
Rlt_le_dec (expVal h + eps) (empVal T h)]
<= exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; last by apply: (@chernoff_bound_h h _ eps _ _).
apply: probOfR_le.
{ move => x; apply: prodR_nonneg => _ y; apply: d_nonneg. }
move => x /=; case: (Rlt_le_dec _ _) => // H1 _.
case: (Rle_lt_dec _ _) => // H2; lra.
Qed.
Definition eps_Hyp_condition_twosided (eps : R) :=
[pred h : Hyp | Rlt_le_dec eps (Rmin (expVal h) (1 - expVal h))].
Lemma chernoff_twosided_bound_h
(h : Hyp)
(Hid : identically_distributed d (X h))
(eps : R) (eps_gt0 : 0 < eps) (Hyp_eps : eps_Hyp_condition_twosided eps h) :
probOfR
(prodR (fun _ : 'I_m => d))
[pred T : training_set | Rle_lt_dec eps (Rabs (expVal h - empVal T h))] <=
2 * exp (-2%R * eps^2 * mR).
Proof.
have eps_range : eps < Rmin (expVal h) (1 - expVal h).
{ rewrite /eps_Hyp_condition_twosided /= in Hyp_eps.
move: Hyp_eps; case: (Rlt_le_dec _ _) => //. }
have eps_range1 : eps < 1 - expVal h.
{ apply: Rlt_le_trans; [by apply: eps_range|by apply: Rmin_r]. }
have eps_range2 : eps < expVal h.
{ apply: Rlt_le_trans; [by apply: eps_range|by apply: Rmin_l]. }
have H1: expVal h = p_exp d m_gt0 (X h) by [].
have H2:
probOfR
(prodR (T:=prod_finType A B) (fun _ : 'I_m => d))
(fun T : training_set =>
Rle_lt_dec eps (Rabs (p_exp (T:=prod_finType A B) d m_gt0 (X h) - empVal T h))) =
probOfR
(prodR (T:=prod_finType A B) (fun _ : 'I_m => d))
(fun T => Rle_lt_dec eps (Rabs (p_exp (T:=prod_finType A B) d m_gt0 (X h) - p_hat (X h) T))).
{ rewrite /probOfR; apply: big_sum_ext => //=; apply eq_in_filter => T Hin.
have ->: empVal T h = p_hat (X h) T.
{ rewrite /p_hat /empVal; rewrite Rmult_comm //. }
by []. }
rewrite H1 H2.
apply: chernoff_twosided => //; try lra.
{ move: H1; rewrite /p_exp => <- //. }
move: H1; rewrite /p_exp => <- //.
Qed.
Definition eps_Hyp_twosided (eps : R) : finType :=
[finType of {h : Hyp | eps_Hyp_condition_twosided eps h}].
Lemma eps_Hyp_twosided_inhabited :
forall h : Hyp,
exists eps, 0 < eps /\ eps_Hyp_condition_twosided eps h.
Proof.
move => h; rewrite /eps_Hyp_condition_twosided.
exists ((Rmin (expVal h) (1 - expVal h))/2); split.
{ apply: Rlt_mult_inv_pos; [|lra].
case: (expVal_nontrivial h) => H1 H2.
apply: Rmin_glb_lt => //; lra. }
rewrite /= /is_true; case Hlt: (Rlt_le_dec _ _) => // [H].
move {Hlt}; move: H; set (Z := Rmin _ _) => H; elimtype False.
have H1: Z / 2 < Z.
{ rewrite /Rdiv -{2}[Z]Rmult_1_r; apply: Rmult_lt_compat_l.
{ rewrite /Z; apply: Rmin_pos; first by case: (expVal_nontrivial h).
case: (expVal_nontrivial h) => H1 H2; lra. }
lra. }
apply: (RIneq.Rle_not_lt _ _ H H1).
Qed.
Lemma chernoff_twosided_bound_eps_Hyp (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| [exists i : 'I_#|eps_Hyp_twosided eps|,
let: h := projT1 (enum_val i)
in Rle_lt_dec eps (Rabs (expVal h - empVal T h))]]
<= 2 * INR #|eps_Hyp_twosided eps| * exp (-2%R * eps^2 * mR).
Proof.
set (P := fun i:'I_#|eps_Hyp_twosided eps| =>
[pred T : training_set |
let: h := projT1 (enum_val i)
in Rle_lt_dec eps (Rabs (expVal h - empVal T h))]).
change (probOfR (prodR (fun _ => d))
[pred T:training_set | [exists i : 'I_#|eps_Hyp_twosided eps|, P i T]]
<= 2 * INR #|eps_Hyp_twosided eps| * exp (-2%R * eps^2 * mR)).
apply: Rle_trans; [apply: union_bound|].
{ by apply: prodR_nonneg. }
rewrite [2 * _]Rmult_comm.
have Hle:
\big[Rplus/0]_(i in 'I_#|eps_Hyp_twosided eps|)
probOfR
(prodR (T:=prod_finType A B) (fun _ : 'I_m => d)) [eta P i]
<= \big[Rplus/0]_(i in 'I_#|eps_Hyp_twosided eps|) (2 * exp (-2%R * eps^2 * mR)).
{ rewrite -2!big_sum_sumP; apply big_sum_le => c Hin.
apply chernoff_twosided_bound_h => //.
case: (enum_val c) => //. }
apply: Rle_trans; first by apply: Hle.
rewrite big_const card_ord; elim: #|eps_Hyp_twosided eps|.
{ rewrite !Rmult_0_l; apply: Rle_refl. }
move => n H; rewrite iterS.
have ->:
INR n.+1 * 2 * exp (- (2) * eps ^ 2 * mR)
= (2 * exp (- (2) * eps ^ 2 * mR)) + INR n * 2 * exp (- (2) * eps ^ 2 * mR).
{ rewrite S_INR Rmult_assoc Rmult_plus_distr_r Rmult_1_l Rplus_comm; f_equal.
by rewrite -Rmult_assoc. }
apply: Rplus_le_compat_l => //.
Qed.
Lemma eps_Hyp_twosided_card eps : (#|eps_Hyp_twosided eps| <= #|Hyp|)%nat.
Proof.
rewrite /eps_Hyp_twosided /= card_sig; apply: leq_trans; first by apply: subset_leq_card.
by rewrite cardsT.
Qed.
Lemma chernoff_twosided_bound1 (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| [exists i : 'I_#|eps_Hyp_twosided eps|,
let: h := projT1 (enum_val i)
in Rle_lt_dec eps (Rabs (expVal h - empVal T h))]]
<= 2 * INR #|Hyp| * exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; first by apply: chernoff_twosided_bound_eps_Hyp.
apply: Rmult_le_compat_r; first by apply: Rlt_le; apply: exp_pos.
apply: Rmult_le_compat_l; first by lra.
apply: le_INR; apply/leP; apply: eps_Hyp_twosided_card.
Qed.
Lemma chernoff_twosided_bound2 (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| [exists h : eps_Hyp_twosided eps,
Rle_lt_dec eps (Rabs (expVal (projT1 h) - empVal T (projT1 h)))]]
<= 2 * INR #|Hyp| * exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; last by apply: chernoff_twosided_bound1.
apply: probOfR_le; first by apply: prodR_nonneg.
move => j /=; case/existsP => h H.
by apply/existsP; exists (enum_rank h); rewrite enum_rankK.
Qed.
End error_RV.
Section zero_one_accuracy.
Variable Params : finType. (*the type of parameters*)
Variable predict : Params -> A -> B. (*the prediction function*)
Definition accuracy01 (p : Params) (i : 'I_m) (xy : A*B) : R :=
let: (x,y) := xy in if predict p x == y then 1%R else 0%R.
Definition loss01 (p : Params) (i : 'I_m) (xy : A*B) : R :=
1 - accuracy01 p i xy.
(*For any function from training_set to Params, assuming joint independence
and that the target class isn't perfectly representable:*)
Variable learn : training_set -> Params.
Variable not_perfectly_learnable : forall p : Params, 0 < expVal accuracy01 p < 1.
(*we get the the following result for any eps: the probability that
the expected accuracy of h is more than eps lower than the empirical
accuracy of h on T is less than |Params| * exp(-2eps*m),
where m is the number of training examples in T.*)
Lemma chernoff_bound_accuracy01 (eps : R) (eps_gt0 : 0 < eps) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| let: h := learn T in
Rlt_le_dec (expVal accuracy01 h + eps) (empVal accuracy01 T h)]
<= INR #|Params| * exp (-2%R * eps^2 * mR).
Proof.
apply chernoff_bound => // p i x; rewrite /accuracy01; case: x => a b.
case: (predict p a == b)%B; split; lra.
Qed.
(*Here's the holdout version of the above lemma (the additional condition
on epsilon appears to fall out -- cf. Mulzer's tutorial on Chernoff bound proofs).*)
Lemma chernoff_bound_accuracy01_holdout
(h : Params) (eps : R) (eps_gt0 : 0 < eps) (eps_lt : eps < 1 - expVal accuracy01 h) :
probOfR (prodR (fun _ : 'I_m => d))
[pred T:training_set
| Rlt_le_dec (expVal accuracy01 h + eps) (empVal accuracy01 T h)]
<= exp (-2%R * eps^2 * mR).
Proof.
apply: Rle_trans; last first.
{ apply: chernoff_bound_holdout => //; last by apply: eps_lt.
move => hx i x; rewrite /accuracy01; case: x => a b.
case: (predict _ _ == _)%B; split; lra. }
apply: Rle_refl.
Qed.
End zero_one_accuracy.
End learning.