Skip to content

Commit

Permalink
fix for missing symbols in nested sdfg
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 20, 2024
1 parent 45c69ec commit 0f9043b
Showing 1 changed file with 1 addition and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -664,10 +664,6 @@ def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr:
nsdfg.add_datadesc(inner_data, inner_desc)
input_memlets[inner_data] = (arg_node, arg_subset)

if arg_subset:
# symbols used in memlet subset are not automatically mapped to the parent SDFG
nsdfg_symbol_mapping.update({sym: sym for sym in arg_subset.free_symbols})

inner_node = state.add_access(inner_data)
if isinstance(arg, IteratorExpr):
return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices)
Expand Down Expand Up @@ -750,7 +746,7 @@ def construct_output(
self.sdfg,
inputs=set(input_memlets.keys()),
outputs=outputs,
symbol_mapping=nsdfg_symbol_mapping,
symbol_mapping=nsdfg_symbol_mapping | {str(sym): sym for sym in nsdfg.free_symbols},
)

for inner, (src_node, src_subset) in input_memlets.items():
Expand Down

0 comments on commit 0f9043b

Please sign in to comment.