Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: lowering of scan to SDFG #1776

Open
wants to merge 97 commits into
base: main
Choose a base branch
from

Conversation

edopao
Copy link
Contributor

@edopao edopao commented Dec 6, 2024

This PR contains two types of contributions:

  • The lowering of the scan builtin function.
  • An extension of the lowering from GTIR to SDFG, with support for some iterator patterns needed by the scan (tuple of iterators, exclusive if in local view).

@edopao edopao changed the title feat[next][dace]: add lowering of scan to SDFG feat[next][dace]: lowering of scan to SDFG Dec 6, 2024
philip-paul-mueller and others added 28 commits December 18, 2024 14:47
Before the function had a special mode in which it performed the renaming through the `symbol_mapping`.
However, this made testing a bit harder and so I decided that there should be a flag to disable this.
There are some functioanlity missing, but it is looking good.
…trides.

However, it is not yet fully tested, tehy are on their wa.
It also seems that it inferes with something.
Because a scalar has a shape of `(1,)` but a stride of `()`.
Thus we have first to handle this case.

However, now we are back at the index stuff, let's fix it.
However, it still seems to fail in some cases.
The type is now a bit better estimated.
The type are now extracted from the stuff we get from `free_symbols`.
I realized that allowing this is not very safe.
I also added a test to show that.
Copy link
Contributor

@philip-paul-mueller philip-paul-mueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is quite complex, so I think a refactoring is needed to make it simpler.

@@ -192,16 +220,39 @@ def _parse_fieldop_arg(
state: dace.SDFGState,
sdfg_builder: gtir_sdfg.SDFGBuilder,
domain: FieldopDomain,
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
by_value: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an explanation of what by_value means would be helpful.

Comment on lines +224 to +235
) -> (
gtir_dataflow.IteratorExpr
| gtir_dataflow.MemletExpr
| gtir_dataflow.ValueExpr
| tuple[
gtir_dataflow.IteratorExpr
| gtir_dataflow.MemletExpr
| gtir_dataflow.ValueExpr
| tuple[Any, ...],
...,
]
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a suggestion, but would it make sense to define an alias for the {Iter, Memelte, Value}Expr, however, I am not sure if this is possible at all?
This type annotation looks like something straight from C++ metaprogramming hell (first level).

raise ValueError(f"Received {node} as argument to field operator, expected a field.")
def get_arg_value(
arg: FieldopData,
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function ever returns an IteratorExpr?
This is more for my own selfish curiosity.

Comment on lines +316 to +317
domain_indices = _get_domain_indices(domain_dims, domain_offset)
domain_subset = dace_subsets.Range.from_indices(domain_indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still highly advocating to change _get_domain_indices() to return Range objects.

Comment on lines +323 to +325
domain_subset = dace_subsets.Range(
domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
domain_subset = dace_subsets.Range(
domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :]
)
domain_subset = dace_subsets.pop([scan_dim_index])

Be aware that DaCe will not create empty ranges, but it will have a length of one, but just with zeros (see DaCe code).
So you will have to adapt the check below, however, I would do it.

# SDFG data containers with name prefix '__tmp' are expected to be transients
inner_data = (
arg_data.replace("__tmp", "__input")
if arg_data.startswith("__tmp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a very picky comment.
You explicitly check if the variable starts with __tmp, but replace will replace every occurence of __tmp.

Comment on lines +658 to +665
try:
inner_desc = nsdfg.data(inner_data)
assert not inner_desc.transient
except KeyError:
inner_desc = arg_desc.clone()
inner_desc.transient = False
nsdfg.add_datadesc(inner_data, inner_desc)
input_memlets[inner_data] = (arg_node, arg_subset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks strange, you should rewrite it with an if.


if arg_subset:
# symbols used in memlet subset are not automatically mapped to the parent SDFG
nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols})
nsdfg_symbol_mapping.update({str(sym): sym for sym in arg_subset.free_symbols})

As the signature of nsdfg_symbol_mapping says that the key is str.
Btw, if I am not mistaken, this variable was defined two levels above.

Comment on lines +707 to +710
try:
output_desc = nsdfg.data(output_data)
assert not output_desc.transient
except KeyError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using an if is clearer.

Comment on lines +1553 to +1555
def make_output_edge(
output_expr: ValueExpr | MemletExpr | SymbolExpr,
) -> DataflowOutputEdge:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you needing two inner functions here.
I do not really understand what has changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants