Skip to content

Commit

Permalink
wip: misc
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed May 23, 2024
1 parent 69323be commit b846344
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 7 deletions.
146 changes: 140 additions & 6 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ def _optimized_for_numpy1(obj):
def _optimized_for_numpy(self):
return

@util.shallow_replace
def _merge_inflates(self):
if isinstance(self, Inflate):
func = self.func
dofmaps = self.dofmap,
dofshape = self.dofmap.shape
lengths = self.length,
while True:
for axis, parts in func._inflations:
if axis == func.ndim - len(dofshape) - 1 and len(parts) == 1:
lengths = (func.shape[axis],) + lengths
(dofmap, func), = parts.items()
dofmaps = appendaxes(dofmap, dofshape), *[prependaxes(d, dofmap.shape) for d in dofmaps]
dofshape = dofmap.shape + dofshape
break # continue the outer while loop
else:
break
# merge inner inflates
func = func._merge_inflates()
return _MultiInflate(func, dofmaps, lengths)

@cached_property
def _loops(self):
deps = util.IDSet()
Expand Down Expand Up @@ -722,6 +743,11 @@ def _unaligned(self):
_diagonals = ()
_inflations = ()

@property
def _as_range_with_offset(self):
if self.ndim == 1 and _equals_scalar_constant(self.shape[0], 1):
return self.shape[0], _OfnTakeUnit(self, constant(0))

def _derivative(self, var, seen):
if self.dtype in (bool, int) or var not in self.arguments:
return Zeros(self.shape + var.shape, dtype=self.dtype)
Expand Down Expand Up @@ -1434,6 +1460,12 @@ def _sum(self, axis):
summed = sum(multiply(*factors[:i], *factors[i+1:]), axis)
return summed * align(unaligned, [i-(i > axis) for i in where], summed.shape)

def _loopsum(self, index):
factors = tuple(self._factors)
for i, fi in enumerate(factors):
if index not in fi.arguments:
return fi * loop_sum(multiply(*factors[:i], *factors[i+1:]), index)

def _add(self, other):
factors = list(self._factors)
other_factors = []
Expand Down Expand Up @@ -1681,6 +1713,16 @@ def _intbounds_impl(self):
lowers, uppers = zip(*[f._intbounds for f in self._terms])
return builtins.sum(lowers), builtins.sum(uppers)

@cached_property
def _as_range_with_offset(self):
if self.ndim != 1:
return
for func1, func2 in (self.funcs, reversed(tuple(self.funcs))):
if (length_offset1 := func1._as_range_with_offset) is not None and isinstance(func2, InsertAxis):
length, offset1 = length_offset1
offset2 = func2.func
return length, offset1 + offset2


class Einsum(Array):

Expand Down Expand Up @@ -1727,6 +1769,10 @@ def _optimized_for_numpy(self):
continue
return Einsum(self.args[:i]+(arg.func,)+self.args[i+1:], self.args_idx[:i]+(idx,)+self.args_idx[i+1:], self.out_idx)

def _intbounds_impl(self):
lower = 0 if all(arg._intbounds[0] >= 0 for arg in self.args) else float('-inf')
return lower, float('inf')


class Sum(Array):

Expand Down Expand Up @@ -1850,6 +1896,7 @@ class Take(Array):
def __init__(self, func: Array, indices: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert isinstance(indices, Array) and indices.dtype == int, f'indices={indices!r}'
#assert isinstance(indices, Array) and indices.dtype == int and indices._intbounds[0] >= 0, f'indices={indices!r}'
self.func = func
self.indices = indices
super().__init__(args=(func, indices), shape=func.shape[:-1]+indices.shape, dtype=func.dtype)
Expand All @@ -1869,6 +1916,13 @@ def _simplified(self):
if axis == self.func.ndim - 1:
return util.sum(Inflate(func, dofmap, self.func.shape[-1])._take(self.indices, self.func.ndim - 1) for dofmap, func in parts.items())

def _optimized_for_numpy(self):
if self.indices.ndim == 0:
return _OfnTakeUnit(self.func, self.indices)
if (length_offset := self.indices._as_range_with_offset) is not None:
length, offset = length_offset
return _OfnTakeSlice(self.func, offset, length)

def _compile_expression(self, py_self, arr, indices):
return _pyast.Variable('numpy').get_attr('take').call(arr, indices, axis=_pyast.LiteralInt(-1))

Expand All @@ -1894,6 +1948,42 @@ def _intbounds_impl(self):
return self.func._intbounds


class _OfnTakeUnit(Array):

def __init__(self, func: Array, index: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert isinstance(index, Array) and index.dtype == int and index.ndim == 0, f'index={index!r}'
#assert _isindex(index), f'index={index!r}'
self.func = func
self.index = index
super().__init__(args=(func, index), shape=func.shape[:-1], dtype=func.dtype)

def _compile_expression(self, py_self, func, index):
return func.get_item(_pyast.Tuple((_pyast.Raw('...'), index)))

def _intbounds_impl(self):
return self.func._intbounds


class _OfnTakeSlice(Array):

def __init__(self, func: Array, start: Array, length: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert _isindex(start), f'start={start!r}'
assert _isindex(length), f'stop={stop!r}'
self.func = func
self.start = start
self.length = length
super().__init__(args=(func, start, length), shape=(*func.shape[:-1], length), dtype=func.dtype)

def _compile_expression(self, py_self, func, start, length):
stop = _pyast.BinOp(start, '+', length)
return func.get_item(_pyast.Tuple((_pyast.Raw('...'), _pyast.Raw('slice').call(start, stop))))

def _intbounds_impl(self):
return self.func._intbounds


class Power(Array):

def __init__(self, func: Array, power: Array):
Expand Down Expand Up @@ -2908,6 +2998,42 @@ def _intbounds_impl(self):
return min(lower, 0), max(upper, 0)


class _MultiInflate(Array):

def __init__(self, func: Array, dofmaps: typing.Tuple[Array, ...], lengths: typing.Tuple[Array, ...]):
assert isinstance(func, Array), f'func={func!r}'
assert isinstance(dofmaps, tuple) and all(isinstance(dofmap, Array) and dofmap.dtype == int for dofmap in dofmaps), f'dofmaps={dofmaps!r}'
assert isinstance(lengths, tuple) and all(map(_isindex, lengths)), f'lengths={lengths!r}'
assert len(dofmaps) == len(lengths)
self._ndofmaps = len(dofmaps)
self._npointwise = func.ndim - dofmaps[0].ndim
assert self._npointwise >= 0
assert all(dofmap.ndim == dofmaps[0].ndim and equalshape(dofmap.shape, func.shape[self._npointwise:]) for dofmap in dofmaps)
self.func = func
self.dofmaps = dofmaps
self.lengths = lengths
self._indices_head = (slice(None),) * self._npointwise
super().__init__(args=(func, *dofmaps, *lengths), shape=func.shape[:self._npointwise] + lengths, dtype=func.dtype)

def evalf(self, array, *args):
inflated = numpy.zeros(array.shape[:self._npointwise] + tuple(args[self._ndofmaps:]), dtype=self.dtype)
numpy.add.at(inflated, self._indices_head + tuple(args[:self._ndofmaps]), array)
return inflated

def _compile_with_out(self, builder, out, out_block_id, mode):
assert mode in ('iadd', 'assign')
if mode == 'assign':
builder.get_block_for_evaluable(self, block_id=out_block_id, comment='zero').array_fill_zeros(out)
indices = (_pyast.Raw('slice(None)'),) * self._npointwise
indices += tuple(builder.compile(self.dofmaps))
values = builder.compile(self.func)
builder.get_block_for_evaluable(self).array_add_at(out, _pyast.Tuple(indices), values)

def _intbounds_impl(self):
lower, upper = self.func._intbounds
return min(lower, 0), max(upper, 0)


class SwapInflateTake(Evaluable):

def __init__(self, inflateidx, takeidx):
Expand Down Expand Up @@ -3086,6 +3212,9 @@ def _simplified(self):
if self.isconstant:
return constant(self.eval())

def _intbounds_impl(self):
return 0, self.shape[0]._intbounds[1]


class DerivativeTargetBase(Array):
'base class for derivative targets'
Expand Down Expand Up @@ -3461,6 +3590,10 @@ def _intbounds_impl(self):
assert lower >= 0
return 0, max(0, upper - 1)

@cached_property
def _as_range_with_offset(self):
return self.length, zeros((), dtype=int)


class InRange(Array):

Expand Down Expand Up @@ -4252,12 +4385,12 @@ def _add(self, other):
if isinstance(other, LoopSum) and other.index == self.index:
return loop_sum(self.func + other.func, self.index)

def _multiply(self, other):
# If `other` depends on `self.index`, e.g. because `self` is the inner
# loop of two nested `LoopSum`s over the same index, then we should not
# move `other` inside this loop.
if self.index not in other.arguments:
return loop_sum(self.func * other, self.index)
#def _multiply(self, other):
# # If `other` depends on `self.index`, e.g. because `self` is the inner
# # loop of two nested `LoopSum`s over the same index, then we should not
# # move `other` inside this loop.
# if self.index not in other.arguments:
# return loop_sum(self.func * other, self.index)

@cached_property
def _assparse(self):
Expand Down Expand Up @@ -5198,6 +5331,7 @@ def compile(func, *, simplify: bool = True, stats: typing.Optional[bool] = None,
# Simplify and optimize `funcs`.
if simplify:
funcs = [func.simplified for func in funcs]
funcs = [func._merge_inflates() for func in funcs]
funcs = [func._optimized_for_numpy1 for func in funcs]
funcs = _define_loop_block_structure(tuple(funcs))
assert not any(isinstance(arg, _LoopIndex) for func in funcs for arg in func.arguments)
Expand Down
2 changes: 1 addition & 1 deletion nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ def __init__(self, ndofs: int, nelems: int, index: Array, coords: Array) -> None
arguments = _join_arguments((index.arguments, coords.arguments))
super().__init__((ndofs,), float, spaces=index.spaces | coords.spaces, arguments=arguments)

_index = evaluable.Argument('_index', shape=(), dtype=int)
_index = evaluable.InRange(evaluable.Argument('_index', shape=(), dtype=int), evaluable.constant(self.nelems))
self._arg_dofs_evaluable, self._arg_coeffs_evaluable = self.f_dofs_coeffs(_index)
self._arg_ndofs_evaluable = evaluable.asarray(self._arg_dofs_evaluable.shape[0])
assert self._arg_dofs_evaluable.ndim == 1
Expand Down

0 comments on commit b846344

Please sign in to comment.