diff --git a/trax/tf_numpy/examples/mnist/train_test.py b/trax/tf_numpy/examples/mnist/train_test.py index 8b51da7c1..77122c3b0 100644 --- a/trax/tf_numpy/examples/mnist/train_test.py +++ b/trax/tf_numpy/examples/mnist/train_test.py @@ -46,8 +46,8 @@ def testRuns(self): def fake_mnist_data(): def gen_examples(num_examples): - x = np.array( - np.random.randn(num_examples, 784), copy=False, dtype=np.float32) + x = np.asarray( + np.random.randn(num_examples, 784), dtype=np.float32) y = np.zeros((num_examples, 10), dtype=np.float32) y[:][0] = 1. return (x, y)