Skip to content

Commit

Permalink
More fixes after the merge
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Nov 16, 2023
1 parent bf378d0 commit f8b7bc7
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 21 deletions.
29 changes: 24 additions & 5 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

DimT = TypeVar("DimT", bound="Dimension")
DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True)
ValueT = TypeVar("ValueT", bound=Union[core_defs.Scalar, "Dimension"])


class Infinity(int):
Expand All @@ -63,6 +62,9 @@ def negative(cls) -> Infinity:
return cls(-sys.maxsize)


Tag: TypeAlias = str


@enum.unique
class DimensionKind(StrEnum):
HORIZONTAL = "horizontal"
Expand Down Expand Up @@ -442,18 +444,19 @@ def __gt_domain__(self) -> Domain:
...


ValueType: TypeAlias = core_defs.ScalarT | Dimension


@extended_runtime_checkable
class Field(
NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT, ValueT]
):
class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]

@property
def domain(self) -> Domain:
...

@property
def value_type(self) -> type[ValueT]:
def value_type(self) -> ValueType:
...

@property
Expand Down Expand Up @@ -610,6 +613,9 @@ def __abs__(self) -> Never:
def __neg__(self) -> Never:
raise TypeError("ConnectivityField does not support this operation")

def __invert__(self) -> Field:
raise TypeError("ConnectivityField does not support this operation")

def __add__(self, other: Field | DimT) -> Never:
raise TypeError("ConnectivityField does not support this operation")

Expand Down Expand Up @@ -643,6 +649,15 @@ def __rfloordiv__(self, other: Field | DimT) -> Never:
def __pow__(self, other: Field | DimT) -> Never:
raise TypeError("ConnectivityField does not support this operation")

def __and__(self, other: Field | core_defs.ScalarT) -> Field:
raise TypeError("ConnectivityField does not support this operation")

def __or__(self, other: Field | core_defs.ScalarT) -> Field:
raise TypeError("ConnectivityField does not support this operation")

def __xor__(self, other: Field | core_defs.ScalarT) -> Field:
raise TypeError("ConnectivityField does not support this operation")


@functools.singledispatch
def field(
Expand Down Expand Up @@ -690,6 +705,10 @@ class NeighborTable(Connectivity, Protocol):
table: npt.NDArray


OffsetProviderElem: TypeAlias = Dimension | Connectivity
OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem]


@dataclasses.dataclass(frozen=True, eq=True)
class CartesianConnectivity(ConnectivityField[DimsT, DimT]):
offset: int = 0
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/embedded/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from . import common, context, exceptions, nd_array_field


__all__ = [
"common",
"context",
"exceptions",
"nd_array_field",
]
2 changes: 2 additions & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

from typing import Any, Optional, Sequence, cast

from gt4py.next import common
Expand Down
50 changes: 50 additions & 0 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

from __future__ import annotations

import contextlib
import contextvars as cvars
import types

import gt4py.eve as eve
import gt4py.next.common as common


#: Column range used in column mode (`column_axis != None`) in the current embedded iterator
#: closure execution context.
closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range")

#: Offset provider dict in the current embedded execution context.
offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar(
"offset_provider", default=types.MappingProxyType({})
)


@contextlib.contextmanager
def new_context(
*,
closure_column_range: range = eve.NOTHING,
offset_provider: common.OffsetProvider = eve.NOTHING,
):
import gt4py.next.embedded.context as this_module

# Create new context with provided values
ctx = cvars.copy_context()
if closure_column_range is not eve.NOTHING:
this_module.closure_column_range.set(closure_column_range)
if offset_provider is not eve.NOTHING:
this_module.offset_provider.set(offset_provider)

yield ctx
16 changes: 10 additions & 6 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def ndarray(self) -> core_defs.NDArrayObject:
def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray:
return np.asarray(self._ndarray, dtype)

@property
def value_type(self) -> common.ValueType:
return None

@property
def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
return core_defs.dtype(self._ndarray.dtype.type)
Expand All @@ -135,16 +139,16 @@ def from_array(
/,
*,
domain: common.DomainLike,
dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike
dtype: Optional[core_defs.DTypeLike] = None,
) -> NdArrayField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type)
xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype_like is not None:
assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES)

Expand All @@ -164,7 +168,7 @@ def _dim_index(self, dim: common.Dimension):
return i
return None

def _compute_idx_array(self, r: common.UnitRange, connectivity) -> definitions.NDArrayObject:
def _compute_idx_array(self, r: common.UnitRange, connectivity) -> core_defs.NDArrayObject:
if hasattr(connectivity, "ndarray") and connectivity.ndarray is not None:
return NotImplemented # TODO
else:
Expand Down Expand Up @@ -208,7 +212,7 @@ def remap(self: NdArrayField, connectivity) -> NdArrayField:

new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)
# print(new_buffer)
return self.__class__.from_array(new_buffer, domain=new_domain, value_type=self.value_type)
return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)

# dim_idx = self.domain.tag_index(restricted_connectivity.value_type.tag)
# new_domain = self._domain.replace_at(dim_idx, restricted_connectivity.domain)
Expand Down
7 changes: 3 additions & 4 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators, common
from gt4py.next import allocators as next_allocators, common, embedded as next_embedded
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -285,9 +285,8 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No
if (
self.backend is None # and DEFAULT_BACKEND is None
): # TODO(havogt): for now enable embedded execution by setting DEFAULT_BACKEND to None
common.offset_provider = offset_provider
self.definition(*args, **kwargs)
common.offset_provider = None
with next_embedded.context.new_context(offset_provider=offset_provider) as ctx:
ctx.run(self.definition, *args, **kwargs)
return

rewritten_args, size_args, kwargs = self._process_args(args, kwargs)
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.next import common
from gt4py.next import common, embedded as next_embedded
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next.iterator import builtins, runtime

Expand All @@ -60,7 +60,7 @@


# Atoms
Tag: TypeAlias = str
Tag: TypeAlias = common.Tag

ArrayIndex: TypeAlias = slice | common.IntIndex
ArrayIndexOrIndices: TypeAlias = ArrayIndex | tuple[ArrayIndex, ...]
Expand Down Expand Up @@ -129,8 +129,8 @@ def mapped_index(
# Offsets
OffsetPart: TypeAlias = Tag | common.IntIndex
CompleteOffset: TypeAlias = tuple[Tag, common.IntIndex]
OffsetProviderElem: TypeAlias = common.Dimension | common.Connectivity
OffsetProvider: TypeAlias = dict[Tag, OffsetProviderElem]
OffsetProviderElem: TypeAlias = common.OffsetProviderElem
OffsetProvider: TypeAlias = common.OffsetProvider

# Positions
SparsePositionEntry = list[int]
Expand Down Expand Up @@ -195,9 +195,9 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None:


#: Column range used in column mode (`column_axis != None`) in the current closure execution context.
column_range_cvar: cvars.ContextVar[range] = cvars.ContextVar("column_range")
column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range
#: Offset provider dict in the current closure execution context.
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = cvars.ContextVar("offset_provider")
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider


class Column(np.lib.mixins.NDArrayOperatorsMixin):
Expand Down

0 comments on commit f8b7bc7

Please sign in to comment.