Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: lowering of scan to SDFG #1776

Open
wants to merge 97 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
df1847a
scan - working draft
edopao Nov 29, 2024
89ca8f7
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 3, 2024
f22eb64
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
c26d906
Improve utility functions for tuples
edopao Dec 4, 2024
ba0a9ba
Fix for empty field domain
edopao Dec 4, 2024
ac7acf8
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
877d81e
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
8baf6d1
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
784b573
Add exclusive if_ in dataflow
edopao Dec 5, 2024
de9c9de
Better handling of isolated nodes
edopao Dec 5, 2024
14e66e8
Fix field offset in nested SDFG context
edopao Dec 6, 2024
fcfaf72
fix problem with dereferencil of 1D vertical fields inside scan
edopao Dec 6, 2024
79204ee
generalize previous fix to all scan input fields
edopao Dec 6, 2024
5fe461a
minor edit
edopao Dec 6, 2024
a4bde3a
fix out-of-bound access
edopao Dec 6, 2024
c75a8e4
Better handling of isolated nodes
edopao Dec 6, 2024
6f72cac
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
397acae
exclude scan tests on dace backend with optimizations
edopao Dec 6, 2024
acf5ac0
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
a706b27
fix pre-commit
edopao Dec 6, 2024
c22cfc8
fix doctest
edopao Dec 6, 2024
59e0ed5
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
792a8eb
temporarily disable one optimize transformation
edopao Dec 9, 2024
61985f7
Revert "temporarily disable one optimize transformation"
edopao Dec 9, 2024
aa236a2
fix for scan output stride
edopao Dec 10, 2024
9bdc75b
fix previous commit
edopao Dec 10, 2024
746f9d8
converto scalar to array on nsdfg output
edopao Dec 10, 2024
0d894ff
Revert "converto scalar to array on nsdfg output"
edopao Dec 11, 2024
440a474
Split handling of let-statement lambdas from stencil body
edopao Dec 11, 2024
500590b
minor edit
edopao Dec 11, 2024
c56e062
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 12, 2024
5d5992a
use dace auto-optimize on gpu
edopao Dec 12, 2024
c167def
Merge remote-tracking branch 'origin/dace-gtir-scan' into dace-gtir-scan
edopao Dec 12, 2024
eb17345
Revert "use dace auto-optimize on gpu"
edopao Dec 12, 2024
8b163da
make map_strides recursive
edopao Dec 12, 2024
d15213a
rename module alias
edopao Dec 13, 2024
55811dc
review comments
edopao Dec 13, 2024
8f0e515
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 13, 2024
f01d291
add test case for sdfg transformation
edopao Dec 13, 2024
62e1648
review comments (1)
edopao Dec 16, 2024
72e8830
review comments (2)
edopao Dec 16, 2024
39aeb20
Merge branch 'dace-refact-lambda' into dace-gtir-scan
edopao Dec 16, 2024
de4a80e
review comments (2)
edopao Dec 16, 2024
45f9927
Merge remote-tracking branch 'origin/main' into dace-refact-lambda
edopao Dec 16, 2024
3fe538b
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 16, 2024
ee62266
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 16, 2024
4b0ac60
Propagate strides to nested SDFG when changing transient strides
edopao Dec 16, 2024
f701605
rename function
edopao Dec 16, 2024
a19019f
fix bug
edopao Dec 16, 2024
c03492c
fix previous commit
edopao Dec 16, 2024
310fcce
Test commit
edopao Dec 16, 2024
4b487ea
propagate strides also to destination nested SDFG
edopao Dec 16, 2024
4cf66e7
fix previous commit (skip scalar inner nodes)
edopao Dec 16, 2024
ab7ee5f
fix - do not call free_symbols on int stride
edopao Dec 17, 2024
82cf491
run simplify before gpu transformations
edopao Dec 17, 2024
a0dbea5
undo renaming graph -> state
edopao Dec 17, 2024
9128ffb
increase slurm timeout to 20 minutes
edopao Dec 17, 2024
f940c4e
increase slurm timeout to 30 minutes
edopao Dec 17, 2024
cc0777b
minor edit
edopao Dec 17, 2024
462f3c5
exclude test_ternary_scan from gpu tests
edopao Dec 17, 2024
d9218b6
This are the changes Edoardo implemented to fix some issues in the op…
edopao Dec 17, 2024
9d7e722
First rework.
philip-paul-mueller Dec 18, 2024
1ddd6fe
Updated some commenst.
philip-paul-mueller Dec 18, 2024
95e0007
I want to ignore register, not only consider them.
philip-paul-mueller Dec 18, 2024
f1b7a3f
There was a missing `not` in the check.
philip-paul-mueller Dec 18, 2024
50ad620
Had to update the propagation, to also handle aliasing.
philip-paul-mueller Dec 18, 2024
983022c
In the function for looking for top level accesses the `only_transien…
philip-paul-mueller Dec 18, 2024
e7b1afb
Small reminder of the future.
philip-paul-mueller Dec 18, 2024
df7bd0c
Forgot to export the new SDFG stuff.
philip-paul-mueller Dec 18, 2024
363ab59
Had to update function for actuall renaming of the strides.
philip-paul-mueller Dec 18, 2024
9c19d32
Added a todo to the replacement function.
philip-paul-mueller Dec 18, 2024
9cad1f7
Added a first test to the propagation function.
philip-paul-mueller Dec 18, 2024
2700f53
Modified the function that performs the actuall modification of the s…
philip-paul-mueller Dec 19, 2024
a20d3c0
Updated some tes, but more are missing.
philip-paul-mueller Dec 19, 2024
b5ff462
Subset caching strikes again.
philip-paul-mueller Dec 19, 2024
d326d3b
It seems that the explicit handling of one dimensions is not working.
philip-paul-mueller Dec 19, 2024
252f348
The test must be moved bellow.
philip-paul-mueller Dec 19, 2024
49f8172
The symbol is also needed to be present in the nested SDFG.
philip-paul-mueller Dec 19, 2024
2d6dfc0
Fixed a bug in determining the free symbols that we need.
philip-paul-mueller Dec 19, 2024
6124c6d
Updated the propagation code for the symbols.
philip-paul-mueller Dec 19, 2024
45bcf97
Addressed Edoardo's changes.
philip-paul-mueller Dec 19, 2024
23b0baa
Updated how we get the type of symbols.
philip-paul-mueller Dec 19, 2024
ff05880
New restriction on the update of the symbol mapping.
philip-paul-mueller Dec 19, 2024
43ec33c
Updated the tests, now also made one that has tests for the symbol ma…
philip-paul-mueller Dec 19, 2024
d43153a
Fixed two bug in the stride propagation function.
philip-paul-mueller Dec 19, 2024
2e82bd5
Added a test that ensures that the dependent adding works.
philip-paul-mueller Dec 19, 2024
07e6a5c
Changed the default of `ignore_symbol_mapping` to `True`.
philip-paul-mueller Dec 19, 2024
4bf145b
Added Edoardo's comments.
philip-paul-mueller Dec 19, 2024
2b03bb4
Removed the creation of aliasing if symbol tables are ignored.
philip-paul-mueller Dec 20, 2024
40c225d
Added a test that shows that `ignore_symbol_mapping=False` does produ…
philip-paul-mueller Dec 20, 2024
419a386
Updated the description.
philip-paul-mueller Dec 20, 2024
cc9801b
Applied Edoardo's comment.
philip-paul-mueller Dec 20, 2024
360baae
Added a todo from Edoardo's suggestions.
philip-paul-mueller Dec 20, 2024
f2396c4
Merge remote-tracking branch 'philip/dace-gtir-better-strides' into d…
edopao Dec 20, 2024
a0c37cb
minor edit
edopao Dec 20, 2024
45c69ec
Merge branch 'main' into dace-gtir-scan
edopao Dec 20, 2024
0f9043b
fix for missing symbols in nested sdfg
edopao Dec 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ markers = [
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_scan_nested: tests that use nested scans',
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
'uses_scan_1d_field: tests scan on a 1D vertical field',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

101 changes: 64 additions & 37 deletions src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union

import dace
from dace.sdfg import utils as dace_sdfg_utils

from gt4py import eve
from gt4py.eve import concepts
Expand Down Expand Up @@ -111,6 +112,21 @@ def get_symbol_type(self, symbol_name: str) -> ts.DataType:
"""Retrieve the GT4Py type of a symbol used in the SDFG."""
...

@abc.abstractmethod
def is_column_dimension(self, dim: gtx_common.Dimension) -> bool:
"""Check if the given dimension is the column dimension."""
...
Comment on lines +116 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a dimension is a column, should not depend on the SDFG builder.
Why is it an abstract method, I would have guessed it is more a static one or does life in some type traits package.


@abc.abstractmethod
def nested_context(
self,
sdfg: dace.SDFG,
global_symbols: dict[str, ts.DataType],
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]],
) -> SDFGBuilder:
"""Create a new empty context, useful to build a nested SDFG."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expand on this a little bit more, as handling neste SDFG are probably the most important if not the only application of this function.

...

@abc.abstractmethod
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any:
"""Visit a node of the GT4Py IR."""
Expand Down Expand Up @@ -149,15 +165,6 @@ def _collect_symbols_in_domain_expressions(
)


def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType:
"""
Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes.
"""
return ts.TupleType(
types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data]
)


@dataclasses.dataclass(frozen=True)
class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder):
"""Provides translation capability from a GTIR program to a DaCe SDFG.
Expand All @@ -173,6 +180,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder):
"""

offset_provider_type: gtx_common.OffsetProviderType
column_dim: Optional[gtx_common.Dimension]
global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {})
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field(
default_factory=lambda: {}
Expand All @@ -199,6 +207,25 @@ def make_field(
def get_symbol_type(self, symbol_name: str) -> ts.DataType:
return self.global_symbols[symbol_name]

def is_column_dimension(self, dim: gtx_common.Dimension) -> bool:
assert self.column_dim
return dim == self.column_dim

def nested_context(
self,
sdfg: dace.SDFG,
global_symbols: dict[str, ts.DataType],
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]],
) -> SDFGBuilder:
Comment on lines +214 to +219
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also modify all places where you create nested SDFG such that they use this function.
//If you do and I just did not see the changes ignore this comment.

nsdfg_builder = GTIRToSDFG(
self.offset_provider_type, self.column_dim, global_symbols, field_offsets
)
nsdfg_params = [
gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items()
]
nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments=None)
return nsdfg_builder

def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str:
nsdfg_list = [
nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix)
Expand Down Expand Up @@ -277,10 +304,11 @@ def _add_storage(
"""
if isinstance(gt_type, ts.TupleType):
tuple_fields = []
for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True):
for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type):
assert isinstance(sym.type, ts.DataType)
tuple_fields.extend(
self._add_storage(
sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name
sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name
)
)
return tuple_fields
Expand Down Expand Up @@ -379,7 +407,7 @@ def _add_sdfg_params(
self,
sdfg: dace.SDFG,
node_params: Sequence[gtir.Sym],
symbolic_arguments: set[str],
symbolic_arguments: Optional[set[str]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does accepting None has for applications.

) -> list[str]:
"""
Helper function to add storage for node parameters and connectivity tables.
Expand All @@ -389,6 +417,9 @@ def _add_sdfg_params(
except when they are listed in 'symbolic_arguments', in which case they
will be represented in the SDFG as DaCe symbols.
"""
if symbolic_arguments is None:
symbolic_arguments = set()

# add non-transient arrays and/or SDFG symbols for the program arguments
sdfg_args = []
for param in node_params:
Expand Down Expand Up @@ -436,7 +467,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG:
assert len(self.field_offsets) == 0

sdfg = dace.SDFG(node.id)
sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo)
sdfg.debuginfo = dace_utils.debug_info(node)

# DaCe requires C-compatible strings for the names of data containers,
# such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name
Expand Down Expand Up @@ -619,32 +650,23 @@ def visit_Lambda(
(str(param.id), arg) for param, arg in zip(node.params, args, strict=True)
]

def flatten_tuples(
name: str,
arg: gtir_builtin_translators.FieldopResult,
) -> list[tuple[str, gtir_builtin_translators.FieldopData]]:
if isinstance(arg, tuple):
tuple_type = _get_tuple_type(arg)
tuple_field_names = [
arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type)
]
tuple_args = zip(tuple_field_names, arg, strict=True)
return list(
itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args])
)
else:
return [(name, arg)]

lambda_arg_nodes = dict(
itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping])
itertools.chain(
*[
gtir_builtin_translators.flatten_tuples(pname, arg)
for pname, arg in lambda_args_mapping
]
)
)

# inherit symbols from parent scope but eventually override with local symbols
lambda_symbols = {
sym: self.global_symbols[sym]
for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys())
} | {
pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type
pname: gtir_builtin_translators.get_tuple_type(arg)
if isinstance(arg, tuple)
else arg.gt_type
for pname, arg in lambda_args_mapping
}

Expand All @@ -659,12 +681,12 @@ def get_field_domain_offset(
elif field_domain_offset := self.field_offsets.get(p_name, None):
return {p_name: field_domain_offset}
elif isinstance(p_type, ts.TupleType):
p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True)
tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type)
return functools.reduce(
lambda field_offsets, field: (
field_offsets | get_field_domain_offset(field[0], field[1])
lambda field_offsets, sym: (
field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type]
),
p_fields,
tsyms,
{},
)
return {}
Expand All @@ -676,7 +698,7 @@ def get_field_domain_offset(

# lower let-statement lambda node as a nested SDFG
lambda_translator = GTIRToSDFG(
self.offset_provider_type, lambda_symbols, lambda_field_offsets
self.offset_provider_type, self.column_dim, lambda_symbols, lambda_field_offsets
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you should use your nested context I guess.

)
nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda"))
nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo)
Expand Down Expand Up @@ -853,6 +875,7 @@ def visit_SymRef(
def build_sdfg_from_gtir(
ir: gtir.Program,
offset_provider_type: gtx_common.OffsetProviderType,
column_dim: Optional[gtx_common.Dimension] = None,
) -> dace.SDFG:
"""
Receives a GTIR program and lowers it to a DaCe SDFG.
Expand All @@ -863,15 +886,19 @@ def build_sdfg_from_gtir(
Args:
ir: The GTIR program node to be lowered to SDFG
offset_provider_type: The definitions of offset providers used by the program node
column_dim: Vertical dimension used for scan expressions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call that thing scan_dim, because column_dim is just to broad and you are only using it in the scan context.
This let me wonder, why does the scan node specify its own scan dimension?


Returns:
An SDFG in the DaCe canonical form (simplified)
"""

ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type)
ir = ir_prune_casts.PruneCasts().visit(ir)
sdfg_genenerator = GTIRToSDFG(offset_provider_type)
sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_dim)
sdfg = sdfg_genenerator.visit(ir)
assert isinstance(sdfg, dace.SDFG)

# TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct
dace_sdfg_utils.inline_loop_blocks(sdfg)

return sdfg
56 changes: 34 additions & 22 deletions src/gt4py/next/program_processors/runners/dace_fieldview/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from __future__ import annotations

import itertools
from typing import Dict, TypeVar

import dace

from gt4py import eve
from gt4py.next import common as gtx_common
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts


Expand All @@ -27,35 +27,47 @@ def get_map_variable(dim: gtx_common.Dimension) -> str:
return f"i_{dim.value}_gtx_{dim.kind}{suffix}"


def get_tuple_fields(
tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False
) -> list[tuple[str, ts.DataType]]:
def make_symbol_tuple(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]:
"""
Creates a list of names with the corresponding data type for all elements of the given tuple.
Creates a tuple representation of the symbols corresponding to the tuple fields.
The constructed tuple preserves the nested nature of the type, if any.

Examples
--------
>>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32)
>>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))
>>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])])
>>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))]
>>> assert get_tuple_fields("a", t, flatten=True) == [
... ("a_0", sty),
... ("a_1_0", fty),
... ("a_1_1", sty),
... ]
>>> assert make_symbol_tuple("a", t) == (
... im.sym("a_0", sty),
... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)),
... )
"""
fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)]
if flatten:
expanded_fields = [
get_tuple_fields(field_name, field_type)
if isinstance(field_type, ts.TupleType)
else [(field_name, field_type)]
for field_name, field_type in fields
]
return list(itertools.chain(*expanded_fields))
else:
return fields
return tuple(
make_symbol_tuple(field_name, field_type) # type: ignore[misc]
if isinstance(field_type, ts.TupleType)
else im.sym(field_name, field_type)
for field_name, field_type in fields
)


def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]:
"""
Creates a list of symbols, annotated with the data type, for all elements of the given tuple.

Examples
--------
>>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32)
>>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))
>>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])])
>>> assert flatten_tuple_fields("a", t) == [
... im.sym("a_0", sty),
... im.sym("a_1_0", fty),
... im.sym("a_1_1", sty),
... ]
"""
symbol_tuple = make_symbol_tuple(tuple_name, tuple_type)
return list(gtx_utils.flatten_nested_tuple(symbol_tuple))


def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def generate_sdfg(
self,
ir: itir.Program,
offset_provider: common.OffsetProvider,
column_axis: Optional[common.Dimension],
column_dim: Optional[common.Dimension],
auto_opt: bool,
on_gpu: bool,
) -> dace.SDFG:
ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider)
sdfg = gtir_sdfg.build_sdfg_from_gtir(
ir, offset_provider_type=common.offset_provider_to_type(offset_provider)
ir, common.offset_provider_to_type(offset_provider), column_dim
)

if auto_opt:
Expand Down
14 changes: 11 additions & 3 deletions tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args"
USES_SCAN_NESTED = "uses_scan_nested"
USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector"
USES_SCAN_1D_FIELD = "uses_scan_1d_field"
USES_SPARSE_FIELDS = "uses_sparse_fields"
USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output"
USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields"
Expand Down Expand Up @@ -134,7 +135,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
]
DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
]
EMBEDDED_SKIP_LIST = [
Expand Down Expand Up @@ -169,9 +169,17 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST,
EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST,
OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST
+ [
# dace issue https://github.com/spcl/dace/issues/1773
(USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE),
],
OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST
+ [
# dace issue https://github.com/spcl/dace/issues/1773
(USES_SCAN_1D_FIELD, XFAIL, UNSUPPORTED_MESSAGE),
],
ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST
+ [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)],
ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField:


@pytest.mark.uses_scan
@pytest.mark.uses_scan_1d_field
def test_ternary_scan(cartesian_case):
@gtx.scan_operator(axis=KDim, forward=True, init=0.0)
def simple_scan_operator(carry: float, a: float) -> float:
Expand Down
Loading