Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chainer_torch_function and TorchCainerFunction #27

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Added the reverse, call from torch a chainer function
Emilio Castillo committed Jul 22, 2020
commit dd4d223c43adf8652b977ffd73f207549de50b40
55 changes: 55 additions & 0 deletions chainer_pytorch_migration/functions.py
Original file line number Diff line number Diff line change
@@ -66,3 +66,58 @@ def chainer_torch_function(torch_fn, inputs, *args, **kwargs):
if len(y) == 1:
return y[0]
return y


class TorchChainerFunction(torch.autograd.Function):
@staticmethod
def chainer_fn():
raise RuntimeError('chainer_fn function must be overriden')

@classmethod
def forward(cls, ctx, *inputs):
chainer_fn = cls.chainer_fn()
ctx.save_for_backward(*inputs)
c_inputs = tuple((chainer.Variable(cpm.asarray(x)) for x in inputs))
ctx.c_inputs = c_inputs
c_outputs = chainer_fn(*c_inputs)
if not type(c_outputs) is tuple:
c_outputs = (c_outputs,)
t_outputs = [cpm.astensor(y.array) for y in c_outputs]
for t_y in t_outputs:
t_y.requires_grad = True
ctx.c_outputs = c_outputs
if len(t_outputs) == 1:
return t_outputs[0]
else:
return tuple(t_outputs)

@staticmethod
def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
out_grads = _TorchChainerFunctionGrad.apply(*grads)
return out_grads


class _TorchChainerFunctionGrad(torch.autograd.Function):

@staticmethod
def forward(ctx, *inputs):
c_outputs = inputs[0]
c_inputs = inputs[1]
inputs = inputs[2:]
ctx.save_for_backward(*inputs)
c_grads = tuple((chainer.Variable(cpm.asarray(g)) for g in inputs))
fwd_outputs = c_outputs
chainer.backward(fwd_outputs, c_grads, enable_double_backprop=True)
out_grads = tuple(
cpm.astensor(x.grad) for x in c_inputs
)
for t_y in out_grads:
t_y.requires_grad = True
ctx.c_outputs = [x.grad for x in c_inputs]
ctx.c_inputs = c_grads
return out_grads

def backward(ctx, *grads):
grads = [ctx.c_outputs, ctx.c_inputs] + list(grads)
return _TorchChainerFunctionGrad.apply(*grads)
56 changes: 56 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -47,3 +47,59 @@ def test_multiple_outputs():
z.backward(torch.ones((3, 5)))
t_grad = x.grad
assert torch.allclose(t_grad, cpm.astensor(c_grad))


def test_torch_chainer_function():
class TorchChainerSigmoid(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.sigmoid
# Combined torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = TorchChainerSigmoid.apply(y)
y = torch.sum(y)
y.backward()
ct_grad = x.grad

# All in torch
x = torch.ones(10)
x.requires_grad = True
y = torch.sin(x)
y = torch.sigmoid(y)
y = torch.sum(y)
y.backward()
assert torch.allclose(ct_grad, x.grad)


def test_torch_chainer_function_2():
class TorchChainerAdd(cpm.functions.TorchChainerFunction):
@staticmethod
def chainer_fn():
return chainer.functions.add
# Combined torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = TorchChainerAdd.apply(y, z)
y = torch.sum(y)
y.backward()
a_ct_grad = a.grad
b_ct_grad = b.grad

# All in torch
a = torch.ones(10)
a.requires_grad = True
b = torch.ones(10)+2
b.requires_grad = True
y = torch.sin(a)
z = torch.sin(b)
y = torch.add(y, z)
y = torch.sum(y)
y.backward()
assert torch.allclose(a_ct_grad, a.grad)
assert torch.allclose(b_ct_grad, b.grad)