Skip to content

Commit

Permalink
simplify Pointwise of tranposed or inserted axes
Browse files Browse the repository at this point in the history
The commit adds simplifications for `evaluable.Pointwise` subclasses with a
single transposed argument, or arguments that have a common inserted axis.
  • Loading branch information
joostvanzwieten committed Dec 11, 2020
1 parent ea18825 commit a057ef2
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,7 @@ def __init__(self, *args:asarrays):
retval = self.evalf(*[numpy.ones((), dtype=arg.dtype) for arg in args])
shapes = set(arg.shape for arg in args)
assert len(shapes) == 1, 'pointwise arguments have inconsistent shapes'
shape, = shapes
shape = tuple(axes[0] if len(set(axes)) == 1 and not isinstance(axes[0], Sparse) else Axis(axes[0].length) for axes in zip(*(arg._axes for arg in args)))
self.args = args
super().__init__(args=args, shape=shape, dtype=retval.dtype)

Expand All @@ -1869,6 +1869,12 @@ def _simplified(self):
if self.isconstant:
retval = self.eval()
return Constant(retval)
if len(self.args) == 1 and isinstance(self.args[0], Transpose):
arg, = self.args
return Transpose(self.__class__(arg.func), arg.axes)
for i in reversed(range(self.ndim)):
if all(isinstance(arg._axes[i], Inserted) for arg in self.args):
return insertaxis(self.__class__(*(arg._uninsert(i) for arg in self.args)), i, self.shape[i])

def _derivative(self, var, seen):
if self.deriv is None:
Expand Down

0 comments on commit a057ef2

Please sign in to comment.