From e4375f301a4582637877f78e60729d18f6b35de9 Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Thu, 27 Feb 2020 09:58:52 +0000 Subject: [PATCH 1/4] Support LinkAsTorchModel.to(device) --- chainer_pytorch_migration/parameter.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/chainer_pytorch_migration/parameter.py b/chainer_pytorch_migration/parameter.py index d783b00..89993ed 100644 --- a/chainer_pytorch_migration/parameter.py +++ b/chainer_pytorch_migration/parameter.py @@ -29,8 +29,9 @@ class LinkAsTorchModel(torch.nn.Module): link (:class:`chainer.Link`): A link. Must have been initialized. ''' - def __init__(self, link): + def __init__(self, link, **kwargs): super().__init__() + device = kwargs.pop('_device', None) uninitialized_params = [ n for n, p in sorted(_named_params(link)) if p.array is None] if uninitialized_params: @@ -43,9 +44,11 @@ def __init__(self, link): ', '.join(repr(n) for n in uninitialized_params))) for name, child in _named_children(link): - child_module = LinkAsTorchModel(child) + child_module = LinkAsTorchModel(child, _device=device) setattr(self, name, child_module) for name, param in sorted(_named_params(link)): + if device is not None: + param.to_device(device) setattr(self, name, ChainerParameter(param)) self.link = link @@ -70,6 +73,12 @@ def __as_tensor(self, value): return _ChainerTensor(value) return value + def to(self, *, device): + if device is None: + return self + device = cpm.to_chainer_device(torch.device(device)) + return LinkAsTorchModel(self.link, _device=device) + class Optimizer(torch.optim.Optimizer): From 8b74ed5a22ec3a805da3bde118681b79c8681db1 Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Mon, 2 Mar 2020 09:20:10 +0000 Subject: [PATCH 2/4] Add parser for LinkAsTorchModel --- chainer_pytorch_migration/parameter.py | 35 +++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/chainer_pytorch_migration/parameter.py b/chainer_pytorch_migration/parameter.py index 89993ed..50db275 100644 --- a/chainer_pytorch_migration/parameter.py +++ b/chainer_pytorch_migration/parameter.py @@ -17,6 +17,32 @@ def _named_params(link): yield name, getattr(link, name) +# Corresponding to ``torch._C._nn._parse_to``. +def _parse_to(*args, device=None, dtype=None, non_blocking=False): + + if len(args) > 0: + if isinstance(args[0], torch.Tensor): + tensor = args.pop(0) + device = tensor.device + dtype = tensor.dtype + elif isinstance(args[0], torch.dtype): + dtype = args.pop(0) + elif isinstance(args[0], (str, torch.device)): + device = torch.device(args.pop(0)) + if len(args) > 0 and isinstance(args[0], torch.dtype): + dtype = torch.dtype(args.pop(0)) + else: + raise TypeError('Received an invalid combination of arguments.') + + if len(args) > 0: + non_blocking = bool(args.pop(0)) + + if len(args) > 0: + raise TypeError('Received an invalid combination of arguments.') + + return device, dtype, non_blocking + + class LinkAsTorchModel(torch.nn.Module): '''Converts a Chainer Link to a PyTorch module. @@ -73,10 +99,13 @@ def __as_tensor(self, value): return _ChainerTensor(value) return value - def to(self, *, device): - if device is None: - return self + def to(self, *args, **kwargs): + device, dtype, non_blocking = _parse_to(*args, **kwargs) device = cpm.to_chainer_device(torch.device(device)) + if dtype is not None: + raise NotImplementedError + if non_blocking: + raise NotImplementedError return LinkAsTorchModel(self.link, _device=device) From b85dd9bf66dfc57dcc5a4e7e3c7a73cc329db037 Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Mon, 2 Mar 2020 19:28:30 +0000 Subject: [PATCH 3/4] Fix to modifying self --- chainer_pytorch_migration/parameter.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/chainer_pytorch_migration/parameter.py b/chainer_pytorch_migration/parameter.py index 50db275..ba2f7d5 100644 --- a/chainer_pytorch_migration/parameter.py +++ b/chainer_pytorch_migration/parameter.py @@ -19,6 +19,7 @@ def _named_params(link): # Corresponding to ``torch._C._nn._parse_to``. def _parse_to(*args, device=None, dtype=None, non_blocking=False): + args = list(args) if len(args) > 0: if isinstance(args[0], torch.Tensor): @@ -28,7 +29,7 @@ def _parse_to(*args, device=None, dtype=None, non_blocking=False): elif isinstance(args[0], torch.dtype): dtype = args.pop(0) elif isinstance(args[0], (str, torch.device)): - device = torch.device(args.pop(0)) + device = args.pop(0) if len(args) > 0 and isinstance(args[0], torch.dtype): dtype = torch.dtype(args.pop(0)) else: @@ -40,9 +41,19 @@ def _parse_to(*args, device=None, dtype=None, non_blocking=False): if len(args) > 0: raise TypeError('Received an invalid combination of arguments.') + if device is not None: + device = torch.device(device) + return device, dtype, non_blocking +def _setattr_recursive(obj, name, value): + attr_list = name.split('.') + for attr in attr_list[:-1]: + obj = getattr(obj, attr) + setattr(obj, attr_list[-1], value) + + class LinkAsTorchModel(torch.nn.Module): '''Converts a Chainer Link to a PyTorch module. @@ -101,12 +112,18 @@ def __as_tensor(self, value): def to(self, *args, **kwargs): device, dtype, non_blocking = _parse_to(*args, **kwargs) - device = cpm.to_chainer_device(torch.device(device)) + chainer_device = cpm.to_chainer_device(device) if dtype is not None: raise NotImplementedError if non_blocking: raise NotImplementedError - return LinkAsTorchModel(self.link, _device=device) + for name, value in self.named_parameters(): + assert isinstance(value, ChainerParameter) + param = value._param + param.to_device(chainer_device) + value = ChainerParameter(param) + _setattr_recursive(self, name, value) + return self class Optimizer(torch.optim.Optimizer): From 097042092af3c02e22988d5078cc9f59e43be7b0 Mon Sep 17 00:00:00 2001 From: Akifumi Imanishi Date: Wed, 4 Mar 2020 11:44:44 +0000 Subject: [PATCH 4/4] Add tests --- tests/test_parameter.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 2009e03..dc2a605 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -274,3 +274,24 @@ def test_named_params(): numpy.testing.assert_array_equal(a_arr, n_params['a'].detach()) assert 'b' in n_params numpy.testing.assert_array_equal(b_arr, n_params['b'].detach()) + + +def test_link_to_device(): + a_arr = numpy.ones((3, 2), 'float32') + a_chainer_param = chainer.Parameter(a_arr) + # 0-size parameter + b_arr = numpy.ones((2, 0, 1), 'float32') + b_chainer_param = chainer.Parameter(b_arr) + + link = chainer.Link() + with link.init_scope(): + link.a = a_chainer_param + link.b = b_chainer_param + + torched = cpm.LinkAsTorchModel(link) + ret = torched.to('cuda') + + assert torched is ret + + for name, param in torched.named_parameters(): + assert param.device.type == 'cuda'