diff --git a/nutils/evaluable.py b/nutils/evaluable.py index 1018f0c68..e124194ed 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -757,10 +757,26 @@ def _assparse(self): return chunks return _assparse + class _evalf_checker: + def __init__(self, orig): + self.evalf_obj = getattr(orig, '__get__', lambda *args: orig) + def __get__(self, instance, owner): + evalf = self.evalf_obj(instance, owner) + @functools.wraps(evalf) + def evalf_with_check(*args, **kwargs): + res = evalf(*args, **kwargs) + assert not hasattr(instance, 'dtype') or asdtype(res.dtype) == instance.dtype, ((instance.dtype, res.dtype), instance, res) + assert not hasattr(instance, 'ndim') or res.ndim == instance.ndim + assert not hasattr(instance, 'shape') or all(m == n for m, n in zip(res.shape, instance.shape) if isinstance(n, int)), 'shape mismatch' + return res + return evalf_with_check + class _ArrayMeta(type(Evaluable)): def __new__(mcls, name, bases, namespace): if '_assparse' in namespace: namespace['_assparse'] = _chunked_assparse_checker(namespace['_assparse']) + if 'evalf' in namespace: + namespace['evalf'] = _evalf_checker(namespace['evalf']) return super().__new__(mcls, name, bases, namespace) else: