Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: concat_where for boundary conditions #1468

Merged
merged 67 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
bf1f1f8
skip value connectivity
havogt Jan 18, 2024
5495937
fix formatting
havogt Jan 23, 2024
512d795
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Jan 30, 2024
c7b01eb
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Jan 31, 2024
d51cfca
cleanup parts and tests
havogt Feb 1, 2024
fdb4423
skip fvm test with no atlas
havogt Feb 1, 2024
1e0e228
testcase which requires broadcasting the mask
havogt Feb 2, 2024
05bdc67
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 6, 2024
96090e8
add comment
havogt Feb 6, 2024
8c408dd
cleanup
havogt Feb 6, 2024
ea798d1
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 6, 2024
b96280e
fix lowering issue past to itir
havogt Feb 7, 2024
8669f33
fix bug
havogt Feb 7, 2024
fab1185
fix connectivity names
havogt Feb 9, 2024
d24f6a3
explicit xp.newaxis
havogt Feb 9, 2024
dcf17e2
wrap the mask hypercube
havogt Feb 9, 2024
3804eb6
prepare configurable skip_value
havogt Feb 9, 2024
e4a6f39
fix test
havogt Feb 9, 2024
1f9ac85
skip value refactoring
havogt Feb 9, 2024
07cffa6
fix skip_value check
havogt Feb 9, 2024
93bf889
fix assert
havogt Feb 10, 2024
cb1d017
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 12, 2024
ce6adde
prototype
havogt Feb 14, 2024
6462610
fix bug in reverse sub and div
havogt Feb 14, 2024
5c92491
alternative concat_where that deals with multiple ranges
havogt Feb 14, 2024
35bc515
Merge remote-tracking branch 'upstream/main' into embedded_skip_value…
havogt Feb 22, 2024
75d1b03
SKIP_VALUE -> _DEFAULT_SKIP_VALUE
havogt Feb 22, 2024
649d30b
Merge remote-tracking branch 'local/embedded_skip_value_connectivity'…
havogt Feb 23, 2024
26a2668
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Feb 23, 2024
cbfd824
format
havogt Feb 23, 2024
971dc44
add tests for scalar binary with field
havogt Feb 23, 2024
ceb1a09
cleanup tests
havogt Feb 23, 2024
ddf6667
cleanup
havogt Feb 23, 2024
77ce553
Merge branch 'fix_reverse_ops' into concat_where
havogt Feb 23, 2024
864ddd3
add concat_where for embedded
havogt Feb 23, 2024
a028503
add tests and very hacked version of broadcasting
havogt Feb 23, 2024
05ae764
add TODOs
havogt Feb 23, 2024
19a176f
cleanup
havogt Feb 24, 2024
6dd79d8
refactoring
havogt Feb 24, 2024
ab78dc4
more cleanups
havogt Feb 24, 2024
f891e37
address review comments
havogt Feb 26, 2024
ddcc272
change scalar value
havogt Feb 26, 2024
5c81c03
Merge remote-tracking branch 'upstream/main' into fix_reverse_ops
havogt Feb 26, 2024
9816121
Merge remote-tracking branch 'origin/fix_reverse_ops' into concat_where
havogt Feb 26, 2024
b604a7a
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Feb 26, 2024
cd11b75
Merge commit 'ea852984dbf22ec0f2bb72ed454d7a1392040478' into concat_w…
havogt Mar 5, 2024
05fd105
default format
havogt Mar 5, 2024
0b1f12c
Merge commit '4c8f706a9f3cceff946f128022390c406523a7a1' into concat_w…
havogt Mar 5, 2024
d3db930
Merge commit '77a205b6b31d9854e0e15d01d91349047ec0c426' into concat_w…
havogt Mar 5, 2024
f22c149
fix tests
havogt Mar 6, 2024
80f21ff
describe algorithm
havogt Mar 7, 2024
c9dfc8d
add docstring to concat_where, but test not working
havogt Mar 7, 2024
14acc84
add tests
havogt Mar 7, 2024
071a9e0
switch back to list[tuple]
havogt Mar 7, 2024
a6186fc
add unit_range.is_empty
havogt Mar 11, 2024
6413753
add more tests
havogt Mar 11, 2024
c798f97
address more review comments
havogt Mar 11, 2024
54428f7
steal refactoring from nfarabullini/as_offset_embedded
havogt Mar 11, 2024
c178f25
move to experimental
havogt Mar 11, 2024
a99561e
address review comments
havogt Mar 13, 2024
5755d62
add a test (wip)
havogt Mar 13, 2024
4ebf2d6
add todos
havogt Mar 13, 2024
2c89af3
fix refactoring bugs
havogt Mar 14, 2024
8b268ed
add tests in field_operators
havogt Mar 14, 2024
b15a0aa
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Mar 14, 2024
c2738b6
add type ignore
havogt Mar 14, 2024
70274a7
Merge remote-tracking branch 'upstream/main' into concat_where
havogt Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,10 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return all(UnitRange.is_finite(rng) for rng in obj.ranges)

@property
def is_empty(self) -> bool:
return any(rng == UnitRange(0, 0) for rng in self.ranges)
havogt marked this conversation as resolved.
Show resolved Hide resolved

@overload
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ...

Expand Down
45 changes: 44 additions & 1 deletion src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,57 @@ def _absolute_sub_domain(
return common.Domain(*named_ranges)


def intersect_domains(*domains: common.Domain) -> common.Domain:
def domain_intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion, I think this could be transformed into the .intersection() method in the Domain class to follow the same API as set (https://devdocs.io/python~3.12/library/stdtypes#set)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe adding that makes sense, however I like more that this function works also when domains is empty

*domains: common.Domain,
) -> common.Domain:
"""
Return the intersection of the given domains.

Example:
>>> I = common.Dimension("I")
>>> domain_intersection(
... common.domain({I: (0, 5)}), common.domain({I: (1, 3)})
... ) # doctest: +ELLIPSIS
Domain(dims=(Dimension(value='I', ...), ranges=(UnitRange(1, 3),))
havogt marked this conversation as resolved.
Show resolved Hide resolved
"""
return functools.reduce(
operator.and_,
domains,
common.Domain(dims=tuple(), ranges=tuple()),
)


def intersect_domains(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this function confusing specially given the previous one. I think the name could contain the name restrict instead of intersection (e.g. restrict_to_lower_bound, restrict_to_intersection) and the ignore_dims argument shouldn't be optional, because otherwise is just a simple intersection of domains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Changed the name.
  • In generic context allowing ignore_dims to be optional makes sense to me. Maybe I could change the API to only accept tuples and default to empty tuple. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't like that much that ignore_dims is optional but it's ok. ignore_dims accepting only tuples makes sense to me, but it's not strictly needed and it's orthogonal to my original complaint because it'll still have a default value.

*domains: common.Domain,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Domain, ...]:
"""
Return the with each other intersected domains, ignoring 'ignore_dims' dimensions for the intersection.

Example:
havogt marked this conversation as resolved.
Show resolved Hide resolved
>>> I = common.Dimension("I")
>>> J = common.Dimension("J")
>>> res = intersect_domains(
... common.domain({I: (0, 5), J: (1, 2)}),
... common.domain({I: (1, 3), J: (0, 3)}),
... ignore_dims=J,
... )
>>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)}))
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(*[
common.Domain(*[(d, r) for d, r in domain if d not in ignore_dims_tuple])
for domain in domains
])
return tuple(
common.Domain(*[
(d, r if d in ignore_dims_tuple else intersection_without_ignore_dims[d][1])
for d, r in domain
])
for domain in domains
)


def iterate_domain(domain: common.Domain):
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/embedded/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def __init__(
self.indices = indices
self.index = index
self.dim = dim


class NonContiguousDomain(gt4py_exceptions.GT4PyError):
msg: str
egparedes marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, msg: str):
super().__init__(f"Operation would result in a non-contiguous domain: `{msg}`.")
self.msg = msg
207 changes: 190 additions & 17 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
import functools
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import ClassVar
from typing import ClassVar, Iterable

import numpy as np
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.eve.extended_typing import Any, Never, Optional, ParamSpec, TypeAlias, TypeVar
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common, context as embedded_context
from gt4py.next.embedded import (
common as embedded_common,
context as embedded_context,
exceptions as embedded_exceptions,
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import embedded as itir_embedded

Expand All @@ -42,20 +46,22 @@
jnp: Optional[ModuleType] = None # type:ignore[no-redef]


def _get_nd_array_class(*fields: common.Field | core_defs.Scalar) -> type[NdArrayField]:
for f in fields:
if isinstance(f, NdArrayField):
return f.__class__
raise AssertionError("No 'NdArrayField' found in the arguments.")


def _make_builtin(
builtin_name: str, array_builtin_name: str, reverse=False
) -> Callable[..., NdArrayField]:
def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
first = None
for f in fields:
if isinstance(f, NdArrayField):
first = f
break
assert first is not None
xp = first.__class__.array_ns
cls_ = _get_nd_array_class(*fields)
xp = cls_.array_ns
op = getattr(xp, array_builtin_name)

domain_intersection = embedded_common.intersect_domains(*[
domain_intersection = embedded_common.domain_intersection(*[
f.domain for f in fields if common.is_field(f)
])

Expand All @@ -76,7 +82,7 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
if reverse:
transformed.reverse()
new_data = op(*transformed)
return first.__class__.from_array(new_data, domain=domain_intersection)
return cls_.from_array(new_data, domain=domain_intersection)

_builtin_op.__name__ = builtin_name
return _builtin_op
Expand Down Expand Up @@ -423,10 +429,7 @@ def inverse_image(
if relative_ranges is None:
raise ValueError("Restriction generates non-contiguous dimensions.")

new_dims = [
common.named_range((d, rr + ar.start))
for d, ar, rr in zip(self.domain.dims, self.domain.ranges, relative_ranges)
]
new_dims = _relative_ranges_to_domain(relative_ranges, self.domain)

self._cache[cache_key] = new_dims

Expand All @@ -448,6 +451,14 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field:
__getitem__ = restrict


def _relative_ranges_to_domain(
havogt marked this conversation as resolved.
Show resolved Hide resolved
relative_ranges: Sequence[common.UnitRange], domain: common.Domain
) -> common.Domain:
return common.Domain(
dims=domain.dims, ranges=[rr + ar.start for ar, rr in zip(domain.ranges, relative_ranges)]
)


def _hypercube(
index_array: core_defs.NDArrayObject,
image_range: common.UnitRange,
Expand Down Expand Up @@ -519,6 +530,168 @@ def _hypercube(
NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where"))


def _compute_mask_ranges(
mask: core_defs.NDArrayObject,
) -> list[tuple[bool, common.UnitRange]]:
"""Take a 1-dimensional mask and return a sequence of mappings from boolean values to ranges."""
egparedes marked this conversation as resolved.
Show resolved Hide resolved
# TODO: does it make sense to upgrade this naive algorithm to numpy?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if used in k, probably not relevant, if used in the horizontal, probably this is a performance problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some timeit microbenchmarks with typical (both small and large) domain sizes would help to figure out if it'd pay off.

assert mask.ndim == 1
cur = bool(mask[0].item())
havogt marked this conversation as resolved.
Show resolved Hide resolved
ind = 0
res = []
for i in range(1, mask.shape[0]):
if (mask_i := bool(mask[i].item())) != cur:
res.append((cur, common.UnitRange(ind, i)))
cur = mask_i
ind = i
res.append((cur, common.UnitRange(ind, mask.shape[0])))
return res


def _trim_empty_domains(
lst: Iterable[tuple[bool, common.Domain]],
) -> list[tuple[bool, common.Domain]]:
"""Remove empty domains from beginning and end of the list."""
lst = list(lst)
if not lst:
return lst
if lst[0][1].is_empty:
return _trim_empty_domains(lst[1:])
if lst[-1][1].is_empty:
return _trim_empty_domains(lst[:-1])
return lst


def _to_field(
value: common.Field | core_defs.Scalar, nd_array_field_type: type[NdArrayField]
) -> common.Field:
# TODO(havogt): this function is only to workaround broadcasting of scalars, once we have a ConstantField, we can broadcast to that directly
return (
value
if common.is_field(value)
else nd_array_field_type.from_array(
nd_array_field_type.array_ns.asarray(value), domain=common.Domain()
)
)


def _intersect_fields(
*fields: common.Field | core_defs.Scalar,
ignore_dims: Optional[common.Dimension | tuple[common.Dimension, ...]] = None,
) -> tuple[common.Field, ...]:
# TODO(havogt): this function could be moved to common, but then requires a broadcast implementation for all field implementations
nd_array_class = _get_nd_array_class(*fields)
promoted_dims = common.promote_dims(*[f.domain.dims for f in fields if common.is_field(f)])
havogt marked this conversation as resolved.
Show resolved Hide resolved
broadcasted_fields = [_broadcast(_to_field(f, nd_array_class), promoted_dims) for f in fields]

intersected_domains = embedded_common.intersect_domains(
*[f.domain for f in broadcasted_fields], ignore_dims=ignore_dims
)

return tuple(
nd_array_class.from_array(
f.ndarray[_get_slices_from_domain_slice(f.domain, intersected_domain)],
domain=intersected_domain,
)
for f, intersected_domain in zip(broadcasted_fields, intersected_domains, strict=True)
)


def _concat_domains(*domains: common.Domain, dim: common.Dimension) -> Optional[common.Domain]:
havogt marked this conversation as resolved.
Show resolved Hide resolved
if not domains:
return common.Domain()
dim_start = domains[0][dim][1].start
dim_stop = dim_start
for domain in domains:
if not domain[dim][1].start == dim_stop:
return None
else:
dim_stop = domain[dim][1].stop
return domains[0].replace(dim, (dim, common.UnitRange(dim_start, dim_stop)))


def _concat(*fields: common.Field, dim: common.Dimension) -> common.Field:
# TODO(havogt): this function could be extended to a general concat
# currently only concatenate along the given dimension and requires the fields to be ordered

if (
len(fields) > 1
and not embedded_common.domain_intersection(*[f.domain for f in fields]).is_empty
):
raise ValueError("Fields to concatenate must not overlap.")
new_domain = _concat_domains(*[f.domain for f in fields], dim=dim)
if new_domain is None:
raise embedded_exceptions.NonContiguousDomain(f"Cannot concatenate fields along {dim}.")
nd_array_class = _get_nd_array_class(*fields)
return nd_array_class.from_array(
nd_array_class.array_ns.concatenate(
[nd_array_class.array_ns.broadcast_to(f.ndarray, f.domain.shape) for f in fields],
axis=new_domain.dim_index(dim),
),
domain=new_domain,
)


def _concat_where(
havogt marked this conversation as resolved.
Show resolved Hide resolved
mask_field: common.Field, true_field: common.Field, false_field: common.Field
) -> common.Field:
cls_ = _get_nd_array_class(mask_field, true_field, false_field)
xp = cls_.array_ns
if mask_field.domain.ndim != 1:
raise NotImplementedError(
"'concat_where': Can only concatenate fields with a 1-dimensional mask."
)
mask_dim = mask_field.domain.dims[0]

# intersect the field in dimensions orthogonal to the mask, then all slices in the mask field have same domain
t_broadcasted, f_broadcasted = _intersect_fields(true_field, false_field, ignore_dims=mask_dim)

# compute the consecutive ranges (first relative, then domain) of true and false values
mask_values_to_relative_range_mapping: Iterable[tuple[bool, common.UnitRange]] = (
_compute_mask_ranges(mask_field.ndarray)
)
mask_values_to_domain_mapping: Iterable[tuple[bool, common.Domain]] = (
havogt marked this conversation as resolved.
Show resolved Hide resolved
(mask, _relative_ranges_to_domain((relative_range,), mask_field.domain))
for mask, relative_range in mask_values_to_relative_range_mapping
)
# mask domains intersected with the respective fields
mask_values_to_intersected_domains_mapping: Iterable[tuple[bool, common.Domain]] = (
(
mask_value,
embedded_common.domain_intersection(
t_broadcasted.domain if mask_value else f_broadcasted.domain, mask_domain
),
)
for mask_value, mask_domain in mask_values_to_domain_mapping
)

# remove the empty domains from the beginning and end
mask_values_to_intersected_domains_mapping = _trim_empty_domains(
mask_values_to_intersected_domains_mapping
)
if any(d.is_empty for _, d in mask_values_to_intersected_domains_mapping):
raise embedded_exceptions.NonContiguousDomain(
f"In 'concat_where', cannot concatenate the following 'Domain's: {[d for _, d in mask_values_to_intersected_domains_mapping]}."
)

# slice the fields with the domain ranges
transformed = [
t_broadcasted[d] if v else f_broadcasted[d]
for v, d in mask_values_to_intersected_domains_mapping
]

# stack the fields together
if transformed:
return _concat(*transformed, dim=mask_dim)
else:
result_domain = common.Domain((mask_dim, common.UnitRange(0, 0)))
result_array = xp.empty(result_domain.shape)
return cls_.from_array(result_array, domain=result_domain)


NdArrayField.register_builtin_func(fbuiltins.concat_where, _concat_where) # type: ignore[arg-type] # tuples are handled in the base implementation


def _make_reduction(
builtin_name: str, array_builtin_name: str, initial_value_op: Callable
) -> Callable[
Expand Down Expand Up @@ -635,7 +808,7 @@ def __setitem__(
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
def _broadcast(field: common.Field, new_dimensions: Sequence[common.Dimension]) -> common.Field:
if field.domain.dims == new_dimensions:
return field
domain_slice: list[slice | None] = []
Expand All @@ -645,7 +818,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
domain_slice.append(slice(None))
named_ranges.append((dim, field.domain[pos][1]))
else:
domain_slice.append(np.newaxis)
domain_slice.append(None) # np.newaxis
named_ranges.append((dim, common.UnitRange.infinite()))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))

Expand Down
15 changes: 10 additions & 5 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def scan_loop(hpos):
return res


def _get_out_domain(
out: common.MutableField | tuple[common.MutableField | tuple, ...],
) -> common.Domain:
return embedded_common.domain_intersection(*[
f.domain for f in utils.flatten_nested_tuple((out,))
])


def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
if "out" in kwargs:
# called from program or direct field_operator as program
Expand All @@ -102,10 +110,7 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):

domain = kwargs.pop("domain", None)

flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,))
assert all(f.domain == flattened_out[0].domain for f in flattened_out)

out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain
out_domain = common.domain(domain) if domain is not None else _get_out_domain(out)

new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain)

Expand Down Expand Up @@ -149,7 +154,7 @@ def impl(target: common.MutableField, source: common.Field):
def _intersect_scan_args(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...],
) -> common.Domain:
return embedded_common.intersect_domains(*[
return embedded_common.domain_intersection(*[
arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)
])

Expand Down
Loading
Loading