Skip to content

Commit

Permalink
Merge pull request #21 from asi1024/module-to-device
Browse files Browse the repository at this point in the history
Support `LinkAsTorchModule.to(device)`
  • Loading branch information
emcastillo authored Mar 5, 2020
2 parents 01abf44 + 0970420 commit 566e681
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
59 changes: 57 additions & 2 deletions chainer_pytorch_migration/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,43 @@ def _named_params(link):
yield name, getattr(link, name)


# Corresponding to ``torch._C._nn._parse_to``.
def _parse_to(*args, device=None, dtype=None, non_blocking=False):
args = list(args)

if len(args) > 0:
if isinstance(args[0], torch.Tensor):
tensor = args.pop(0)
device = tensor.device
dtype = tensor.dtype
elif isinstance(args[0], torch.dtype):
dtype = args.pop(0)
elif isinstance(args[0], (str, torch.device)):
device = args.pop(0)
if len(args) > 0 and isinstance(args[0], torch.dtype):
dtype = torch.dtype(args.pop(0))
else:
raise TypeError('Received an invalid combination of arguments.')

if len(args) > 0:
non_blocking = bool(args.pop(0))

if len(args) > 0:
raise TypeError('Received an invalid combination of arguments.')

if device is not None:
device = torch.device(device)

return device, dtype, non_blocking


def _setattr_recursive(obj, name, value):
attr_list = name.split('.')
for attr in attr_list[:-1]:
obj = getattr(obj, attr)
setattr(obj, attr_list[-1], value)


class LinkAsTorchModel(torch.nn.Module):

'''Converts a Chainer Link to a PyTorch module.
Expand All @@ -29,8 +66,9 @@ class LinkAsTorchModel(torch.nn.Module):
link (:class:`chainer.Link`): A link. Must have been initialized.
'''

def __init__(self, link):
def __init__(self, link, **kwargs):
super().__init__()
device = kwargs.pop('_device', None)
uninitialized_params = [
n for n, p in sorted(_named_params(link)) if p.array is None]
if uninitialized_params:
Expand All @@ -43,9 +81,11 @@ def __init__(self, link):
', '.join(repr(n) for n in uninitialized_params)))

for name, child in _named_children(link):
child_module = LinkAsTorchModel(child)
child_module = LinkAsTorchModel(child, _device=device)
setattr(self, name, child_module)
for name, param in sorted(_named_params(link)):
if device is not None:
param.to_device(device)
setattr(self, name, ChainerParameter(param))

self.link = link
Expand All @@ -70,6 +110,21 @@ def __as_tensor(self, value):
return _ChainerTensor(value)
return value

def to(self, *args, **kwargs):
device, dtype, non_blocking = _parse_to(*args, **kwargs)
chainer_device = cpm.to_chainer_device(device)
if dtype is not None:
raise NotImplementedError
if non_blocking:
raise NotImplementedError
for name, value in self.named_parameters():
assert isinstance(value, ChainerParameter)
param = value._param
param.to_device(chainer_device)
value = ChainerParameter(param)
_setattr_recursive(self, name, value)
return self


class Optimizer(torch.optim.Optimizer):

Expand Down
21 changes: 21 additions & 0 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,24 @@ def test_named_params():
numpy.testing.assert_array_equal(a_arr, n_params['a'].detach())
assert 'b' in n_params
numpy.testing.assert_array_equal(b_arr, n_params['b'].detach())


def test_link_to_device():
a_arr = numpy.ones((3, 2), 'float32')
a_chainer_param = chainer.Parameter(a_arr)
# 0-size parameter
b_arr = numpy.ones((2, 0, 1), 'float32')
b_chainer_param = chainer.Parameter(b_arr)

link = chainer.Link()
with link.init_scope():
link.a = a_chainer_param
link.b = b_chainer_param

torched = cpm.LinkAsTorchModel(link)
ret = torched.to('cuda')

assert torched is ret

for name, param in torched.named_parameters():
assert param.device.type == 'cuda'

0 comments on commit 566e681

Please sign in to comment.