Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eliminate and remove EvaluableConstant #847

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,32 +521,6 @@ def _node(self, cache, subgraph, times):
EVALARGS = EVALARGS()


class EvaluableConstant(Evaluable):
'''Evaluate to the given constant value.

Parameters
----------
value
The return value of ``eval``.
'''

def __init__(self, value):
self.value = value
super().__init__(())

def evalf(self):
return self.value

@property
def _node_details(self):
s = repr(self.value)
if '\n' in s:
s = s.split('\n', 1)[0] + '...'
if len(s) > 20:
s = s[:17] + '...'
return s


class Tuple(Evaluable):

def __init__(self, items):
Expand Down
22 changes: 14 additions & 8 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointw
super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else evaluable.EvaluableConstant(arg) for arg in self._args) # type: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...]
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else arg for arg in self._args)
add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise]))
points_shape = args.points_shape + add_points_shape
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in args.coordinates.items()}
Expand Down Expand Up @@ -758,8 +758,7 @@ def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray:

class _CustomEvaluable(evaluable.Array):

def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...], shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
assert all(isinstance(arg, (evaluable.Array, evaluable.EvaluableConstant)) for arg in args)
def __init__(self, name, evalf, partial_derivative, args, shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
self.name = name
self.custom_evalf = evalf
self.custom_partial_derivative = partial_derivative
Expand All @@ -768,7 +767,7 @@ def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.
self.lower_args = lower_args
self.spaces = spaces
self.function_arguments = arguments
super().__init__((evaluable.Tuple(lower_args.points_shape), *args), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)
super().__init__((evaluable.Tuple(lower_args.points_shape), *(arg for arg in args if isinstance(arg, evaluable.Array))), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)

@property
def _node_details(self) -> str:
Expand All @@ -777,11 +776,18 @@ def _node_details(self) -> str:
def evalf(self, points_shape: Tuple[numpy.ndarray, ...], *args: Any) -> numpy.ndarray:
points_shape = tuple(n.__index__() for n in points_shape)
npoints = util.product(points_shape, 1)
# Flatten the points axes of the array arguments and call `custom_evalf`.
flattened = (arg.reshape(npoints, *arg.shape[self.points_dim:]) if isinstance(origarg, evaluable.Array) else arg for arg, origarg in zip(args, self.args))
# Flatten the points axes of the evaluable arguments, merge with the
# unevaluable arguments and call `custom_evalf`.
flattened = []
args = iter(args)
for arg in self.args:
if isinstance(arg, evaluable.Array):
arg = next(args)
arg = arg.reshape(npoints, *arg.shape[self.points_dim:])
flattened.append(arg)
result = self.custom_evalf(*flattened)
assert result.ndim == self.ndim + 1 - self.points_dim
# Unflatten the points axes of the result. If there are no array arguments,
# Unflatten the points axes of the result. If there are no arguments,
# the points axis must have length one. Otherwise the length must be
# `npoints` (checked by `reshape`).
if not any(isinstance(origarg, evaluable.Array) for origarg in self.args):
Expand All @@ -795,7 +801,7 @@ def _derivative(self, var: evaluable.Array, seen: Dict[evaluable.Array, evaluabl
if self.dtype in (bool, int):
return super()._derivative(var, seen)
result = evaluable.Zeros(self.shape + var.shape, dtype=self.dtype)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg.value for arg in self.args)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg for arg in self.args)
for iarg, arg in enumerate(self.args):
if not isinstance(arg, evaluable.Array) or arg.dtype in (bool, int) or var not in arg.arguments and var != arg:
continue
Expand Down
21 changes: 0 additions & 21 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,27 +1146,6 @@ def test_nested_variant(self):
self.assertEqual(actual, desired)


class EvaluableConstant(TestCase):

def test_evalf(self):
self.assertEqual(evaluable.EvaluableConstant(1).evalf(), 1)
self.assertEqual(evaluable.EvaluableConstant('1').evalf(), '1')

def test_node_details(self):

class Test:
def __init__(self, s):
self.s = s

def __repr__(self):
return self.s

self.assertEqual(evaluable.EvaluableConstant(Test('some string'))._node_details, 'some string')
self.assertEqual(evaluable.EvaluableConstant(Test('a very long string that should be abbreviated'))._node_details, 'a very long strin...')
self.assertEqual(evaluable.EvaluableConstant(Test('a string with\nmultiple lines'))._node_details, 'a string with...')
self.assertEqual(evaluable.EvaluableConstant(Test('a very long string with\nmultiple lines'))._node_details, 'a very long strin...')


class Einsum(TestCase):

def test_swapaxes(self):
Expand Down