Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next] Embedded field remove __array__ #1366

Merged
merged 30 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
faca090
feat[next] Enable embedded field view in ffront_tests
havogt Nov 16, 2023
7931cfd
broadcast for scalars
havogt Nov 16, 2023
4734c84
implement astype
havogt Nov 16, 2023
e1463d0
support binary builtins for scalars
havogt Nov 16, 2023
f1047dc
support domain
havogt Nov 16, 2023
9ac0ddd
add __ne__, __eq__
havogt Nov 16, 2023
f8682ed
fix typo
havogt Nov 16, 2023
42805f7
this is the typo, the other was improve alloc
havogt Nov 16, 2023
ec0a0d5
cleanup import in fbuiltin
havogt Nov 16, 2023
ac28ea0
fix test case
havogt Nov 16, 2023
89e05ea
fix/ignore typing
havogt Nov 16, 2023
cdbaf0b
remove __array__ from field
havogt Nov 16, 2023
43cb189
some more...
havogt Nov 16, 2023
520da0d
fix some iterator tests
havogt Nov 16, 2023
1e7b0ed
added asnumpy where needed
nfarabullini Nov 17, 2023
8234894
Merge remote-tracking branch 'upstream/main' into remove_array2
havogt Nov 20, 2023
5fbf826
remove test exclusion
havogt Nov 20, 2023
0742650
undo unintendend change
havogt Nov 20, 2023
3b9e35d
undo unrelated change
havogt Nov 20, 2023
8a6e6c2
undo change
havogt Nov 20, 2023
5f52a6c
fix column stencil test
havogt Nov 20, 2023
d2e840b
fix ndarray test
havogt Nov 20, 2023
c394ed0
fix 2 more
havogt Nov 20, 2023
894e41d
astype tuple
havogt Nov 20, 2023
89c429d
2 more
havogt Nov 20, 2023
27999c1
fix fvm nabla
havogt Nov 20, 2023
fb76d11
refactor asnumpy
havogt Nov 20, 2023
753417e
address review comments
havogt Nov 21, 2023
117e2c9
Update src/gt4py/next/utils.py
havogt Nov 21, 2023
9671c54
fix formatting
havogt Nov 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ def ndarray(self) -> core_defs.NDArrayObject:
def __str__(self) -> str:
return f"⟨{self.domain!s} → {self.dtype}⟩"

@abc.abstractmethod
def asnumpy(self) -> np.ndarray:
...

@abc.abstractmethod
def remap(self, index_field: Field) -> Field:
...
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,11 @@ def __gt_origin__(self) -> tuple[int, ...]:
def ndarray(self) -> core_defs.NDArrayObject:
return self._ndarray

def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray:
def asnumpy(self) -> np.ndarray:
if self.array_ns == cp:
return np.asarray(cp.asnumpy(self._ndarray), dtype)
return cp.asnumpy(self._ndarray)
else:
return np.asarray(self._ndarray, dtype)
return np.asarray(self._ndarray)

@property
def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,9 @@ def dtype(self) -> core_defs.Int32DType:
def ndarray(self) -> core_defs.NDArrayObject:
raise AttributeError("Cannot get `ndarray` of an infinite Field.")

def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

def remap(self, index_field: common.Field) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()
Expand Down Expand Up @@ -1180,6 +1183,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
def ndarray(self) -> core_defs.NDArrayObject:
raise AttributeError("Cannot get `ndarray` of an infinite Field.")

def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

def remap(self, index_field: common.Field) -> common.Field:
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
raise NotImplementedError()
Expand Down
40 changes: 39 additions & 1 deletion src/gt4py/next/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Any, ClassVar, TypeGuard, TypeVar
import functools
from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast

import numpy as np

from gt4py.next import common


class RecursionGuard:
Expand Down Expand Up @@ -53,6 +58,39 @@ def __exit__(self, *exc):

_T = TypeVar("_T")

_P = ParamSpec("_P")
_R = TypeVar("_R")


def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]:
return isinstance(v, tuple) and all(isinstance(e, t) for e in v)


def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]:
"""Apply `fun` to each entry of (possibly nested) tuples.

Examples:
>>> tree_map(lambda x: x + 1)(((1, 2), 3))
((2, 3), 4)

>>> tree_map(lambda x, y: x + y)(((1, 2), 3), ((4, 5), 6))
((5, 7), 9)
"""

@functools.wraps(fun)
def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]:
if isinstance(args[0], tuple):
assert all(isinstance(arg, tuple) and len(args[0]) == len(arg) for arg in args)
return tuple(impl(*arg) for arg in zip(*args))

return fun(
*cast(_P.args, args)
) # mypy doesn't understand that `args` at this point is of type `_P.args`

return impl


# TODO(havogt): consider moving to module like `field_utils`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am introducing this in #1365

@tree_map
def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition
13 changes: 6 additions & 7 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import Self
from gt4py.next import common, constructors
from gt4py.next import common, constructors, utils
from gt4py.next.ffront import decorator
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts, type_translation
Expand Down Expand Up @@ -435,14 +435,13 @@ def verify(
run(case, fieldview_prog, *args, offset_provider=offset_provider)

out_comp = out or inout
out_comp_str = str(out_comp)
assert out_comp is not None
if hasattr(out_comp, "ndarray"):
out_comp_str = str(out_comp.ndarray)
assert comparison(ref, out_comp), (
out_comp_ndarray = utils.asnumpy(out_comp)
ref_ndarray = utils.asnumpy(ref)
assert comparison(ref_ndarray, out_comp_ndarray), (
f"Verification failed:\n"
f"\tcomparison={comparison.__name__}(ref, out)\n"
f"\tref = {ref}\n\tout = {out_comp_str}"
f"\tref = {ref_ndarray}\n\tout = {str(out_comp_ndarray)}"
)


Expand All @@ -468,7 +467,7 @@ def verify_with_default_data(
``comparison(ref, <out | inout>)`` and should return a boolean.
"""
inps, kwfields = get_default_data(case, fieldop)
ref_args = tuple(i.__array__() if common.is_field(i) else i for i in inps)
ref_args = tuple(i.asnumpy() if common.is_field(i) else i for i in inps)
verify(
case,
fieldop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def testee(a: IField, b: IField, c: IField) -> IField:
*pos_args, **kw_args, out=out, offset_provider=cartesian_case.offset_provider
)

expected = np.asarray(args["a"]) * 2 * np.asarray(args["b"]) - np.asarray(args["c"])
expected = args["a"] * 2 * args["b"] - args["c"]

assert np.allclose(out, expected)
assert np.allclose(out.asnumpy(), expected.asnumpy())


@pytest.mark.parametrize("arg_spec", _generate_arg_permutations(("a", "b", "out")))
Expand All @@ -89,9 +89,9 @@ def testee(a: IField, b: IField, out: IField):
*pos_args, **kw_args, offset_provider=cartesian_case.offset_provider
)

expected = np.asarray(args["a"]) + 2 * np.asarray(args["b"])
expected = args["a"] + 2 * args["b"]

assert np.allclose(args["out"], expected)
assert np.allclose(args["out"].asnumpy(), expected.asnumpy())


def test_call_field_operator_from_field_operator(cartesian_case):
Expand Down Expand Up @@ -177,9 +177,7 @@ def testee(a: IJKFloatField, b: IJKFloatField) -> IJKFloatField:
a, b, out = (
cases.allocate(cartesian_case, testee, name)() for name in ("a", "b", cases.RETURN)
)
expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate(
np.asarray(a) + 2.0 * np.asarray(b), axis=2
)
expected = (1.0 + 3.0 + 5.0 + 7.0) * np.add.accumulate(a.asnumpy() + 2.0 * b.asnumpy(), axis=2)

cases.verify(cartesian_case, testee, a, b, out=out, ref=expected)

Expand Down Expand Up @@ -210,7 +208,7 @@ def testee(
for name in ("out1", "out2", "out3", "out4")
)

ref = np.add.accumulate(np.asarray(a) + 2 * np.asarray(b), axis=2)
ref = np.add.accumulate(a.asnumpy() + 2 * b.asnumpy(), axis=2)

cases.verify(
cartesian_case,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def cast_nested_tuple(

a = cases.allocate(cartesian_case, cast_tuple, "a")()
b = cases.allocate(cartesian_case, cast_tuple, "b")()
a_asint = cartesian_case.as_field([IDim], np.asarray(a).astype(int32))
b_asint = cartesian_case.as_field([IDim], np.asarray(b).astype(int32))
a_asint = cartesian_case.as_field([IDim], a.asnumpy().astype(int32))
b_asint = cartesian_case.as_field([IDim], b.asnumpy().astype(int32))
out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)()
out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()

Expand All @@ -384,7 +384,10 @@ def cast_nested_tuple(
a_asint,
b_asint,
out=out_tuple,
ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)),
ref=(
np.full_like(a.asnumpy(), True, dtype=bool),
np.full_like(b.asnumpy(), True, dtype=bool),
),
)

cases.verify(
Expand All @@ -396,9 +399,9 @@ def cast_nested_tuple(
b_asint,
out=out_nested_tuple,
ref=(
np.full_like(a, True, dtype=bool),
np.full_like(a, True, dtype=bool),
np.full_like(b, True, dtype=bool),
np.full_like(a.asnumpy(), True, dtype=bool),
np.full_like(a.asnumpy(), True, dtype=bool),
np.full_like(b.asnumpy(), True, dtype=bool),
),
)

Expand Down Expand Up @@ -473,7 +476,7 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD
comparison=lambda out, ref: np.all(out == ref),
)

assert np.allclose(out, ref)
assert np.allclose(out.asnumpy(), ref)


def test_nested_tuple_return(cartesian_case):
Expand Down Expand Up @@ -846,8 +849,8 @@ def program_domain(a: cases.IField, out: cases.IField):
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = np.asarray(out).copy() # ensure we are not overwriting `out` outside of the domain
ref[1:9] = a[1:9] * 2
ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain
ref[1:9] = a.asnumpy()[1:9] * 2

cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref)

Expand Down Expand Up @@ -881,8 +884,8 @@ def program_domain(
inp = cases.allocate(cartesian_case, program_domain, "inp")()
out = cases.allocate(cartesian_case, fieldop_domain, cases.RETURN)()

ref = np.asarray(out).copy()
ref[lower_i : int(upper_i / 2)] = inp[lower_i : int(upper_i / 2)] * 2
ref = out.asnumpy().copy()
ref[lower_i : int(upper_i / 2)] = inp.asnumpy()[lower_i : int(upper_i / 2)] * 2

cases.verify(
cartesian_case,
Expand Down Expand Up @@ -924,9 +927,9 @@ def program_domain(
a = cases.allocate(cartesian_case, program_domain, "a")()
out = cases.allocate(cartesian_case, program_domain, "out")()

ref = np.asarray(out).copy()
ref = out.asnumpy().copy()
ref[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] = (
a[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2
a.asnumpy()[1 * lower_i : upper_i + 0, lower_j - 0 : upper_j] * 2
)

cases.verify(
Expand Down Expand Up @@ -964,10 +967,10 @@ def program_domain_tuple(
out0 = cases.allocate(cartesian_case, program_domain_tuple, "out0")()
out1 = cases.allocate(cartesian_case, program_domain_tuple, "out1")()

ref0 = np.asarray(out0).copy()
ref0[1:9, 4:6] = inp0[1:9, 4:6] + inp1[1:9, 4:6]
ref1 = np.asarray(out1).copy()
ref1[1:9, 4:6] = inp1[1:9, 4:6]
ref0 = out0.asnumpy().copy()
ref0[1:9, 4:6] = inp0.asnumpy()[1:9, 4:6] + inp1.asnumpy()[1:9, 4:6]
ref1 = out1.asnumpy().copy()
ref1[1:9, 4:6] = inp1.asnumpy()[1:9, 4:6]

cases.verify(
cartesian_case,
Expand Down Expand Up @@ -995,7 +998,7 @@ def fieldop_where_k_offset(
)()
out = cases.allocate(cartesian_case, fieldop_where_k_offset, "inp")()

ref = np.where(np.asarray(k_index) > 0, np.roll(inp, 1, axis=1), 2)
ref = np.where(k_index.asnumpy() > 0, np.roll(inp.asnumpy(), 1, axis=1), 2)

cases.verify(cartesian_case, fieldop_where_k_offset, inp, k_index, out=out, ref=ref)

Expand Down Expand Up @@ -1119,13 +1122,6 @@ def _invalid_unpack() -> tuple[int32, float64, int32]:


def test_constant_closure_vars(cartesian_case):
if cartesian_case.backend is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is why we are doing the PR

# >>> field = gtx.zeros(domain)
# >>> np.int32(1)*field # steals the buffer from the field
# array([0.])

# TODO(havogt): remove `__array__`` from `NdArrayField`
pytest.xfail("Bug: Binary operation between np datatype and Field returns ndarray.")
from gt4py.eve.utils import FrozenNamespace

constants = FrozenNamespace(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ def conditional_nested_tuple(
b,
out=cases.allocate(cartesian_case, conditional_nested_tuple, cases.RETURN)(),
ref=np.where(
mask,
((a, b), (b, a)),
mask.asnumpy(),
((a.asnumpy(), b.asnumpy()), (b.asnumpy(), a.asnumpy())),
((np.full(size, 5.0), np.full(size, 7.0)), (np.full(size, 7.0), np.full(size, 5.0))),
),
)
Expand Down Expand Up @@ -219,7 +219,15 @@ def conditional(
b = cases.allocate(cartesian_case, conditional, "b")()
out = cases.allocate(cartesian_case, conditional, cases.RETURN)()

cases.verify(cartesian_case, conditional, mask, a, b, out=out, ref=np.where(mask, a, b))
cases.verify(
cartesian_case,
conditional,
mask,
a,
b,
out=out,
ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy()),
)


def test_conditional_promotion(cartesian_case):
Expand All @@ -231,10 +239,9 @@ def conditional_promotion(mask: cases.IBoolField, a: cases.IFloatField) -> cases
mask = cartesian_case.as_field([IDim], np.random.choice(a=[False, True], size=(size)))
a = cases.allocate(cartesian_case, conditional_promotion, "a")()
out = cases.allocate(cartesian_case, conditional_promotion, cases.RETURN)()
ref = np.where(mask.asnumpy(), a.asnumpy(), 10.0)

cases.verify(
cartesian_case, conditional_promotion, mask, a, out=out, ref=np.where(mask, a, 10.0)
)
cases.verify(cartesian_case, conditional_promotion, mask, a, out=out, ref=ref)


def test_conditional_compareop(cartesian_case):
Expand Down Expand Up @@ -279,7 +286,7 @@ def conditional_program(
b,
out,
inout=out,
ref=np.where(mask, a, b)[1:],
ref=np.where(mask.asnumpy(), a.asnumpy(), b.asnumpy())[1:],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp

builtin_field_op(*inps, out=out, offset_provider={})

assert np.allclose(np.asarray(out), expected)
assert np.allclose(out.asnumpy(), expected)
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def prog(

cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={})

assert np.allclose((a, b), (out_a, out_b))
assert np.allclose((a.asnumpy(), b.asnumpy()), (out_a.asnumpy(), out_b.asnumpy()))


def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case):
Expand All @@ -178,7 +178,9 @@ def prog(

cases.run(cartesian_case, prog, a, b, out_a, out_b, offset_provider={})

assert np.allclose((a[1:], b[1:]), (out_a[1:], out_b[1:]))
assert np.allclose(
(a[1:].asnumpy(), b[1:].asnumpy()), (out_a[1:].asnumpy(), out_b[1:].asnumpy())
)
assert out_a[0] == 0 and out_b[0] == 0


Expand Down Expand Up @@ -209,7 +211,9 @@ def prog(

cases.run(cartesian_case, prog, a, b, c, out_a, out_b, out_c, offset_provider={})

assert np.allclose((a, b, c), (out_a, out_b, out_c))
assert np.allclose(
(a.asnumpy(), b.asnumpy(), c.asnumpy()), (out_a.asnumpy(), out_b.asnumpy(), out_c.asnumpy())
)


def test_wrong_argument_type(cartesian_case, copy_program_def):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,10 @@ def if_without_else(
out = cases.allocate(cartesian_case, if_without_else, cases.RETURN)()

ref = {
(True, True): np.asarray(a) + 2,
(True, False): np.asarray(a),
(False, True): np.asarray(b) + 1,
(False, False): np.asarray(b) + 1,
(True, True): a.asnumpy() + 2,
(True, False): a.asnumpy(),
(False, True): b.asnumpy() + 1,
(False, False): b.asnumpy() + 1,
}

cases.verify(
Expand Down
Loading