diff --git a/nutils/evaluable.py b/nutils/evaluable.py index a138032d0..0e9507849 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -521,32 +521,6 @@ def _node(self, cache, subgraph, times): EVALARGS = EVALARGS() -class EvaluableConstant(Evaluable): - '''Evaluate to the given constant value. - - Parameters - ---------- - value - The return value of ``eval``. - ''' - - def __init__(self, value): - self.value = value - super().__init__(()) - - def evalf(self): - return self.value - - @property - def _node_details(self): - s = repr(self.value) - if '\n' in s: - s = s.split('\n', 1)[0] + '...' - if len(s) > 20: - s = s[:17] + '...' - return s - - class Tuple(Evaluable): def __init__(self, items): diff --git a/nutils/function.py b/nutils/function.py index d15ab0594..05e8aa073 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -687,7 +687,7 @@ def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointw super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments) def lower(self, args: LowerArgs) -> evaluable.Array: - evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else evaluable.EvaluableConstant(arg) for arg in self._args) # type: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...] + evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else arg for arg in self._args) add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise])) points_shape = args.points_shape + add_points_shape coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in args.coordinates.items()} @@ -758,8 +758,7 @@ def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray: class _CustomEvaluable(evaluable.Array): - def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...], shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None: - assert all(isinstance(arg, (evaluable.Array, evaluable.EvaluableConstant)) for arg in args) + def __init__(self, name, evalf, partial_derivative, args, shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None: self.name = name self.custom_evalf = evalf self.custom_partial_derivative = partial_derivative @@ -768,7 +767,7 @@ def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable. self.lower_args = lower_args self.spaces = spaces self.function_arguments = arguments - super().__init__((evaluable.Tuple(lower_args.points_shape), *args), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype) + super().__init__((evaluable.Tuple(lower_args.points_shape), *(arg for arg in args if isinstance(arg, evaluable.Array))), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype) @property def _node_details(self) -> str: @@ -777,11 +776,18 @@ def _node_details(self) -> str: def evalf(self, points_shape: Tuple[numpy.ndarray, ...], *args: Any) -> numpy.ndarray: points_shape = tuple(n.__index__() for n in points_shape) npoints = util.product(points_shape, 1) - # Flatten the points axes of the array arguments and call `custom_evalf`. - flattened = (arg.reshape(npoints, *arg.shape[self.points_dim:]) if isinstance(origarg, evaluable.Array) else arg for arg, origarg in zip(args, self.args)) + # Flatten the points axes of the evaluable arguments, merge with the + # unevaluable arguments and call `custom_evalf`. + flattened = [] + args = iter(args) + for arg in self.args: + if isinstance(arg, evaluable.Array): + arg = next(args) + arg = arg.reshape(npoints, *arg.shape[self.points_dim:]) + flattened.append(arg) result = self.custom_evalf(*flattened) assert result.ndim == self.ndim + 1 - self.points_dim - # Unflatten the points axes of the result. If there are no array arguments, + # Unflatten the points axes of the result. If there are no arguments, # the points axis must have length one. Otherwise the length must be # `npoints` (checked by `reshape`). if not any(isinstance(origarg, evaluable.Array) for origarg in self.args): @@ -795,7 +801,7 @@ def _derivative(self, var: evaluable.Array, seen: Dict[evaluable.Array, evaluabl if self.dtype in (bool, int): return super()._derivative(var, seen) result = evaluable.Zeros(self.shape + var.shape, dtype=self.dtype) - unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg.value for arg in self.args) + unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg for arg in self.args) for iarg, arg in enumerate(self.args): if not isinstance(arg, evaluable.Array) or arg.dtype in (bool, int) or var not in arg.arguments and var != arg: continue diff --git a/tests/test_evaluable.py b/tests/test_evaluable.py index 1da113e27..d10e4b389 100644 --- a/tests/test_evaluable.py +++ b/tests/test_evaluable.py @@ -1146,27 +1146,6 @@ def test_nested_variant(self): self.assertEqual(actual, desired) -class EvaluableConstant(TestCase): - - def test_evalf(self): - self.assertEqual(evaluable.EvaluableConstant(1).evalf(), 1) - self.assertEqual(evaluable.EvaluableConstant('1').evalf(), '1') - - def test_node_details(self): - - class Test: - def __init__(self, s): - self.s = s - - def __repr__(self): - return self.s - - self.assertEqual(evaluable.EvaluableConstant(Test('some string'))._node_details, 'some string') - self.assertEqual(evaluable.EvaluableConstant(Test('a very long string that should be abbreviated'))._node_details, 'a very long strin...') - self.assertEqual(evaluable.EvaluableConstant(Test('a string with\nmultiple lines'))._node_details, 'a string with...') - self.assertEqual(evaluable.EvaluableConstant(Test('a very long string with\nmultiple lines'))._node_details, 'a very long strin...') - - class Einsum(TestCase): def test_swapaxes(self):