-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: concat_where for boundary conditions #1468
Changes from all commits
bf1f1f8
5495937
512d795
c7b01eb
d51cfca
fdb4423
1e0e228
05bdc67
96090e8
8c408dd
ea798d1
b96280e
8669f33
fab1185
d24f6a3
dcf17e2
3804eb6
e4a6f39
1f9ac85
07cffa6
93bf889
cb1d017
ce6adde
6462610
5c92491
35bc515
75d1b03
649d30b
26a2668
cbfd824
971dc44
ceb1a09
ddf6667
77ce553
864ddd3
a028503
05ae764
19a176f
6dd79d8
ab78dc4
f891e37
ddcc272
5c81c03
9816121
b604a7a
cd11b75
05fd105
0b1f12c
d3db930
f22c149
80f21ff
c9dfc8d
14acc84
071a9e0
a6186fc
6413753
c798f97
54428f7
c178f25
a99561e
5755d62
4ebf2d6
2c89af3
8b268ed
b15a0aa
c2738b6
70274a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,20 @@ | |
import functools | ||
from collections.abc import Callable, Sequence | ||
from types import ModuleType | ||
from typing import ClassVar | ||
from typing import ClassVar, Iterable | ||
|
||
import numpy as np | ||
from numpy import typing as npt | ||
|
||
from gt4py._core import definitions as core_defs | ||
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar | ||
from gt4py.next import common | ||
from gt4py.next.embedded import common as embedded_common, context as embedded_context | ||
from gt4py.next.ffront import fbuiltins | ||
from gt4py.next.embedded import ( | ||
common as embedded_common, | ||
context as embedded_context, | ||
exceptions as embedded_exceptions, | ||
) | ||
from gt4py.next.ffront import experimental, fbuiltins | ||
from gt4py.next.iterator import embedded as itir_embedded | ||
|
||
|
||
|
@@ -42,20 +46,22 @@ | |
jnp: Optional[ModuleType] = None # type:ignore[no-redef] | ||
|
||
|
||
def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArrayField]: | ||
for f in fields: | ||
if isinstance(f, NdArrayField): | ||
return f.__class__ | ||
raise AssertionError("No 'NdArrayField' found in the arguments.") | ||
|
||
|
||
def _make_builtin( | ||
builtin_name: str, array_builtin_name: str, reverse=False | ||
) -> Callable[..., NdArrayField]: | ||
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: | ||
first = None | ||
for f in fields: | ||
if isinstance(f, NdArrayField): | ||
first = f | ||
break | ||
assert first is not None | ||
xp = first.__class__.array_ns | ||
cls_ = _get_nd_array_class(*fields) | ||
xp = cls_.array_ns | ||
op = getattr(xp, array_builtin_name) | ||
|
||
domain_intersection = embedded_common.intersect_domains(*[ | ||
domain_intersection = embedded_common.domain_intersection(*[ | ||
f.domain for f in fields if common.is_field(f) | ||
]) | ||
|
||
|
@@ -76,7 +82,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: | |
if reverse: | ||
transformed.reverse() | ||
new_data = op(*transformed) | ||
return first.__class__.from_array(new_data, domain=domain_intersection) | ||
return cls_.from_array(new_data, domain=domain_intersection) | ||
|
||
_builtin_op.__name__ = builtin_name | ||
return _builtin_op | ||
|
@@ -423,10 +429,7 @@ def inverse_image( | |
if relative_ranges is None: | ||
raise ValueError("Restriction generates non-contiguous dimensions.") | ||
|
||
new_dims = [ | ||
common.named_range((d, rr + ar.start)) | ||
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges) | ||
] | ||
new_dims = _relative_ranges_to_domain(relative_ranges, self.domain) | ||
|
||
self._cache[cache_key] = new_dims | ||
|
||
|
@@ -448,6 +451,14 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field: | |
__getitem__ = restrict | ||
|
||
|
||
def _relative_ranges_to_domain( | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
relative_ranges: Sequence[common.UnitRange], domain: common.Domain | ||
) -> common.Domain: | ||
return common.Domain( | ||
dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)] | ||
) | ||
|
||
|
||
def _hypercube( | ||
index_array: core_defs.NDArrayObject, | ||
image_range: common.UnitRange, | ||
|
@@ -519,6 +530,172 @@ def _hypercube( | |
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) | ||
|
||
|
||
def _compute_mask_ranges( | ||
mask: core_defs.NDArrayObject, | ||
) -> list[tuple[bool, common.UnitRange]]: | ||
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges.""" | ||
egparedes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: does it make sense to upgrade this naive algorithm to numpy? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if used in k, probably not relevant, if used in the horizontal, probably this is a performance problem There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe some |
||
assert mask.ndim == 1 | ||
cur = bool(mask[0].item()) | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ind = 0 | ||
res = [] | ||
for i in range(1, mask.shape[0]): | ||
if ( | ||
mask_i := bool(mask[i].item()) | ||
) != cur: # `.item()` to extract the scalar from a 0-d array in case of e.g. cupy | ||
res.append((cur, common.UnitRange(ind, i))) | ||
cur = mask_i | ||
ind = i | ||
res.append((cur, common.UnitRange(ind, mask.shape[0]))) | ||
return res | ||
|
||
|
||
def _trim_empty_domains( | ||
lst: Iterable[tuple[bool, common.Domain]], | ||
) -> list[tuple[bool, common.Domain]]: | ||
"""Remove empty domains from beginning and end of the list.""" | ||
lst = list(lst) | ||
if not lst: | ||
return lst | ||
if lst[0][1].is_empty(): | ||
return _trim_empty_domains(lst[1:]) | ||
if lst[-1][1].is_empty(): | ||
return _trim_empty_domains(lst[:-1]) | ||
return lst | ||
|
||
|
||
def _to_field( | ||
value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField] | ||
) -> common.Field: | ||
# TODO(havogt): this function is only to workaround broadcasting of scalars, once we have a ConstantField, we can broadcast to that directly | ||
return ( | ||
value | ||
if common.is_field(value) | ||
else nd_array_field_type.from_array( | ||
nd_array_field_type.array_ns.asarray(value), domain=common.Domain() | ||
) | ||
) | ||
|
||
|
||
def _intersect_fields( | ||
*fields: common.Field | core_defs.Scalar, | ||
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None, | ||
) -> tuple[common.Field, ...]: | ||
# TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations; | ||
# currently blocked, because requiring the `_to_field` function, see comment there. | ||
nd_array_class = _get_nd_array_class(*fields) | ||
promoted_dims = common.promote_dims(*(f.domain.dims for f in fields if common.is_field(f))) | ||
broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields] | ||
|
||
intersected_domains = embedded_common.restrict_to_intersection( | ||
*[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims | ||
) | ||
|
||
return tuple( | ||
nd_array_class.from_array( | ||
f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)], | ||
domain=intersected_domain, | ||
) | ||
for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True) | ||
) | ||
|
||
|
||
def _stack_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]: | ||
if not domains: | ||
return common.Domain() | ||
dim_start = domains[0][dim][1].start | ||
dim_stop = dim_start | ||
for domain in domains: | ||
if not domain[dim][1].start == dim_stop: | ||
return None | ||
else: | ||
dim_stop = domain[dim][1].stop | ||
return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop))) | ||
|
||
|
||
def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field: | ||
# TODO(havogt): this function could be extended to a general concat | ||
# currently only concatenate along the given dimension and requires the fields to be ordered | ||
|
||
if ( | ||
len(fields) > 1 | ||
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty() | ||
): | ||
raise ValueError("Fields to concatenate must not overlap.") | ||
new_domain = _stack_domains(*[f.domain for f in fields], dim=dim) | ||
if new_domain is None: | ||
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.") | ||
nd_array_class = _get_nd_array_class(*fields) | ||
return nd_array_class.from_array( | ||
nd_array_class.array_ns.concatenate( | ||
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields], | ||
axis=new_domain.dim_index(dim), | ||
), | ||
domain=new_domain, | ||
) | ||
|
||
|
||
def _concat_where( | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask_field: common.Field, true_field: common.Field, false_field: common.Field | ||
) -> common.Field: | ||
cls_ = _get_nd_array_class(mask_field, true_field, false_field) | ||
xp = cls_.array_ns | ||
if mask_field.domain.ndim != 1: | ||
raise NotImplementedError( | ||
"'concat_where': Can only concatenate fields with a 1-dimensional mask." | ||
) | ||
mask_dim = mask_field.domain.dims[0] | ||
|
||
# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain | ||
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim) | ||
|
||
# TODO(havogt): for clarity, most of it could be implemented on named_range in the masked dimension, but we currently lack the utils | ||
# compute the consecutive ranges (first relative, then domain) of true and false values | ||
mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = ( | ||
_compute_mask_ranges(mask_field.ndarray) | ||
) | ||
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = ( | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(mask, _relative_ranges_to_domain((relative_range,), mask_field.domain)) | ||
for mask, relative_range in mask_values_to_relative_range_mapping | ||
) | ||
# mask domains intersected with the respective fields | ||
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = ( | ||
( | ||
mask_value, | ||
embedded_common.domain_intersection( | ||
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain | ||
), | ||
) | ||
for mask_value, mask_domain in mask_values_to_domain_mapping | ||
) | ||
|
||
# remove the empty domains from the beginning and end | ||
mask_values_to_intersected_domains_mapping = _trim_empty_domains( | ||
mask_values_to_intersected_domains_mapping | ||
) | ||
if any(d.is_empty() for _, d in mask_values_to_intersected_domains_mapping): | ||
raise embedded_exceptions.NonContiguousDomain( | ||
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}." | ||
) | ||
|
||
# slice the fields with the domain ranges | ||
transformed = [ | ||
t_broadcasted[d] if v else f_broadcasted[d] | ||
for v, d in mask_values_to_intersected_domains_mapping | ||
] | ||
|
||
# stack the fields together | ||
if transformed: | ||
return _concat(*transformed, dim=mask_dim) | ||
else: | ||
result_domain = common.Domain((mask_dim, common.UnitRange(0, 0))) | ||
result_array = xp.empty(result_domain.shape) | ||
return cls_.from_array(result_array, domain=result_domain) | ||
|
||
|
||
NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] | ||
havogt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _make_reduction( | ||
builtin_name: str, array_builtin_name: str, initial_value_op: Callable | ||
) -> Callable[ | ||
|
@@ -635,7 +812,7 @@ def __setitem__( | |
common._field.register(jnp.ndarray, JaxArrayField.from_array) | ||
|
||
|
||
def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: | ||
def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) -> common.Field: | ||
if field.domain.dims == new_dimensions: | ||
return field | ||
domain_slice: list[slice | None] = [] | ||
|
@@ -645,7 +822,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] | |
domain_slice.append(slice(None)) | ||
named_ranges.append((dim, field.domain[pos][1])) | ||
else: | ||
domain_slice.append(np.newaxis) | ||
domain_slice.append(None) # np.newaxis | ||
named_ranges.append((dim, common.UnitRange.infinite())) | ||
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion, I think this could be transformed into the
.intersection()
method in theDomain
class to follow the same API asset
(https://devdocs.io/python~3.12/library/stdtypes#set)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe adding that makes sense, however I like more that this function works also when domains is empty