diff --git a/chainer_pytorch_migration/parameter.py b/chainer_pytorch_migration/parameter.py index d783b00..ba2f7d5 100644 --- a/chainer_pytorch_migration/parameter.py +++ b/chainer_pytorch_migration/parameter.py @@ -17,6 +17,43 @@ def _named_params(link): yield name, getattr(link, name) +# Corresponding to ``torch._C._nn._parse_to``. +def _parse_to(*args, device=None, dtype=None, non_blocking=False): + args = list(args) + + if len(args) > 0: + if isinstance(args[0], torch.Tensor): + tensor = args.pop(0) + device = tensor.device + dtype = tensor.dtype + elif isinstance(args[0], torch.dtype): + dtype = args.pop(0) + elif isinstance(args[0], (str, torch.device)): + device = args.pop(0) + if len(args) > 0 and isinstance(args[0], torch.dtype): + dtype = torch.dtype(args.pop(0)) + else: + raise TypeError('Received an invalid combination of arguments.') + + if len(args) > 0: + non_blocking = bool(args.pop(0)) + + if len(args) > 0: + raise TypeError('Received an invalid combination of arguments.') + + if device is not None: + device = torch.device(device) + + return device, dtype, non_blocking + + +def _setattr_recursive(obj, name, value): + attr_list = name.split('.') + for attr in attr_list[:-1]: + obj = getattr(obj, attr) + setattr(obj, attr_list[-1], value) + + class LinkAsTorchModel(torch.nn.Module): '''Converts a Chainer Link to a PyTorch module. @@ -29,8 +66,9 @@ class LinkAsTorchModel(torch.nn.Module): link (:class:`chainer.Link`): A link. Must have been initialized. ''' - def __init__(self, link): + def __init__(self, link, **kwargs): super().__init__() + device = kwargs.pop('_device', None) uninitialized_params = [ n for n, p in sorted(_named_params(link)) if p.array is None] if uninitialized_params: @@ -43,9 +81,11 @@ def __init__(self, link): ', '.join(repr(n) for n in uninitialized_params))) for name, child in _named_children(link): - child_module = LinkAsTorchModel(child) + child_module = LinkAsTorchModel(child, _device=device) setattr(self, name, child_module) for name, param in sorted(_named_params(link)): + if device is not None: + param.to_device(device) setattr(self, name, ChainerParameter(param)) self.link = link @@ -70,6 +110,21 @@ def __as_tensor(self, value): return _ChainerTensor(value) return value + def to(self, *args, **kwargs): + device, dtype, non_blocking = _parse_to(*args, **kwargs) + chainer_device = cpm.to_chainer_device(device) + if dtype is not None: + raise NotImplementedError + if non_blocking: + raise NotImplementedError + for name, value in self.named_parameters(): + assert isinstance(value, ChainerParameter) + param = value._param + param.to_device(chainer_device) + value = ChainerParameter(param) + _setattr_recursive(self, name, value) + return self + class Optimizer(torch.optim.Optimizer): diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 2009e03..dc2a605 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -274,3 +274,24 @@ def test_named_params(): numpy.testing.assert_array_equal(a_arr, n_params['a'].detach()) assert 'b' in n_params numpy.testing.assert_array_equal(b_arr, n_params['b'].detach()) + + +def test_link_to_device(): + a_arr = numpy.ones((3, 2), 'float32') + a_chainer_param = chainer.Parameter(a_arr) + # 0-size parameter + b_arr = numpy.ones((2, 0, 1), 'float32') + b_chainer_param = chainer.Parameter(b_arr) + + link = chainer.Link() + with link.init_scope(): + link.a = a_chainer_param + link.b = b_chainer_param + + torched = cpm.LinkAsTorchModel(link) + ret = torched.to('cuda') + + assert torched is ret + + for name, param in torched.named_parameters(): + assert param.device.type == 'cuda'