Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
gertjanvanzwieten committed Oct 7, 2022
1 parent f04c6a6 commit 79fe647
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,22 @@ def __repr__(self):
@staticmethod
def _dispatch(op, *args, **kwargs):
name = op.__name__
if name in ('add', 'sub', 'subtract', 'hypot', 'minimum', 'maximum', 'remainder'):
if name in ('add', 'sub', 'subtract', 'hypot', 'minimum', 'maximum', 'remainder', 'divmod'):
Dim = type(args[0])
if type(args[1]) != Dim:
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in args))
elif name in ('mul', 'multiply', 'matmul'):
elif name == 'reciprocal':
Dim = type(args[0])**-1
elif name in ('mul', 'multiply', 'matmul', 'dot'):
Dim = type(args[0]) * type(args[1])
elif name in ('truediv', 'true_divide', 'divide'):
Dim = type(args[0]) / type(args[1])
elif name in ('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum', 'mean', 'broadcast_to', 'transpose', 'trace', 'take', 'ptp', 'getitem', 'amax', 'amin', 'diff'):
elif name in ('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum', 'cumsum', 'mean', 'broadcast_to', 'transpose', 'trace', 'take', 'compress', 'ptp', 'getitem', 'amax', 'amin', 'diff', 'reshape', 'ravel', 'repeat', 'swapaxes'):
Dim = type(args[0])
elif name == 'sqrt':
Dim = type(args[0])**fractions.Fraction(1,2)
elif name == 'square':
Dim = type(args[0])**2
elif name == 'setitem':
Dim = type(args[0])
if type(args[2]) != Dim:
Expand Down Expand Up @@ -214,8 +218,11 @@ def _dispatch(op, *args, **kwargs):
retval = op(*(arg.__value if isinstance(arg, Quantity) else arg for arg in args), **kwargs)
except TypeError:
return NotImplemented
else:
return Dim(retval) if Dim != Dimensionless else retval
if Dim == Dimensionless:
return retval
if name == 'divmod':
return retval[0], Dim(retval[1])
return Dim(retval)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != '__call__':
Expand Down

0 comments on commit 79fe647

Please sign in to comment.