Skip to content

Commit

Permalink
evalf dtype runtime checker
Browse files Browse the repository at this point in the history
  • Loading branch information
soraros committed Jan 19, 2021
1 parent 111eecd commit 46542f8
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,10 +757,26 @@ def _assparse(self):
return chunks
return _assparse

class _evalf_checker:
def __init__(self, orig):
self.evalf_obj = getattr(orig, '__get__', lambda *args: orig)
def __get__(self, instance, owner):
evalf = self.evalf_obj(instance, owner)
@functools.wraps(evalf)
def evalf_with_check(*args, **kwargs):
res = evalf(*args, **kwargs)
assert not hasattr(instance, 'dtype') or asdtype(res.dtype) == instance.dtype, ((instance.dtype, res.dtype), instance, res)
assert not hasattr(instance, 'ndim') or res.ndim == instance.ndim
assert not hasattr(instance, 'shape') or all(m == n for m, n in zip(res.shape, instance.shape) if isinstance(n, int)), 'shape mismatch'
return res
return evalf_with_check

class _ArrayMeta(type(Evaluable)):
def __new__(mcls, name, bases, namespace):
if '_assparse' in namespace:
namespace['_assparse'] = _chunked_assparse_checker(namespace['_assparse'])
if 'evalf' in namespace:
namespace['evalf'] = _evalf_checker(namespace['evalf'])
return super().__new__(mcls, name, bases, namespace)

else:
Expand Down

0 comments on commit 46542f8

Please sign in to comment.