Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
asi1024 committed Mar 5, 2020
1 parent b85dd9b commit 0970420
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,24 @@ def test_named_params():
numpy.testing.assert_array_equal(a_arr, n_params['a'].detach())
assert 'b' in n_params
numpy.testing.assert_array_equal(b_arr, n_params['b'].detach())


def test_link_to_device():
a_arr = numpy.ones((3, 2), 'float32')
a_chainer_param = chainer.Parameter(a_arr)
# 0-size parameter
b_arr = numpy.ones((2, 0, 1), 'float32')
b_chainer_param = chainer.Parameter(b_arr)

link = chainer.Link()
with link.init_scope():
link.a = a_chainer_param
link.b = b_chainer_param

torched = cpm.LinkAsTorchModel(link)
ret = torched.to('cuda')

assert torched is ret

for name, param in torched.named_parameters():
assert param.device.type == 'cuda'

0 comments on commit 0970420

Please sign in to comment.