Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dace-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Nov 16, 2023
2 parents 0ba4fd5 + da1da20 commit aa705ef
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 14 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_if_stmts: tests that require backend support for if-statements',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,29 @@ def convert_arg(arg: Any):
return arg


def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]):
program = apply_common_transforms(
def preprocess_program(
program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode
):
node = apply_common_transforms(
program,
offset_provider=offset_provider,
lift_mode=LiftMode.FORCE_INLINE,
common_subexpression_elimination=False,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=False,
)
return program
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
# In this case, just retry with unrolled reductions.
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
fencil_definition = node
else:
fencil_definition = apply_common_transforms(
program,
common_subexpression_elimination=False,
lift_mode=lift_mode,
offset_provider=offset_provider,
unroll_reduce=True,
)
return fencil_definition


def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -155,11 +170,14 @@ def get_cache_id(
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
build_cache = kwargs.get("build_cache", None)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
build_cache = kwargs.get("build_cache", None)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
lift_mode = (
LiftMode.FORCE_INLINE
) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
Expand All @@ -172,7 +190,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
sdfg = sdfg_program.sdfg
else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider)
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,29 @@ def builtin_neighbors(
return [ValueExpr(result_access, iterator.dtype)]


def builtin_can_deref(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
# first visit shift, to get set of indices for deref
can_deref_callable = node_args[0]
assert isinstance(can_deref_callable, itir.FunCall)
shift_callable = can_deref_callable.fun
assert isinstance(shift_callable, itir.FunCall)
assert isinstance(shift_callable.fun, itir.SymRef)
assert shift_callable.fun.id == "shift"
iterator = transformer._visit_shift(can_deref_callable)

# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " && ".join([f"{v} >= 0" for v in internals])

# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
return transformer.add_expr_tasklet(
list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref"
)


def builtin_if(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
Expand Down Expand Up @@ -318,11 +341,12 @@ def builtin_undefined(*args: Any) -> Any:
_GENERAL_BUILTIN_MAPPING: dict[
str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]]
] = {
"make_tuple": builtin_make_tuple,
"tuple_get": builtin_tuple_get,
"if_": builtin_if,
"can_deref": builtin_can_deref,
"cast_": builtin_cast,
"if_": builtin_if,
"make_tuple": builtin_make_tuple,
"neighbors": builtin_neighbors,
"tuple_get": builtin_tuple_get,
}


Expand Down
2 changes: 0 additions & 2 deletions tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
# Test markers
REQUIRES_ATLAS = "requires_atlas"
USES_APPLIED_SHIFTS = "uses_applied_shifts"
USES_CAN_DEREF = "uses_can_deref"
USES_CONSTANT_FIELDS = "uses_constant_fields"
USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets"
USES_IF_STMTS = "uses_if_stmts"
Expand Down Expand Up @@ -116,7 +115,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE),
]
DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
(USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE),
(USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def foo(a):


@pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted])
@pytest.mark.uses_can_deref
def test_can_deref(program_processor, stencil):
program_processor, validate = program_processor

Expand Down

0 comments on commit aa705ef

Please sign in to comment.