Skip to content

Commit

Permalink
Replace (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gertjan van Zwieten committed Mar 22, 2024
2 parents 9121e22 + 052536d commit bebe0d7
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 42 deletions.
2 changes: 1 addition & 1 deletion nutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'Numerical Utilities for Finite Element Analysis'

__version__ = version = '9a22'
__version__ = version = '9a23'
version_name = 'jook-sing'
162 changes: 124 additions & 38 deletions nutils/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,95 @@ def wrapper(*args, **kwargs):
return wrapper


class IDSetView:

def __init__(self, init=()):
self._dict = init._dict if isinstance(init, IDSetView) else {id(obj): obj for obj in init}

def __len__(self):
return len(self._dict)

def __bool__(self):
return bool(self._dict)

def __iter__(self):
return iter(self._dict.values())

def __and__(self, other):
return self.copy().__iand__(other)

def __or__(self, other):
return self.copy().__ior__(other)

def __sub__(self, other):
return self.copy().__isub__(other)

def isdisjoint(self, other):
return self._dict.isdisjoint(IDSetView(other))

def intersection(self, other):
return self.__and__(IDSetView(other))

def difference(self, other):
return self.__sub__(IDSetView(other))

def union(self, other):
return self.__or__(IDSetView(other))

def __repr__(self):
return '{' + ', '.join(map(repr, self)) + '}'

def copy(self):
return IDSet(self)


class IDSet(IDSetView):

def __init__(self, init=()):
self._dict = init._dict.copy() if isinstance(init, IDSetView) else {id(obj): obj for obj in init}

def __iand__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented
if not other._dict:
self._dict.clear()
elif self._dict:
for k in set(self._dict) - set(other._dict):
del self._dict[k]
return self

def __ior__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented
self._dict.update(other._dict)
return self

def __isub__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented
for k in other._dict:
self._dict.pop(k, None)
return self

def add(self, obj):
self._dict[id(obj)] = obj

def pop(self):
return self._dict.popitem()[1]

def intersection_update(self, other):
self.__iand__(IDSetView(other))

def difference_update(self, other):
self.__isub__(IDSetView(other))

def update(self, other):
self.__ior__(IDSetView(other))

def view(self):
return IDSetView(self)


class IDDict:
'''Mapping from instance (is, not ==) to value. Keys need not be hashable.'''

Expand Down Expand Up @@ -821,11 +910,8 @@ def __iter__(self):
def __contains__(self, key):
return self.__dict.__contains__(id(key))

def __str__(self):
return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}'

def __repr__(self):
return self.__str__()
return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}'


def _tuple(*args):
Expand Down Expand Up @@ -865,7 +951,7 @@ class deep_replace_property:
Args
----
func
Callable which maps an object onto a new object, or ``None`` if no
Callable which maps an object onto a new object, or onto itself if no
replacement is made. It must have precisely one positional argument for
the object.
'''
Expand All @@ -889,15 +975,15 @@ def __delete__(self, obj):
def __get__(self, obj, objtype=None):
fstack = [obj] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
ostack = [] # stack of original objects to cache new value into
ostack = IDSet() # stack of original objects to cache new value into

while fstack:
obj = fstack.pop()

if isinstance(obj, self.recreate): # recreate object from rstack
f, nargs = obj
r = f(*[rstack.pop() for _ in range(nargs)])
if isinstance(r, self.owner) and (newr := self.func(r)) is not None:
if isinstance(r, self.owner) and (newr := self.func(r)) is not r:
fstack.append(newr) # recursion
else:
rstack.append(r)
Expand All @@ -913,10 +999,9 @@ def __get__(self, obj, objtype=None):
if (r := obj.__dict__.get(self.name)) is not None: # in cache
rstack.append(r if r is not self.identity else obj)
elif obj in ostack:
index = ostack.index(obj)
raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop of size {len(ostack)-index}')
raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop')
else:
ostack.append(obj)
ostack.add(obj)
fstack.append(ostack)
f, args = obj.__reduce__()
fstack.append(self.recreate(f, len(args)))
Expand All @@ -935,7 +1020,7 @@ def __get__(self, obj, objtype=None):
return rstack[0]


def shallow_replace(func):
def shallow_replace(func, *funcargs, **funckwargs):
'''decorator for deep object replacement
Generates a deep replacement method for reduceable objects based on a
Expand All @@ -958,42 +1043,43 @@ def shallow_replace(func):
The method that searches the object to perform the replacements.
'''

recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig'])
if not funcargs and not funckwargs: # decorator
# it would be nice to use partial here but then the decorator doesn't work with methods
return functools.wraps(func)(lambda *args, **kwargs: shallow_replace(func, *args, **kwargs))

@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
cache = IDDict() # cache of seen objects
target, *funcargs = funcargs
recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig'])

while fstack:
obj = fstack.pop()
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
cache = IDDict() # cache of seen objects

if isinstance(obj, recreate):
f, nargs, orig = obj
r = f(*[rstack.pop() for _ in range(nargs)])
cache[orig] = r
rstack.append(r)
while fstack:
obj = fstack.pop()

elif (r := cache.get(obj)) is not None:
rstack.append(r)
if isinstance(obj, recreate):
f, nargs, orig = obj
r = f(*[rstack.pop() for _ in range(nargs)])
cache[orig] = r
rstack.append(r)

elif (r := func(obj, *funcargs, **funckwargs)) is not None:
cache[obj] = r
rstack.append(r)
elif (r := cache.get(obj)) is not None:
rstack.append(r)

elif reduced := _reduce(obj):
f, args = reduced
fstack.append(recreate(f, len(args), obj))
fstack.extend(args)
elif (r := func(obj, *funcargs, **funckwargs)) is not None:
cache[obj] = r
rstack.append(r)

else: # obj cannot be reduced
rstack.append(obj)
elif reduced := _reduce(obj):
f, args = reduced
fstack.append(recreate(f, len(args), obj))
fstack.extend(args)

assert len(rstack) == 1
return rstack[0]
else: # obj cannot be reduced
rstack.append(obj)

return wrapped
assert len(rstack) == 1
return rstack[0]


# vim:sw=4:sts=4:et
8 changes: 6 additions & 2 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def _format_stack(self, values, e):
@util.deep_replace_property
def simplified(obj):
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
if retval is None:
return obj
if isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval

Expand All @@ -310,7 +312,9 @@ def optimized_for_numpy(self):
@util.deep_replace_property
def _optimized_for_numpy1(obj):
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
if retval is None:
return obj
if isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval

Expand Down
89 changes: 88 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,81 @@ def test_repr(self):
self.assertEqual(repr(self.d), "{'a': 1, 'b': 2}")


class IDSet(TestCase):

def setUp(self):
self.a, self.b, self.c = 'abc'
self.ab = util.IDSet([self.a, self.b])
self.ac = util.IDSet([self.a, self.c])

def test_union(self):
union = self.ab | self.ac
self.assertEqual(list(union), ['a', 'b', 'c'])
union = self.ac.union([self.a, self.b])
self.assertEqual(list(union), ['a', 'c', 'b'])

def test_union_update(self):
self.ab |= self.ac
self.assertEqual(list(self.ab), ['a', 'b', 'c'])
self.ac.update([self.a, self.b])
self.assertEqual(list(self.ac), ['a', 'c', 'b'])

def test_intersection(self):
intersection = self.ab & self.ac
self.assertEqual(list(intersection), ['a'])
intersection = self.ab.intersection([self.a, self.c])
self.assertEqual(list(intersection), ['a'])

def test_intersection_update(self):
self.ab &= self.ac
self.assertEqual(list(self.ab), ['a'])
self.ac.intersection_update([self.a, self.b])
self.assertEqual(list(self.ac), ['a'])

def test_difference(self):
difference = self.ab - self.ac
self.assertEqual(list(difference), ['b'])
difference = self.ac - self.ab
self.assertEqual(list(difference), ['c'])

def test_difference_update(self):
self.ab -= self.ac
self.assertEqual(list(self.ab), ['b'])
self.ac.difference_update([self.a, self.b])
self.assertEqual(list(self.ac), ['c'])

def test_add(self):
self.ab.add(self.a)
self.assertEqual(list(self.ab), ['a', 'b'])
self.ab.add(self.c)
self.assertEqual(list(self.ab), ['a', 'b', 'c'])
self.ac.add(self.b)
self.assertEqual(list(self.ac), ['a', 'c', 'b'])

def test_pop(self):
self.assertEqual(self.ab.pop(), 'b')
self.assertEqual(list(self.ab), ['a'])

def test_copy(self):
copy = self.ab.copy()
self.ab.pop()
self.assertEqual(list(self.ab), ['a'])
self.assertEqual(list(copy), ['a', 'b'])

def test_view(self):
view = self.ab.view()
self.ab.pop()
self.assertEqual(list(view), ['a'])
with self.assertRaises(AttributeError):
view.pop()

def test_str(self):
self.assertEqual(str(self.ab), "{'a', 'b'}")

def test_repr(self):
self.assertEqual(repr(self.ab), "{'a', 'b'}")


class replace(TestCase):

class Base:
Expand All @@ -465,8 +540,10 @@ def simple(self):
self.called = True
if isinstance(self, replace.Ten):
return replace.Intermediate() # to test recursion
if isinstance(self, replace.Intermediate):
elif isinstance(self, replace.Intermediate):
return 10
else:
return self

class Ten(Base): pass
class Intermediate(Base): pass
Expand Down Expand Up @@ -501,3 +578,13 @@ def test_shallow_nested(self):
newobj = self.subs10(obj, 20)
self.assertEqual(type(newobj), type(obj))
self.assertEqual(newobj.args, (5, {7, 20}))

def test_shallow_direct(self):
ten = self.Ten()
obj = self.Base(5, {7, ten})
def subs(arg):
if isinstance(arg, self.Ten):
return 20
newobj = util.shallow_replace(subs, obj)
self.assertEqual(type(newobj), type(obj))
self.assertEqual(newobj.args, (5, {7, 20}))

0 comments on commit bebe0d7

Please sign in to comment.