Skip to content

Commit

Permalink
replace f @ sample by sample(f)
Browse files Browse the repository at this point in the history
This patch changes Sample.__rmatmul__ to Sample.__call__, in order to avoid
ambiguity issues when function arrays gain support for Numpy's array function
protocol.
  • Loading branch information
gertjanvanzwieten committed Nov 5, 2021
1 parent f226d3f commit 90b081f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
16 changes: 8 additions & 8 deletions nutils/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ def eval_sparse(self, funcs: Iterable[function.IntoArray], arguments: Optional[M
Optional arguments for function evaluation.
'''

return eval_integrals_sparse(*map(self.__rmatmul__, funcs), **(arguments or {}))
return eval_integrals_sparse(*map(self, funcs), **(arguments or {}))

def __rmatmul__(self, __func: function.IntoArray) -> function.Array:
def __call__(self, __func: function.IntoArray) -> function.Array:
func = _ConcatenatePoints(function.Array.cast(__func), self)
ielem = evaluable.loop_index('_sample_' + '_'.join(self.spaces), self.nelems)
indices = evaluable.loop_concatenate(evaluable._flat(self.get_evaluable_indices(ielem)), ielem)
Expand Down Expand Up @@ -430,7 +430,7 @@ def get_evaluable_indices(self, ielem: evaluable.Array) -> evaluable.Array:
offset = evaluable.get(_offsets(self.points), 0, ielem)
return evaluable.Range(npoints) + offset

def __rmatmul__(self, __func: function.IntoArray) -> function.Array:
def __call__(self, __func: function.IntoArray) -> function.Array:
return _ConcatenatePoints(function.Array.cast(__func), self)

class _CustomIndex(_TransformChainsSample):
Expand Down Expand Up @@ -519,7 +519,7 @@ def integral(self, __func: function.IntoArray) -> function.Array:
func = function.Array.cast(__func)
return function.zeros(func.shape, func.dtype)

def __rmatmul__(self, __func: function.IntoArray) -> function.Array:
def __call__(self, __func: function.IntoArray) -> function.Array:
func = function.Array.cast(__func)
return function.zeros((0, *func.shape), func.dtype)

Expand Down Expand Up @@ -569,8 +569,8 @@ def take_elements(self, __indices: numpy.ndarray) -> Sample:
def integral(self, func: function.IntoArray) -> function.Array:
return self._sample1.integral(func) + self._sample2.integral(func)

def __rmatmul__(self, func: function.IntoArray) -> function.Array:
return function.concatenate([func @ self._sample1, func @ self._sample2])
def __call__(self, func: function.IntoArray) -> function.Array:
return function.concatenate([self._sample1(func), self._sample2(func)])

class _Mul(_TensorialSample):

Expand Down Expand Up @@ -650,8 +650,8 @@ def hull(self) -> numpy.ndarray:
def integral(self, func: function.IntoArray) -> function.Array:
return self._sample1.integral(self._sample2.integral(func))

def __rmatmul__(self, func: function.IntoArray) -> function.Array:
return function.ravel(func @ self._sample2 @ self._sample1, axis=0)
def __call__(self, func: function.IntoArray) -> function.Array:
return function.ravel(self._sample1(self._sample2(func)), axis=0)

def basis(self) -> Sample:
basis1 = self._sample1.basis()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_expression_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def test_invalid_default_geometry_no_variable(self):
def assertEqualLowered(self, actual, desired, *, topo=None):
if topo:
smpl = topo.sample('gauss', 2)
lower = lambda f: evaluable.asarray(f @ smpl)
lower = lambda f: evaluable.asarray(smpl(f))
else:
lower = evaluable.asarray
return self.assertEqual(lower(actual), lower(desired))
Expand Down Expand Up @@ -1134,13 +1134,13 @@ def test_builtin_functions(self):
def test_builtin_jacobian_vector(self):
ns = expression_v1.Namespace()
domain, ns.x = mesh.rectilinear([1]*2)
l = lambda f: evaluable.asarray(f @ domain.sample('gauss', 2)).simplified
l = lambda f: evaluable.asarray(domain.sample('gauss', 2)(f)).simplified
self.assertEqual(l(ns.eval_('J(x)')), l(function.jacobian(ns.x)))

def test_builtin_jacobian_scalar(self):
ns = expression_v1.Namespace()
domain, (ns.t,) = mesh.rectilinear([1])
l = lambda f: evaluable.asarray(f @ domain.sample('gauss', 2)).simplified
l = lambda f: evaluable.asarray(domain.sample('gauss', 2)(f)).simplified
self.assertEqual(l(ns.eval_('J(t)')), l(function.jacobian(ns.t[None])))

def test_builtin_jacobian_matrix(self):
Expand Down Expand Up @@ -1173,7 +1173,7 @@ def assertEqualLowered(self, s, f, *, topo=None, indices=None):
if topo is None:
topo = self.domain
smpl = topo.sample('gauss', 2)
lower = lambda g: evaluable.asarray(g @ smpl).simplified
lower = lambda g: evaluable.asarray(smpl(g)).simplified
if indices:
evaluated = getattr(self.ns, 'eval_'+indices)(s)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_expression_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def test_define_for_3d(self):
ns.ε = function.levicivita(3)
ns.f = function.Array.cast([['x', '-z', 'y'], ['0', 'x z', '0']] @ ns)
smpl = topo.sample('gauss', 5)
assertEvalAlmostEqual = lambda *args: self.assertAllAlmostEqual(*((f @ smpl).as_evaluable_array.eval() for f in args))
assertEvalAlmostEqual = lambda *args: self.assertAllAlmostEqual(*(smpl(f).as_evaluable_array.eval() for f in args))
assertEvalAlmostEqual('curl_ij(y δ_j0 - x δ_j1 + z δ_j2)' @ ns, '-2 δ_i2' @ ns)
assertEvalAlmostEqual('curl_ij(-x^2 δ_j1)' @ ns, '-2 x δ_i2' @ ns)
assertEvalAlmostEqual('curl_ij((x δ_j0 - z δ_j1 + y δ_j2) δ_k0 + x z δ_j1 δ_k1)' @ ns, '2 δ_i0 δ_k0 - x δ_i0 δ_k1 + z δ_i2 δ_k1' @ ns)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,16 @@ def test_take_elements_empty(self):
self.assertEqual(take.npoints, 0)

def test_ones_at(self):
self.assertEqual((function.ones((), int) @ self.sample).eval().tolist(), [1]*self.desired_npoints)
self.assertEqual(self.sample(function.ones((), int)).eval().tolist(), [1]*self.desired_npoints)

def test_at_in_integral(self):
topo, geom = mesh.line(2, space='parent-integral')
actual = topo.integral(function.jacobian(geom) @ self.sample, degree=0)
actual = topo.integral(self.sample(function.jacobian(geom)), degree=0)
self.assertEqual(actual.eval().round(5).tolist(), [2]*self.desired_npoints)

def test_asfunction(self):
func = self.sample.asfunction(numpy.arange(self.sample.npoints))
self.assertEqual((func @ self.sample).eval().tolist(), numpy.arange(self.desired_npoints).tolist())
self.assertEqual(self.sample(func).eval().tolist(), numpy.arange(self.desired_npoints).tolist())

class Empty(TestCase, Common):

Expand Down Expand Up @@ -298,14 +298,14 @@ def setUp(self):

def test_at(self):
self.geom = function.rootcoords('a', 2) + numpy.array([0,2]) * function.transforms_index('a', self.transforms)
actual = (self.geom @ self.sample).as_evaluable_array.eval()
actual = self.sample(self.geom).as_evaluable_array.eval()
desired = numpy.array([[0,0],[0,1],[1,0],[1,1],[0,2],[1,2],[0,3],[0,4],[0,5],[1,4],[1,5]])
self.assertAllAlmostEqual(actual, desired)

def test_basis(self):
with _builtin_warnings.catch_warnings():
_builtin_warnings.simplefilter('ignore', category=evaluable.ExpensiveEvaluationWarning)
self.assertAllAlmostEqual((self.sample.basis() @ self.sample).as_evaluable_array.eval(), numpy.eye(11))
self.assertAllAlmostEqual(self.sample(self.sample.basis()).as_evaluable_array.eval(), numpy.eye(11))

class CustomIndex(TestCase, Common):

Expand All @@ -325,15 +325,15 @@ def setUp(self):

def test_at(self):
self.geom = function.rootcoords('a', 2) + numpy.array([0,2]) * function.transforms_index('a', self.transforms)
actual = (self.geom @ self.sample).as_evaluable_array.eval()
actual = self.sample(self.geom).as_evaluable_array.eval()
desired = numpy.array([[0,0],[0,1],[1,0],[1,1],[0,2],[1,2],[0,3],[0,4],[0,5],[1,4],[1,5]])
desired = numpy.take(desired, numpy.argsort(numpy.concatenate(self.desired_indices), axis=0), axis=0)
self.assertAllAlmostEqual(actual, desired)

def test_basis(self):
with _builtin_warnings.catch_warnings():
_builtin_warnings.simplefilter('ignore', category=evaluable.ExpensiveEvaluationWarning)
self.assertAllAlmostEqual((self.sample.basis() @ self.sample).as_evaluable_array.eval(), numpy.eye(11))
self.assertAllAlmostEqual(self.sample(self.sample.basis()).as_evaluable_array.eval(), numpy.eye(11))

class Special(TestCase):

Expand Down Expand Up @@ -391,7 +391,7 @@ def test_asfunction(self):
self.bezier2.eval(sampled)
self.assertAllEqual(self.gauss2.eval(sampled), values)
arg = function.Argument('dofs', [2,3])
self.assertTrue(evaluable.iszero(evaluable.asarray(function.derivative(sampled, arg) @ self.gauss2)))
self.assertTrue(evaluable.iszero(evaluable.asarray(self.gauss2(function.derivative(sampled, arg)))))

class integral(TestCase):

Expand Down

0 comments on commit 90b081f

Please sign in to comment.