Skip to content

Commit

Permalink
handle trivial transposes directly
Browse files Browse the repository at this point in the history
This patch makes the trans argument of evaluable.transpose mandatory and adds a
simplification path for the trivial case. As transpose is used in the pre- and
post processing of almost all other operations there are considerable merits to
handling this case directly. To avoid double work, the Transpose operation no
longer accepts the trivial permutation as an argument.
  • Loading branch information
gertjanvanzwieten committed Apr 10, 2024
1 parent 7a56a1d commit b64aaee
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
58 changes: 27 additions & 31 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,14 +533,12 @@ def imag(arg):
return zeros_like(arg)


def transpose(arg, trans=None):
def transpose(arg, trans):
arg = asarray(arg)
if trans is None:
normtrans = tuple(range(arg.ndim-1, -1, -1))
else:
normtrans = tuple(numeric.normdim(arg.ndim, sh).__index__() for sh in trans)
assert sorted(normtrans) == list(range(arg.ndim))
return Transpose(arg, normtrans)
trans = tuple(i.__index__() for i in trans)
if all(i == n for i, n in enumerate(trans)):
return arg
return Transpose(arg, trans)


def swapaxes(arg, axis1, axis2):
Expand Down Expand Up @@ -621,7 +619,7 @@ def unalign(*args, naxes: int = None):
if not ret: # first argument
commonwhere = tuple(i for i in where if i < naxes)
if where != commonwhere + keep:
unaligned = Transpose(unaligned, tuple(where.index(n) for n in commonwhere + keep))
unaligned = transpose(unaligned, tuple(where.index(n) for n in commonwhere + keep))
ret.append(unaligned)
return (*ret, commonwhere)

Expand Down Expand Up @@ -759,7 +757,7 @@ def __index__(self):
index = self.__index = int(self.simplified.eval())
return index

T = property(lambda self: transpose(self))
T = property(lambda self: transpose(self, tuple(range(self.ndim-1, -1, -1))))

__add__ = __radd__ = add
__sub__ = lambda self, other: subtract(self, other)
Expand Down Expand Up @@ -1220,7 +1218,7 @@ def _flatten_constant(self):
class Transpose(Array):

@classmethod
def _mk_axes(cls, ndim, axes, invert=False):
def _mk_axes(cls, ndim, axes):
axes = [numeric.normdim(ndim, axis) for axis in axes]
if all(a == b for a, b in enumerate(axes, start=ndim-len(axes))):
return
Expand Down Expand Up @@ -1248,6 +1246,7 @@ def __init__(self, func: Array, axes: typing.Tuple[int, ...]):
assert isinstance(func, Array), f'func={func!r}'
assert isinstance(axes, tuple) and all(isinstance(axis, int) for axis in axes), f'axes={axes!r}'
assert sorted(axes) == list(range(func.ndim))
assert axes != tuple(range(func.ndim))
self.func = func
self.axes = axes
super().__init__(args=(func,), shape=tuple(func.shape[n] for n in axes), dtype=func.dtype)
Expand All @@ -1258,7 +1257,7 @@ def _diagonals(self):

@cached_property
def _inflations(self):
return tuple((self._invaxes[axis], types.frozendict((dofmap, Transpose(func, self._axes_for(dofmap.ndim, self._invaxes[axis]))) for dofmap, func in parts.items())) for axis, parts in self.func._inflations)
return tuple((self._invaxes[axis], types.frozendict((dofmap, transpose(func, self._axes_for(dofmap.ndim, self._invaxes[axis]))) for dofmap, func in parts.items())) for axis, parts in self.func._inflations)

@cached_property
def _unaligned(self):
Expand All @@ -1270,8 +1269,6 @@ def _invaxes(self):
return tuple(n.__index__() for n in numpy.argsort(self.axes))

def _simplified(self):
if self.axes == tuple(range(self.ndim)):
return self.func
return self.func._transpose(self.axes)

def evalf(self, arr):
Expand All @@ -1289,22 +1286,22 @@ def _transpose(self, axes):
# to separately account for the trivial case.
return self.func
newaxes = tuple(self.axes[i] for i in axes)
return Transpose(self.func, newaxes)
return transpose(self.func, newaxes)

def _takediag(self, axis1, axis2):
assert axis1 < axis2
orig1, orig2 = sorted(self.axes[axis] for axis in [axis1, axis2])
trytakediag = self.func._takediag(orig1, orig2)
if trytakediag is not None:
exclude_orig = [ax-(ax > orig1)-(ax > orig2) for ax in self.axes[:axis1] + self.axes[axis1+1:axis2] + self.axes[axis2+1:]]
return Transpose(trytakediag, (*exclude_orig, self.ndim-2))
return transpose(trytakediag, (*exclude_orig, self.ndim-2))

def _sum(self, i):
axis = self.axes[i]
trysum = self.func._sum(axis)
if trysum is not None:
axes = tuple(ax-(ax > axis) for ax in self.axes if ax != axis)
return Transpose(trysum, axes)
return transpose(trysum, axes)

def _derivative(self, var, seen):
return transpose(derivative(self.func, var, seen), self.axes+tuple(range(self.ndim, self.ndim+var.ndim)))
Expand All @@ -1331,7 +1328,7 @@ def _add(self, other):
def _take(self, indices, axis):
trytake = self.func._take(indices, self.axes[axis])
if trytake is not None:
return Transpose(trytake, self._axes_for(indices.ndim, axis))
return transpose(trytake, self._axes_for(indices.ndim, axis))

def _axes_for(self, ndim, axis):
funcaxis = self.axes[axis]
Expand All @@ -1352,7 +1349,7 @@ def _unravel(self, axis, shape):
if tryunravel is not None:
axes = [ax + (ax > orig_axis) for ax in self.axes]
axes.insert(axis+1, orig_axis+1)
return Transpose(tryunravel, tuple(axes))
return transpose(tryunravel, tuple(axes))

def _product(self):
if self.axes[-1] == self.ndim-1:
Expand All @@ -1363,7 +1360,7 @@ def _determinant(self, axis1, axis2):
trydet = self.func._determinant(orig1, orig2)
if trydet:
axes = tuple(ax-(ax > orig1)-(ax > orig2) for ax in self.axes if ax != orig1 and ax != orig2)
return Transpose(trydet, axes)
return transpose(trydet, axes)

def _inverse(self, axis1, axis2):
tryinv = self.func._inverse(self.axes[axis1], self.axes[axis2])
Expand All @@ -1381,7 +1378,7 @@ def _inflate(self, dofmap, length, axis):
if tryinflate is not None:
axes = [ax-(ax > i)*(dofmap.ndim-1) for ax in self.axes]
axes[axis:axis+dofmap.ndim] = i,
return Transpose(tryinflate, tuple(axes))
return transpose(tryinflate, tuple(axes))

def _diagonalize(self, axis):
trydiagonalize = self.func._diagonalize(self.axes[axis])
Expand Down Expand Up @@ -2680,23 +2677,22 @@ def Elemwise(data: typing.Tuple[types.arraydata, ...], index: Array, dtype: Dtyp
index = Take(constant(indices), index)
# Move all axes with constant shape to the left and ravel the remainder.
is_constant = numpy.all(shapes[1:] == shapes[0], axis=0)
nconstant = is_constant.sum()
reorder = numpy.argsort(~is_constant)
raveled = [numpy.transpose(d, reorder).reshape(*shapes[0, reorder[:nconstant]], -1) for d in unique]
const_axes = tuple(is_constant.nonzero()[0])
var_axes = tuple((~is_constant).nonzero()[0])
raveled = [numpy.transpose(d, const_axes + var_axes).reshape(*shapes[0, const_axes], -1) for d in unique]
# Concatenate the raveled axis, take slices, unravel and reorder the axes to
# the original position.
concat = constant(numpy.concatenate(raveled, axis=-1))
if is_constant.all():
if not var_axes:
return Take(concat, index)
var_shape = tuple(shape[i] for i in reorder[nconstant:])
cumprod = list(var_shape)
for i in reversed(range(len(var_shape)-1)):
cumprod = [shape[i] for i in var_axes]
for i in reversed(range(len(cumprod)-1)):
cumprod[i] *= cumprod[i+1] # work backwards so that the shape check matches in Unravel
offsets = _SizesToOffsets(asarray([d.shape[-1] for d in raveled]))
elemwise = Take(concat, Range(cumprod[0]) + Take(offsets, index))
for i in range(len(var_shape)-1):
elemwise = Unravel(elemwise, var_shape[i], cumprod[i+1])
return Transpose.inv(elemwise, reorder)
for i in range(len(var_axes)-1):
elemwise = Unravel(elemwise, shape[var_axes[i]], cumprod[i+1])
return Transpose.from_end(elemwise, *var_axes)


class Eig(Evaluable):
Expand Down Expand Up @@ -5121,7 +5117,7 @@ def einsum(fmt, *args, **dims):
if c not in index:
index[c] = arg.ndim
arg = InsertAxis(arg, shapes[c])
v = Transpose(arg, tuple(index[c] for c in sall))
v = transpose(arg, tuple(index[c] for c in sall))
ret = v if ret is None else ret * v
for i in range(len(sout), len(sall)):
ret = Sum(ret)
Expand Down
2 changes: 1 addition & 1 deletion nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def lower(self, args: LowerArgs) -> evaluable.Array:
arg = self._arg.lower(args)
offset = len(args.points_shape)
axes = (*range(offset), *(i+offset for i in self._axes))
return evaluable.Transpose(arg, axes)
return evaluable.transpose(arg, axes)


class _Opposite(Array):
Expand Down

0 comments on commit b64aaee

Please sign in to comment.