diff --git a/nutils/evaluable.py b/nutils/evaluable.py index d03b13e09..8575709c3 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -850,9 +850,6 @@ def _node(self, cache, subgraph, times): cache[self] = node = RegularNode(label, args, {}, (type(self).__name__, times[self]), subgraph) return node - # simplifications - _eig = lambda self, symmetric: None - representations = () def _as(self, op, condition=None): @@ -1033,13 +1030,6 @@ def _node(self, cache, subgraph, times): def _isunit(self): return numpy.equal(self.value, 1).all() - def _eig(self, symmetric): - eigval, eigvec = (numpy.linalg.eigh if symmetric else numpy.linalg.eig)(self.value) - if not symmetric: - eigval = eigval.astype(complex, copy=False) - eigvec = eigvec.astype(complex, copy=False) - return Tuple((constant(eigval), constant(eigvec))) - def _intbounds_impl(self): if self.dtype == int and self.value.size: return int(self.value.min()), int(self.value.max()) @@ -1392,10 +1382,6 @@ def _simplified(self): def _derivative(self, var, seen): return -einsum('Aij,AjkB,Akl->AilB', self, derivative(self.func, var, seen), self) - def _eig(self, symmetric): - eigval, eigvec = Eig(self.func, symmetric) - return Tuple((reciprocal(eigval), eigvec)) - class Determinant(Array): @@ -2793,7 +2779,16 @@ def __getitem__(self, index): return ArrayFromTuple(self, index=index, shape=shape, dtype=dtype) def _simplified(self): - return self.func._eig(self.symmetric) + for v, in self.func._as(constant): + eigval, eigvec = (numpy.linalg.eigh if self.symmetric else numpy.linalg.eig)(v) + if not self.symmetric: + eigval = eigval.astype(complex, copy=False) + eigvec = eigvec.astype(complex, copy=False) + return Tuple((constant(eigval), constant(eigvec))) + for f, *axes in self.func._as(inverse): + if min(axes) == f.ndim-2: + eigval, eigvec = Eig(f, self.symmetric) + return Tuple((reciprocal(eigval), eigvec)) def evalf(self, arr): w, vt = (numpy.linalg.eigh if self.symmetric else numpy.linalg.eig)(arr)