diff --git a/tests/test_function.py b/tests/test_function.py index ebf1437de..a2c1afebb 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -672,7 +672,7 @@ def test_d_arg(self): def test_n(self): ns = function.Namespace() topo, ns.x = mesh.rectilinear([1]) - self.assertEqualLowered(ns.eval_i('n(x_i)'), function.normal(ns.x), ndims=topo.ndims) + self.assertEqualLowered(ns.eval_i('n(x_i)'), function.normal(ns.x), ndims=topo.ndims-1) def test_functions(self): def sqr(a): @@ -738,8 +738,8 @@ def setUp(self): self.ns.a32 = numpy.array([[1,2],[3,4],[5,6]]) self.x = function.Argument('x',()) - def assertEqualLowered(self, s, f): - self.assertEqual((s @ self.ns).prepare_eval(ndims=2).simplified, f.prepare_eval(ndims=2).simplified) + def assertEqualLowered(self, s, f, *, ndims=2): + self.assertEqual((s @ self.ns).prepare_eval(ndims=ndims).simplified, f.prepare_eval(ndims=ndims).simplified) def test_group(self): self.assertEqualLowered('(a)', self.ns.a) def test_arg(self): self.assertEqualLowered('a2_i ?x_i', function.dot(self.ns.a2, function.Argument('x', [2]), axes=[0])) @@ -748,7 +748,7 @@ def test_multisubstitute(self): self.assertEqualLowered('(a2_i + ?x_i + ?y_i)(x_ def test_call(self): self.assertEqualLowered('sin(a)', function.sin(self.ns.a)) def test_call2(self): self.assertEqual(self.ns.eval_ij('arctan2(a2_i, a3_j)').prepare_eval(ndims=2).simplified, function.arctan2(self.ns.a2[:,None], self.ns.a3[None,:]).prepare_eval(ndims=2).simplified) def test_eye(self): self.assertEqualLowered('δ_ij a2_i', function.dot(function.eye(2), self.ns.a2, axes=[0])) - def test_normal(self): self.assertEqualLowered('n_i', self.ns.x.normal()) + def test_normal(self): self.assertEqualLowered('n_i', self.ns.x.normal(), ndims=1) def test_getitem(self): self.assertEqualLowered('a2_0', self.ns.a2[0]) def test_trace(self): self.assertEqualLowered('a22_ii', function.trace(self.ns.a22, 0, 1)) def test_sum(self): self.assertEqualLowered('a2_i a2_i', function.sum(self.ns.a2 * self.ns.a2, axis=0))