diff --git a/nutils/_util.py b/nutils/_util.py index c40bbb440..a3658502e 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -1020,7 +1020,7 @@ def __get__(self, obj, objtype=None): return rstack[0] -def shallow_replace(func): +def shallow_replace(func, *funcargs, **funckwargs): '''decorator for deep object replacement Generates a deep replacement method for reduceable objects based on a @@ -1043,42 +1043,43 @@ def shallow_replace(func): The method that searches the object to perform the replacements. ''' - recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig']) + if not funcargs and not funckwargs: # decorator + # it would be nice to use partial here but then the decorator doesn't work with methods + return functools.wraps(func)(lambda *args, **kwargs: shallow_replace(func, *args, **kwargs)) - @functools.wraps(func) - def wrapped(target, *funcargs, **funckwargs): - fstack = [target] # stack of unprocessed objects and command tokens - rstack = [] # stack of processed objects - cache = IDDict() # cache of seen objects + target, *funcargs = funcargs + recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig']) - while fstack: - obj = fstack.pop() + fstack = [target] # stack of unprocessed objects and command tokens + rstack = [] # stack of processed objects + cache = IDDict() # cache of seen objects - if isinstance(obj, recreate): - f, nargs, orig = obj - r = f(*[rstack.pop() for _ in range(nargs)]) - cache[orig] = r - rstack.append(r) + while fstack: + obj = fstack.pop() - elif (r := cache.get(obj)) is not None: - rstack.append(r) + if isinstance(obj, recreate): + f, nargs, orig = obj + r = f(*[rstack.pop() for _ in range(nargs)]) + cache[orig] = r + rstack.append(r) - elif (r := func(obj, *funcargs, **funckwargs)) is not None: - cache[obj] = r - rstack.append(r) + elif (r := cache.get(obj)) is not None: + rstack.append(r) - elif reduced := _reduce(obj): - f, args = reduced - fstack.append(recreate(f, len(args), obj)) - fstack.extend(args) + elif (r := func(obj, *funcargs, **funckwargs)) is not None: + cache[obj] = r + rstack.append(r) - else: # obj cannot be reduced - rstack.append(obj) + elif reduced := _reduce(obj): + f, args = reduced + fstack.append(recreate(f, len(args), obj)) + fstack.extend(args) - assert len(rstack) == 1 - return rstack[0] + else: # obj cannot be reduced + rstack.append(obj) - return wrapped + assert len(rstack) == 1 + return rstack[0] # vim:sw=4:sts=4:et diff --git a/tests/test_util.py b/tests/test_util.py index db834d82a..ec05d7cbe 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -576,3 +576,13 @@ def test_shallow_nested(self): newobj = self.subs10(obj, 20) self.assertEqual(type(newobj), type(obj)) self.assertEqual(newobj.args, (5, {7, 20})) + + def test_shallow_direct(self): + ten = self.Ten() + obj = self.Base(5, {7, ten}) + def subs(arg): + if isinstance(arg, self.Ten): + return 20 + newobj = util.shallow_replace(subs, obj) + self.assertEqual(type(newobj), type(obj)) + self.assertEqual(newobj.args, (5, {7, 20}))