Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: iterator-view support to DaCe backend #1790

Merged
merged 143 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 142 commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
df1847a
scan - working draft
edopao Nov 29, 2024
89ca8f7
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 3, 2024
f22eb64
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
c26d906
Improve utility functions for tuples
edopao Dec 4, 2024
ba0a9ba
Fix for empty field domain
edopao Dec 4, 2024
ac7acf8
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
877d81e
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
8baf6d1
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 4, 2024
784b573
Add exclusive if_ in dataflow
edopao Dec 5, 2024
de9c9de
Better handling of isolated nodes
edopao Dec 5, 2024
14e66e8
Fix field offset in nested SDFG context
edopao Dec 6, 2024
fcfaf72
fix problem with dereferencil of 1D vertical fields inside scan
edopao Dec 6, 2024
79204ee
generalize previous fix to all scan input fields
edopao Dec 6, 2024
5fe461a
minor edit
edopao Dec 6, 2024
a4bde3a
fix out-of-bound access
edopao Dec 6, 2024
c75a8e4
Better handling of isolated nodes
edopao Dec 6, 2024
6f72cac
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
397acae
exclude scan tests on dace backend with optimizations
edopao Dec 6, 2024
acf5ac0
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
a706b27
fix pre-commit
edopao Dec 6, 2024
c22cfc8
fix doctest
edopao Dec 6, 2024
59e0ed5
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 6, 2024
792a8eb
temporarily disable one optimize transformation
edopao Dec 9, 2024
61985f7
Revert "temporarily disable one optimize transformation"
edopao Dec 9, 2024
aa236a2
fix for scan output stride
edopao Dec 10, 2024
9bdc75b
fix previous commit
edopao Dec 10, 2024
746f9d8
converto scalar to array on nsdfg output
edopao Dec 10, 2024
0d894ff
Revert "converto scalar to array on nsdfg output"
edopao Dec 11, 2024
440a474
Split handling of let-statement lambdas from stencil body
edopao Dec 11, 2024
500590b
minor edit
edopao Dec 11, 2024
c56e062
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 12, 2024
5d5992a
use dace auto-optimize on gpu
edopao Dec 12, 2024
c167def
Merge remote-tracking branch 'origin/dace-gtir-scan' into dace-gtir-scan
edopao Dec 12, 2024
eb17345
Revert "use dace auto-optimize on gpu"
edopao Dec 12, 2024
8b163da
make map_strides recursive
edopao Dec 12, 2024
d15213a
rename module alias
edopao Dec 13, 2024
55811dc
review comments
edopao Dec 13, 2024
8f0e515
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 13, 2024
f01d291
add test case for sdfg transformation
edopao Dec 13, 2024
62e1648
review comments (1)
edopao Dec 16, 2024
72e8830
review comments (2)
edopao Dec 16, 2024
39aeb20
Merge branch 'dace-refact-lambda' into dace-gtir-scan
edopao Dec 16, 2024
de4a80e
review comments (2)
edopao Dec 16, 2024
45f9927
Merge remote-tracking branch 'origin/main' into dace-refact-lambda
edopao Dec 16, 2024
3fe538b
Merge remote-tracking branch 'origin/dace-refact-lambda' into dace-gt…
edopao Dec 16, 2024
ee62266
Merge remote-tracking branch 'origin/main' into dace-gtir-scan
edopao Dec 16, 2024
4b0ac60
Propagate strides to nested SDFG when changing transient strides
edopao Dec 16, 2024
f701605
rename function
edopao Dec 16, 2024
a19019f
fix bug
edopao Dec 16, 2024
c03492c
fix previous commit
edopao Dec 16, 2024
310fcce
Test commit
edopao Dec 16, 2024
4b487ea
propagate strides also to destination nested SDFG
edopao Dec 16, 2024
4cf66e7
fix previous commit (skip scalar inner nodes)
edopao Dec 16, 2024
ab7ee5f
fix - do not call free_symbols on int stride
edopao Dec 17, 2024
82cf491
run simplify before gpu transformations
edopao Dec 17, 2024
a0dbea5
undo renaming graph -> state
edopao Dec 17, 2024
9128ffb
increase slurm timeout to 20 minutes
edopao Dec 17, 2024
f940c4e
increase slurm timeout to 30 minutes
edopao Dec 17, 2024
cc0777b
minor edit
edopao Dec 17, 2024
462f3c5
exclude test_ternary_scan from gpu tests
edopao Dec 17, 2024
d9218b6
This are the changes Edoardo implemented to fix some issues in the op…
edopao Dec 17, 2024
9d7e722
First rework.
philip-paul-mueller Dec 18, 2024
1ddd6fe
Updated some commenst.
philip-paul-mueller Dec 18, 2024
95e0007
I want to ignore register, not only consider them.
philip-paul-mueller Dec 18, 2024
f1b7a3f
There was a missing `not` in the check.
philip-paul-mueller Dec 18, 2024
50ad620
Had to update the propagation, to also handle aliasing.
philip-paul-mueller Dec 18, 2024
983022c
In the function for looking for top level accesses the `only_transien…
philip-paul-mueller Dec 18, 2024
e7b1afb
Small reminder of the future.
philip-paul-mueller Dec 18, 2024
df7bd0c
Forgot to export the new SDFG stuff.
philip-paul-mueller Dec 18, 2024
363ab59
Had to update function for actuall renaming of the strides.
philip-paul-mueller Dec 18, 2024
9c19d32
Added a todo to the replacement function.
philip-paul-mueller Dec 18, 2024
9cad1f7
Added a first test to the propagation function.
philip-paul-mueller Dec 18, 2024
2700f53
Modified the function that performs the actuall modification of the s…
philip-paul-mueller Dec 19, 2024
a20d3c0
Updated some tes, but more are missing.
philip-paul-mueller Dec 19, 2024
b5ff462
Subset caching strikes again.
philip-paul-mueller Dec 19, 2024
d326d3b
It seems that the explicit handling of one dimensions is not working.
philip-paul-mueller Dec 19, 2024
252f348
The test must be moved bellow.
philip-paul-mueller Dec 19, 2024
49f8172
The symbol is also needed to be present in the nested SDFG.
philip-paul-mueller Dec 19, 2024
2d6dfc0
Fixed a bug in determining the free symbols that we need.
philip-paul-mueller Dec 19, 2024
6124c6d
Updated the propagation code for the symbols.
philip-paul-mueller Dec 19, 2024
45bcf97
Addressed Edoardo's changes.
philip-paul-mueller Dec 19, 2024
23b0baa
Updated how we get the type of symbols.
philip-paul-mueller Dec 19, 2024
ff05880
New restriction on the update of the symbol mapping.
philip-paul-mueller Dec 19, 2024
43ec33c
Updated the tests, now also made one that has tests for the symbol ma…
philip-paul-mueller Dec 19, 2024
d43153a
Fixed two bug in the stride propagation function.
philip-paul-mueller Dec 19, 2024
2e82bd5
Added a test that ensures that the dependent adding works.
philip-paul-mueller Dec 19, 2024
07e6a5c
Changed the default of `ignore_symbol_mapping` to `True`.
philip-paul-mueller Dec 19, 2024
4bf145b
Added Edoardo's comments.
philip-paul-mueller Dec 19, 2024
2b03bb4
Removed the creation of aliasing if symbol tables are ignored.
philip-paul-mueller Dec 20, 2024
40c225d
Added a test that shows that `ignore_symbol_mapping=False` does produ…
philip-paul-mueller Dec 20, 2024
419a386
Updated the description.
philip-paul-mueller Dec 20, 2024
cc9801b
Applied Edoardo's comment.
philip-paul-mueller Dec 20, 2024
360baae
Added a todo from Edoardo's suggestions.
philip-paul-mueller Dec 20, 2024
f2396c4
Merge remote-tracking branch 'philip/dace-gtir-better-strides' into d…
edopao Dec 20, 2024
a0c37cb
minor edit
edopao Dec 20, 2024
45c69ec
Merge branch 'main' into dace-gtir-scan
edopao Dec 20, 2024
0f9043b
fix for missing symbols in nested sdfg
edopao Dec 20, 2024
059a448
wip - fix iterator tests
edopao Dec 20, 2024
b8fe277
disable tests with sparse fields
edopao Jan 7, 2025
0dd4b4e
disable unsupported features
edopao Jan 8, 2025
2d3238c
fix for if_ lowering
edopao Jan 8, 2025
aec47c8
lowering of tuple_deref
edopao Jan 8, 2025
1f68857
lowering of tuple iterators
edopao Jan 8, 2025
c20728d
allow tuple fields with different size
edopao Jan 8, 2025
d9691a8
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 8, 2025
312e69c
add scan test marker
edopao Jan 8, 2025
65b4dd2
undo lowering of scan
edopao Jan 8, 2025
61b06b3
Minor edit based on review comments
edopao Jan 8, 2025
62a2a80
ignore atlas tests
edopao Jan 8, 2025
aa9f999
undo scan-related change
edopao Jan 8, 2025
7508e03
Minor edit based on review comments
edopao Jan 8, 2025
43f2e40
fix
edopao Jan 8, 2025
de203f9
Revert "undo scan-related change"
edopao Jan 8, 2025
ab11d77
fix previous commits
edopao Jan 8, 2025
7329c4b
update test skip list
edopao Jan 9, 2025
fab8288
fix gtir dace tests (add tuple symbols)
edopao Jan 9, 2025
46322ac
undo extra change
edopao Jan 9, 2025
2c1156b
remove support for tuple iterator
edopao Jan 9, 2025
ac24404
fix test marker
edopao Jan 9, 2025
a414eba
move 2 nested function definitions to separate helper functions
edopao Jan 9, 2025
015f69c
edit test markers
edopao Jan 9, 2025
f05a730
edit test markers
edopao Jan 9, 2025
2363b62
Revert "edit test markers"
edopao Jan 9, 2025
fd1462d
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 10, 2025
db94493
remove wrong assert
edopao Jan 10, 2025
f7b18b3
edit code comments
edopao Jan 10, 2025
87b5bd5
add tuple_get
edopao Jan 10, 2025
d93a387
better symbol mapping for lambda nested SDFG
edopao Jan 13, 2025
678b782
Revert "add tuple_get"
edopao Jan 13, 2025
eaaee4e
Merge remote-tracking branch 'origin/main' into dace-gtir-iterator_view
edopao Jan 13, 2025
d4599d2
address review comments
edopao Jan 13, 2025
4a82810
fix
edopao Jan 13, 2025
81b5fd3
fix subset num_elements
edopao Jan 13, 2025
ea44598
address review comments (1)
edopao Jan 14, 2025
981f2c7
fix test markers
edopao Jan 14, 2025
6035ccc
address review comments (2)
edopao Jan 14, 2025
bed7e0d
fix previous commit
edopao Jan 14, 2025
3529964
better tuple symbol tree
edopao Jan 15, 2025
974d643
rename sym_tree to symbol_tree
edopao Jan 15, 2025
6f4ff65
helper function add_temp_array
edopao Jan 15, 2025
46879b7
make _visit_if_branch_result separate function
edopao Jan 15, 2025
4437108
fix doc test
edopao Jan 15, 2025
d7671a7
address review comment (1)
edopao Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,31 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref builtin function',
'uses_composite_shifts: tests that use composite shifts in unstructured domain',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_floordiv: tests that require backend support for floor division',
'uses_if_stmts: tests that require backend support for if-statements',
'uses_index_fields: tests that require backend support for index fields',
'uses_ir_if_stmts',
'uses_lift: tests that require backend support for lift builtin function',
'uses_negative_modulo: tests that require backend support for modulo on negative numbers',
'uses_origin: tests that require backend support for domain origin',
'uses_reduce_with_lambda: tests that use lambdas as reduce functions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scalar_in_domain_and_fo',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_scan_in_stencil: tests that require backend support for scan in stencil',
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_scan_nested: tests that use nested scans',
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_iterator: tests that require backend support to deref tuple iterators',
'uses_tuple_returns: tests that require backend support for tuple results',
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from collections.abc import Mapping, Sequence
from typing import Any, Iterable
from typing import Any

import dace
import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next import common as gtx_common

from . import utility as dace_utils

Expand Down Expand Up @@ -46,10 +46,9 @@ def _convert_arg(arg: Any, sdfg_param: str) -> Any:

def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
return {
sdfg_param: _convert_arg(arg, sdfg_param)
for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True)
for sdfg_param, arg in zip(sdfg_params, args, strict=True)
}


Expand All @@ -73,17 +72,8 @@ def _get_shape_args(
for name, value in args.items():
for sym, size in zip(arrays[name].shape, value.shape, strict=True):
if isinstance(sym, dace.symbol):
if sym.name not in shape_args:
shape_args[sym.name] = size
elif shape_args[sym.name] != size:
# The same shape symbol is used by all fields of a tuple, because the current assumption is that all fields
# in a tuple have the same dimensions and sizes. Therefore, this if-branch only exists to ensure that array
# size (i.e. the value assigned to the shape symbol) is the same for all fields in a tuple.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# TODO(edopao): change to `assert sym.name not in shape_args` to ensure that shape symbols are unique,
# once the assumption on tuples is removed.
raise ValueError(
f"Expected array size {sym.name} for arg {name} to be {shape_args[sym.name]}, got {size}."
)
assert sym.name not in shape_args
shape_args[sym.name] = size
elif sym != size:
raise ValueError(
f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}."
Expand All @@ -103,15 +93,8 @@ def _get_stride_args(
f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)."
)
if isinstance(sym, dace.symbol):
if sym.name not in stride_args:
stride_args[str(sym)] = stride
elif stride_args[sym.name] != stride:
# See above comment in `_get_shape_args`, same for stride symbols of fields in a tuple.
# TODO(edopao): change to `assert sym.name not in stride_args` to ensure that stride symbols are unique,
# once the assumption on tuples is removed.
raise ValueError(
f"Expected array stride {sym.name} for arg {name} to be {stride_args[sym.name]}, got {stride}."
)
assert sym.name not in stride_args
stride_args[sym.name] = stride
elif sym != stride:
raise ValueError(
f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ctypes
import dataclasses
from typing import Any
from typing import Any, Sequence

import dace
import factory
Expand Down Expand Up @@ -112,11 +112,13 @@ def decorated_program(
) -> None:
if out is not None:
args = (*args, out)
if len(sdfg.arg_names) > len(args):
args = (*args, *arguments.iter_size_args(args))
flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
if len(sdfg.arg_names) > len(flat_args):
# The Ahead-of-Time (AOT) workflow for FieldView programs requires domain size arguments.
flat_args = (*flat_args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
kwargs = dict(zip(sdfg.arg_names, gtx_utils.flatten_nested_tuple(args), strict=True))
kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True))
kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu))

use_fast_call = True
Expand Down Expand Up @@ -151,7 +153,7 @@ def decorated_program(
sdfg_args = dace_backend.get_sdfg_args(
sdfg,
offset_provider,
*args,
*flat_args,
check_args=False,
on_gpu=on_gpu,
use_field_canonical_representation=use_field_canonical_representation,
Expand Down
Loading
Loading