-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added
chainer_torch_functions
to call torch functions inside chainer
- Loading branch information
Emilio Castillo
committed
Jun 19, 2020
1 parent
e953560
commit b7c0d80
Showing
3 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import chainer | ||
import torch | ||
|
||
import chainer_pytorch_migration as cpm | ||
|
||
|
||
class _ChainerTorchFunction(chainer.FunctionNode): | ||
def __init__(self, torch_fn, *args, **kwargs): | ||
self.torch_fn = torch_fn | ||
self.torch_fwd_inputs = None | ||
self.torch_fwd_outputs = None | ||
self.args = args | ||
self.kwargs = kwargs | ||
|
||
def forward(self, inputs): | ||
t_inputs = [cpm.astensor(x) for x in inputs] | ||
for t_x in t_inputs: | ||
t_x.requires_grad = True | ||
self.torch_fwd_inputs = t_inputs | ||
f_inputs = t_inputs + list(self.args) | ||
# The torch function might require other arguments other than input | ||
# tensors so append them here | ||
t_outs = self.torch_fn(*f_inputs, **self.kwargs) | ||
if type(t_outs) is not list and type(t_outs) is not tuple: | ||
t_outs = (t_outs,) | ||
self.torch_fwd_outputs = t_outs | ||
# Need to access res from a chainer variable | ||
c_outs = tuple(cpm.asarray(out) for out in t_outs) | ||
# The outputs will be used in the grad function so we should retain | ||
# them ? | ||
self.retain_outputs(tuple(range(len(c_outs)))) | ||
return c_outs | ||
|
||
def backward(self, indexes, grads): | ||
out_grads = _ChainerTorchFunctionGrad( | ||
self.torch_fwd_inputs, self.torch_fwd_outputs | ||
).apply(grads) | ||
return out_grads | ||
|
||
|
||
class _ChainerTorchFunctionGrad(chainer.FunctionNode): | ||
def __init__(self, inputs, outputs): | ||
super(_ChainerTorchFunctionGrad, self).__init__() | ||
self.inputs = inputs | ||
self.outputs = outputs | ||
|
||
def forward(self, inputs): | ||
t_grads = tuple([cpm.astensor(g) for g in inputs]) | ||
torch.autograd.backward(self.outputs, t_grads) | ||
out_grads = tuple( | ||
cpm.asarray(x.grad) for x in self.inputs | ||
) | ||
self.outputs = [x.grad for x in self.inputs] | ||
self.inputs = t_grads | ||
return out_grads | ||
|
||
def backward(self, indexes, grads): | ||
return _ChainerTorchFunctionGrad( | ||
self.inputs, self.outputs).apply(grads) | ||
|
||
|
||
def chainer_torch_function(torch_fn, inputs, *args, **kwargs): | ||
if type(inputs) is not list and type(inputs) is not tuple: | ||
inputs = (inputs,) | ||
y = _ChainerTorchFunction(torch_fn, *args, **kwargs).apply(inputs) | ||
if len(y) == 1: | ||
return y[0] | ||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import chainer | ||
import numpy | ||
import torch | ||
|
||
import chainer_pytorch_migration as cpm | ||
|
||
|
||
def test_one_output(): | ||
torch_fn = torch.sigmoid | ||
x = chainer.Variable(numpy.ones((5, 5), dtype=numpy.float32)) | ||
z = chainer.functions.sin(x) | ||
res = cpm.chainer_torch_function(torch_fn, z) | ||
res = chainer.functions.sqrt(res) | ||
res = cpm.chainer_torch_function(torch_fn, res) | ||
res = chainer.functions.sqrt(res) | ||
res.grad = numpy.ones((5, 5), dtype=numpy.float32) | ||
res.backward() | ||
c_grad = x.grad | ||
|
||
# Do it now in pytorch and compare | ||
x = torch.ones((5, 5), requires_grad=True) | ||
z = torch.sin(x) | ||
y = torch.sigmoid(torch.sigmoid(z).sqrt()).sqrt() | ||
y.backward(torch.ones(5, 5)) | ||
t_grad = x.grad | ||
assert torch.allclose(t_grad, cpm.astensor(c_grad)) | ||
|
||
|
||
def test_multiple_outputs(): | ||
torch_fn = torch.split | ||
x = chainer.Variable(numpy.ones((6, 5), dtype=numpy.float32)) | ||
y = chainer.functions.sin(x) | ||
y, z = cpm.chainer_torch_function(torch_fn, y, 3, dim=0) | ||
y = chainer.functions.log(y) | ||
z = chainer.functions.cos(z) | ||
z = y + z | ||
z.grad = numpy.ones((3, 5), dtype=numpy.float32) | ||
z.backward() | ||
c_grad = x.grad | ||
|
||
x = torch.ones((6, 5), requires_grad=True) | ||
z = torch.sin(x) | ||
y, z = torch.split(z, 3, dim=0) | ||
y = torch.log(y) | ||
z = torch.cos(z) | ||
z = y + z | ||
z.backward(torch.ones((3, 5))) | ||
t_grad = x.grad | ||
assert torch.allclose(t_grad, cpm.astensor(c_grad)) |