forked from verivital/nnv
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtraining.m
76 lines (58 loc) · 2.04 KB
/
training.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
%% Training of an MNIST classifier (CNN)
% Code based on MathWorks example
% https://www.mathworks.com/help/deeplearning/ug/create-simple-deep-learning-network-for-classification.html
t = tic; % track total time for the training
% Load data (no download necessary)
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
% Images
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% Display some of the data
figure;
perm = randperm(10000,20);
for i = 1:20
subplot(4,5,i);
imshow(imds.Files{perm(i)});
end
% Show the number of images (per class) in the dataset
labelCount = countEachLabel(imds);
disp(labelCount);
% Select training procedure
numTrainFiles = 750; %total number of images to use
% Split the data
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
numClasses = height(labelCount); % number of classes in dataset
imgSize = [28 28 1]; % size of the images
% Create the neural network model
layers = [
imageInputLayer(imgSize) % image size = [28 28 1] (Height x Width x Channels)
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
averagePooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses) % 10 = number of classes
softmaxLayer
classificationLayer];
% Training options
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',true);
% Train network
net = trainNetwork(imdsTrain,layers,options);
% Validate network (accuracy)
YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation);
disp ("Validation accuracy = "+string(accuracy));
% Save model
disp("Saving model...");
save('mnist_model.mat', 'net', 'accuracy');
toc(t);