-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcv_lr_platform.m
73 lines (59 loc) · 2.55 KB
/
cv_lr_platform.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
%Program for CSC522 course project: Precipitation prediction of
%Northwestern United States
%run after day_data_preprocessing
%10-fold cross validation model selection
feature_matrix = correlation_matrix;
edges = [0, 0.1, 1, 2.5, 10, 25, 50, 100, 400];
SSE = zeros(length(lon), length(lat));
accur = zeros(length(lon), length(lat));
[~, ~, d, ~] = size(feature_matrix);
[lo, la, dat] = size(mod_sel_data);
for i=1:10
cv_set = (cv_index == i);
train_set = ~cv_set;
train_data = mod_sel_data(:, :, train_set);
train_reg = mod_sel_reg(:, :, train_set);
cv_data = mod_sel_data(:, :, cv_set);
cv_reg = mod_sel_data(:, :, cv_set);
cv_labels = mod_sel_labels(:, :, cv_set);
for m=1:lo
for n=1:la
if (~no_detect(m, n))
%preparing training data and cv data for each cell
cur_train_reg = reshape(train_reg(m, n, :), [sum(train_set), 1]);
cur_cv_labels = reshape(cv_labels(m, n, :), [sum(cv_set), 1]);
cur_cv_reg = reshape(cv_reg(m, n, :), [sum(cv_set), 1]);
cur_train_data = [];
cur_cv_data = [];
for j=1:(d-1)
cm = feature_matrix(m, n, j, 2);
cn = feature_matrix(m, n, j, 3);
cur_train_data = [cur_train_data; train_data(cm, cn, :)];
cur_cv_data = [cur_cv_data; cv_data(cm, cn, :)];
end
cur_train_data = reshape(cur_train_data, [(d-1), sum(train_set)]);
cur_train_data = cur_train_data';
cur_cv_data = reshape(cur_cv_data, [(d-1), sum(cv_set)]);
cur_cv_data = cur_cv_data';
LRMD = fitlm(cur_train_data, cur_train_reg);
for p=1:length(cur_cv_data)
prd = predict(LRMD, cur_cv_data(p, :));
pl = discretize(prd, edges);
if (prd < 0)
pl = 0;
end
if (isnan(cur_cv_labels(p)))
disp([m, n, p]);
end
SSE(m, n) = SSE(m, n) + (cur_cv_reg(p) - prd)^2;
if (pl == cur_cv_labels(p))
accur(m, n) = accur(m, n) + 1;
end
end
end
end
end
end
MSE = SSE./length(mod_sel_data);
RMSE = sqrt(MSE);
mean_RMSE = meanRMSE(RMSE, no_detect);