Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 14, 2020
1 parent fb1f1d9 commit c2bff8d
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
4 changes: 2 additions & 2 deletions minitorch/fast_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def tensor_conv2d(
s1 = input_strides
s2 = weight_strides

# TODO: Implement for Task 4.1.
raise NotImplementedError('Need to implement for Task 4.1')
# TODO: Implement for Task 4.2.
raise NotImplementedError('Need to implement for Task 4.2')


class Conv2dFun(Function):
Expand Down
33 changes: 24 additions & 9 deletions project/run_mnist_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ def __init__(self, in_size, out_size):
self.out_size = out_size

def forward(self, x):
batch, in_size = x.shape
return (
x.view(batch, 1, in_size) @ self.weights.value.view(in_size, self.out_size)
).view(batch, self.out_size) + self.bias.value
# TODO: Implement for Task 4.5.
raise NotImplementedError('Need to implement for Task 4.5')


class Conv2d(minitorch.Module):
Expand Down Expand Up @@ -110,8 +108,12 @@ def make_mnist(start, stop):
for batch_num, example_num in enumerate(range(0, N, BATCH)):
if N - example_num <= BATCH:
continue
y = minitorch.tensor_fromlist(ys[example_num : example_num + BATCH], backend=BACKEND)
x = minitorch.tensor_fromlist(X[example_num : example_num + BATCH], backend=BACKEND)
y = minitorch.tensor_fromlist(
ys[example_num : example_num + BATCH], backend=BACKEND
)
x = minitorch.tensor_fromlist(
X[example_num : example_num + BATCH], backend=BACKEND
)
x.requires_grad_(True)
y.requires_grad_(True)

Expand All @@ -134,8 +136,12 @@ def make_mnist(start, stop):

correct = 0
for val_example_num in range(0, 5 * BATCH, BATCH):
y = minitorch.tensor_fromlist(val_ys[val_example_num : val_example_num + BATCH], backend=BACKEND)
x = minitorch.tensor_fromlist(val_x[val_example_num : val_example_num + BATCH], backend=BACKEND)
y = minitorch.tensor_fromlist(
val_ys[val_example_num : val_example_num + BATCH], backend=BACKEND
)
x = minitorch.tensor_fromlist(
val_x[val_example_num : val_example_num + BATCH], backend=BACKEND
)
out = model.forward(x.view(BATCH, 1, H, W)).view(BATCH, C)
for i in range(BATCH):
m = -1000
Expand All @@ -147,7 +153,16 @@ def make_mnist(start, stop):
if y[i, ind] == 1.0:
correct += 1

print("Epoch ", epoch, " example ", example_num, " loss ", total_loss[0], " accuracy ", correct / float(5 * BATCH))
print(
"Epoch ",
epoch,
" example ",
example_num,
" loss ",
total_loss[0],
" accuracy ",
correct / float(5 * BATCH),
)

# Visualize test batch
for channel in range(4):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ def test_avg(t):


@pytest.mark.task4_4
@given(tensors(shape=(1, 1, 4, 4)))
@given(tensors(shape=(2, 3, 4)))
def test_max(t):
# TODO: Implement for Task 4.4.
raise NotImplementedError('Need to implement for Task 4.4')


@pytest.mark.task4_4
@given(tensors(shape=(1, 1, 4, 4)))
def test_max_pool(t):
out = minitorch.maxpool2d(t, (2, 2))
print(out)
print(t)
Expand Down

0 comments on commit c2bff8d

Please sign in to comment.