Skip to content

Commit

Permalink
repl const Arrays with Consts in optimize (#854)
Browse files Browse the repository at this point in the history
To speed-up repeated evaluation of the same `Evaluable`, this PR introduces
`Evaluable._deep_flatten_constants`, which replaces constant `Array`s with
`Constant`s, excluding the no-op arrays `Transpose` and `InsertAxis`, and
includes this in `optimized_for_numpy`.
  • Loading branch information
joostvanzwieten committed Feb 27, 2024
2 parents c0e6580 + 733bd99 commit 99349c0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion nutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'Numerical Utilities for Finite Element Analysis'

__version__ = version = '9a15'
__version__ = version = '9a16'
version_name = 'jook-sing'
24 changes: 19 additions & 5 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def _simplified(self):
@cached_property
def optimized_for_numpy(self):
retval = self.simplified._optimized_for_numpy1() or self
retval = retval._deep_flatten_constants() or retval
return retval._combine_loop_concatenates(frozenset())

@replace(depthfirst=True, recursive=True)
Expand All @@ -462,6 +463,11 @@ def _optimized_for_numpy1(obj):
def _optimized_for_numpy(self):
return

@replace(depthfirst=False, recursive=False)
def _deep_flatten_constants(self):
if isinstance(self, Array):
return self._flatten_constant()

@cached_property
def _loop_concatenate_deps(self):
deps = []
Expand Down Expand Up @@ -997,6 +1003,10 @@ def _const_uniform(self):
lower, upper = self._intbounds
return lower if lower == upper else None

def _flatten_constant(self):
if self.isconstant:
return constant(self.eval())


class Orthonormal(Array):
'make a vector orthonormal to a subspace'
Expand Down Expand Up @@ -1177,6 +1187,9 @@ def _const_uniform(self):
if self.ndim == 0:
return self.dtype(self.value[()])

def _flatten_constant(self):
pass


class InsertAxis(Array):

Expand Down Expand Up @@ -1296,6 +1309,9 @@ def _intbounds_impl(self):
def _const_uniform(self):
return self.func._const_uniform

def _flatten_constant(self):
pass


class Transpose(Array):

Expand Down Expand Up @@ -1491,6 +1507,9 @@ def _intbounds_impl(self):
def _const_uniform(self):
return self.func._const_uniform

def _flatten_constant(self):
pass


class Product(Array):

Expand Down Expand Up @@ -2178,11 +2197,6 @@ def _simplified(self):
if len(where) != self.ndim:
return align(self._newargs(*uninserted), where, self.shape)

def _optimized_for_numpy(self):
if self.isconstant:
retval = self.eval()
return constant(retval)

def _derivative(self, var, seen):
if self.dtype == complex or var.dtype == complex:
raise NotImplementedError('The complex derivative is not implemented.')
Expand Down

0 comments on commit 99349c0

Please sign in to comment.