Skip to content

Commit

Permalink
Merge pull request #100 from disktnk/fix/multiple-outputs
Browse files Browse the repository at this point in the history
Fix graph exporter when model has multiple outputs
  • Loading branch information
disktnk authored Feb 20, 2019
2 parents b07d255 + e0aea68 commit 33f9434
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 24 deletions.
27 changes: 13 additions & 14 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def backward_postprocess(self, function, in_data, out_grad):
nodes = create_node(
func_name, onnx_op_name, opset_version, function, input_names,
output_names, self.additional_parameters)
for node in nodes:
if node not in self.graph:
self.graph.append(node)
self.graph.extend(nodes)


def export(model, args, filename=None, export_params=True,
Expand Down Expand Up @@ -209,19 +207,23 @@ def export(model, args, filename=None, export_params=True,
if isinstance(arg, chainer.get_array_types()):
args[i] = chainer.Variable(arg)
network_inputs.append(args[i])
flat_args = args
outputs = model(*args)
elif isinstance(args, dict):
for key, arg in args.items():
if isinstance(arg, chainer.get_array_types()):
args[key] = chainer.Variable(arg)
network_inputs.append(args[key])
flat_args = list(args.values())
outputs = model(**args)
elif isinstance(args, chainer.get_array_types()):
args = chainer.Variable(args)
network_inputs.append(args)
flat_args = [args]
outputs = model(args)
elif isinstance(args, chainer.Variable):
network_inputs.append(args)
flat_args = [args]
outputs = model(args)
else:
raise ValueError(
Expand All @@ -247,19 +249,16 @@ def export(model, args, filename=None, export_params=True,

with ONNXExport(opset_version) as o:
if isinstance(outputs, (list, tuple)):
for output in outputs:
output.grad = model.xp.ones_like(
output.array, dtype=output.array.dtype)
output.backward()
flat_outputs = outputs
elif isinstance(outputs, dict):
outputs = list(outputs.values())
for output in outputs:
output.grad = model.xp.ones_like(
output.array, dtype=output.array.dtype)
output.backward()
flat_outputs = list(outputs.values())
elif isinstance(outputs, chainer.Variable):
outputs.grad = model.xp.ones_like(outputs.array)
outputs.backward()
flat_outputs = [outputs]
else:
raise RuntimeError(
'Unexpected output type from the model: {}'.format(
type(outputs)))
chainer.grad(flat_outputs, list(model.params()) + flat_args)

implicit_input_names = set(o.inputs.keys()) - param_names -\
network_input_names
Expand Down
8 changes: 4 additions & 4 deletions onnx_chainer/testing/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MXNET_AVAILABLE = False


def check_compatibility(model, x, fn, out_key='prob', opset_version=None):
def check_compatibility(model, x, fn, out_keys=None, opset_version=None):
if opset_version is None:
opset_version = onnx.defs.onnx_opset_version()
if not MXNET_AVAILABLE:
Expand All @@ -45,9 +45,9 @@ def check_compatibility(model, x, fn, out_key='prob', opset_version=None):
if isinstance(chainer_out, (list, tuple)):
chainer_out = [y.array for y in chainer_out]
elif isinstance(chainer_out, dict):
chainer_out = chainer_out[out_key]
if isinstance(chainer_out, chainer.Variable):
chainer_out = (chainer_out.array,)
chainer_outs = [chainer_out[k] for k in out_keys]
chainer_out = tuple(out.array if isinstance(out, chainer.Variable) else
out for out in chainer_outs)
elif isinstance(chainer_out, chainer.Variable):
chainer_out = (chainer_out.array,)
else:
Expand Down
18 changes: 12 additions & 6 deletions onnx_chainer/testing/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TEST_OUT_DIR = 'out'


def check_output(model, x, filename, out_key='prob', opset_version=None):
def check_output(model, x, filename, out_keys=None, opset_version=None):
model.xp.random.seed(42)

os.makedirs(TEST_OUT_DIR, exist_ok=True)
Expand Down Expand Up @@ -61,12 +61,18 @@ def check_output(model, x, filename, out_key='prob', opset_version=None):
' chainer.Variable itself. But a {} object was given.'.format(
type(x)))

rt_out_keys = None
if isinstance(chainer_out, (list, tuple)):
chainer_out = (y.array for y in chainer_out)
chainer_out = tuple(y.array for y in chainer_out)
if out_keys is not None:
assert len(out_keys) == len(chainer_out)
rt_out_keys = out_keys
elif isinstance(chainer_out, dict):
chainer_out = chainer_out[out_key]
if isinstance(chainer_out, chainer.Variable):
chainer_out = (chainer_out.array,)
if len(out_keys) > 1:
rt_out_keys = out_keys
chainer_outs = [chainer_out[k] for k in out_keys]
chainer_out = tuple(out.array if isinstance(out, chainer.Variable) else
out for out in chainer_outs)
elif isinstance(chainer_out, chainer.Variable):
chainer_out = (chainer_out.array,)
else:
Expand All @@ -93,7 +99,7 @@ def check_output(model, x, filename, out_key='prob', opset_version=None):
assert list(sorted(input_names)) == list(sorted(graph_input_names))

rt_out = sess.run(
None, {name: array for name, array in zip(input_names, x_rt)})
rt_out_keys, {name: array for name, array in zip(input_names, x_rt)})

for cy, my in zip(chainer_out, rt_out):
np.testing.assert_allclose(cy, my, rtol=1e-5, atol=1e-5)
50 changes: 50 additions & 0 deletions tests/test_inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import testing
import numpy as np

from onnx_chainer.testing import input_generator
Expand Down Expand Up @@ -68,3 +69,52 @@ def __call__(self, x):
def test_implicit_input(self):
x = chainer.Variable(np.array(1, dtype=np.float32))
test_onnxruntime.check_output(self.model, x, self.fn)


@testing.parameterize(
{'use_bn': True, 'out_type': 'dict'},
{'use_bn': False, 'out_type': 'dict'},
{'use_bn': True, 'out_type': 'tuple'},
{'use_bn': True, 'out_type': 'list'},
)
class TestMultipleOutput(unittest.TestCase):

def get_model(self, use_bn=False, out_type=None):
class Model(chainer.Chain):

def __init__(self, use_bn=False, out_type=None):
super(Model, self).__init__()

self._use_bn = use_bn
self._out_type = out_type
with self.init_scope():
self.conv = L.Convolution2D(None, 32, ksize=3, stride=1)
if self._use_bn:
self.bn = L.BatchNormalization(32)

def __call__(self, x):
h = self.conv(x)
if self._use_bn:
h = self.bn(h)
o1 = F.tanh(h)
o2 = F.sigmoid(h)
if self._out_type == 'dict':
return {
'Tanh_0': o1,
'Sigmoid_0': o2
}
elif self._out_type == 'tuple':
return o1, o2
elif self._out_type == 'list':
return [o1, o2]

return Model(use_bn=use_bn, out_type=out_type)

def test_multiple_outputs(self):
model = self.get_model(use_bn=self.use_bn, out_type=self.out_type)
x = np.zeros((1, 3, 32, 32), dtype=np.float32)
# 'out_keys' is necessary even if self.out_type is tuple or list
# because onnxruntime does not guarantee the order of outputs when
# output keys are None.
test_onnxruntime.check_output(
model, x, 'MultipleOutputs.onnx', out_keys=['Tanh_0', 'Sigmoid_0'])

0 comments on commit 33f9434

Please sign in to comment.