Skip to content

Commit

Permalink
Merge pull request #2 from anton164/master
Browse files Browse the repository at this point in the history
Added more tests for conv1d and tiling
  • Loading branch information
srush authored Dec 8, 2020
2 parents a1c5702 + c62118c commit 59fe29e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,51 @@ def test_conv1d_simple():
assert out[0, 0, 2] == 2 * 1 + 3 * 2
assert out[0, 0, 3] == 3 * 1

@pytest.mark.task4_1
def test_conv1d_simple_backward():
input_tensor = minitorch.tensor_fromlist([0, 1, 2, 3]).view(1, 1, 4)
weight = minitorch.tensor_fromlist([[1, 2, 3]]).view(1, 1, 3)
grad_output = minitorch.tensor_fromlist([0, 1, 2, 3]).view(1, 1, 4)
ctx = minitorch.Context()
ctx.save_for_backward(input_tensor, weight)
grad_input, grad_weight = minitorch.Conv1dFun.backward(ctx, grad_output)

assert grad_input[0, 0, 0] == weight[0, 0, 0] * grad_output[0, 0, 0]
assert (
grad_input[0, 0, 1]
== weight[0, 0, 0] * grad_output[0, 0, 1]
+ weight[0, 0, 1] * grad_output[0, 0, 0]
)
assert (
grad_input[0, 0, 2]
== weight[0, 0, 0] * grad_output[0, 0, 2]
+ weight[0, 0, 1] * grad_output[0, 0, 1]
+ weight[0, 0, 2] * grad_output[0, 0, 0]
)
assert (
grad_input[0, 0, 3]
== weight[0, 0, 0] * grad_output[0, 0, 3]
+ weight[0, 0, 1] * grad_output[0, 0, 2]
+ weight[0, 0, 2] * grad_output[0, 0, 1]
)

@pytest.mark.task4_1
@given(tensors(shape=(1, 1, 6)), tensors(shape=(1, 1, 4)))
def test_conv1d(input, weight):
print(input, weight)
minitorch.grad_check(minitorch.Conv1dFun.apply, input, weight)

@pytest.mark.task4_1
def test_conv1d_in_channel():
t = minitorch.tensor_fromlist([[0, 1, 2, 3], [0, 1, 2, 3]]).view(1, 2, 4)
t.requires_grad_(True)
t2 = minitorch.tensor_fromlist([[1, 2, 3], [1, 2, 3]]).view(1, 2, 3)
out = minitorch.Conv1dFun.apply(t, t2)

assert out[0, 0, 0] == (0 * 1 + 1 * 2 + 2 * 3) * 2
assert out[0, 0, 1] == (1 * 1 + 2 * 2 + 3 * 3) * 2
assert out[0, 0, 2] == (2 * 1 + 3 * 2) * 2
assert out[0, 0, 3] == (3 * 1) * 2

@pytest.mark.task4_1
@given(tensors(shape=(2, 2, 6)), tensors(shape=(3, 2, 2)))
Expand Down
21 changes: 21 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@
from .strategies import tensors, assert_close
import pytest

@pytest.mark.task4_3
def test_tile():
t = minitorch.tensor_fromlist(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
).view(1, 1, 4, 4)
tiled, _, _ = minitorch.tile(t, (2, 2))
assert tiled[0, 0, 0, 0, 0] == 1
assert tiled[0, 0, 0, 0, 1] == 2
assert tiled[0, 0, 0, 0, 2] == 5
assert tiled[0, 0, 0, 0, 3] == 6

@pytest.mark.task4_3
def test_tile_2():
t = minitorch.tensor_fromlist(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
).view(1, 1, 4, 4)
tiled, _, _ = minitorch.tile(t, (1, 2))
assert tiled[0, 0, 0, 0, 0] == 1
assert tiled[0, 0, 0, 0, 1] == 2
assert tiled[0, 0, 0, 1, 0] == 3
assert tiled[0, 0, 0, 1, 1] == 4

@pytest.mark.task4_3
@given(tensors(shape=(1, 1, 4, 4)))
Expand Down

0 comments on commit 59fe29e

Please sign in to comment.