From fd802722c9fdb24b4ab3ad5033406553381db5ab Mon Sep 17 00:00:00 2001 From: Emilio Castillo Date: Fri, 19 Jun 2020 05:07:45 +0000 Subject: [PATCH] Added `chainer_torch_functions` to call torch functions inside chainer --- chainer_pytorch_migration/__init__.py | 1 + chainer_pytorch_migration/functions.py | 68 ++++++++++++++++++++++++++ tests/test_functions.py | 49 +++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 chainer_pytorch_migration/functions.py create mode 100644 tests/test_functions.py diff --git a/chainer_pytorch_migration/__init__.py b/chainer_pytorch_migration/__init__.py index 546ccd4..0e07c34 100644 --- a/chainer_pytorch_migration/__init__.py +++ b/chainer_pytorch_migration/__init__.py @@ -1,6 +1,7 @@ from . import links from .allocator import use_mempool_in_cupy_malloc, use_torch_in_cupy_malloc from .datasets import TransformDataset +from .functions import chainer_torch_function from .links import TorchModule from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer from .tensor import asarray, astensor, to_numpy_dtype diff --git a/chainer_pytorch_migration/functions.py b/chainer_pytorch_migration/functions.py new file mode 100644 index 0000000..9aff504 --- /dev/null +++ b/chainer_pytorch_migration/functions.py @@ -0,0 +1,68 @@ +import chainer +import torch + +import chainer_pytorch_migration as cpm + + +class _ChainerTorchFunction(chainer.FunctionNode): + def __init__(self, torch_fn, *args, **kwargs): + self.torch_fn = torch_fn + self.torch_fwd_inputs = None + self.torch_fwd_outputs = None + self.args = args + self.kwargs = kwargs + + def forward(self, inputs): + t_inputs = [cpm.astensor(x) for x in inputs] + for t_x in t_inputs: + t_x.requires_grad = True + self.torch_fwd_inputs = t_inputs + f_inputs = t_inputs + list(self.args) + # The torch function might require other arguments other than input + # tensors so append them here + t_outs = self.torch_fn(*f_inputs, **self.kwargs) + if type(t_outs) is not list and type(t_outs) is not tuple: + t_outs = (t_outs,) + self.torch_fwd_outputs = t_outs + # Need to access res from a chainer variable + c_outs = tuple(cpm.asarray(out) for out in t_outs) + # The outputs will be used in the grad function so we should retain + # them ? + self.retain_outputs(tuple(range(len(c_outs)))) + return c_outs + + def backward(self, indexes, grads): + out_grads = _ChainerTorchFunctionGrad( + self.torch_fwd_inputs, self.torch_fwd_outputs + ).apply(grads) + return out_grads + + +class _ChainerTorchFunctionGrad(chainer.FunctionNode): + def __init__(self, inputs, outputs): + super(_ChainerTorchFunctionGrad, self).__init__() + self.inputs = inputs + self.outputs = outputs + + def forward(self, inputs): + t_grads = tuple([cpm.astensor(g) for g in inputs]) + torch.autograd.backward(self.outputs, t_grads) + out_grads = tuple( + cpm.asarray(x.grad) for x in self.inputs + ) + self.outputs = [x.grad for x in self.inputs] + self.inputs = t_grads + return out_grads + + def backward(self, indexes, grads): + return _ChainerTorchFunctionGrad( + self.inputs, self.outputs).apply(grads) + + +def chainer_torch_function(torch_fn, inputs, *args, **kwargs): + if type(inputs) is not list and type(inputs) is not tuple: + inputs = (inputs,) + y = _ChainerTorchFunction(torch_fn, *args, **kwargs).apply(inputs) + if len(y) == 1: + return y[0] + return y diff --git a/tests/test_functions.py b/tests/test_functions.py new file mode 100644 index 0000000..ac3fa86 --- /dev/null +++ b/tests/test_functions.py @@ -0,0 +1,49 @@ +import chainer +import numpy +import torch + +import chainer_pytorch_migration as cpm + + +def test_one_output(): + torch_fn = torch.sigmoid + x = chainer.Variable(numpy.ones((5, 5), dtype=numpy.float32)) + z = chainer.functions.sin(x) + res = cpm.chainer_torch_function(torch_fn, z) + res = chainer.functions.sqrt(res) + res = cpm.chainer_torch_function(torch_fn, res) + res = chainer.functions.sqrt(res) + res.grad = numpy.ones((5, 5), dtype=numpy.float32) + res.backward() + c_grad = x.grad + + # Do it now in pytorch and compare + x = torch.ones((5, 5), requires_grad=True) + z = torch.sin(x) + y = torch.sigmoid(torch.sigmoid(z).sqrt()).sqrt() + y.backward(torch.ones(5, 5)) + t_grad = x.grad + assert torch.allclose(t_grad, cpm.astensor(c_grad)) + + +def test_multiple_outputs(): + torch_fn = torch.split + x = chainer.Variable(numpy.ones((6, 5), dtype=numpy.float32)) + y = chainer.functions.sin(x) + y, z = cpm.chainer_torch_function(torch_fn, y, 3, dim=0) + y = chainer.functions.log(y) + z = chainer.functions.cos(z) + z = y + z + z.grad = numpy.ones((3, 5), dtype=numpy.float32) + z.backward() + c_grad = x.grad + + x = torch.ones((6, 5), requires_grad=True) + z = torch.sin(x) + y, z = torch.split(z, 3, dim=0) + y = torch.log(y) + z = torch.cos(z) + z = y + z + z.backward(torch.ones((3, 5))) + t_grad = x.grad + assert torch.allclose(t_grad, cpm.astensor(c_grad))