Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dispatch mechanism #10

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ At this point, the dimension is not very useful yet as it lacks units. To
rectify this we define the radian by its abbreviation 'rad' in terms of the
provided reference quantity, and assign it to the global table of units:

>>> SI.units.rad = Angle.reference_quantity
>>> SI.units.rad = Angle.__wrap__(1.)

Additional units can be defined by relating them to pre-existing ones:

Expand Down
256 changes: 154 additions & 102 deletions nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,14 @@ def _binop(op, a, b):

def __mul__(cls, other):
if not isinstance(other, Dimension):
return cls
return NotImplemented
return cls._binop(operator.add, cls.__powers, other.__powers)

def __rmul__(cls, other):
assert not isinstance(other, Dimension)
return cls

def __truediv__(cls, other):
if not isinstance(other, Dimension):
return cls
return NotImplemented
return cls._binop(operator.sub, cls.__powers, other.__powers)

def __rtruediv__(cls, other):
assert not isinstance(other, Dimension)
return cls**-1

def __pow__(cls, other):
try:
# Fraction supports only a fixed set of input types, so to extend
Expand Down Expand Up @@ -126,10 +118,6 @@ def __wrap__(cls, value):
return value
return super().__call__(value)

@property
def reference_quantity(cls):
return cls.__wrap__(1.)


def parse(s):
if not isinstance(s, str):
Expand Down Expand Up @@ -172,93 +160,157 @@ def __str__(self):
return str(self.__value) + type(self).__name__

@staticmethod
def _dispatch(op, *args, **kwargs):
name = op.__name__
args = [parse(arg) if isinstance(arg, str) else arg for arg in args]
if name in ('add', 'sub', 'subtract', 'hypot'):
Dim = type(args[0])
if type(args[1]) != Dim:
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in args))
elif name in ('mul', 'multiply', 'matmul'):
Dim = type(args[0]) * type(args[1])
elif name in ('truediv', 'true_divide', 'divide'):
Dim = type(args[0]) / type(args[1])
elif name in ('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum', 'mean', 'broadcast_to', 'transpose', 'trace', 'take', 'ptp', 'getitem', 'amax', 'amin', 'max', 'min'):
Dim = type(args[0])
elif name == 'sqrt':
Dim = type(args[0])**fractions.Fraction(1,2)
elif name == 'setitem':
Dim = type(args[0])
if type(args[2]) != Dim:
raise TypeError(f'cannot assign {type(args[2]).__name__} to {Dim.__name__}')
elif name in ('pow', 'power'):
Dim = type(args[0])**args[1]
elif name in ('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less', 'less_equal', 'greater', 'greater_equal', 'isfinite', 'isnan'):
if any(type(q) != type(args[0]) for q in args[1:]):
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in args))
Dim = Dimension.from_powers({})
elif name in ('stack', 'concatenate'):
stack_args, = args
Dim = type(stack_args[0])
if any(type(q) != Dim for q in stack_args[1:]):
raise TypeError(f'incompatible arguments for {name}: ' + ', '.join(type(arg).__name__ for arg in stack_args))
args = [q.__value for q in stack_args],
elif name in ('shape', 'ndim', 'size'):
Dim = Dimension.from_powers({})
else:
return NotImplemented
assert isinstance(Dim, Dimension)
return Dim.__wrap__(op(*(arg.__value if isinstance(arg, Quantity) else arg for arg in args), **kwargs))

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != '__call__':
return NotImplemented
return self._dispatch(ufunc, *inputs, **kwargs)

def __array_function__(self, func, types, args, kwargs):
return self._dispatch(func, *args, **kwargs)

__getitem__ = lambda self, item: self._dispatch(operator.getitem, self, item)
__setitem__ = lambda self, item, value: self._dispatch(operator.setitem, self, item, value)

def _unary(name):
def __unpack(*args):
unpacked_any = False
for arg in args:
if isinstance(arg, Quantity):
yield type(arg), arg.__value
unpacked_any = True
else:
yield Dimension.from_powers({}), arg
assert unpacked_any, 'no dimensional quantities found'

__DISPATCH_TABLE = {}

## POPULATE DISPATCH TABLE

def register(*names, __table=__DISPATCH_TABLE):
assert not any(name in __table for name in names)
return lambda f: __table.update(dict.fromkeys(names, f))

@register('neg', 'negative', 'pos', 'positive', 'abs', 'absolute', 'sum',
'trace', 'ptp', 'amax', 'amin', 'max', 'min', 'opposite', 'mean',
'broadcast_to', 'transpose', 'take', 'getitem', 'mean', 'jump',
'replace_arguments', 'linearize', 'derivative', 'integral')
def __unary(op, *args, **kwargs):
(dim0, arg0), = Quantity.__unpack(args[0])
return dim0.__wrap__(op(arg0, *args[1:], **kwargs))

@register('add', 'sub', 'subtract', 'hypot', 'mod')
def __add_like(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
if dim0 != dim1:
raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')
return dim0.__wrap__(op(arg0, arg1, *args[2:], **kwargs))

@register('mul', 'multiply', 'matmul')
def __mul_like(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
return (dim0 * dim1).__wrap__(op(arg0, arg1, *args[2:], **kwargs))

@register('truediv', 'true_divide', 'divide', 'grad')
def __div_like(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
return (dim0 / dim1).__wrap__(op(arg0, arg1, *args[2:], **kwargs))

@register('sqrt')
def __sqrt(op, *args, **kwargs):
(dim0, arg0), = Quantity.__unpack(args[0])
return (dim0**fractions.Fraction(1,2)).__wrap__(op(arg0, *args[1:], **kwargs))

@register('setitem')
def __setitem(op, *args, **kwargs):
(dim0, arg0), (dim2, arg2) = Quantity.__unpack(args[0], args[2])
if dim0 != dim2:
raise TypeError(f'cannot assign {dim2.__name__} to {dim0.__name__}')
return dim0.__wrap__(op(arg0, args[1], arg2, *args[3:], **kwargs))

@register('pow', 'power', 'jacobian')
def __pow_like(op, *args, **kwargs):
(dim0, arg0), = Quantity.__unpack(args[0])
return (dim0**args[1]).__wrap__(op(arg0, *args[1:], **kwargs))

@register('isfinite', 'isnan', 'shape', 'ndim', 'size', 'normal')
def __unary_drop(op, *args, **kwargs):
(_dim0, arg0), = Quantity.__unpack(args[0])
return op(arg0, *args[1:], **kwargs)

@register('lt', 'le', 'eq', 'ne', 'gt', 'ge', 'equal', 'not_equal', 'less',
'less_equal', 'greater', 'greater_equal')
def __binary_drop(op, *args, **kwargs):
(dim0, arg0), (dim1, arg1) = Quantity.__unpack(args[0], args[1])
if dim0 != dim1:
raise TypeError(f'incompatible arguments for {op.__name__}: {dim0.__name__}, {dim1.__name__}')
return op(arg0, arg1, *args[2:], **kwargs)

@register('stack', 'concatenate')
def __stack_like(op, *args, **kwargs):
dims, arg0 = zip(*Quantity.__unpack(*args[0]))
if any(dim != dims[0] for dim in dims[1:]):
raise TypeError(f'incompatible arguments for {op.__name__}: ' + ', '.join(dim.__name__ for dim in dims))
return dims[0].__wrap__(op(arg0, *args[1:], **kwargs))

@register('evaluate')
def __evaluate(op, *args, **kwargs):
dims, args = zip(*Quantity.__unpack(*args))
return tuple(dim.__wrap__(ret) for (dim, ret) in zip(dims, op(*args, **kwargs)))

del register

## DEFINE OPERATORS

def op(name, with_reverse=False, *, __table=__DISPATCH_TABLE):
dispatch = __table[name]
op = getattr(operator, name)
return lambda self: self._dispatch(op, self)
ret = lambda *args: dispatch(op, *args)
if with_reverse:
ret = ret, lambda self, other: dispatch(op, other, self)
return ret

__getitem__ = op('getitem')
__setitem__ = op('setitem')
__neg__ = op('neg')
__pos__ = op('pos')
__abs__ = op('abs')
__lt__ = op('lt')
__le__ = op('le')
__eq__ = op('eq')
__ne__ = op('ne')
__gt__ = op('gt')
__ge__ = op('ge')
__add__, __radd__ = op('add', True)
__sub__, __rsub__ = op('sub', True)
__mul__, __rmul__ = op('mul', True)
__matmul__, __rmatmul__ = op('matmul', True)
__truediv, __rtruediv__ = op('truediv', True)
__mod__, __rmod__ = op('mod', True)
__pow__, __rpow__ = op('pow', True)

def __truediv__(self, other):
if type(other) is str:
return self.__value / self.__class__(other).__value
return self.__truediv(other)

del op

## DEFINE ATTRIBUTES

def attr(name):
return property(lambda self: getattr(self.__value, name))

__neg__ = _unary('neg')
__pos__ = _unary('pos')
__abs__ = _unary('abs')
shape = attr('shape')
size = attr('size')
ndim = attr('ndim')

def _binary(name):
op = getattr(operator, name)
return lambda self, other: self._dispatch(op, self, other)
del attr

__lt__ = _binary('lt')
__le__ = _binary('le')
__eq__ = _binary('eq')
__ne__ = _binary('ne')
__gt__ = _binary('gt')
__ge__ = _binary('ge')
## DISPATCH THIRD PARTY CALLS

def _binary_r(name):
op = getattr(operator, name)
return lambda self, other: self._dispatch(op, self, other), \
lambda self, other: self._dispatch(op, other, self)

__add__, __radd__ = _binary_r('add')
__sub__, __rsub__ = _binary_r('sub')
__mul__, __rmul__ = _binary_r('mul')
__matmul__, __rmatmul__ = _binary_r('matmul')
__truediv__, __rtruediv__ = _binary_r('truediv')
__mod__, __rmod__ = _binary_r('mod')
__pow__, __rpow__ = _binary_r('pow')

def _attr(name):
return property(lambda self: getattr(self.__value, name))
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method != '__call__' or ufunc.__name__ not in self.__DISPATCH_TABLE:
return NotImplemented
return self.__DISPATCH_TABLE[ufunc.__name__](ufunc, *inputs, **kwargs)

def __array_function__(self, func, types, args, kwargs):
if func.__name__ not in self.__DISPATCH_TABLE:
return NotImplemented
return self.__DISPATCH_TABLE[func.__name__](func, *args, **kwargs)

shape = _attr('shape')
size = _attr('size')
ndim = _attr('ndim')
@classmethod
def __nutils_function__(cls, func, args, kwargs):
if func.__name__ not in cls.__DISPATCH_TABLE:
return NotImplemented
return cls.__DISPATCH_TABLE[func.__name__](func, *args, **kwargs)


class Units(dict):
Expand Down Expand Up @@ -343,13 +395,13 @@ def _split_factors(s):

units = Units()

units.m = Length.reference_quantity
units.s = Time.reference_quantity
units.g = Mass.reference_quantity * 1e-3
units.A = ElectricCurrent.reference_quantity
units.K = Temperature.reference_quantity
units.mol = AmountOfSubstance.reference_quantity
units.cd = LuminousIntensity.reference_quantity
units.m = Length.__wrap__(1.)
units.s = Time.__wrap__(1.)
units.g = Mass.__wrap__(1e-3)
units.A = ElectricCurrent.__wrap__(1.)
units.K = Temperature.__wrap__(1.)
units.mol = AmountOfSubstance.__wrap__(1.)
units.cd = LuminousIntensity.__wrap__(1.)

units.N = 'kg*m/s2' # newton
units.Pa = 'N/m2' # pascal
Expand Down
4 changes: 0 additions & 4 deletions tests/test_SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ class Dimension(unittest.TestCase):

def test_multiply(self):
self.assertEqual(SI.Velocity * SI.Time, SI.Length)
self.assertEqual(SI.Length * int, SI.Length)
self.assertEqual(float * SI.Time, SI.Time)

def test_divide(self):
self.assertEqual(SI.Length / SI.Time, SI.Velocity)
self.assertEqual(SI.Length / int, SI.Length)
self.assertEqual(float / SI.Time, SI.Frequency)

def test_power(self):
self.assertEqual(SI.Length**2, SI.Area)
Expand Down