diff --git a/onnx_tf/handlers/backend/instance_normalization.py b/onnx_tf/handlers/backend/instance_normalization.py index 59beba3bf..04d83b68c 100644 --- a/onnx_tf/handlers/backend/instance_normalization.py +++ b/onnx_tf/handlers/backend/instance_normalization.py @@ -3,7 +3,7 @@ from onnx_tf.handlers.backend_handler import BackendHandler from onnx_tf.handlers.handler import onnx_op from onnx_tf.handlers.handler import tf_func - +from onnx_tf.common.tf_helper import tf_shape @onnx_op("InstanceNormalization") @tf_func(tf.nn.batch_normalization) @@ -31,7 +31,7 @@ def _common(cls, node, **kwargs): beta = tensor_dict[node.inputs[2]] inputs = tensor_dict[node.inputs[0]] - inputs_shape = inputs.shape + inputs_shape = tf_shape(inputs) inputs_rank = inputs.shape.ndims moments_axes = list(range(inputs_rank))[2:]