diff --git a/chainer_pytorch_migration/parameter.py b/chainer_pytorch_migration/parameter.py
index d783b00..ba2f7d5 100644
--- a/chainer_pytorch_migration/parameter.py
+++ b/chainer_pytorch_migration/parameter.py
@@ -17,6 +17,43 @@ def _named_params(link):
         yield name, getattr(link, name)
 
 
+# Corresponding to ``torch._C._nn._parse_to``.
+def _parse_to(*args, device=None, dtype=None, non_blocking=False):
+    args = list(args)
+
+    if len(args) > 0:
+        if isinstance(args[0], torch.Tensor):
+            tensor = args.pop(0)
+            device = tensor.device
+            dtype = tensor.dtype
+        elif isinstance(args[0], torch.dtype):
+            dtype = args.pop(0)
+        elif isinstance(args[0], (str, torch.device)):
+            device = args.pop(0)
+            if len(args) > 0 and isinstance(args[0], torch.dtype):
+                dtype = torch.dtype(args.pop(0))
+        else:
+            raise TypeError('Received an invalid combination of arguments.')
+
+        if len(args) > 0:
+            non_blocking = bool(args.pop(0))
+
+        if len(args) > 0:
+            raise TypeError('Received an invalid combination of arguments.')
+
+    if device is not None:
+        device = torch.device(device)
+
+    return device, dtype, non_blocking
+
+
+def _setattr_recursive(obj, name, value):
+    attr_list = name.split('.')
+    for attr in attr_list[:-1]:
+        obj = getattr(obj, attr)
+    setattr(obj, attr_list[-1], value)
+
+
 class LinkAsTorchModel(torch.nn.Module):
 
     '''Converts a Chainer Link to a PyTorch module.
@@ -29,8 +66,9 @@ class LinkAsTorchModel(torch.nn.Module):
         link (:class:`chainer.Link`): A link. Must have been initialized.
     '''
 
-    def __init__(self, link):
+    def __init__(self, link, **kwargs):
         super().__init__()
+        device = kwargs.pop('_device', None)
         uninitialized_params = [
             n for n, p in sorted(_named_params(link)) if p.array is None]
         if uninitialized_params:
@@ -43,9 +81,11 @@ def __init__(self, link):
                     ', '.join(repr(n) for n in uninitialized_params)))
 
         for name, child in _named_children(link):
-            child_module = LinkAsTorchModel(child)
+            child_module = LinkAsTorchModel(child, _device=device)
             setattr(self, name, child_module)
         for name, param in sorted(_named_params(link)):
+            if device is not None:
+                param.to_device(device)
             setattr(self, name, ChainerParameter(param))
 
         self.link = link
@@ -70,6 +110,21 @@ def __as_tensor(self, value):
             return _ChainerTensor(value)
         return value
 
+    def to(self, *args, **kwargs):
+        device, dtype, non_blocking = _parse_to(*args, **kwargs)
+        chainer_device = cpm.to_chainer_device(device)
+        if dtype is not None:
+            raise NotImplementedError
+        if non_blocking:
+            raise NotImplementedError
+        for name, value in self.named_parameters():
+            assert isinstance(value, ChainerParameter)
+            param = value._param
+            param.to_device(chainer_device)
+            value = ChainerParameter(param)
+            _setattr_recursive(self, name, value)
+        return self
+
 
 class Optimizer(torch.optim.Optimizer):
 
diff --git a/tests/test_parameter.py b/tests/test_parameter.py
index 2009e03..dc2a605 100644
--- a/tests/test_parameter.py
+++ b/tests/test_parameter.py
@@ -274,3 +274,24 @@ def test_named_params():
     numpy.testing.assert_array_equal(a_arr, n_params['a'].detach())
     assert 'b' in n_params
     numpy.testing.assert_array_equal(b_arr, n_params['b'].detach())
+
+
+def test_link_to_device():
+    a_arr = numpy.ones((3, 2), 'float32')
+    a_chainer_param = chainer.Parameter(a_arr)
+    # 0-size parameter
+    b_arr = numpy.ones((2, 0, 1), 'float32')
+    b_chainer_param = chainer.Parameter(b_arr)
+
+    link = chainer.Link()
+    with link.init_scope():
+        link.a = a_chainer_param
+        link.b = b_chainer_param
+
+    torched = cpm.LinkAsTorchModel(link)
+    ret = torched.to('cuda')
+
+    assert torched is ret
+
+    for name, param in torched.named_parameters():
+        assert param.device.type == 'cuda'