Skip to content

Commit

Permalink
support complex arguments in function.linearize
Browse files Browse the repository at this point in the history
This patch adds support for linearizing functions with complex-valued
arguments by instantiating `function.Argument`s with the correct dtype
in `function.linearize`.
  • Loading branch information
joostvanzwieten committed Sep 5, 2023
1 parent e040cd7 commit 4c9eb72
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,7 @@ def linearize(__array: IntoArray, __arguments: Union[str, Dict[str, str], Iterab
for kv in args:
k, v = kv.split(':', 1) if isinstance(kv, str) else kv
f = derivative(array, k)
parts.append(numpy.sum(f * Argument(v, f.shape[array.ndim:]), tuple(range(array.ndim, f.ndim))))
parts.append(numpy.sum(f * Argument(v, *array.arguments[k]), tuple(range(array.ndim, f.ndim))))
return util.sum(parts)


Expand Down
9 changes: 9 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,15 @@ def test(self):
_q = 5.
self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q)

def test_complex(self):
f = function.linearize(function.Argument('u', shape=(3, 4), dtype=complex)**3
+ function.Argument('p', shape=(), dtype=complex), 'u:v,p:q')
# test linearization of u**3 + p -> 3 u**2 v + q through evaluation
_u = numpy.array([1+2j, 3+4j, 5+6j])[:,numpy.newaxis].repeat(4, 1)
_v = numpy.array([5+1j, 6+2j, 7+3j, 8+4j])[numpy.newaxis,:].repeat(3, 0)
_q = 5.
self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q)


class attributes(TestCase):

Expand Down

0 comments on commit 4c9eb72

Please sign in to comment.