A short project to train a model using Keras in Python and use TFJS library for on-browser prediction.
This code has been tested on Tensorflow 1.11
and higher using python3
.
#Start a virtual environment
virtualenv venv
source venv/bin/activate
#Install libraries
pip3 install tensorflow==1.11
pip3 install tensorflowjs
python3 model_builder.py
Here we use the following line in the Python code as mentioned in the docs to convert to desired format:
tfjs.converters.save_keras_model(model, 'tfjs_target_dir')
Alternatively, we could use the bash command:
tensorflowjs_converter --input_format keras \
my_mnist_model.h5 \
tfjs_target_dir
Notice that our model is very simple ie just a hidden layer with relu
activations, a dropout layer and finally an output layer with softmax activations.
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
Some part of the javascript code is borrowed from the TFJS mnist-core example and also in order to avoid a lot of work on frontend, I used the template from the TFJS examples repo.
Bonus tip: It helps to clone the TFJS examples repo and run everything just to see what TFJS is capable of (a lot). You can use this custom-made bash command for this.
# replace with suitable command for non-macOS from https://yarnpkg.com/lang/en/docs/install
brew install yarn
git clone https://github.com/tensorflow/tfjs-examples
cd tfjs-examples
# goes into each directory and runs yarn watch without hanging up and gets you back to terminal
ls -d */ | xargs -I {} bash -c "cd '{}' && pwd && yarn && nohup yarn watch > /dev/null 2>&1 &"
The focus for this project is not to attain very high accuracy on actual hand-written images in demo but to make a Tensorflow JS script work with the a model trained in Python. Even so, there were several challenges due to some backward incompatibility issues of TFJS.
If one would like to improve accuracy, one sure-shot way to improve accuracy would be to modify the way the canvas output is being converted into 2D tensor (Issue). The current workaround produces non-continuous segments. (Look here); whereas actual MNIST images contain continuous segments. This however will require significant tweaking with the canvas event listener functions. However since the results are reasonably well, this is not necessary.