Skip to content

Commit

Permalink
tf - add bf16 numpy dtype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 11, 2023
1 parent 50d65d7 commit df2ac41
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from tensorflow.keras.applications import ResNet50

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
conv_sequence,
load_pretrained_params,
)
from doctr.utils.repr import NestedObject

from ...classification import mobilenet_v3_large
Expand Down Expand Up @@ -241,7 +246,7 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)
prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map
Expand Down
10 changes: 8 additions & 2 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

from doctr.file_utils import CLASS_NAME
from doctr.models.classification import resnet18, resnet34, resnet50
from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
conv_sequence,
load_pretrained_params,
)
from doctr.utils.repr import NestedObject

from .base import LinkNetPostProcessor, _LinkNet
Expand Down Expand Up @@ -229,7 +234,8 @@ def call(
return out

if return_model_output or target is None or return_preds:
prob_map = tf.math.sigmoid(logits)
prob_map = _bf16_numpy_dtype_converter(tf.math.sigmoid(logits))

if return_model_output:
out["out_map"] = prob_map

Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/crnn/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.datasets import VOCABS

from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
Expand Down Expand Up @@ -199,7 +199,7 @@ def call(
w, h, c = transposed_feat.get_shape().as_list()[1:]
# B x W x H x C --> B x W x H * C
features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
logits = self.decoder(features_seq, **kwargs)
logits = _bf16_numpy_dtype_converter(self.decoder(features_seq, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/master/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.models.classification import magc_resnet31
from doctr.models.modules.transformer import Decoder, PositionalEncoding

from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from .base import _MASTER, _MASTERPostProcessor

__all__ = ["MASTER", "master"]
Expand Down Expand Up @@ -181,7 +181,7 @@ def call(
output = self.decoder(gt, encoded, source_mask, target_mask, **kwargs)
logits = self.linear(output, **kwargs)
else:
logits = self.decode(encoded, **kwargs)
logits = _bf16_numpy_dtype_converter(self.decode(encoded, **kwargs))

if self.exportable:
out["logits"] = logits
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward

from ...classification import vit_s
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor

__all__ = ["PARSeq", "parseq"]
Expand Down Expand Up @@ -388,7 +388,7 @@ def call(
)
)
else:
logits = self.decode_autoregressive(features, **kwargs)
logits = _bf16_numpy_dtype_converter(self.decode_autoregressive(features, **kwargs))

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
6 changes: 4 additions & 2 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from doctr.utils.repr import NestedObject

from ...classification import resnet31
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from ..core import RecognitionModel, RecognitionPostProcessor

__all__ = ["SAR", "sar_resnet31"]
Expand Down Expand Up @@ -316,7 +316,9 @@ def call(
if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training for teacher forcing")

decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
decoded_features = _bf16_numpy_dtype_converter(
self.decoder(features, encoded, gt=None if target is None else gt, **kwargs)
)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/recognition/vitstr/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from doctr.datasets import VOCABS

from ...classification import vit_b, vit_s
from ...utils.tensorflow import load_pretrained_params
from ...utils.tensorflow import _bf16_numpy_dtype_converter, load_pretrained_params
from .base import _ViTSTR, _ViTSTRPostProcessor

__all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
Expand Down Expand Up @@ -131,7 +131,7 @@ def call(
logits = tf.reshape(
self.head(features, **kwargs), (B, N, len(self.vocab) + 1)
) # (batch_size, max_length, vocab + 1)
decoded_features = logits[:, 1:] # remove cls_token
decoded_features = _bf16_numpy_dtype_converter(logits[:, 1:]) # remove cls_token

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down
14 changes: 13 additions & 1 deletion doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@
logging.getLogger("tensorflow").setLevel(logging.DEBUG)


__all__ = ["load_pretrained_params", "conv_sequence", "IntermediateLayerGetter", "export_model_to_onnx", "_copy_tensor"]
__all__ = [
"load_pretrained_params",
"conv_sequence",
"IntermediateLayerGetter",
"export_model_to_onnx",
"_copy_tensor",
"_bf16_numpy_dtype_converter",
]


def _copy_tensor(x: tf.Tensor) -> tf.Tensor:
return tf.identity(x)


def _bf16_numpy_dtype_converter(x: tf.Tensor) -> tf.Tensor:
# Convert bfloat16 to float32 for numpy compatibility
return tf.cast(x, tf.float32) if x.dtype == tf.bfloat16 else x


def load_pretrained_params(
model: Model,
url: Optional[str] = None,
Expand Down
14 changes: 13 additions & 1 deletion tests/tensorflow/test_models_utils_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from tensorflow.keras import Sequential, layers
from tensorflow.keras.applications import ResNet50

from doctr.models.utils import IntermediateLayerGetter, _copy_tensor, conv_sequence, load_pretrained_params
from doctr.models.utils import (
IntermediateLayerGetter,
_bf16_numpy_dtype_converter,
_copy_tensor,
conv_sequence,
load_pretrained_params,
)


def test_copy_tensor():
Expand All @@ -14,6 +20,12 @@ def test_copy_tensor():
assert m.device == x.device and m.dtype == x.dtype and m.shape == x.shape and tf.reduce_all(tf.equal(m, x))


def test_bf16_numpy_dtype_converter():
x = tf.random.uniform(shape=[8], minval=0, maxval=1, dtype=tf.bfloat16)
m = _bf16_numpy_dtype_converter(x)
assert x.dtype == tf.bfloat16 and m.dtype == tf.float32 and tf.reduce_all(tf.equal(m, tf.cast(x, tf.float32)))


def test_load_pretrained_params(tmpdir_factory):
model = Sequential([layers.Dense(8, activation="relu", input_shape=(4,)), layers.Dense(4)])
# Retrieve this URL
Expand Down

0 comments on commit df2ac41

Please sign in to comment.