Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Next #8

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add support for more operations
  • Loading branch information
gertjanvanzwieten committed Nov 18, 2022
commit e30c84c7056603aa52353ac5d379d1dfd21e7ed3
19 changes: 12 additions & 7 deletions nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,25 +172,29 @@ def __repr__(self):
@staticmethod
def _dispatch(op, *args, **kwargs):
name = op.__name__
if name in ('add', 'sub', 'subtract', 'hypot'):
if name in ('add', 'sub', 'subtract', 'hypot', 'minimum', 'maximum', 'remainder', 'divmod'):
Dim = type(args[0])
if type(args[1]) != Dim:
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in args))
elif name in ('mul', 'multiply', 'matmul'):
elif name == 'reciprocal':
Dim = type(args[0])**-1
elif name in ('mul', 'multiply', 'matmul', 'dot'):
Dim = type(args[0]) * type(args[1])
elif name in ('truediv', 'true_divide', 'divide'):
Dim = type(args[0]) / type(args[1])
elif name in ('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum', 'mean', 'broadcast_to', 'transpose', 'trace', 'take', 'ptp', 'getitem', 'amax', 'amin'):
elif name in ('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum', 'cumsum', 'mean', 'broadcast_to', 'transpose', 'trace', 'take', 'compress', 'ptp', 'getitem', 'amax', 'amin', 'diff', 'reshape', 'ravel', 'repeat', 'swapaxes'):
Dim = type(args[0])
elif name == 'sqrt':
Dim = type(args[0])**fractions.Fraction(1,2)
elif name == 'square':
Dim = type(args[0])**2
elif name == 'setitem':
Dim = type(args[0])
if type(args[2]) != Dim:
raise TypeError(f'cannot assign {type(args[2]).__name__} to {Dim.__name__}')
elif name in ('pow', 'power'):
Dim = type(args[0])**args[1]
elif name in ('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less', 'less_equal', 'greater', 'greater_equal', 'isfinite', 'isnan'):
elif name in ('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less', 'less_equal', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan'):
if any(type(q) != type(args[0]) for q in args[1:]):
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in args))
Dim = Dimension.from_powers({})
Expand All @@ -200,7 +204,7 @@ def _dispatch(op, *args, **kwargs):
if any(type(q) != Dim for q in stack_args[1:]):
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in stack_args))
args = [q.__value for q in stack_args],
elif name in ('shape', 'ndim', 'size'):
elif name in ('shape', 'ndim', 'size', 'sign'):
Dim = Dimension.from_powers({})
else:
return NotImplemented
Expand All @@ -209,8 +213,9 @@ def _dispatch(op, *args, **kwargs):
retval = op(*(arg.__value if isinstance(arg, Quantity) else arg for arg in args), **kwargs)
except TypeError:
return NotImplemented
else:
return Dim.__wrap__(retval)
if name == 'divmod':
return retval[0], Dim.__wrap__(retval[1])
return Dim.__wrap__(retval)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != '__call__':
Expand Down