Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: itir embedded: cleaner closure run #1521

Merged
merged 7 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
#: closure execution context.
closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range")

_undefined_offset_provider: common.OffsetProvider = {}

#: Offset provider dict in the current embedded execution context.
offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar(
"offset_provider", default=_undefined_offset_provider
)
offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar("offset_provider")


@contextlib.contextmanager
Expand All @@ -41,6 +37,8 @@ def new_context(
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
) -> Generator[cvars.Context, None, None]:
"""Create a new context, updating the provided values."""

import gt4py.next.embedded.context as this_module

updates: list[tuple[cvars.ContextVar[Any], Any]] = []
Expand All @@ -62,4 +60,4 @@ def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None:


def within_context() -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename it to something like within_valid_context() and add a docstring to this function also.

return offset_provider.get() is not _undefined_offset_provider
return offset_provider.get(eve.NOTHING) is not eve.NOTHING
124 changes: 63 additions & 61 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import abc
import contextvars as cvars
import copy
import dataclasses
import itertools
Expand All @@ -28,6 +27,7 @@
import numpy as np
import numpy.typing as npt

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import (
Expand All @@ -52,8 +52,8 @@
overload,
runtime_checkable,
)
from gt4py.next import common, embedded as next_embedded
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next import common
from gt4py.next.embedded import context as embedded_context, exceptions as embedded_exceptions
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime

Expand Down Expand Up @@ -191,12 +191,6 @@ class MutableLocatedField(LocatedField, Protocol):
def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ...


#: Column range used in column mode (`column_axis != None`) in the current closure execution context.
column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range
#: Offset provider dict in the current closure execution context.
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider


class Column(np.lib.mixins.NDArrayOperatorsMixin):
"""Represents a column when executed in column mode (`column_axis != None`).

Expand All @@ -207,7 +201,7 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673
column_range: common.NamedRange = column_range_cvar.get()
column_range: common.NamedRange = embedded_context.closure_column_range.get()
self.data = (
data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data)
)
Expand Down Expand Up @@ -751,7 +745,7 @@ def _make_tuple(
except embedded_exceptions.IndexOutOfBounds:
return _UNDEFINED
else:
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
assert column_range is not None

col: list[
Expand Down Expand Up @@ -796,7 +790,7 @@ class MDIterator:

def shift(self, *offsets: OffsetPart) -> MDIterator:
complete_offsets = group_offsets(*offsets)
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
return MDIterator(
self.field,
Expand All @@ -821,8 +815,8 @@ def deref(self) -> Any:
if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None):
raise IndexError("Iterator position doesn't point to valid location for its field.")
slice_column = dict[Tag, range]()
column_range = column_range_cvar.get()
if self.column_axis is not None:
column_range = embedded_context.closure_column_range.get()
assert column_range is not None
k_pos = shifted_pos.pop(self.column_axis)
assert isinstance(k_pos, int)
Expand Down Expand Up @@ -862,7 +856,7 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert column_range is not None
new_pos[column_axis] = column_range.start
Expand Down Expand Up @@ -1303,7 +1297,7 @@ def __getitem__(self, _):
def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
offset_str = offset.value if isinstance(offset, runtime.Offset) else offset
assert isinstance(offset_str, str)
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
connectivity = offset_provider[offset_str]
assert isinstance(connectivity, common.Connectivity)
Expand Down Expand Up @@ -1359,7 +1353,7 @@ class SparseListIterator:
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
connectivity = offset_provider[self.list_offset]
assert isinstance(connectivity, common.Connectivity)
Expand All @@ -1376,12 +1370,6 @@ def shift(self, *offsets: OffsetPart) -> SparseListIterator:
return SparseListIterator(self.it, self.list_offset, offsets=[*offsets, *self.offsets])


@dataclasses.dataclass(frozen=True)
class ColumnDescriptor:
axis: str
col_range: range # TODO(havogt) introduce range type that doesn't have step


@dataclasses.dataclass(frozen=True)
class ScanArgIterator:
wrapped_iter: ItIterator
Expand Down Expand Up @@ -1480,7 +1468,7 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
if column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")

Expand Down Expand Up @@ -1508,64 +1496,78 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None:
)


def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("'offset_provider' not provided.")

offset_provider = kwargs["offset_provider"]

@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
sten: Callable[..., Any],
out, #: MutableLocatedField,
ins: list[common.Field],
) -> None:
_validate_domain(domain_, kwargs["offset_provider"])
domain: dict[Tag, range] = _dimension_to_tag(domain_)
if not (isinstance(out, common.Field) or is_tuple_of_field(out)):
raise TypeError("'Out' needs to be a located field.")

column_range = None
column: Optional[ColumnDescriptor] = None
if kwargs.get("column_axis") and kwargs["column_axis"].value in domain:
column_axis = kwargs["column_axis"]
column = ColumnDescriptor(column_axis.value, domain[column_axis.value])
del domain[column_axis.value]

@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
sten: Callable[..., Any],
out, #: MutableLocatedField,
egparedes marked this conversation as resolved.
Show resolved Hide resolved
ins: list[common.Field],
) -> None:
assert embedded_context.within_context()
offset_provider = embedded_context.offset_provider.get()
_validate_domain(domain_, offset_provider)
domain: dict[Tag, range] = _dimension_to_tag(domain_)
if not (isinstance(out, common.Field) or is_tuple_of_field(out)):
raise TypeError("'Out' needs to be a located field.")

column_range: common.NamedRange | eve.NothingType = eve.NOTHING
if (col_range_placeholder := embedded_context.closure_column_range.get(None)) is not None:
assert (
col_range_placeholder.unit_range.is_empty()
) # check it's just the placeholder with empty range
column_axis = col_range_placeholder.dim
if column_axis is not None and column_axis.value in domain:
column_range = common.NamedRange(
column_axis, common.UnitRange(column.col_range.start, column.col_range.stop)
column_axis,
common.UnitRange(domain[column_axis.value].start, domain[column_axis.value].stop),
)
del domain[column_axis.value]

out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out)
out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out)

def _closure_runner():
# Set context variables before executing the closure
column_range_cvar.set(column_range)
offset_provider_cvar.set(offset_provider)
with embedded_context.new_context(closure_column_range=column_range) as ctx:

def _iterate():
for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(inp, pos, column_axis=column.axis if column else None)
make_in_iterator(
inp,
pos,
column_axis=column_range.dim.value
if column_range is not eve.NOTHING
else None,
)
for inp in promoted_ins
)
res = sten(*ins_iters)

if column is None:
if column_range is eve.NOTHING:
assert _is_concrete_position(pos)
out.field_setitem(pos, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
for k in column_range.unit_range:
col_pos[column_range.dim.value] = k
assert _is_concrete_position(col_pos)
out.field_setitem(col_pos, res[k])

ctx = cvars.copy_context()
ctx.run(_closure_runner)
ctx.run(_iterate)


def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("'offset_provider' not provided.")

context_vars = {"offset_provider": kwargs["offset_provider"]}
if "column_axis" in kwargs:
context_vars["closure_column_range"] = common.NamedRange(
kwargs["column_axis"],
common.UnitRange(0, 0), # empty: indicates column operation, will update later
)

fun(*args)
with embedded_context.new_context(**context_vars) as ctx:
ctx.run(fun, *args)


runtime.fendef_embedded = fendef_embedded
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import contextvars as cvars
import threading
from typing import Any, Callable, Optional

import numpy as np
import pytest

from gt4py.next import common
from gt4py.next.embedded import context as embedded_context
from gt4py.next.iterator import embedded


Expand All @@ -30,8 +30,8 @@ def _run_within_context(
offset_provider: Optional[embedded.OffsetProvider] = None,
) -> Any:
def wrapped_func():
embedded.column_range_cvar.set(column_range)
embedded.offset_provider_cvar.set(offset_provider)
embedded_context.closure_column_range.set(column_range)
embedded_context.offset_provider.set(offset_provider)
func()

cvars.copy_context().run(wrapped_func)
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_func(data_a: int, data_b: int):
assert res.kstart == 1

# Setting an invalid column_range here shouldn't affect other contexts
embedded.column_range_cvar.set(range(2, 999))
embedded_context.closure_column_range.set(range(2, 999))
_run_within_context(
lambda: test_func(2, 3),
column_range=common.NamedRange(
Expand Down
Loading