diff --git a/nutils/__init__.py b/nutils/__init__.py index 412cc90fa..4dbefc297 100644 --- a/nutils/__init__.py +++ b/nutils/__init__.py @@ -1,4 +1,4 @@ 'Numerical Utilities for Finite Element Analysis' -__version__ = version = '9a15' +__version__ = version = '9a16' version_name = 'jook-sing' diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 0e9507849..208f75a74 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -449,6 +449,7 @@ def _simplified(self): @cached_property def optimized_for_numpy(self): retval = self.simplified._optimized_for_numpy1() or self + retval = retval._deep_flatten_constants() or retval return retval._combine_loop_concatenates(frozenset()) @replace(depthfirst=True, recursive=True) @@ -462,6 +463,11 @@ def _optimized_for_numpy1(obj): def _optimized_for_numpy(self): return + @replace(depthfirst=False, recursive=False) + def _deep_flatten_constants(self): + if isinstance(self, Array): + return self._flatten_constant() + @cached_property def _loop_concatenate_deps(self): deps = [] @@ -997,6 +1003,10 @@ def _const_uniform(self): lower, upper = self._intbounds return lower if lower == upper else None + def _flatten_constant(self): + if self.isconstant: + return constant(self.eval()) + class Orthonormal(Array): 'make a vector orthonormal to a subspace' @@ -1177,6 +1187,9 @@ def _const_uniform(self): if self.ndim == 0: return self.dtype(self.value[()]) + def _flatten_constant(self): + pass + class InsertAxis(Array): @@ -1296,6 +1309,9 @@ def _intbounds_impl(self): def _const_uniform(self): return self.func._const_uniform + def _flatten_constant(self): + pass + class Transpose(Array): @@ -1491,6 +1507,9 @@ def _intbounds_impl(self): def _const_uniform(self): return self.func._const_uniform + def _flatten_constant(self): + pass + class Product(Array): @@ -2178,11 +2197,6 @@ def _simplified(self): if len(where) != self.ndim: return align(self._newargs(*uninserted), where, self.shape) - def _optimized_for_numpy(self): - if self.isconstant: - retval = self.eval() - return constant(retval) - def _derivative(self, var, seen): if self.dtype == complex or var.dtype == complex: raise NotImplementedError('The complex derivative is not implemented.')