Skip to content

Commit

Permalink
starting docs
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jul 11, 2023
1 parent 609cf4a commit e4bae75
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 142 deletions.
100 changes: 56 additions & 44 deletions docs/source/using_doctr/using_model_export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,81 @@ Preparing your model for inference

A well-trained model is a good achievement but you might want to tune a few things to make it production-ready!

.. currentmodule:: doctr.models.export
.. currentmodule:: doctr.models.utils


Model compression
-----------------
Model optimization
------------------

This section is meant to help you perform inference with compressed versions of your model.
This section is meant to help you perform inference with optimized versions of your model.


TensorFlow Lite
^^^^^^^^^^^^^^^
Export to ONNX
^^^^^^^^^^^^^^

TensorFlow provides utilities packaged as TensorFlow Lite to take resource constraints into account. You can easily convert any Keras model into a serialized TFLite version as follows:
ONNX (Open Neural Network Exchange) is an open and interoperable format for representing and exchanging machine learning models.
ONNX defines a common format for representing models, including the network structure, layer types, parameters, and metadata.
ONNX Runtime is an inference engine that enables efficient execution of ONNX models across different platforms and hardware accelerators.
It provides optimized performance and supports a wide range of hardware configurations.

>>> import tensorflow as tf
>>> from tensorflow.keras import Sequential
>>> from doctr.models import conv_sequence
>>> model = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=(224, 224, 3)))
>>> converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
>>> serialized_model = converter.convert()
.. tabs::

Half-precision
^^^^^^^^^^^^^^
.. tab:: TensorFlow

.. code:: python3
import tensorflow as tf
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx
batch_size = 16
input_shape = (3, 32, 128)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")]
model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input)
If you want to convert it to half-precision using your TFLite converter
>>> converter.optimizations = [tf.lite.Optimize.DEFAULT]
>>> converter.target_spec.supported_types = [tf.float16]
>>> serialized_model = converter.convert()
.. tab:: PyTorch

.. code:: python3
import torch
from doctr.models import vitstr_small
from doctr.models.utils import export_model_to_onnx
batch_size = 16
input_shape = (32, 128, 3)
model = vitstr_small(pretrained=True, exportable=True)
dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32)
model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input)
Half-precision
^^^^^^^^^^^^^^

Post-training quantization
^^^^^^^^^^^^^^^^^^^^^^^^^^
Half-precision (or FP16) is a binary floating-point format that occupies 16 bits in computer memory.

Finally if you wish to quantize the model with your TFLite converter
.. tabs::

>>> converter.optimizations = [tf.lite.Optimize.DEFAULT]
>>> # Float fallback for operators that do not have an integer implementation
>>> def representative_dataset():
>>> for _ in range(100): yield [np.random.rand(1, *input_shape).astype(np.float32)]
>>> converter.representative_dataset = representative_dataset
>>> converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
>>> converter.inference_input_type = tf.int8
>>> converter.inference_output_type = tf.int8
>>> serialized_model = converter.convert()
.. tab:: TensorFlow

.. code:: python3
Using SavedModel
----------------
import tensorflow as tf
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True)
Additionally, models in docTR inherit TensorFlow 2 model properties and can be exported to
`SavedModel <https://www.tensorflow.org/guide/saved_model>`_ format as follows:
.. tab:: PyTorch

.. code:: python3
>>> import tensorflow as tf
>>> from doctr.models import db_resnet50
>>> model = db_resnet50(pretrained=True)
>>> input_t = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32)
>>> _ = model(input_t, training=False)
>>> tf.saved_model.save(model, 'path/to/your/folder/db_resnet50/')
import torch
predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half()
res = predictor(doc)
And loaded just as easily:
Using your ONNX model inside DocTR
----------------------------------

>>> import tensorflow as tf
>>> model = tf.saved_model.load('path/to/your/folder/db_resnet50/')
** Coming soon **
Loading

0 comments on commit e4bae75

Please sign in to comment.