Skip to content

Commit

Permalink
eliminate and remove EvaluableConstant (#847)
Browse files Browse the repository at this point in the history
In an effort to make the `evaluable` module array-only (#738), this PR
eliminates and removes the `evaluable.EvaluableConstant` class.
  • Loading branch information
joostvanzwieten committed Dec 20, 2023
2 parents 9d89a1e + 4b73a78 commit e1a61d2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 56 deletions.
2 changes: 1 addition & 1 deletion nutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'Numerical Utilities for Finite Element Analysis'

__version__ = version = '9a12'
__version__ = version = '9a13'
version_name = 'jook-sing'
26 changes: 0 additions & 26 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,32 +521,6 @@ def _node(self, cache, subgraph, times):
EVALARGS = EVALARGS()


class EvaluableConstant(Evaluable):
'''Evaluate to the given constant value.
Parameters
----------
value
The return value of ``eval``.
'''

def __init__(self, value):
self.value = value
super().__init__(())

def evalf(self):
return self.value

@property
def _node_details(self):
s = repr(self.value)
if '\n' in s:
s = s.split('\n', 1)[0] + '...'
if len(s) > 20:
s = s[:17] + '...'
return s


class Tuple(Evaluable):

def __init__(self, items):
Expand Down
22 changes: 14 additions & 8 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointw
super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else evaluable.EvaluableConstant(arg) for arg in self._args) # type: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...]
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else arg for arg in self._args)
add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise]))
points_shape = args.points_shape + add_points_shape
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in args.coordinates.items()}
Expand Down Expand Up @@ -758,8 +758,7 @@ def partial_derivative(cls, iarg: int, *args: Any) -> IntoArray:

class _CustomEvaluable(evaluable.Array):

def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.Array, evaluable.EvaluableConstant], ...], shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
assert all(isinstance(arg, (evaluable.Array, evaluable.EvaluableConstant)) for arg in args)
def __init__(self, name, evalf, partial_derivative, args, shape: Tuple[int, ...], dtype: DType, spaces: FrozenSet[str], arguments: types.frozendict, lower_args: LowerArgs) -> None:
self.name = name
self.custom_evalf = evalf
self.custom_partial_derivative = partial_derivative
Expand All @@ -768,7 +767,7 @@ def __init__(self, name, evalf, partial_derivative, args: Tuple[Union[evaluable.
self.lower_args = lower_args
self.spaces = spaces
self.function_arguments = arguments
super().__init__((evaluable.Tuple(lower_args.points_shape), *args), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)
super().__init__((evaluable.Tuple(lower_args.points_shape), *(arg for arg in args if isinstance(arg, evaluable.Array))), shape=(*lower_args.points_shape, *map(evaluable.constant, shape)), dtype=dtype)

@property
def _node_details(self) -> str:
Expand All @@ -777,11 +776,18 @@ def _node_details(self) -> str:
def evalf(self, points_shape: Tuple[numpy.ndarray, ...], *args: Any) -> numpy.ndarray:
points_shape = tuple(n.__index__() for n in points_shape)
npoints = util.product(points_shape, 1)
# Flatten the points axes of the array arguments and call `custom_evalf`.
flattened = (arg.reshape(npoints, *arg.shape[self.points_dim:]) if isinstance(origarg, evaluable.Array) else arg for arg, origarg in zip(args, self.args))
# Flatten the points axes of the evaluable arguments, merge with the
# unevaluable arguments and call `custom_evalf`.
flattened = []
args = iter(args)
for arg in self.args:
if isinstance(arg, evaluable.Array):
arg = next(args)
arg = arg.reshape(npoints, *arg.shape[self.points_dim:])
flattened.append(arg)
result = self.custom_evalf(*flattened)
assert result.ndim == self.ndim + 1 - self.points_dim
# Unflatten the points axes of the result. If there are no array arguments,
# Unflatten the points axes of the result. If there are no arguments,
# the points axis must have length one. Otherwise the length must be
# `npoints` (checked by `reshape`).
if not any(isinstance(origarg, evaluable.Array) for origarg in self.args):
Expand All @@ -795,7 +801,7 @@ def _derivative(self, var: evaluable.Array, seen: Dict[evaluable.Array, evaluabl
if self.dtype in (bool, int):
return super()._derivative(var, seen)
result = evaluable.Zeros(self.shape + var.shape, dtype=self.dtype)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg.value for arg in self.args)
unlowered_args = tuple(_Unlower(arg, self.spaces, self.function_arguments, self.lower_args) if isinstance(arg, evaluable.Array) else arg for arg in self.args)
for iarg, arg in enumerate(self.args):
if not isinstance(arg, evaluable.Array) or arg.dtype in (bool, int) or var not in arg.arguments and var != arg:
continue
Expand Down
21 changes: 0 additions & 21 deletions tests/test_evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,27 +1146,6 @@ def test_nested_variant(self):
self.assertEqual(actual, desired)


class EvaluableConstant(TestCase):

def test_evalf(self):
self.assertEqual(evaluable.EvaluableConstant(1).evalf(), 1)
self.assertEqual(evaluable.EvaluableConstant('1').evalf(), '1')

def test_node_details(self):

class Test:
def __init__(self, s):
self.s = s

def __repr__(self):
return self.s

self.assertEqual(evaluable.EvaluableConstant(Test('some string'))._node_details, 'some string')
self.assertEqual(evaluable.EvaluableConstant(Test('a very long string that should be abbreviated'))._node_details, 'a very long strin...')
self.assertEqual(evaluable.EvaluableConstant(Test('a string with\nmultiple lines'))._node_details, 'a string with...')
self.assertEqual(evaluable.EvaluableConstant(Test('a very long string with\nmultiple lines'))._node_details, 'a very long strin...')


class Einsum(TestCase):

def test_swapaxes(self):
Expand Down

0 comments on commit e1a61d2

Please sign in to comment.