-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: SDFGConvertible Program for dace_fieldview backend #1742
Conversation
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
if conn_id not in self.connectivity_tables_data_descriptors: | ||
conn = self.connectivities[name] | ||
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( | ||
dtype=dace.int64 if conn.index_type == np.int64 else dace.int32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that dace dtypes
has a utility to parse numpy types:
dace.dtypes.typeclass(conn.index_type)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying that lead to KeyError: dtype('int64')
in typeclass()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and trying dace.dtypes.typeclass(conn.index_type.type)
leads to wrong stencil results. Edit: this might have instead been caused by auto_optimize=False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
edit: after merging main, it passes. However according to type hints of NeighborTableOffsetProvider
, your version should work and mine should fail. Yet, in the tests, the connectivities are set up in a way so that my version works and yours fails. The current version on the other hand is safe in both cases.
If I can get the test to produce the correct types so the official version works, I will change it. Otherwise I would prefer to leave this one until the neighbor table construction is anyway changed by what @havogt is working on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could not get it to work, it seems like the type hints are wrong, but there is no point investing time in this when the API will change heavily anyway.
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Outdated
Show resolved
Hide resolved
continue | ||
if param.id not in sdfg.gt4py_program_input_fields: | ||
continue | ||
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
All these properties are needed for the automatic halo exchange placement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Results from discussion in a separate channel for "offset_providers_per_input_field":
- it is correct, as only used in
icon4py
which uses "unstructured" fields. There is max one horizontal dimension per field. - trace_shifts should work™️ partially on GTIR (but not accross
as_fieldop
). - currently not actively used in
icon4py
, so defer implementation and put a TODO instead.
src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
# Add them as dynamic properties to the SDFG | ||
program = typing.cast( | ||
itir.Program, gtir_stage.data | ||
) # we already checked that our backend uses GTIR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already checked that our backend uses GTIR
where? otherwise we could make an assert isinstance()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We checked that self.backend
is a DaCe backend. Probably should check that it is specifically the dace_fieldview
backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thank you for the explanation. No, not needed to check that it is the fieldview backend because the iterator backend will be removed soon (next week).
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Outdated
Show resolved
Hide resolved
connectivities: Optional[common.OffsetProvider] = ( | ||
None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@havogt I had to change this to satisfy mypy. Now it is also consistent with CompileTimeArgs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense
return {name for name in cls().visit(program) if name in field_param_names} | ||
|
||
|
||
class InputNamesExtractor(SymbolNameSetExtractor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move these utilities to some module in iterator/transforms
. I would like to use them in other places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering about that. Will move the classes as well as the tests.
If it can help, I got the tests to pass locally with this diff:
|
This is part of what I ended up doing but I wanted to honor the original purpose of the test, which was to compile the SDFG without knowledge of the connectivity tables and then call it with two different ones. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some suggestions for cleanup, otherwise it looks good to me.
tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
Outdated
Show resolved
Hide resolved
tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, although we need to additionally check it with ICON4Py.
Description
Add a
decrator.Program
subclass, which implementsSDFGConvertible
todace_fieldview
backend, analogous to the one indace_iterator
. Conditionally shadowdecorator.Program
with it and reactivate the orchestration tests by usingdace_fieldview
instead ofdace_iterator
.One caveat: The toolchain is not ready for pure
CompileTimeConnectivities
in all cases yet, so thetest_sdfgConvertible_connectivities
had to be adjusted for the moment.Requirements
If this PR contains code authored by new contributors please make sure:
AUTHORS.md
file adding the names of all the new contributors.