Skip to content

Commit

Permalink
use evaluable.compile in function.evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Apr 16, 2024
1 parent 8327a99 commit eb69386
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
53 changes: 44 additions & 9 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def __repr__(self) -> str:
def eval(self, /, **arguments) -> numpy.ndarray:
'Evaluate this function.'

return evaluate(self, _post=_convert, arguments=arguments)[0]
return evaluate(self, _convert=_convert_dense_vec_sparse_matrix, arguments=arguments)[0]

def derivative(self, __var: Union[str, 'Argument']) -> 'Array':
'See :func:`derivative`.'
Expand Down Expand Up @@ -2115,7 +2115,33 @@ def nsymgrad(__arg: IntoArray, __geom: IntoArray, ndims: int = 0) -> Array:
# MISC


def _convert(data: numpy.ndarray, inplace: bool = True) -> Union[numpy.ndarray, matrix.Matrix]:
def _post_identity(x):
return x


def _post_scalar(x):
return x[()] if isinstance(x, numpy.ndarray) else x


def _post_coo_to_matrix(coo):
values, indices, shape = coo
return matrix.assemble(values, indices, shape)


def _post_coo_to_sparse(coo):
values, indices, shape = coo
data = numpy.empty((len(values),), dtype=sparse.dtype(shape, values.dtype))
data['value'] = values
for idim, ii in enumerate(indices):
data['index']['i'+str(idim)] = ii
return data


def _convert_dense(data: evaluable.Array):
return data, _post_scalar if data.ndim == 0 else _post_identity


def _convert_dense_vec_sparse_matrix(data: evaluable.Array):
'''Convert a two-dimensional sparse object to an appropriate object.
The return type is determined based on dimension: a zero-dimensional object
Expand All @@ -2124,10 +2150,19 @@ def _convert(data: numpy.ndarray, inplace: bool = True) -> Union[numpy.ndarray,
deduplicated and pruned sparse object.
'''

ndim = sparse.ndim(data)
return sparse.toarray(data) if ndim < 2 \
else matrix.fromsparse(data, inplace=inplace) if ndim == 2 \
else sparse.prune(sparse.dedup(data, inplace=inplace), inplace=True)
ndim = data.ndim
if ndim == 0:
return data, _post_scalar
elif ndim == 1:
return data, _post_identity
values, indices, shape = data.as_coo_with_shape()
post = _post_coo_to_matrix if ndim == 2 else _post_coo_to_sparse
return (values, indices, shape), post


def _convert_sparse(data: evaluable.Array):
values, indices, shape = data.as_coo_with_shape()
return (values, indices, shape), _post_coo_to_sparse


@util.single_or_multiple
Expand All @@ -2150,11 +2185,11 @@ def eval(funcs: evaluable.AsEvaluableArray, /, **arguments: numpy.ndarray) -> Tu


@nutils_dispatch
def evaluate(*arrays, _post=sparse.toarray, arguments={}):
def evaluate(*arrays, _convert=_convert_dense, arguments={}):
if len(arguments) == 1 and 'arguments' in arguments and isinstance(arguments['arguments'], dict):
arguments = arguments['arguments']
sparse_arrays = evaluable.eval_sparse(map(Array.cast, arrays), **arguments)
return tuple(map(_post, sparse_arrays))
arrays, posts = zip(*(_convert(Array.cast(array).as_evaluable_array) for array in arrays))
return tuple(post(array) for post, array in zip(posts, evaluable.compile(arrays)(**arguments)))


@nutils_dispatch
Expand Down
6 changes: 3 additions & 3 deletions nutils/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def integrate(self, funcs, /, **arguments):
'''


return function.evaluate(*map(self.integral, funcs), _post=function._convert, arguments=arguments)
return function.evaluate(*map(self.integral, funcs), _convert=function._convert_dense_vec_sparse_matrix, arguments=arguments)

@util.single_or_multiple
def integrate_sparse(self, funcs, /, **arguments):
Expand All @@ -181,7 +181,7 @@ def integrate_sparse(self, funcs, /, **arguments):
Optional arguments for function evaluation.
'''

return function.evaluate(*map(self.integral, funcs), _post=lambda x: x, arguments=arguments)
return function.evaluate(*map(self.integral, funcs), _convert=function._convert_sparse, arguments=arguments)

def integral(self, __func: function.IntoArray) -> function.Array:
'''Create Integral object for postponed integration.
Expand Down Expand Up @@ -220,7 +220,7 @@ def eval_sparse(self, funcs, /, **arguments):
Optional arguments for function evaluation.
'''

return function.evaluate(*map(self, funcs), _post=lambda x: x, arguments=arguments)
return function.evaluate(*map(self, funcs), _convert=function._convert_sparse, arguments=arguments)

def _integral(self, func: function.Array) -> function.Array:
'''Create Integral object for postponed integration.
Expand Down

0 comments on commit eb69386

Please sign in to comment.