-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTrainTest.m
30 lines (24 loc) · 968 Bytes
/
TrainTest.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function [ trnData,trnLab,tstData,tstLab ] = TrainTest(data, lab, trnPer, clsCnt)
clsNum = zeros(1, clsCnt);
trnNum = zeros(1, clsCnt);
tstNum = zeros(1, clsCnt);
for i = 1 : clsCnt
index = find(lab == i);
clsNum(i) = size(index, 1);
trnNum(i) = ceil(clsNum(i) * trnPer);
tstNum(i) = clsNum(i) - trnNum(i);
end
trnData = [];
trnLab = [];
tstData = [];
tstLab = [];
for i = 1 : clsCnt
index = find(lab == i);
random_index = index(randperm(length(index)));
index = random_index(1: trnNum(i));
trnData = [trnData data(index, :)'];
trnLab = [trnLab ones(1, length(index)) * i];
index = random_index(trnNum(i) + 1:end);
tstData = [tstData data(index, :)'];
tstLab = [tstLab ones(1, length(index)) * i];
end