diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py index 20f74a2ad5..73856d5655 100644 --- a/doctr/models/detection/differentiable_binarization/tensorflow.py +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -15,7 +15,12 @@ from tensorflow.keras.applications import ResNet50 from doctr.file_utils import CLASS_NAME -from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_numpy_dtype_converter, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from ...classification import mobilenet_v3_large @@ -241,7 +246,7 @@ def call( return out if return_model_output or target is None or return_preds: - prob_map = tf.math.sigmoid(logits) + prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py index 3ac436088e..6522f9d329 100644 --- a/doctr/models/detection/linknet/tensorflow.py +++ b/doctr/models/detection/linknet/tensorflow.py @@ -15,7 +15,12 @@ from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 -from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_numpy_dtype_converter, + conv_sequence, + load_pretrained_params, +) from doctr.utils.repr import NestedObject from .base import LinkNetPostProcessor, _LinkNet @@ -229,7 +234,8 @@ def call( return out if return_model_output or target is None or return_preds: - prob_map = tf.math.sigmoid(logits) + prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits)) + if return_model_output: out["out_map"] = prob_map diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py index 618a4c0e92..89e467d180 100644 --- a/doctr/models/recognition/crnn/tensorflow.py +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -13,7 +13,7 @@ from doctr.datasets import VOCABS from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"] @@ -199,7 +199,7 @@ def call( w, h, c = transposed_feat.get_shape().as_list()[1:] # B x W x H x C --> B x W x H * C features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c)) - logits = self.decoder(features_seq, **kwargs) + logits = _bf16_numpy_dtype_converter(self.decoder(features_seq, **kwargs)) out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py index 908c2e8b8f..aecd41dfbe 100644 --- a/doctr/models/recognition/master/tensorflow.py +++ b/doctr/models/recognition/master/tensorflow.py @@ -13,7 +13,7 @@ from doctr.models.classification import magc_resnet31 from doctr.models.modules.transformer import Decoder, PositionalEncoding -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params from .base import _MASTER, _MASTERPostProcessor __all__ = ["MASTER", "master"] @@ -181,7 +181,7 @@ def call( output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs) logits = self.linear(output, **kwargs) else: - logits = self.decode(encoded, **kwargs) + logits = _bf16_numpy_dtype_converter(self.decode(encoded, **kwargs)) if self.exportable: out["logits"] = logits diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index 21a35605f5..817c5a6ec1 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -16,7 +16,7 @@ from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward from ...classification import vit_s -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params from .base import _PARSeq, _PARSeqPostProcessor __all__ = ["PARSeq", "parseq"] @@ -388,7 +388,7 @@ def call( ) ) else: - logits = self.decode_autoregressive(features, **kwargs) + logits = _bf16_numpy_dtype_converter(self.decode_autoregressive(features, **kwargs)) out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py index 6a688c7bac..76d79f8116 100644 --- a/doctr/models/recognition/sar/tensorflow.py +++ b/doctr/models/recognition/sar/tensorflow.py @@ -13,7 +13,7 @@ from doctr.utils.repr import NestedObject from ...classification import resnet31 -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params from ..core import RecognitionModel, RecognitionPostProcessor __all__ = ["SAR", "sar_resnet31"] @@ -316,7 +316,9 @@ def call( if kwargs.get("training", False) and target is None: raise ValueError("Need to provide labels during training for teacher forcing") - decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + decoded_features = _bf16_numpy_dtype_converter( + self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + ) out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/recognition/vitstr/tensorflow.py b/doctr/models/recognition/vitstr/tensorflow.py index 70c7325b3f..b7c2bf89fa 100644 --- a/doctr/models/recognition/vitstr/tensorflow.py +++ b/doctr/models/recognition/vitstr/tensorflow.py @@ -12,7 +12,7 @@ from doctr.datasets import VOCABS from ...classification import vit_b, vit_s -from ...utils.tensorflow import load_pretrained_params +from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params from .base import _ViTSTR, _ViTSTRPostProcessor __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"] @@ -131,7 +131,7 @@ def call( logits = tf.reshape( self.head(features, **kwargs), (B, N, len(self.vocab) + 1) ) # (batch_size, max_length, vocab + 1) - decoded_features = logits[:, 1:] # remove cls_token + decoded_features = _bf16_numpy_dtype_converter(logits[:, 1:]) # remove cls_token out: Dict[str, tf.Tensor] = {} if self.exportable: diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 8490c09f11..199b2d6ca2 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -17,13 +17,25 @@ logging.getLogger("tensorflow").setLevel(logging.DEBUG) -__all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"] +__all__ = [ + "load_pretrained_params", + "conv_sequence", + "IntermediateLayerGetter", + "export_model_to_onnx", + "_copy_tensor", + "_bf16_numpy_dtype_converter", +] def _copy_tensor(x: tf.Tensor) -> tf.Tensor: return tf.identity(x) +def _bf16_numpy_dtype_converter(x: tf.Tensor) -> tf.Tensor: + # Convert bfloat16 to float32 for numpy compatibility + return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x + + def load_pretrained_params( model: Model, url: Optional[str] = None, diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py index 2e256cacb8..73e2c5ffd3 100644 --- a/tests/tensorflow/test_models_utils_tf.py +++ b/tests/tensorflow/test_models_utils_tf.py @@ -5,7 +5,13 @@ from tensorflow.keras import Sequential, layers from tensorflow.keras.applications import ResNet50 -from doctr.models.utils import IntermediateLayerGetter, _copy_tensor, conv_sequence, load_pretrained_params +from doctr.models.utils import ( + IntermediateLayerGetter, + _bf16_numpy_dtype_converter, + _copy_tensor, + conv_sequence, + load_pretrained_params, +) def test_copy_tensor(): @@ -14,6 +20,12 @@ def test_copy_tensor(): assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x)) +def test_bf16_numpy_dtype_converter(): + x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16) + m = _bf16_numpy_dtype_converter(x) + assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32))) + + def test_load_pretrained_params(tmpdir_factory): model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)]) # Retrieve this URL