Skip to content

Commit

Permalink
feat[next]: extend embedded implementation of premap() (#1501)
Browse files Browse the repository at this point in the history
Extend the implementation of the `premap` field operation (previously
named `remap`, conceptually equivalent to a Contravariant Functor's
`contramap`) to support more efficient implementations of different use
cases depending on the contents of the connectivity field.

### Added
- `gt4py.eve`: new typing aliases and minor utilities

### Changed
- `gt4py.next.common`:
	- new typing aliases.
- small refactoring of `Domain` to support creation of subdomains via
slicing using the `.slice_at` attribute. The actual implementation comes
from the now deleted
`gt4py.next.embedded.nd_array_field._relative_ranges_to_domain()`
function.
	- refactor `ConnectivityKind` to represent all known use cases
- extend `CartesianConnectivity` to support translation and relocations
	- rename `remap` to `premap`

- `gt4py.next.embedded.nd_array_field`:
- full refactoring of `premap()` (old `remap`) and add usage
documentation
- some renamings (`_hypercube()` -> `_hyperslice()`,
`_compute_mask_ranges()` -> `_compute_mask_slices()`

### Removed

- `gt4py.next.embedded.nd_array_field`: `_relative_ranges_to_domain()`
function moved to an `Domain` attribute in `gt4py.next.common`

---------

Co-authored-by: Hannes Vogt <[email protected]>
  • Loading branch information
egparedes and havogt authored May 16, 2024
1 parent c89bd81 commit 723b81f
Show file tree
Hide file tree
Showing 10 changed files with 791 additions and 230 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import functools
import math
import numbers
from typing import overload

import numpy as np
import numpy.typing as npt
Expand All @@ -42,6 +41,7 @@
TypeVar,
Union,
cast,
overload,
)


Expand Down
12 changes: 10 additions & 2 deletions src/gt4py/eve/extended_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ def __dir__() -> List[str]:
return self_func.__cached_dir


_T = TypeVar("_T")

# -- Common type aliases --
NoArgsCallable = Callable[[], Any]

_A = TypeVar("_A", contravariant=True)
_R = TypeVar("_R", covariant=True)


class ArgsOnlyCallable(Protocol[_A, _R]):
def __call__(self, *args: _A) -> _R: ...


# -- Typing annotations --
if _sys.version_info >= (3, 9):
Expand Down Expand Up @@ -367,6 +372,9 @@ def has_type_parameters(cls: Type) -> bool:
return issubclass(cls, Generic) and len(getattr(cls, "__parameters__", [])) > 0 # type: ignore[arg-type] # Generic not considered as a class


_T = TypeVar("_T")


def get_actual_type(obj: _T) -> Type[_T]:
"""Return type of an object (also working for GenericAlias instances which pretend to be an actual type)."""
return StdGenericAliasType if isinstance(obj, StdGenericAliasType) else type(obj)
Expand Down
32 changes: 32 additions & 0 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from . import extended_typing as xtyping
from .extended_typing import (
Any,
ArgsOnlyCallable,
Callable,
Collection,
Dict,
Expand Down Expand Up @@ -84,6 +85,15 @@
T = TypeVar("T")


def first(iterable: Iterable[T], *, default: Union[T, NothingType] = NOTHING) -> T:
try:
return next(iter(iterable))
except StopIteration as error:
if default is not NOTHING:
return cast(T, default)
raise error


def isinstancechecker(type_info: Union[Type, Iterable[Type]]) -> Callable[[Any], bool]:
"""Return a callable object that checks if operand is an instance of `type_info`.
Expand Down Expand Up @@ -227,9 +237,31 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]:


_P = ParamSpec("_P")
_S = TypeVar("_S")
_T = TypeVar("_T")


@dataclasses.dataclass(frozen=True)
class IndexerCallable(Generic[_S, _T]):
"""
An indexer class applying the wrapped function to the index arguments.
Examples:
>>> indexer = IndexerCallable(lambda x: x**2)
>>> indexer[3]
9
>>> indexer = IndexerCallable(lambda a, b: a + b)
>>> indexer[3, 4]
7
"""

func: ArgsOnlyCallable[_S, _T]

def __getitem__(self, key: _S | Tuple[_S, ...]) -> _T:
return self.func(*key) if isinstance(key, tuple) else self.func(key)


class fluid_partial(functools.partial):
"""Create a `functools.partial` with support for multiple applications calling `.partial()`."""

Expand Down
Loading

0 comments on commit 723b81f

Please sign in to comment.