forked from handspeaker/RandomForests
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.cpp
76 lines (68 loc) · 2.35 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include"RandomForest.h"
#include"MnistPreProcess.h"
#define TRAIN_NUM 60000
#define TEST_NUM 10000
#define FEATURE 784
#define NUMBER_OF_CLASSES 10
int main(int argc, const char * argv[])
{
//1. prepare data
float**trainset;
float** testset;
float*trainlabels;
float*testlabels;
trainset=new float*[TRAIN_NUM];
testset=new float*[TEST_NUM];
trainlabels=new float[TRAIN_NUM];
testlabels=new float[TEST_NUM];
for(int i=0;i<TRAIN_NUM;++i)
{trainset[i]=new float[FEATURE];}
for(int i=0;i<TEST_NUM;++i)
{testset[i]=new float[FEATURE];}
readData(trainset,trainlabels,argv[1],argv[2]);
readData(testset,testlabels,argv[3],argv[4]);
// readData(trainset,trainlabels,
// "/Users/xinling/PycharmProjects/MNIST_data/train-images-idx3-ubyte",
// "/Users/xinling/PycharmProjects/MNIST_data/train-labels-idx1-ubyte");
// readData(testset,testlabels,
// "/Users/xinling/PycharmProjects/MNIST_data/t10k-images-idx3-ubyte",
// "/Users/xinling/PycharmProjects/MNIST_data/t10k-labels-idx1-ubyte");
//2. create RandomForest class and set some parameters
RandomForest randomForest(100,10,10,0);
//3. start to train RandomForest
// randomForest.train(trainset,trainlabels,TRAIN_NUM,FEATURE,10,true,56);//regression
randomForest.train(trainset,trainlabels,TRAIN_NUM,FEATURE,10,false);//classification
//restore model from file and save model to file
// randomForest.saveModel("E:\\RandomForest2.Model");
// randomForest.readModel("E:\\RandomForest.Model");
// RandomForest randomForest("E:\\RandomForest2.Model");
//predict single sample
// float resopnse;
// randomForest.predict(testset[0],resopnse);
//predict a list of samples
float*resopnses=new float[TEST_NUM];
randomForest.predict(testset,TEST_NUM,resopnses);
float errorRate=0;
for(int i=0;i<TEST_NUM;++i)
{
if(resopnses[i]!=testlabels[i])
{
errorRate+=1.0f;
}
//for regression
// float diff=abs(resopnses[i]-testlabels[i]);
// errorRate+=diff;
}
errorRate/=TEST_NUM;
printf("the total error rate is:%f\n",errorRate);
delete[] resopnses;
for(int i=0;i<TRAIN_NUM;++i)
{delete[] trainset[i];}
for(int i=0;i<TEST_NUM;++i)
{delete[] testset[i];}
delete[] trainlabels;
delete[] testlabels;
delete[] trainset;
delete[] testset;
return 0;
};