Skip to content

Commit

Permalink
Merge pull request #81 from disktnk/fix/out-scope-param
Browse files Browse the repository at this point in the history
Fix to accept parameters out of init scope
  • Loading branch information
disktnk authored Feb 12, 2019
2 parents 3f55ff1 + a1e85d6 commit ab3ac33
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
35 changes: 23 additions & 12 deletions onnx_chainer/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def convert_parameter(parameter):
'The type of parameter is unknown. It should be either Parameter '
'or Variable or ndarray, but the type was {}.'.format(
type(parameter)))
if array.shape == ():
array = array[None]
return numpy_helper.from_array(array, str(id(parameter)))


Expand Down Expand Up @@ -97,6 +95,7 @@ class ONNXExport(chainer.FunctionHook):

def __init__(self, opset_version=None):
self.graph = []
self.inputs = {}
self.additional_parameters = []
self.middle_output_var_to_varnode = {}
self.specified_opset_version = opset_version
Expand All @@ -110,9 +109,11 @@ def backward_postprocess(self, function, in_data, out_grad):
# 'i' is a VariableNode, so check if it has a Variable/Parameter
var = i.get_variable_or_none()
if var is None: # No reference to Variable/Parameter
input_names.append(str(id(i))) # Use VariableNode as is
input_name = str(id(i)) # Use VariableNode as is
else: # It is a parameter inside a Link or network input
input_names.append(str(id(var)))
input_name = str(id(var))
self.inputs[input_name] = var
input_names.append(input_name)

# This is to get corresponding VariableNode id from the output
# Variable of the network
Expand Down Expand Up @@ -232,14 +233,17 @@ def export(model, args, filename=None, export_params=True,

initializers = []
input_tensors = []
param_names = set()
for param in model.params():
initializers.append(convert_parameter(param))
param_shape = (1,) if param.shape == () else param.shape
param_names.add(str(id(param)))
tensor = convert_parameter(param)
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
str(id(param)), NP_TYPE_TO_TENSOR_TYPE[param.array.dtype],
param_shape))
str(id(param)), tensor.data_type, tensor.dims))

network_input_names = set()
for i in network_inputs:
network_input_names.add(str(id(i)))
input_tensors.append(helper.make_tensor_value_info(
str(id(i)), NP_TYPE_TO_TENSOR_TYPE[i.dtype], i.shape))

Expand All @@ -259,14 +263,21 @@ def export(model, args, filename=None, export_params=True,
outputs.grad = numpy.ones_like(outputs.array)
outputs.backward()

implicit_input_names = set(o.inputs.keys()) - param_names -\
network_input_names
for name in implicit_input_names:
tensor = convert_parameter(o.inputs[name])
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
name, tensor.data_type, tensor.dims))

# If additional parameters are created during conversion
if o.additional_parameters:
for param in o.additional_parameters:
initializers.append(convert_parameter(param))
param_shape = (1,) if param.shape == () else param.shape
tensor = convert_parameter(param)
initializers.append(tensor)
input_tensors.append(helper.make_tensor_value_info(
str(id(param)), NP_TYPE_TO_TENSOR_TYPE[param.array.dtype],
param_shape))
str(id(param)), tensor.data_type, tensor.dims))

# The graph must be topologically sorted
graph = reversed(o.graph)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_inout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,25 @@ def test_variable_dicts(self):
ins = {arg_names[i]: chainer.Variable(v)
for i, v in enumerate(self.ins)}
test_onnxruntime.check_output(self.model, ins, self.fn)


class TestImplicitInput(unittest.TestCase):

def setUp(self):

class Model(chainer.Chain):

def __init__(self):
super(Model, self).__init__()

self.frac = chainer.Parameter(np.array(2, dtype=np.float32))

def __call__(self, x):
return x / self.frac

self.model = Model()
self.fn = 'ImplicitInput.onnx'

def test_implicit_input(self):
x = chainer.Variable(np.array(1, dtype=np.float32))
test_onnxruntime.check_output(self.model, x, self.fn)

0 comments on commit ab3ac33

Please sign in to comment.