Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: SDFGConvertible Program for dace_fieldview backend #1742

Merged
merged 27 commits into from
Jan 13, 2025

Conversation

DropD
Copy link
Contributor

@DropD DropD commented Nov 19, 2024

Description

Add a decrator.Program subclass, which implements SDFGConvertible to dace_fieldview backend, analogous to the one in dace_iterator. Conditionally shadow decorator.Program with it and reactivate the orchestration tests by using dace_fieldview instead of dace_iterator.

One caveat: The toolchain is not ready for pure CompileTimeConnectivities in all cases yet, so the test_sdfgConvertible_connectivities had to be adjusted for the moment.

Requirements

  • All fixes and/or new features come with corresponding tests.
  • Important design decisions have been documented in the approriate ADR inside the docs/development/ADRs/ folder.

If this PR contains code authored by new contributors please make sure:

  • The PR contains an updated version of the AUTHORS.md file adding the names of all the new contributors.

src/gt4py/next/ffront/decorator.py Show resolved Hide resolved
if conn_id not in self.connectivity_tables_data_descriptors:
conn = self.connectivities[name]
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
dtype=dace.int64 if conn.index_type == np.int64 else dace.int32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that dace dtypes has a utility to parse numpy types:
dace.dtypes.typeclass(conn.index_type)

Copy link
Contributor Author

@DropD DropD Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying that lead to KeyError: dtype('int64') in typeclass().

Copy link
Contributor Author

@DropD DropD Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and trying dace.dtypes.typeclass(conn.index_type.type) leads to wrong stencil results. Edit: this might have instead been caused by auto_optimize=False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: after merging main, it passes. However according to type hints of NeighborTableOffsetProvider, your version should work and mine should fail. Yet, in the tests, the connectivities are set up in a way so that my version works and yours fails. The current version on the other hand is safe in both cases.

If I can get the test to produce the correct types so the official version works, I will change it. Otherwise I would prefer to leave this one until the neighbor table construction is anyway changed by what @havogt is working on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could not get it to work, it seems like the type hints are wrong, but there is no point investing time in this when the API will change heavily anyway.

@kotsaloscv kotsaloscv self-requested a review November 20, 2024 09:34
continue
if param.id not in sdfg.gt4py_program_input_fields:
continue
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

All these properties are needed for the automatic halo exchange placement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Results from discussion in a separate channel for "offset_providers_per_input_field":

  • it is correct, as only used in icon4py which uses "unstructured" fields. There is max one horizontal dimension per field.
  • trace_shifts should work™️ partially on GTIR (but not accross as_fieldop).
  • currently not actively used in icon4py, so defer implementation and put a TODO instead.

# Add them as dynamic properties to the SDFG
program = typing.cast(
itir.Program, gtir_stage.data
) # we already checked that our backend uses GTIR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already checked that our backend uses GTIR

where? otherwise we could make an assert isinstance().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We checked that self.backend is a DaCe backend. Probably should check that it is specifically the dace_fieldview backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you for the explanation. No, not needed to check that it is the fieldview backend because the iterator backend will be removed soon (next week).

Comment on lines +84 to +86
connectivities: Optional[common.OffsetProvider] = (
None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@havogt I had to change this to satisfy mypy. Now it is also consistent with CompileTimeArgs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

return {name for name in cls().visit(program) if name in field_param_names}


class InputNamesExtractor(SymbolNameSetExtractor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move these utilities to some module in iterator/transforms. I would like to use them in other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering about that. Will move the classes as well as the tests.

edopao added a commit that referenced this pull request Dec 3, 2024
The dace orchestration tests are temporarily skipped until #1742 is
merged.
The dace backend with SDFG optimization is temporarily disabled in unit
tests until #1639 is merged.
A second PR will reorganize the files in dace backend module.
@edopao
Copy link
Contributor

edopao commented Dec 4, 2024

If it can help, I got the tests to pass locally with this diff:

diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
index 803ae866..d65f4fac 100644
--- a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
@@ -176,7 +176,7 @@ class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible):
                     conn = self.connectivities[name]
                     assert common.is_neighbor_table(conn)
                     self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
-                        dtype=dace.int64 if conn.dtype == np.int64 else dace.int32,
+                        dtype=dace.dtypes.typeclass(conn.dtype.scalar_type),
                         shape=[
                             symbols[dace_utils.field_size_symbol_name(conn_id, 0)],
                             symbols[dace_utils.field_size_symbol_name(conn_id, 1)],
diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
index 9bde9aca..80cd9eac 100644
--- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
+++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
@@ -133,7 +133,7 @@ def test_sdfgConvertible_connectivities(unstructured_case):
         data=xp.asarray([[0, 1], [1, 2], [2, 0]]),
         allocator=allocator,
     )
-    connectivities = {"E2V": e2v.__gt_type__()}
+    connectivities = {"E2V": e2v}
     offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr())
 
     SDFG = sdfg.to_sdfg(connectivities=connectivities)

@DropD
Copy link
Contributor Author

DropD commented Dec 17, 2024

If it can help, I got the tests to pass locally with this diff:

diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
index 803ae866..d65f4fac 100644
--- a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
+++ b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py
@@ -176,7 +176,7 @@ class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible):
                     conn = self.connectivities[name]
                     assert common.is_neighbor_table(conn)
                     self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
-                        dtype=dace.int64 if conn.dtype == np.int64 else dace.int32,
+                        dtype=dace.dtypes.typeclass(conn.dtype.scalar_type),
                         shape=[
                             symbols[dace_utils.field_size_symbol_name(conn_id, 0)],
                             symbols[dace_utils.field_size_symbol_name(conn_id, 1)],
diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
index 9bde9aca..80cd9eac 100644
--- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
+++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
@@ -133,7 +133,7 @@ def test_sdfgConvertible_connectivities(unstructured_case):
         data=xp.asarray([[0, 1], [1, 2], [2, 0]]),
         allocator=allocator,
     )
-    connectivities = {"E2V": e2v.__gt_type__()}
+    connectivities = {"E2V": e2v}
     offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr())
 
     SDFG = sdfg.to_sdfg(connectivities=connectivities)

This is part of what I ended up doing but I wanted to honor the original purpose of the test, which was to compile the SDFG without knowledge of the connectivity tables and then call it with two different ones.

Copy link
Contributor

@edopao edopao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions for cleanup, otherwise it looks good to me.

Copy link
Contributor

@kotsaloscv kotsaloscv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although we need to additionally check it with ICON4Py.

@edopao edopao merged commit 9a56fbd into GridTools:main Jan 13, 2025
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants