Skip to content

Commit

Permalink
Add model stepping test for Mnist (#734)
Browse files Browse the repository at this point in the history
* Add model stepping test for Mnist

Add model stepping test for Mnist using ONNX runtime. The
assumption is that ONNX runtime is installed and the mnist model
from ONNX model zoo is downloaded.

Signed-off-by: Chin Huang <[email protected]>

* add tensor_dict back in TFRep

Signed-off-by: Chin Huang <[email protected]>
  • Loading branch information
chinhuang007 authored and masakistan committed Sep 26, 2020
1 parent 3bc773c commit 7b27f5d
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 19 deletions.
105 changes: 105 additions & 0 deletions example/test_mnist_onnxruntime_stepping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import unittest
import numpy as np

import onnx
from onnx import helper
from onnx import TensorProto
import tensorflow as tf
import onnxruntime.backend as ort

import onnx_tf.backend as otf
from onnx_tf.common import data_type


def find_between(s, first, last):
try:
start = s.index(first)
end = s.index(last) + len(last)
return s[start:end]
except ValueError:
return ""


class TestMnistModel(unittest.TestCase):
# Make sure the onnx file path is correct, assuming copied to the
# current directory
model_path = 'mnist-8.onnx'

def test(self):
_model = onnx.load(self.model_path)
print("Total node count in model: ", len(_model.graph.node))

# The input tensors could be provided as constants
# The example below illustrates such a dictionary could be
# provided for models with unknown input shapes. Since
# mnist has known input shape, we don't provide input tensors.
# input_tensors = {'Input3': tf.constant(0, dtype = tf.float32,
# name='Input3',
# shape=[1, 1, 28, 28])}
input_tensors = {}
tensor_dict = otf.prepare(_model,
gen_tensor_dict=True,
input_tensor_dict=input_tensors).tensor_dict
more_outputs = []
output_to_check = []
for node in _model.graph.node:
# add the first output of each node to the model output
output_tensor = None
for i in range(len(_model.graph.value_info)):
if _model.graph.value_info[i].name == node.output[0]:
output_tensor = _model.graph.value_info[i]

for i in range(len(_model.graph.initializer)):
if _model.graph.initializer[i].name == node.output[0]:
output_tensor = _model.graph.initializer[i]

# assume the first output is a tensor
tensor = tensor_dict[node.output[0]]
output_tensor = helper.make_tensor_value_info(
node.output[0], data_type.tf2onnx(tensor.dtype),
tensor.shape) if output_tensor is None else output_tensor
more_outputs.append(output_tensor)
output_to_check.append(node.output[0])
_model.graph.output.extend(more_outputs)

tf_rep = otf.prepare(_model)
rt_rep = ort.prepare(_model)

# prepare input data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
sample = x_test[:1].reshape(1, 1, 28, 28).astype(np.float32)

inputs = [sample]
my_out = tf_rep.run(inputs)
rt_out = rt_rep.run(inputs)

for op in output_to_check:
for i in range(len(my_out)):
# find the index of output in the list
if my_out[op] is my_out[i]:

try:
np.savetxt(op.replace("/", "__") + ".rt",
rt_out[i].flatten(),
delimiter='\t')
np.savetxt(op.replace("/", "__") + ".tf",
my_out[i].flatten(),
delimiter='\t')
np.testing.assert_allclose(my_out[i], rt_out[i], rtol=1e-2)
print(op, "results of this layer are correct within tolerence.")
except Exception as e:
np.set_printoptions(threshold=np.inf)
mismatch_percent = (find_between(str(e), "(mismatch", "%)"))
print(op, "mismatch with percentage {} %".format(mismatch_percent))


if __name__ == '__main__':
unittest.main()
pass
68 changes: 51 additions & 17 deletions onnx_tf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def prepare(cls,
super(TensorflowBackend, cls).prepare(model, device, **kwargs)
common.logger.setLevel(logging_level)
common.logger.handlers[0].setLevel(logging_level)
common.sys_config.auto_cast=auto_cast
common.sys_config.auto_cast = auto_cast

return cls.onnx_model_to_tensorflow_rep(model, strict)
return cls.onnx_model_to_tensorflow_rep(model, strict, **kwargs)

@classmethod
def onnx_model_to_tensorflow_rep(cls, model, strict):
def onnx_model_to_tensorflow_rep(cls, model, strict, **kwargs):
""" Convert ONNX model to TensorflowRep.
:param model: ONNX ModelProto object.
Expand All @@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
opset_import = [make_opsetid(defs.ONNX_DOMAIN, 1)]
else:
opset_import = model.opset_import
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict)
return cls._onnx_graph_to_tensorflow_rep(model.graph, opset_import, strict,
**kwargs)

@classmethod
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
""" Convert ONNX graph to TensorflowRep.
:param graph_def: ONNX GraphProto object.
:param opset: ONNX OperatorSetIdProto list.
:param strict: whether to enforce semantic equivalence between the original model
and the converted tensorflow model.
:kwargs: additional arguements to generate tensor_dict for model debugging
:return: TensorflowRep object.
"""
# To generate tensor_dict or not, default is False
gen_tensor_dict = kwargs[
'gen_tensor_dict'] if 'gen_tensor_dict' in kwargs else False
# User provided input tensors, in the case the model inputs have unknown shapes
input_tensor_dict = kwargs[
'input_tensor_dict'] if 'input_tensor_dict' in kwargs else dict()

handlers = cls._get_handlers(opset)

# initializer: TensorProtos representing the values to initialize
# a given tensor.
# initialized: A list of names of the initialized tensors.

if graph_def.initializer:
input_dict_items = cls._onnx_initializer_to_input_dict_items(
graph_def.initializer)
initialized = {init.name for init in graph_def.initializer}
else:
input_dict_items = []
initialized = set()

module = BackendTFModule(handlers, opset, strict, graph_def, cls)
signatures = dict()

for value_info in graph_def.input:
if value_info.name in initialized:
continue
shape = list(
d.dim_value if (d.dim_value > 0 and d.dim_param == "") else None
for d in value_info.type.tensor_type.shape.dim)
value_info_name = value_info.name.replace(
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name
":", "_tf_") + "_" + get_unique_suffix(
) if ":" in value_info.name else value_info.name

tf_spec = tf.TensorSpec(shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type), value_info_name)
tf_spec = tf.TensorSpec(
shape, data_type.onnx2tf(value_info.type.tensor_type.elem_type),
value_info_name)
signatures[value_info.name] = tf_spec

if gen_tensor_dict:
x = tf.constant(
0,
dtype=data_type.onnx2tf(value_info.type.tensor_type.elem_type),
name=value_info_name,
shape=shape
) if value_info.name not in input_tensor_dict else input_tensor_dict[
value_info.name]
input_dict_items.append((value_info_name, x))

tf_rep = TensorflowRep()
tf_rep.inputs = [
value_info.name
Expand All @@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
module.outputs = tf_rep.outputs
tf_rep.tf_module = module
tf_rep.signatures = signatures
tf_rep.tensor_dict = module.gen_tensor_dict(
input_dict_items) if gen_tensor_dict else None

return tf_rep

@classmethod
Expand All @@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
:param kwargs: Other args.
:return: Outputs.
"""

class TFModule(tf.Module):

def __init__(self, node):
super(TFModule, self).__init__()
self.node = node
Expand All @@ -171,13 +199,16 @@ def __call__(self, **input_dict):
feed_dict_raw = dict(zip(node.inputs, inputs))

# TODO: is constant the best way for feeding inputs?
input_dict = dict(
[(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()])
input_dict = dict([(x[0], tf.constant(x[1])) for x in feed_dict_raw.items()
])

module = TFModule(node)

output_vals = module(**input_dict)
output_vals = [val.numpy() if isinstance(val, tf.Tensor) else val for val in output_vals]
output_vals = [
val.numpy() if isinstance(val, tf.Tensor) else val
for val in output_vals
]

return namedtupledict('Outputs', node.outputs)(*output_vals)

Expand Down Expand Up @@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
"""
handlers = handlers or cls._get_handlers(opset)
if handlers:
handler = handlers[node.domain].get(node.op_type, None) if node.domain in handlers else None
handler = handlers[node.domain].get(
node.op_type, None) if node.domain in handlers else None
if handler:
return handler.handle(node, tensor_dict=tensor_dict, strict=strict)

raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(node.op_type))
raise BackendIsNotSupposedToImplementIt("{} is not implemented.".format(
node.op_type))

@classmethod
def _get_handlers(cls, opset):
Expand Down Expand Up @@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
nodes_outputs.append(o_name)
for node in subgraph.node:
for i_name in node.input:
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys():
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
):
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
onnx_node = OnnxNode(node)
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
Expand All @@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
return subgraph_tensor_dict

@classmethod
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):
"""
Converts ONNX graph to TensorflowRep
Args:
Expand All @@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
"""
# get the opset of the installed ONNX
opset = [make_opsetid(defs.ONNX_DOMAIN, defs.onnx_opset_version())]
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict)
return cls._onnx_graph_to_tensorflow_rep(graph_def, opset, strict, **kwargs)


prepare = TensorflowBackend.prepare
Expand Down
9 changes: 9 additions & 0 deletions onnx_tf/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, graph=None, inputs=None, outputs=None, tensor_dict=None):
self._inputs = inputs or []
self._outputs = outputs or []
self._tensor_dict = tensor_dict or {}
self._tf_module = None

@property
def graph(self):
Expand Down Expand Up @@ -49,6 +50,14 @@ def tensor_dict(self):
def tensor_dict(self, tensor_dict):
self._tensor_dict = tensor_dict

@property
def tf_module(self):
return self._tf_module

@tf_module.setter
def tf_module(self, tf_module):
self._tf_module = tf_module

def run(self, inputs, **kwargs):
""" Run TensorflowRep.
Expand Down
24 changes: 22 additions & 2 deletions onnx_tf/backend_tf_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from onnx_tf.pb_wrapper import OnnxNode


class BackendTFModule(tf.Module):

def __init__(self, handlers, opset, strict, graph_def, backend):
Expand All @@ -12,6 +13,22 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
self.backend = backend
self.outputs = []

@tf.function
def gen_tensor_dict(self, input_dict_items):
tensor_dict = dict(input_dict_items)

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
self.handlers,
opset=self.opset,
strict=self.strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

return tensor_dict

@tf.function
def __call__(self, **kwargs):
tensor_dict = kwargs
Expand All @@ -26,8 +43,11 @@ def __call__(self, **kwargs):

for node in self.graph_def.node:
onnx_node = OnnxNode(node)
output_ops = self.backend._onnx_node_to_tensorflow_op(
onnx_node, tensor_dict, self.handlers, opset=self.opset, strict=self.strict)
output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
tensor_dict,
self.handlers,
opset=self.opset,
strict=self.strict)
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
tensor_dict.update(curr_node_output_map)

Expand Down

0 comments on commit 7b27f5d

Please sign in to comment.