Skip to content

Commit

Permalink
replace Evaluable._node_tuple with TupleNode
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Mar 8, 2024
1 parent 259376f commit d2cb558
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from . import debug_flags, _util as util, types, numeric, cache, warnings, parallel, sparse
from functools import cached_property
from ._graph import Node, RegularNode, DuplicatedLeafNode, InvisibleNode, Subgraph
from ._graph import Node, RegularNode, DuplicatedLeafNode, InvisibleNode, Subgraph, TupleNode
import nutils_poly as poly
import numpy
import sys
Expand Down Expand Up @@ -562,6 +562,12 @@ def __radd__(self, other):

return Tuple(tuple(other) + self.items)

def _node(self, cache, subgraph, times):
if (cached := cache.get(self)) is not None:
return cached
cache[self] = node = TupleNode(tuple(item._node(cache, subgraph, times) for item in self.items), (type(self).__name__, times[self]), subgraph=subgraph)
return node


# ARRAYFUNC
#
Expand Down Expand Up @@ -2849,11 +2855,11 @@ def evalf(self, arrays):
def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
elif hasattr(self.arrays, '_node_tuple'):
cache[self] = node = self.arrays._node_tuple(cache, subgraph, times)[self.index]
return node
else:
return super()._node(cache, subgraph, times)
node = self.arrays._node(cache, subgraph, times)
if isinstance(node, TupleNode):
node = node.items[self.index]
cache[self] = node
return node

def _intbounds_impl(self):
return self._lower, self._upper
Expand Down Expand Up @@ -4431,7 +4437,7 @@ def _node(self, cache, subgraph, times):
if self in cache:
return cache[self]
else:
cache[self] = node = self._lcc._node_tuple(cache, subgraph, times)[0]
cache[self] = node = self._lcc._node(cache, subgraph, times)[0]
return node

def _simplified(self):
Expand Down Expand Up @@ -4547,7 +4553,7 @@ def evalf_withtimes(self, times, shapes, length, *args):
result[..., start:stop] = block
return tuple(results)

def _node_tuple(self, cache, subgraph, times):
def _node(self, cache, subgraph, times):
if (self, 'tuple') in cache:
return cache[self, 'tuple']
subcache = {}
Expand All @@ -4562,8 +4568,8 @@ def _node_tuple(self, cache, subgraph, times):
concat_kwargs['stop'] = stop._node(subcache, loopgraph, subtimes)
concat_kwargs['func'] = func._node(subcache, loopgraph, subtimes)
concats.append(RegularNode('LoopConcatenate', (), concat_kwargs, (type(self).__name__, subtimes['concat', func]), loopgraph))
cache[self, 'tuple'] = concats = tuple(concats)
return concats
cache[self, 'tuple'] = node = TupleNode(tuple(concats), (type(self).__name__, times[self]), subgraph)
return node


class SearchSorted(Array):
Expand Down

0 comments on commit d2cb558

Please sign in to comment.