diff --git a/nutils/_util.py b/nutils/_util.py index a3658502e..0854e6a77 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -951,7 +951,7 @@ class deep_replace_property: Args ---- func - Callable which maps an object onto a new object, or ``None`` if no + Callable which maps an object onto a new object, or onto itself if no replacement is made. It must have precisely one positional argument for the object. ''' @@ -983,7 +983,7 @@ def __get__(self, obj, objtype=None): if isinstance(obj, self.recreate): # recreate object from rstack f, nargs = obj r = f(*[rstack.pop() for _ in range(nargs)]) - if isinstance(r, self.owner) and (newr := self.func(r)) is not None: + if isinstance(r, self.owner) and (newr := self.func(r)) is not r: fstack.append(newr) # recursion else: rstack.append(r) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index b3a417ec0..5093bcaba 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -293,7 +293,9 @@ def _format_stack(self, values, e): @util.deep_replace_property def simplified(obj): retval = obj._simplified() - if retval is not None and isinstance(obj, Array): + if retval is None: + return obj + if isinstance(obj, Array): assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval) return retval @@ -310,7 +312,9 @@ def optimized_for_numpy(self): @util.deep_replace_property def _optimized_for_numpy1(obj): retval = obj._simplified() or obj._optimized_for_numpy() - if retval is not None and isinstance(obj, Array): + if retval is None: + return obj + if isinstance(obj, Array): assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__) return retval diff --git a/tests/test_util.py b/tests/test_util.py index ec05d7cbe..07927f3e3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -540,8 +540,10 @@ def simple(self): self.called = True if isinstance(self, replace.Ten): return replace.Intermediate() # to test recursion - if isinstance(self, replace.Intermediate): + elif isinstance(self, replace.Intermediate): return 10 + else: + return self class Ten(Base): pass class Intermediate(Base): pass