Skip to content

Commit

Permalink
Eig
Browse files Browse the repository at this point in the history
  • Loading branch information
soraros committed Jan 19, 2021
1 parent d1c1ef2 commit 27d4d95
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,27 +2343,32 @@ def evalf(self, index):

class Eig(Evaluable):

__slots__ = 'symmetric', 'func'
__slots__ = 'symmetric', 'func', '_w_dtype', '_vt_dtype'

@types.apply_annotations
def __init__(self, func:asarray, symmetric:bool=False):
assert func.ndim >= 2 and func.shape[-1] == func.shape[-2]
self.symmetric = symmetric
self.func = func
self._w_dtype = float if symmetric else complex
self._vt_dtype = _jointdtype(float, func.dtype if symmetric else complex)
super().__init__(args=[func])

def __len__(self):
return 2

def __iter__(self):
yield ArrayFromTuple(self, index=0, shape=self.func.shape[:-1], dtype=complex if not self.symmetric else float)
yield ArrayFromTuple(self, index=1, shape=self.func.shape, dtype=complex if not self.symmetric or self.func.dtype == complex else float)
yield ArrayFromTuple(self, index=0, shape=self.func.shape[:-1], dtype=self._w_dtype)
yield ArrayFromTuple(self, index=1, shape=self.func.shape, dtype=self._vt_dtype)

def _simplified(self):
return self.func._eig(self.symmetric)

def evalf(self, arr):
return (numpy.linalg.eigh if self.symmetric else numpy.linalg.eig)(arr)
w, vt = (numpy.linalg.eigh if self.symmetric else numpy.linalg.eig)(arr)
w = w.astype(self._w_dtype, copy=False)
vt = vt.astype(self._vt_dtype, copy=False)
return (w, vt)

class ArrayFromTuple(Array):

Expand Down

0 comments on commit 27d4d95

Please sign in to comment.