Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
add test_consistency for psroi_pooling_2d
Browse files Browse the repository at this point in the history
  • Loading branch information
knorth55 committed Mar 28, 2018
1 parent b972636 commit aa2c88b
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tests/functions_tests/test_psroi_pooling_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def test_forward_gpu(self):
cuda.to_gpu(self.x), cuda.to_gpu(self.rois),
cuda.to_gpu(self.roi_indices))

def check_backward(self, x_data, roi_data, roi_index_data, y_grad):
def check_backward(self, x_data, roi_data, roi_index_data, y_grad_data):
gradient_check.check_backward(
functions.PSROIPooling2D(
self.out_c, self.out_h, self.out_w,
self.spatial_scale, self.group_size),
(x_data, roi_data, roi_index_data), y_grad,
(x_data, roi_data, roi_index_data), y_grad_data,
no_grads=[False, True, True], **self.check_backward_options)

@condition.retry(3)
Expand All @@ -82,5 +82,27 @@ def test_backward_gpu(self):
cuda.to_gpu(self.x), cuda.to_gpu(self.rois),
cuda.to_gpu(self.roi_indices), cuda.to_gpu(self.gy))

def apply_backward(self, x_data, roi_data, roi_index_data, y_grad_data):
x = chainer.Variable(x_data)
rois = chainer.Variable(roi_data)
roi_indices = chainer.Variable(roi_index_data)
y = functions.psroi_pooling_2d(
x, rois, roi_indices, self.out_c, self.out_h, self.out_w,
self.spatial_scale, self.group_size)
x.cleargrad()
y.grad = y_grad_data
y.backward()
return x, y

@attr.gpu
@condition.retry(3)
def test_consistency_with_gpu(self):
x_cpu, y_cpu = self.apply_backward(
self.x, self.rois, self.roi_indices, self.gy)
x_gpu, y_gpu = self.apply_backward(
cuda.to_gpu(self.x), cuda.to_gpu(self.rois),
cuda.to_gpu(self.roi_indices), cuda.to_gpu(self.gy))
testing.assert_allclose(y_cpu.data, y_gpu.data)
testing.assert_allclose(x_cpu.grad, x_gpu.grad)

testing.run_module(__name__, __file__)

0 comments on commit aa2c88b

Please sign in to comment.