-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next] Embedded field remove __array__ #1366
Conversation
d7f31ee
to
8234894
Compare
return impl | ||
|
||
|
||
# TODO(havogt): consider moving to module like `field_utils` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am introducing this in #1365
@@ -1119,13 +1122,6 @@ def _invalid_unpack() -> tuple[int32, float64, int32]: | |||
|
|||
|
|||
def test_constant_closure_vars(cartesian_case): | |||
if cartesian_case.backend is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is why we are doing the PR
src/gt4py/next/utils.py
Outdated
|
||
def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: | ||
return isinstance(v, tuple) and all(isinstance(e, t) for e in v) | ||
|
||
|
||
def apply_to_tuple_elements(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def apply_to_tuple_elements(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: | |
def deep_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: |
maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either deep_map or tree_map (jax calls tree_map
to something similar)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll take tree_map then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of optional suggestions/comments but otherwise it looks good to me.
src/gt4py/next/common.py
Outdated
@@ -464,6 +464,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: | |||
def ndarray(self) -> core_defs.NDArrayObject: | |||
... | |||
|
|||
def asnumpy(self) -> np.ndarray: | |||
return np.asarray(self.ndarray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be default implementation or just leave it blank?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot to remove...
src/gt4py/next/utils.py
Outdated
|
||
def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: | ||
return isinstance(v, tuple) and all(isinstance(e, t) for e in v) | ||
|
||
|
||
def apply_to_tuple_elements(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either deep_map or tree_map (jax calls tree_map
to something similar)
Changes to the quickstart guide to use `field.asnumpy()` (introduced in #1366) instead of `np.asarray(field)`. The quickstart guide is still broken though since the embedded backend (used by default) does not support skip neighbors connectivities.
Add
.asnumpy
toField
. Implicit conversion via__array__
creates a problem, because expressionnp.float*field
will return ndarray instead of field, becausenp.float
's multiply operator willasarray(rhs)
.