This repository has been archived by the owner on Jan 15, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 531
/
Copy pathtest_op.py
138 lines (117 loc) · 6.13 KB
/
test_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
from numpy.testing import assert_allclose
import mxnet as mx
from mxnet import gluon
from scipy.stats import ks_2samp
import pytest
from gluonnlp.op import *
mx.npx.set_np()
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seq_length', [16, 32])
@pytest.mark.parametrize('num_sel_positions', [1, 5])
@pytest.mark.parametrize('feature_shape', [(16,), (16, 32)])
@pytest.mark.parametrize('hybridized', [False, True])
@pytest.mark.seed(1)
def test_select_vectors_by_position(batch_size, seq_length, num_sel_positions,
feature_shape, hybridized):
data = mx.np.random.uniform(-1, 1, (batch_size, seq_length) + feature_shape, dtype=np.float32)
positions = mx.np.random.randint(0, seq_length, (batch_size, num_sel_positions), dtype=np.int32)
class Foo(gluon.HybridBlock):
def forward(self, p_data, p_positions):
return select_vectors_by_position(p_data, p_positions)
foo = Foo()
if hybridized:
foo.hybridize()
out_mx = foo(data, positions)
out_np = data.asnumpy()[np.expand_dims(np.arange(data.shape[0]).astype(np.int32),
axis=1),
positions.asnumpy()]
assert_allclose(out_mx.asnumpy(), out_np, 1E-4, 1E-4)
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seq_length', [16, 32])
@pytest.mark.parametrize('num_sel_positions', [1, 5])
@pytest.mark.parametrize('feature_shape,increment_shape', [((16,), (16,)),
((16, 32), (16, 1)),
((16, 32), (16, 32))])
@pytest.mark.parametrize('hybridized', [False, True])
@pytest.mark.seed(1)
def test_add_vectors_by_position(batch_size, seq_length, num_sel_positions,
feature_shape, increment_shape, hybridized):
data = mx.np.random.uniform(-1, 1, (batch_size, seq_length) + feature_shape, dtype=np.float32)
positions = mx.np.random.randint(0, seq_length, (batch_size, num_sel_positions), dtype=np.int32)
increment = mx.np.random.uniform(-1, 1, (batch_size, num_sel_positions) + increment_shape)
class Foo(gluon.HybridBlock):
def forward(self, p_data, p_increment, p_positions):
return add_vectors_by_position(p_data, p_increment, p_positions)
foo = Foo()
if hybridized:
foo.hybridize()
out_mx = foo(data, increment, positions).asnumpy()
out_np = data.asnumpy().copy()
positions = positions.asnumpy()
increment = increment.asnumpy()
for bidx in range(batch_size):
for sidx in range(num_sel_positions):
sel = positions[bidx, sidx]
out_np[bidx, sel] += increment[bidx, sidx]
assert_allclose(out_np, out_mx, 1E-4, 1E-4)
@pytest.mark.parametrize('batch_size', [1, 4])
@pytest.mark.parametrize('seq_length', [16, 32])
@pytest.mark.parametrize('num_sel_positions', [1, 5])
@pytest.mark.parametrize('feature_shape,update_shape', [((16,), (16,)),
((16, 32), (16, 1)),
((16, 32), (16, 32))])
@pytest.mark.parametrize('hybridized', [False, True])
@pytest.mark.seed(1)
def test_update_vectors_by_position(batch_size, seq_length, num_sel_positions,
feature_shape, update_shape, hybridized):
data = mx.np.random.uniform(-1, 1, (batch_size, seq_length) + feature_shape, dtype=np.float32)
val = mx.np.random.uniform(-1, 1, (batch_size, num_sel_positions) + update_shape)
positions = mx.np.zeros((batch_size, num_sel_positions), dtype=np.int32)
for i in range(batch_size):
positions[i, :] = np.random.choice(seq_length, num_sel_positions, replace=False)
class Foo(gluon.HybridBlock):
def forward(self, p_data, p_val, p_positions):
return update_vectors_by_position(p_data, p_val, p_positions)
foo = Foo()
if hybridized:
foo.hybridize()
out_mx = foo(data, val, positions)
out_np = data.asnumpy().copy()
out_np[np.expand_dims(np.arange(data.shape[0]).astype(np.int32), axis=1),
positions.asnumpy()] = val.asnumpy()
assert_allclose(out_mx.asnumpy(), out_np, 1E-4, 1E-4)
@pytest.mark.parametrize('shape', [(10,), (5, 10)])
@pytest.mark.seed(1)
def test_gumbel_softmax(shape):
# Here, we just verify that it will generate one-hot vectors and will have gradient
logits = mx.np.random.uniform(-2, -1, shape)
ret = gumbel_softmax(logits)
assume_allones = (ret == 1).sum(axis=-1).asnumpy()
assert_allclose(assume_allones, np.ones_like(assume_allones))
@pytest.mark.parametrize('shape', (50,))
@pytest.mark.seed(1)
def test_trunc_gumbel(shape):
# We first just verify that the samples are smaller than the provided threshold (i.e. they are truncated)
# And also attempt to remove the truncation and verify if it is sampled from a gumbel distribution
# using a KS-test with another sampled gumbel distribution
# Verifying if the distribution is truncated
for i in range(1000):
samples = trunc_gumbel(mx.np.ones(shape), 1.0).asnumpy()
assert (samples < 1.0).all()
# perform ks-tests
pvalues = []
for i in range(1000):
logits = mx.np.random.uniform(-2, -1, shape)
sampled_gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits # sample a gumbel distribution
# sample a potential truncated gumbel distribution
gumbels = mx.np.random.gumbel(mx.np.zeros_like(logits)) + logits
sampled_truncated_gumbels = trunc_gumbel(logits, 0.5)
# remove the truncation
reconstructed_sample = -mx.np.log(mx.np.exp(-sampled_truncated_gumbels) - mx.np.exp(-0.5))
pvalue = ks_2samp(reconstructed_sample.asnumpy(), sampled_gumbels.asnumpy()).pvalue
pvalues.append(pvalue)
pvalues = np.array(pvalues)
# Statistical inference condition: if out of all the tests, 90% of the resultant p-values > 0.05,
# accept the null hypothesis (i.e. the reconstructed_samples indeed arrive from a gumbel distribution)
assert (len(pvalues[pvalues > 0.05]) > 900)