From 4c9eb7202f30129738e381c18201189c951c6b40 Mon Sep 17 00:00:00 2001 From: Joost van Zwieten Date: Tue, 5 Sep 2023 10:15:13 +0200 Subject: [PATCH] support complex arguments in function.linearize This patch adds support for linearizing functions with complex-valued arguments by instantiating `function.Argument`s with the correct dtype in `function.linearize`. --- nutils/function.py | 2 +- tests/test_function.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/nutils/function.py b/nutils/function.py index 7dc89a5df..97497498b 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -1728,7 +1728,7 @@ def linearize(__array: IntoArray, __arguments: Union[str, Dict[str, str], Iterab for kv in args: k, v = kv.split(':', 1) if isinstance(kv, str) else kv f = derivative(array, k) - parts.append(numpy.sum(f * Argument(v, f.shape[array.ndim:]), tuple(range(array.ndim, f.ndim)))) + parts.append(numpy.sum(f * Argument(v, *array.arguments[k]), tuple(range(array.ndim, f.ndim)))) return util.sum(parts) diff --git a/tests/test_function.py b/tests/test_function.py index fbc5c61cb..4176d26a0 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1377,6 +1377,15 @@ def test(self): _q = 5. self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q) + def test_complex(self): + f = function.linearize(function.Argument('u', shape=(3, 4), dtype=complex)**3 + + function.Argument('p', shape=(), dtype=complex), 'u:v,p:q') + # test linearization of u**3 + p -> 3 u**2 v + q through evaluation + _u = numpy.array([1+2j, 3+4j, 5+6j])[:,numpy.newaxis].repeat(4, 1) + _v = numpy.array([5+1j, 6+2j, 7+3j, 8+4j])[numpy.newaxis,:].repeat(3, 0) + _q = 5. + self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q) + class attributes(TestCase):