Skip to content

Commit

Permalink
Added chainer_torch_functions to call torch functions inside chainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Jun 19, 2020
1 parent e953560 commit b7c0d80
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
1 change: 1 addition & 0 deletions chainer_pytorch_migration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import links
from .allocator import use_mempool_in_cupy_malloc, use_torch_in_cupy_malloc
from .datasets import TransformDataset
from .functions import chainer_torch_function
from .links import TorchModule
from .parameter import ChainerParameter, LinkAsTorchModel, Optimizer
from .tensor import asarray, astensor, to_numpy_dtype
Expand Down
68 changes: 68 additions & 0 deletions chainer_pytorch_migration/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import chainer
import torch

import chainer_pytorch_migration as cpm


class _ChainerTorchFunction(chainer.FunctionNode):
def __init__(self, torch_fn, *args, **kwargs):
self.torch_fn = torch_fn
self.torch_fwd_inputs = None
self.torch_fwd_outputs = None
self.args = args
self.kwargs = kwargs

def forward(self, inputs):
t_inputs = [cpm.astensor(x) for x in inputs]
for t_x in t_inputs:
t_x.requires_grad = True
self.torch_fwd_inputs = t_inputs
f_inputs = t_inputs + list(self.args)
# The torch function might require other arguments other than input
# tensors so append them here
t_outs = self.torch_fn(*f_inputs, **self.kwargs)
if type(t_outs) is not list and type(t_outs) is not tuple:
t_outs = (t_outs,)
self.torch_fwd_outputs = t_outs
# Need to access res from a chainer variable
c_outs = tuple(cpm.asarray(out) for out in t_outs)
# The outputs will be used in the grad function so we should retain
# them ?
self.retain_outputs(tuple(range(len(c_outs))))
return c_outs

def backward(self, indexes, grads):
out_grads = _ChainerTorchFunctionGrad(
self.torch_fwd_inputs, self.torch_fwd_outputs
).apply(grads)
return out_grads


class _ChainerTorchFunctionGrad(chainer.FunctionNode):
def __init__(self, inputs, outputs):
super(_ChainerTorchFunctionGrad, self).__init__()
self.inputs = inputs
self.outputs = outputs

def forward(self, inputs):
t_grads = tuple([cpm.astensor(g) for g in inputs])
torch.autograd.backward(self.outputs, t_grads)
out_grads = tuple(
cpm.asarray(x.grad) for x in self.inputs
)
self.outputs = [x.grad for x in self.inputs]
self.inputs = t_grads
return out_grads

def backward(self, indexes, grads):
return _ChainerTorchFunctionGrad(
self.inputs, self.outputs).apply(grads)


def chainer_torch_function(torch_fn, inputs, *args, **kwargs):
if type(inputs) is not list and type(inputs) is not tuple:
inputs = (inputs,)
y = _ChainerTorchFunction(torch_fn, *args, **kwargs).apply(inputs)
if len(y) == 1:
return y[0]
return y
49 changes: 49 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import chainer
import numpy
import torch

import chainer_pytorch_migration as cpm


def test_one_output():
torch_fn = torch.sigmoid
x = chainer.Variable(numpy.ones((5, 5), dtype=numpy.float32))
z = chainer.functions.sin(x)
res = cpm.chainer_torch_function(torch_fn, z)
res = chainer.functions.sqrt(res)
res = cpm.chainer_torch_function(torch_fn, res)
res = chainer.functions.sqrt(res)
res.grad = numpy.ones((5, 5), dtype=numpy.float32)
res.backward()
c_grad = x.grad

# Do it now in pytorch and compare
x = torch.ones((5, 5), requires_grad=True)
z = torch.sin(x)
y = torch.sigmoid(torch.sigmoid(z).sqrt()).sqrt()
y.backward(torch.ones(5, 5))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_multiple_outputs():
torch_fn = torch.split
x = chainer.Variable(numpy.ones((6, 5), dtype=numpy.float32))
y = chainer.functions.sin(x)
y, z = cpm.chainer_torch_function(torch_fn, y, 3, dim=0)
y = chainer.functions.log(y)
z = chainer.functions.cos(z)
z = y + z
z.grad = numpy.ones((3, 5), dtype=numpy.float32)
z.backward()
c_grad = x.grad

x = torch.ones((6, 5), requires_grad=True)
z = torch.sin(x)
y, z = torch.split(z, 3, dim=0)
y = torch.log(y)
z = torch.cos(z)
z = y + z
z.backward(torch.ones((3, 5)))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))

0 comments on commit b7c0d80

Please sign in to comment.