Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into enable_embedded_in_ff…
Browse files Browse the repository at this point in the history
…ront_tests
  • Loading branch information
havogt committed Nov 17, 2023
2 parents 7bc0689 + 39d1c09 commit d5b15c4
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 30 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_if_stmts: tests that require backend support for if-statements',
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def _np_cp_setitem(
_nd_array_implementations = [np]


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np

Expand All @@ -299,7 +299,7 @@ class NumPyArrayField(NdArrayField):
if cp:
_nd_array_implementations.append(cp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

Expand All @@ -311,7 +311,7 @@ class CuPyArrayField(NdArrayField):
if jnp:
_nd_array_implementations.append(jnp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = jnp

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def __call__(
)
else:
# "out" -> field_operator called from program in embedded execution
# TODO(egparedes) put offset_provider in ctxt var here when implementing remap
# TODO(egparedes): put offset_provider in ctxt var here when implementing remap
domain = kwargs.pop("domain", None)
res = self.definition(*args, **kwargs)
_tuple_assign_field(
Expand Down
13 changes: 10 additions & 3 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,16 @@ def where(


@BuiltInFunction
def astype(field: common.Field | core_defs.ScalarT, type_: type, /) -> common.Field:
assert core_defs.is_scalar_type(field)
return type_(field)
def astype(
value: Field | core_defs.ScalarT | Tuple,
type_: type,
/,
) -> Field | core_defs.ScalarT | Tuple:
if isinstance(value, tuple):
return tuple(astype(v, type_) for v in value)
# default implementation for scalars, Fields are handled via dispatch
assert core_defs.is_scalar_type(value)
return core_defs.dtype(type_).scalar_type(value)


UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"]
Expand Down
11 changes: 8 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType)
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)

return foast.Call(
Expand Down
35 changes: 29 additions & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:

def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
obj, new_type = node.args[0], node.args[1].id
return self._process_elements(
lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs
)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
Expand Down Expand Up @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs):

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)

def _process_elements(
self,
process_func: Callable[[itir.Expr], itir.Expr],
obj: foast.Expr,
current_el_type: ts.TypeSpec,
current_el_expr: itir.Expr = im.ref("expr"),
):
"""Recursively applies a processing function to all primitive constituents of a tuple."""
if isinstance(current_el_type, ts.TupleType):
# TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element.
return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*[
self._process_elements(
process_func,
obj,
current_el_type.types[i],
im.tuple_get(i, current_el_expr),
)
for i in range(len(current_el_type.types))
]
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj)


class FieldOperatorLoweringError(Exception):
...
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,29 @@ def convert_arg(arg: Any):
return arg


def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]):
program = apply_common_transforms(
def preprocess_program(
program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode
):
node = apply_common_transforms(
program,
offset_provider=offset_provider,
lift_mode=LiftMode.FORCE_INLINE,
common_subexpression_elimination=False,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=False,
)
return program
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
# In this case, just retry with unrolled reductions.
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
fencil_definition = node
else:
fencil_definition = apply_common_transforms(
program,
common_subexpression_elimination=False,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=True,
)
return fencil_definition


def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -156,11 +171,14 @@ def get_cache_id(
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
build_cache = kwargs.get("build_cache", None)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
build_cache = kwargs.get("build_cache", None)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
lift_mode = (
LiftMode.FORCE_INLINE
) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
Expand All @@ -173,7 +191,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
sdfg = sdfg_program.sdfg
else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider)
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ def builtin_neighbors(
return [ValueExpr(result_access, iterator.dtype)]


def builtin_can_deref(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
# first visit shift, to get set of indices for deref
can_deref_callable = node_args[0]
assert isinstance(can_deref_callable, itir.FunCall)
shift_callable = can_deref_callable.fun
assert isinstance(shift_callable, itir.FunCall)
assert isinstance(shift_callable.fun, itir.SymRef)
assert shift_callable.fun.id == "shift"
iterator = transformer._visit_shift(can_deref_callable)

# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " && ".join([f"{v} >= 0" for v in internals])

# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
return transformer.add_expr_tasklet(
list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref"
)


def builtin_if(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
Expand Down Expand Up @@ -318,11 +341,12 @@ def builtin_undefined(*args: Any) -> Any:
_GENERAL_BUILTIN_MAPPING: dict[
str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]]
] = {
"make_tuple": builtin_make_tuple,
"tuple_get": builtin_tuple_get,
"if_": builtin_if,
"can_deref": builtin_can_deref,
"cast_": builtin_cast,
"if_": builtin_if,
"make_tuple": builtin_make_tuple,
"neighbors": builtin_neighbors,
"tuple_get": builtin_tuple_get,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,76 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


@pytest.mark.uses_tuple_returns
def test_astype_on_tuples(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def field_op_returning_a_tuple(
a: cases.IFloatField, b: cases.IFloatField
) -> tuple[gtx.Field[[IDim], float], gtx.Field[[IDim], float]]:
tup = (a, b)
return tup

@gtx.field_operator
def cast_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_asint: cases.IField,
b_asint: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype(field_op_returning_a_tuple(a, b), int32)
return (
result[0] == a_asint,
result[1] == b_asint,
)

@gtx.field_operator
def cast_nested_tuple(
a: cases.IFloatField,
b: cases.IFloatField,
a_asint: cases.IField,
b_asint: cases.IField,
) -> tuple[gtx.Field[[IDim], bool], gtx.Field[[IDim], bool], gtx.Field[[IDim], bool]]:
result = astype((a, field_op_returning_a_tuple(a, b)), int32)
return (
result[0] == a_asint,
result[1][0] == a_asint,
result[1][1] == b_asint,
)

a = cases.allocate(cartesian_case, cast_tuple, "a")()
b = cases.allocate(cartesian_case, cast_tuple, "b")()
a_asint = gtx.as_field([IDim], np.asarray(a).astype(int32))
b_asint = gtx.as_field([IDim], np.asarray(b).astype(int32))
out_tuple = cases.allocate(cartesian_case, cast_tuple, cases.RETURN)()
out_nested_tuple = cases.allocate(cartesian_case, cast_nested_tuple, cases.RETURN)()

cases.verify(
cartesian_case,
cast_tuple,
a,
b,
a_asint,
b_asint,
out=out_tuple,
ref=(np.full_like(a, True, dtype=bool), np.full_like(b, True, dtype=bool)),
)

cases.verify(
cartesian_case,
cast_nested_tuple,
a,
b,
a_asint,
b_asint,
out=out_nested_tuple,
ref=(
np.full_like(a, True, dtype=bool),
np.full_like(a, True, dtype=bool),
np.full_like(b, True, dtype=bool),
),
)


def test_astype_bool_field(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], bool]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ def simple_astype(a: Field[[TDim], float64]):

def test_astype_wrong_value_type():
def simple_astype(a: Field[[TDim], float64]):
# we just use a tuple here but anything that is not a field or scalar works
return astype((1, 2), bool)
# we just use broadcast here but anything that is not a field, scalar or tuple thereof works
return astype(broadcast, bool)

with pytest.raises(errors.DSLError) as exc_info:
_ = FieldOperatorParser.apply_to_function(simple_astype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def foo(a):


@pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted])
@pytest.mark.uses_can_deref
def test_can_deref(program_processor, stencil):
program_processor, validate = program_processor

Expand Down

0 comments on commit d5b15c4

Please sign in to comment.