Skip to content

Commit

Permalink
complex features (#832)
Browse files Browse the repository at this point in the history
This PR adds support for

* calling `numpy.{real,imag,conjugate}` with a `Quantity` and
* linearizing a function with complex-valued arguments.
  • Loading branch information
joostvanzwieten committed Sep 6, 2023
2 parents f28fb33 + 4c9eb72 commit 288080e
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def register(*names, __table=__DISPATCH_TABLE):
'trace', 'ptp', 'amax', 'amin', 'max', 'min', 'mean', 'take',
'broadcast_to', 'transpose', 'getitem', 'opposite', 'jump',
'replace_arguments', 'linearize', 'derivative', 'integral',
'sample', 'scatter', 'kronecker')
'sample', 'scatter', 'kronecker', 'real', 'imag', 'conjugate')
def __unary(op, *args, **kwargs):
(dim0, arg0), = Quantity.__unpack(args[0])
return dim0.wrap(op(arg0, *args[1:], **kwargs))
Expand Down
2 changes: 1 addition & 1 deletion nutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'Numerical Utilities for Finite Element Analysis'

__version__ = version = '9a3'
__version__ = version = '9a4'
version_name = 'jook-sing'
2 changes: 1 addition & 1 deletion nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,7 +1728,7 @@ def linearize(__array: IntoArray, __arguments: Union[str, Dict[str, str], Iterab
for kv in args:
k, v = kv.split(':', 1) if isinstance(kv, str) else kv
f = derivative(array, k)
parts.append(numpy.sum(f * Argument(v, f.shape[array.ndim:]), tuple(range(array.ndim, f.ndim))))
parts.append(numpy.sum(f * Argument(v, *array.arguments[k]), tuple(range(array.ndim, f.ndim))))
return util.sum(parts)


Expand Down
9 changes: 9 additions & 0 deletions tests/test_SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ def test_pos(self):
def test_abs(self):
self.assertEqual(numpy.abs(SI.Mass('-2kg')), SI.Mass('2kg'))

def test_real(self):
self.assertEqual(numpy.real(SI.ElectricPotential('1V') + 1j * SI.ElectricPotential('2V')), SI.ElectricPotential('1V'))

def test_imag(self):
self.assertEqual(numpy.imag(SI.ElectricPotential('1V') + 1j * SI.ElectricPotential('2V')), SI.ElectricPotential('2V'))

def test_conjugate(self):
self.assertEqual(numpy.conjugate(SI.ElectricPotential('1V') + 1j * SI.ElectricPotential('2V')), SI.ElectricPotential('1V') - 1j * SI.ElectricPotential('2V'))

def test_sqrt(self):
self.assertEqual(numpy.sqrt(SI.Area('4m2')), SI.Length('2m'))

Expand Down
9 changes: 9 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,15 @@ def test(self):
_q = 5.
self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q)

def test_complex(self):
f = function.linearize(function.Argument('u', shape=(3, 4), dtype=complex)**3
+ function.Argument('p', shape=(), dtype=complex), 'u:v,p:q')
# test linearization of u**3 + p -> 3 u**2 v + q through evaluation
_u = numpy.array([1+2j, 3+4j, 5+6j])[:,numpy.newaxis].repeat(4, 1)
_v = numpy.array([5+1j, 6+2j, 7+3j, 8+4j])[numpy.newaxis,:].repeat(3, 0)
_q = 5.
self.assertAllEqual(f.eval(u=_u, v=_v, q=_q).export('dense'), 3 * _u**2 * _v + _q)


class attributes(TestCase):

Expand Down

0 comments on commit 288080e

Please sign in to comment.