Skip to content

Commit

Permalink
rewrite function.Custom without EvaluableConstant
Browse files Browse the repository at this point in the history
The arguments passed to the convenience class `function.Custom` are lowered to
`evaluable.Array` if the argument is an instance of `function.Array` and to
`EvaluableConstant` if something else. This patch eliminates the use of
`EvaluableConstant` by passing the non-`function.Array` arguments to
`_CustomEvaluable` as is.
  • Loading branch information
joostvanzwieten committed Dec 20, 2023
1 parent 9d89a1e commit f5c34de
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointw
super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else evaluable.EvaluableConstant(arg) for arg in self._args) # type: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...]
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else arg for arg in self._args)
add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise]))
points_shape = args.points_shape + add_points_shape
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in args.coordinates.items()}
Expand Down Expand Up @@ -758,8 +758,7 @@ def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray:

class _CustomEvaluable(evaluable.Array):

def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...], shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
assert all(isinstance(arg, (evaluable.Array, evaluable.EvaluableConstant)) for arg in args)
def __init__(self, name, evalf, partial_derivative, args, shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
self.name = name
self.custom_evalf = evalf
self.custom_partial_derivative = partial_derivative
Expand All @@ -768,7 +767,7 @@ def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.
self.lower_args = lower_args
self.spaces = spaces
self.function_arguments = arguments
super().__init__((evaluable.Tuple(lower_args.points_shape), *args), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)
super().__init__((evaluable.Tuple(lower_args.points_shape), *(arg for arg in args if isinstance(arg, evaluable.Array))), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)

@property
def _node_details(self) -> str:
Expand All @@ -777,11 +776,18 @@ def _node_details(self) -> str:
def evalf(self, points_shape: Tuple[numpy.ndarray, ...], *args: Any) -> numpy.ndarray:
points_shape = tuple(n.__index__() for n in points_shape)
npoints = util.product(points_shape, 1)
# Flatten the points axes of the array arguments and call `custom_evalf`.
flattened = (arg.reshape(npoints, *arg.shape[self.points_dim:]) if isinstance(origarg, evaluable.Array) else arg for arg, origarg in zip(args, self.args))
# Flatten the points axes of the evaluable arguments, merge with the
# unevaluable arguments and call `custom_evalf`.
flattened = []
args = iter(args)
for arg in self.args:
if isinstance(arg, evaluable.Array):
arg = next(args)
arg = arg.reshape(npoints, *arg.shape[self.points_dim:])
flattened.append(arg)
result = self.custom_evalf(*flattened)
assert result.ndim == self.ndim + 1 - self.points_dim
# Unflatten the points axes of the result. If there are no array arguments,
# Unflatten the points axes of the result. If there are no arguments,
# the points axis must have length one. Otherwise the length must be
# `npoints` (checked by `reshape`).
if not any(isinstance(origarg, evaluable.Array) for origarg in self.args):
Expand All @@ -795,7 +801,7 @@ def _derivative(self, var: evaluable.Array, seen: Dict[evaluable.Array, evaluabl
if self.dtype in (bool, int):
return super()._derivative(var, seen)
result = evaluable.Zeros(self.shape + var.shape, dtype=self.dtype)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg.value for arg in self.args)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg for arg in self.args)
for iarg, arg in enumerate(self.args):
if not isinstance(arg, evaluable.Array) or arg.dtype in (bool, int) or var not in arg.arguments and var != arg:
continue
Expand Down

0 comments on commit f5c34de

Please sign in to comment.