-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestNN.py
36 lines (31 loc) · 1.19 KB
/
TestNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from NeuralNet import NeuralNet
import ReadData
import os
import tensorflow as tf
"The main function which operates the neural network"
def main():
"Supresses memory errors"
tf.get_logger().setLevel('ERROR')
"Sets checkpoint path for model history"
checkpoint_path = "Dog_path/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
"Reads testing dataset from directory"
testingDS = ReadData.get_testingData()
"Reads training dataset from directory"
trainingDS = ReadData.get_trainingData()
"Reads validation dataset from directory"
validationDS = ReadData.get_validationgData()
'''Instantiates the NeuralNet object which will do the training and
prediction'''
model = NeuralNet(checkpoint_path,[180, 180,3], train_data = trainingDS,\
val_data = validationDS)
"Trains the model on the training and validation data"
model.train_model()
'''
Predicts the classification of the test data and reports the accuracy
'''
predictions, accuracy = model.classify_data(testingDS)
"Prints the accuracy"
#print(predictions)
print("The accuracy on the test data was: " + accuracy[1])
main()