Skip to content

Commit

Permalink
fix normal tests in namespace, expression
Browse files Browse the repository at this point in the history
Two unittests where the normal in (namespace) expressions is being tested
incorrectly uses the dimension of the volume instead of the boundary. This
patch fixes the problem by explicitly defining the dimension to lower for.
  • Loading branch information
joostvanzwieten committed Apr 14, 2021
1 parent 5d5c27d commit 4cfa5c7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def test_d_arg(self):
def test_n(self):
ns = function.Namespace()
topo, ns.x = mesh.rectilinear([1])
self.assertEqualLowered(ns.eval_i('n(x_i)'), function.normal(ns.x), ndims=topo.ndims)
self.assertEqualLowered(ns.eval_i('n(x_i)'), function.normal(ns.x), ndims=topo.ndims-1)

def test_functions(self):
def sqr(a):
Expand Down Expand Up @@ -738,8 +738,8 @@ def setUp(self):
self.ns.a32 = numpy.array([[1,2],[3,4],[5,6]])
self.x = function.Argument('x',())

def assertEqualLowered(self, s, f):
self.assertEqual((s @ self.ns).prepare_eval(ndims=2).simplified, f.prepare_eval(ndims=2).simplified)
def assertEqualLowered(self, s, f, *, ndims=2):
self.assertEqual((s @ self.ns).prepare_eval(ndims=ndims).simplified, f.prepare_eval(ndims=ndims).simplified)

def test_group(self): self.assertEqualLowered('(a)', self.ns.a)
def test_arg(self): self.assertEqualLowered('a2_i ?x_i', function.dot(self.ns.a2, function.Argument('x', [2]), axes=[0]))
Expand All @@ -748,7 +748,7 @@ def test_multisubstitute(self): self.assertEqualLowered('(a2_i + ?x_i + ?y_i)(x_
def test_call(self): self.assertEqualLowered('sin(a)', function.sin(self.ns.a))
def test_call2(self): self.assertEqual(self.ns.eval_ij('arctan2(a2_i, a3_j)').prepare_eval(ndims=2).simplified, function.arctan2(self.ns.a2[:,None], self.ns.a3[None,:]).prepare_eval(ndims=2).simplified)
def test_eye(self): self.assertEqualLowered('δ_ij a2_i', function.dot(function.eye(2), self.ns.a2, axes=[0]))
def test_normal(self): self.assertEqualLowered('n_i', self.ns.x.normal())
def test_normal(self): self.assertEqualLowered('n_i', self.ns.x.normal(), ndims=1)
def test_getitem(self): self.assertEqualLowered('a2_0', self.ns.a2[0])
def test_trace(self): self.assertEqualLowered('a22_ii', function.trace(self.ns.a22, 0, 1))
def test_sum(self): self.assertEqualLowered('a2_i a2_i', function.sum(self.ns.a2 * self.ns.a2, axis=0))
Expand Down

0 comments on commit 4cfa5c7

Please sign in to comment.