-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next][dace]: lowering of scan to SDFG #1776
base: main
Are you sure you want to change the base?
Conversation
Before the function had a special mode in which it performed the renaming through the `symbol_mapping`. However, this made testing a bit harder and so I decided that there should be a flag to disable this.
There are some functioanlity missing, but it is looking good.
…trides. However, it is not yet fully tested, tehy are on their wa.
It also seems that it inferes with something.
Because a scalar has a shape of `(1,)` but a stride of `()`. Thus we have first to handle this case. However, now we are back at the index stuff, let's fix it.
However, it still seems to fail in some cases.
The type is now a bit better estimated.
The type are now extracted from the stuff we get from `free_symbols`.
I realized that allowing this is not very safe. I also added a test to show that.
…ces errors in certain cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is quite complex, so I think a refactoring is needed to make it simpler.
@@ -192,16 +220,39 @@ def _parse_fieldop_arg( | |||
state: dace.SDFGState, | |||
sdfg_builder: gtir_sdfg.SDFGBuilder, | |||
domain: FieldopDomain, | |||
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: | |||
by_value: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an explanation of what by_value
means would be helpful.
) -> ( | ||
gtir_dataflow.IteratorExpr | ||
| gtir_dataflow.MemletExpr | ||
| gtir_dataflow.ValueExpr | ||
| tuple[ | ||
gtir_dataflow.IteratorExpr | ||
| gtir_dataflow.MemletExpr | ||
| gtir_dataflow.ValueExpr | ||
| tuple[Any, ...], | ||
..., | ||
] | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a suggestion, but would it make sense to define an alias for the {Iter, Memelte, Value}Expr
, however, I am not sure if this is possible at all?
This type annotation looks like something straight from C++ metaprogramming hell (first level).
raise ValueError(f"Received {node} as argument to field operator, expected a field.") | ||
def get_arg_value( | ||
arg: FieldopData, | ||
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this function ever returns an IteratorExpr
?
This is more for my own selfish curiosity.
domain_indices = _get_domain_indices(domain_dims, domain_offset) | ||
domain_subset = dace_subsets.Range.from_indices(domain_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still highly advocating to change _get_domain_indices()
to return Range
objects.
domain_subset = dace_subsets.Range( | ||
domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
domain_subset = dace_subsets.Range( | |
domain_subset[:scan_dim_index] + domain_subset[scan_dim_index + 1 :] | |
) | |
domain_subset = dace_subsets.pop([scan_dim_index]) |
Be aware that DaCe will not create empty ranges, but it will have a length of one, but just with zeros (see DaCe code).
So you will have to adapt the check below, however, I would do it.
# SDFG data containers with name prefix '__tmp' are expected to be transients | ||
inner_data = ( | ||
arg_data.replace("__tmp", "__input") | ||
if arg_data.startswith("__tmp") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a very picky comment.
You explicitly check if the variable starts with __tmp
, but replace will replace every occurence of __tmp
.
try: | ||
inner_desc = nsdfg.data(inner_data) | ||
assert not inner_desc.transient | ||
except KeyError: | ||
inner_desc = arg_desc.clone() | ||
inner_desc.transient = False | ||
nsdfg.add_datadesc(inner_data, inner_desc) | ||
input_memlets[inner_data] = (arg_node, arg_subset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks strange, you should rewrite it with an if
.
|
||
if arg_subset: | ||
# symbols used in memlet subset are not automatically mapped to the parent SDFG | ||
nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols}) | |
nsdfg_symbol_mapping.update({str(sym): sym for sym in arg_subset.free_symbols}) |
As the signature of nsdfg_symbol_mapping
says that the key is str
.
Btw, if I am not mistaken, this variable was defined two levels above.
try: | ||
output_desc = nsdfg.data(output_data) | ||
assert not output_desc.transient | ||
except KeyError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using an if
is clearer.
def make_output_edge( | ||
output_expr: ValueExpr | MemletExpr | SymbolExpr, | ||
) -> DataflowOutputEdge: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you needing two inner functions here.
I do not really understand what has changed.
This PR contains two types of contributions: