Skip to content

Commit

Permalink
test eager
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 25, 2024
1 parent 57d1fe5 commit 8a12442
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
3 changes: 0 additions & 3 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ def ensure_keras_v2() -> None: # pragma: no cover
else:
logging.info(f"TensorFlow version {_tf_version} available.")
ensure_keras_v2()
import tensorflow as tf

# Enable eager execution - this is required for some models to work properly
tf.config.run_functions_eagerly(True)
else: # pragma: no cover
logging.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
Expand Down
1 change: 0 additions & 1 deletion doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def decode(
target_query = self.dropout(target_query, **kwargs)
return self.decoder(target_query, content, memory, target_mask, **kwargs)

@tf.function
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
"""Generate predictions for the given features."""
max_length = max_len if max_len is not None else self.max_length
Expand Down
7 changes: 7 additions & 0 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def export_model_to_onnx(
-------
the path to the exported model and a list with the output layer names
"""
# get the users eager mode
eager_mode = tf.executing_eagerly()
# set eager mode to false to avoid issues with tf2onnx
tf.config.run_functions_eagerly(False)
large_model = kwargs.get("large_model", False)
model_proto, _ = tf2onnx.convert.from_keras(
model,
Expand All @@ -171,6 +175,9 @@ def export_model_to_onnx(
# Get the output layer names
output = [n.name for n in model_proto.graph.output]

# reset the eager mode to the users mode
tf.config.run_functions_eagerly(eager_mode)

# models which are too large (weights > 2GB while converting to ONNX) needs to be handled
# about an external tensor storage where the graph and weights are seperatly stored in a archive
if large_model:
Expand Down

0 comments on commit 8a12442

Please sign in to comment.