Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Apr 16, 2024
1 parent eb69386 commit bf0ef32
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 11 deletions.
157 changes: 149 additions & 8 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def _optimized_for_numpy1(obj):
def _optimized_for_numpy(self):
return

@util.shallow_replace
def _merge_inflates(self):
if isinstance(self, Inflate):
func = self.func
dofmaps = self.dofmap,
dofshape = self.dofmap.shape
lengths = self.length,
while True:
for axis, parts in func._inflations:
if axis == func.ndim - len(dofshape) - 1 and len(parts) == 1:
lengths = (func.shape[axis],) + lengths
(dofmap, func), = parts.items()
dofmaps = appendaxes(dofmap, dofshape), *[prependaxes(d, dofmap.shape) for d in dofmaps]
dofshape = dofmap.shape + dofshape
break # continue the outer while loop
else:
break
# merge inner inflates
func = func._merge_inflates()
return _MultiInflate(func, dofmaps, lengths)

@cached_property
def _loop_deps(self):
deps = util.IDSet()
Expand Down Expand Up @@ -969,6 +990,11 @@ def _unaligned(self):
_diagonals = ()
_inflations = ()

@property
def _as_range_with_offset(self):
if self.ndim == 1 and _equals_scalar_constant(self.shape[0], 1):
return self.shape[0], _OfnTakeUnit(self, constant(0))

def _derivative(self, var, seen):
if self.dtype in (bool, int) or var not in self.arguments:
return Zeros(self.shape + var.shape, dtype=self.dtype)
Expand Down Expand Up @@ -1690,6 +1716,12 @@ def _sum(self, axis):
summed = sum(multiply(*factors[:i], *factors[i+1:]), axis)
return summed * align(unaligned, [i-(i > axis) for i in where], summed.shape)

def _loopsum(self, index):
factors = tuple(self._factors)
for i, fi in enumerate(factors):
if index not in fi.arguments:
return fi * loop_sum(multiply(*factors[:i], *factors[i+1:]), index)

def _add(self, other):
factors = list(self._factors)
other_factors = []
Expand Down Expand Up @@ -1937,6 +1969,16 @@ def _intbounds_impl(self):
lowers, uppers = zip(*[f._intbounds for f in self._terms])
return builtins.sum(lowers), builtins.sum(uppers)

@cached_property
def _as_range_with_offset(self):
if self.ndim != 1:
return
for func1, func2 in (self.funcs, reversed(tuple(self.funcs))):
if (length_offset1 := func1._as_range_with_offset) is not None and isinstance(func2, InsertAxis):
length, offset1 = length_offset1
offset2 = func2.func
return length, offset1 + offset2


class Einsum(Array):

Expand Down Expand Up @@ -2000,6 +2042,10 @@ def _takediag(self, axis1, axis2):
args_idx = tuple(tuple(ikeep if i == irm else i for i in idx) for idx in self.args_idx)
return Einsum(self.args, args_idx, self.out_idx[:axis1] + self.out_idx[axis1+1:axis2] + self.out_idx[axis2+1:] + (ikeep,))

def _intbounds_impl(self):
lower = 0 if all(arg._intbounds[0] >= 0 for arg in self.args) else float('-inf')
return lower, float('inf')


class Sum(Array):

Expand Down Expand Up @@ -2134,6 +2180,7 @@ class Take(Array):
def __init__(self, func: Array, indices: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert isinstance(indices, Array) and indices.dtype == int, f'indices={indices!r}'
#assert isinstance(indices, Array) and indices.dtype == int and indices._intbounds[0] >= 0, f'indices={indices!r}'
self.func = func
self.indices = indices
super().__init__(args=(func, indices), shape=func.shape[:-1]+indices.shape, dtype=func.dtype)
Expand All @@ -2153,6 +2200,13 @@ def _simplified(self):
if axis == self.func.ndim - 1:
return util.sum(Inflate(func, dofmap, self.func.shape[-1])._take(self.indices, self.func.ndim - 1) for dofmap, func in parts.items())

def _optimized_for_numpy(self):
if self.indices.ndim == 0:
return _OfnTakeUnit(self.func, self.indices)
if (length_offset := self.indices._as_range_with_offset) is not None:
length, offset = length_offset
return _OfnTakeSlice(self.func, offset, length)

@staticmethod
def evalf(arr, indices):
return arr[..., indices]
Expand Down Expand Up @@ -2182,6 +2236,48 @@ def _intbounds_impl(self):
return self.func._intbounds


class _OfnTakeUnit(Array):

def __init__(self, func: Array, index: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert isinstance(index, Array) and index.dtype == int and index.ndim == 0, f'index={index!r}'
#assert _isindex(index), f'index={index!r}'
self.func = func
self.index = index
super().__init__(args=(func, index), shape=func.shape[:-1], dtype=func.dtype)

def evalf(self, func, index):
return func[...,index]

def _compile_expression(self, add_const, func, index):
return func.get_item(_pyast.Tuple((_pyast.Raw('...'), index)))

def _intbounds_impl(self):
return self.func._intbounds


class _OfnTakeSlice(Array):

def __init__(self, func: Array, start: Array, length: Array):
assert isinstance(func, Array) and func.ndim > 0, f'func={func!r}'
assert _isindex(start), f'start={start!r}'
assert _isindex(length), f'stop={stop!r}'
self.func = func
self.start = start
self.length = length
super().__init__(args=(func, start, length), shape=(*func.shape[:-1], length), dtype=func.dtype)

def evalf(self, func, start, length):
return func[...,start:start+length]

def _compile_expression(self, add_const, func, start, length):
stop = _pyast.BinOp(start, '+', length)
return func.get_item(_pyast.Tuple((_pyast.Raw('...'), _pyast.Raw('slice').call(start, stop))))

def _intbounds_impl(self):
return self.func._intbounds


class Power(Array):

def __init__(self, func: Array, power: Array):
Expand Down Expand Up @@ -3204,6 +3300,43 @@ def _intbounds_impl(self):
return min(lower, 0), max(upper, 0)


class _MultiInflate(Array):

def __init__(self, func: Array, dofmaps: typing.Tuple[Array, ...], lengths: typing.Tuple[Array, ...]):
assert isinstance(func, Array), f'func={func!r}'
assert isinstance(dofmaps, tuple) and all(isinstance(dofmap, Array) and dofmap.dtype == int for dofmap in dofmaps), f'dofmaps={dofmaps!r}'
assert isinstance(lengths, tuple) and all(map(_isindex, lengths)), f'lengths={lengths!r}'
assert len(dofmaps) == len(lengths)
self._ndofmaps = len(dofmaps)
self._npointwise = func.ndim - dofmaps[0].ndim
assert self._npointwise >= 0
assert all(dofmap.ndim == dofmaps[0].ndim and equalshape(dofmap.shape, func.shape[self._npointwise:]) for dofmap in dofmaps)
self.func = func
self.dofmaps = dofmaps
self.lengths = lengths
self._indices_head = (slice(None),) * self._npointwise
super().__init__(args=(func, *dofmaps, *lengths), shape=func.shape[:self._npointwise] + lengths, dtype=func.dtype)

def evalf(self, array, *args):
inflated = numpy.zeros(array.shape[:self._npointwise] + tuple(args[self._ndofmaps:]), dtype=self.dtype)
numpy.add.at(inflated, self._indices_head + tuple(args[:self._ndofmaps]), array)
return inflated

@cached_property
def _can_compile_iadd(self):
return True

def _compile_iadd(self, builder, out, zeroed_block_id):
indices = (_pyast.Raw('slice(None)'),) * self._npointwise
indices += tuple(builder.compile(self.dofmaps))
values = builder.compile(self.func)
builder.get_block_for_evaluable(self).array_add_at(out, _pyast.Tuple(indices), values)

def _intbounds_impl(self):
lower, upper = self.func._intbounds
return min(lower, 0), max(upper, 0)


class SwapInflateTake(Evaluable):

def __init__(self, inflateidx, takeidx):
Expand Down Expand Up @@ -3397,6 +3530,9 @@ def _simplified(self):
if self.isconstant:
return constant(self.eval())

def _intbounds_impl(self):
return 0, self.shape[0]._intbounds[1]


class DerivativeTargetBase(Array):
'base class for derivative targets'
Expand Down Expand Up @@ -3793,6 +3929,10 @@ def _intbounds_impl(self):
assert lower >= 0
return 0, max(0, upper - 1)

@cached_property
def _as_range_with_offset(self):
return self.length, zeros((), dtype=int)


class InRange(Array):

Expand Down Expand Up @@ -3905,7 +4045,7 @@ def __init__(self, ncoeffs: Array, nvars: int) -> None:
super().__init__(args=(ncoeffs,), shape=(), dtype=int)

def evalf(self, ncoeffs):
return numpy.array(poly.degree(self.nvars, ncoeffs.__index__()))
return numpy.int_(poly.degree(self.nvars, ncoeffs.__index__()))

def _compile_expression(self, add_constant, ncoeffs):
ncoeffs = ncoeffs.get_attr('__index__').call()
Expand Down Expand Up @@ -3949,7 +4089,7 @@ def __init__(self, nvars: int, degree: Array) -> None:
super().__init__(args=(degree,), shape=(), dtype=int)

def evalf(self, degree):
return numpy.array(poly.ncoeffs(self.nvars, degree.__index__()))
return numpy.int_(poly.ncoeffs(self.nvars, degree.__index__()))

def _compile_expression(self, add_constant, degree):
degree = degree.get_attr('__index__').call()
Expand Down Expand Up @@ -4733,12 +4873,12 @@ def _add(self, other):
if isinstance(other, LoopSum) and other.index == self.index:
return loop_sum(self.func + other.func, self.index)

def _multiply(self, other):
# If `other` depends on `self.index`, e.g. because `self` is the inner
# loop of two nested `LoopSum`s over the same index, then we should not
# move `other` inside this loop.
if self.index not in other.arguments:
return loop_sum(self.func * other, self.index)
#def _multiply(self, other):
# # If `other` depends on `self.index`, e.g. because `self` is the inner
# # loop of two nested `LoopSum`s over the same index, then we should not
# # move `other` inside this loop.
# if self.index not in other.arguments:
# return loop_sum(self.func * other, self.index)

@cached_property
def _assparse(self):
Expand Down Expand Up @@ -5667,6 +5807,7 @@ def compile(func, *, simplify: bool = True, _plot_stats: typing.Optional[bool] =
# Simplify and optimize `funcs`.
if simplify:
funcs = [func.simplified for func in funcs]
funcs = [func._merge_inflates() for func in funcs]
funcs = [func._optimized_for_numpy1 for func in funcs]
funcs = _define_loop_block_structure(tuple(funcs))

Expand Down
2 changes: 1 addition & 1 deletion nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,7 @@ def __init__(self, ndofs: int, nelems: int, index: Array, coords: Array) -> None
arguments = _join_arguments((index.arguments, coords.arguments))
super().__init__((ndofs,), float, spaces=index.spaces | coords.spaces, arguments=arguments)

_index = evaluable.Argument('_index', shape=(), dtype=int)
_index = evaluable.InRange(evaluable.Argument('_index', shape=(), dtype=int), evaluable.constant(self.nelems))
self._arg_dofs_evaluable, self._arg_coeffs_evaluable = self.f_dofs_coeffs(_index)
self._arg_ndofs_evaluable = evaluable.asarray(self._arg_dofs_evaluable.shape[0])
assert self._arg_dofs_evaluable.ndim == 1
Expand Down
3 changes: 2 additions & 1 deletion nutils/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,8 @@ def _asaffine(self, geom, arguments):
geom_ = self.sample('uniform', n).eval(geom, **arguments) \
.reshape(*self.shape, *[n] * self.ndims, self.ndims) \
.transpose(*(i+j for i in range(self.ndims) for j in (0, self.ndims)), self.ndims*2) \
.reshape(*sampleshape, self.ndims)
.reshape(*sampleshape, self.ndims) \
.copy()
# strategy: fit an affine plane through the minima and maxima of a uniform sample,
# and evaluate the error as the largest difference on the remaining sample points
xmin, xmax = geom_.reshape(-1, self.ndims)[[0, -1]]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def assertInterfaces(self, domain, geom, periodic, interfaces=None, elemindicato
elemindicator = domain.basis('discont', degree=0)
elemindicator = elemindicator.vector(domain.ndims)
lhs = domain.integrate((elemindicator*f.grad(geom)[None]).sum(axis=1)*function.J(geom), ischeme='gauss2')
rhs = interfaces.integrate((-function.jump(elemindicator)*f*function.normal(geom)[None]).sum(axis=1)*function.J(geom), ischeme='gauss2')
rhs = interfaces.integrate((-function.jump(elemindicator)*f*function.normal(geom)[None]).sum(axis=1)*function.J(geom), ischeme='gauss2').copy()
if len(domain.boundary):
rhs += domain.boundary.integrate((elemindicator*f*function.normal(geom)[None]).sum(axis=1)*function.J(geom), ischeme='gauss2')
numpy.testing.assert_array_almost_equal(lhs, rhs)
Expand Down

0 comments on commit bf0ef32

Please sign in to comment.