From 813a2f435cacf509d43be8e109498f7526d06d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 25 Oct 2024 17:25:06 +0200 Subject: [PATCH] Modified `SDFGState.unordered_arglist()` (#1708) This PR fixes the way how arguments are detected in scopes. Technically this only affects GPU code generation, but it is a side effect of how the code is generated. In GPU mode a `Map` is translated into one kernel, for this reason a signature must be computed (this is the reason why CPU code generation is not affected, no function call is produced). To compute this signature the `unsorted_arglist()` function scans what is needed. However, this was implemented not correctly. Assume that AccessNode for array `A` is outside the map and inside the map a temporary scalar, `tmp_in` is defined and initialized to `tmp_in = A[__i0, __i1]`, see also this image: ![argliost_situation](https://github.com/user-attachments/assets/fdf54dea-4ef5-49be-8ce2-33b78ce5962d) If the `data` property of the Memlet that connects the MapEntry with the AccessNode for `tmp_in` is referencing `A` then the (old) function would find that `A` is needed inside, although there is no AccessNode for `A` inside the map. If however, this Memlet referrers `tmp_in` (which is not super standard, but should be allowed), then the old version would not pick up. This would then lead to a code generation error. This PR modifies the function such that such cases are handled. This is done by following all edges that are adjacent to the MapEntry (from the inside) to where the are actually originate. --- dace/sdfg/state.py | 60 ++++++-- tests/codegen/argumet_signature_test.py | 197 ++++++++++++++++++++++++ 2 files changed, 247 insertions(+), 10 deletions(-) create mode 100644 tests/codegen/argumet_signature_test.py diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 09e7607d65..b982dfd718 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -849,6 +849,8 @@ def unordered_arglist(self, for node in self.nodes(): if isinstance(node, nd.AccessNode): descs[node.data] = node.desc(sdfg) + # NOTE: In case of multiple nodes of the same data this will + # override previously found nodes. descs_with_nodes[node.data] = node if isinstance(node.desc(sdfg), dt.Scalar): scalars_with_nodes.add(node.data) @@ -865,19 +867,57 @@ def unordered_arglist(self, else: data_args[node.data] = desc - # Add data arguments from memlets, if do not appear in any of the nodes - # (i.e., originate externally) + # Add data arguments from memlets, if do not appear in any of the nodes (i.e., originate externally) + # TODO: Investigate is scanning the adjacent edges of the input and output connectors is better. for edge in self.edges(): - if edge.data.data is not None and edge.data.data not in descs: - desc = sdfg.arrays[edge.data.data] - if isinstance(desc, dt.Scalar): - # Ignore code->code edges. - if (isinstance(edge.src, nd.CodeNode) and isinstance(edge.dst, nd.CodeNode)): - continue + if edge.data.is_empty(): + continue + + elif edge.data.data not in descs: + # The edge reads data from the outside, and the Memlet is directly indicating what is read. + if (isinstance(edge.src, nd.CodeNode) and isinstance(edge.dst, nd.CodeNode)): + continue # Ignore code->code edges. + additional_descs = {edge.data.data: sdfg.arrays[edge.data.data]} + + elif isinstance(edge.dst, (nd.AccessNode, nd.CodeNode)) and isinstance(edge.src, nd.EntryNode): + # Special case from the above; An AccessNode reads data from the Outside, but + # the Memlet references the data on the inside. Thus we have to follow the data + # to where it originates from. + # NOTE: We have to use a memlet path, because we have to go "against the flow" + # Furthermore, in a valid SDFG the data will only come from one source anyway. + top_source_edge = self.graph.memlet_path(edge)[0] + if not isinstance(top_source_edge.src, nd.AccessNode): + continue + additional_descs = ( + {top_source_edge.src.data: top_source_edge.src.desc(sdfg)} + if top_source_edge.src.data not in descs + else {} + ) + + elif isinstance(edge.dst, nd.ExitNode) and isinstance(edge.src, (nd.AccessNode, nd.CodeNode)): + # Same case as above, but for outgoing Memlets. + # NOTE: We have to use a memlet tree here, because the data could potentially + # go to multiple sources. We have to do it this way, because if we would call + # `memlet_tree()` here, then we would just get the edge back. + additional_descs = {} + connector_to_look = "OUT_" + edge.dst_conn[3:] + for oedge in self.graph.out_edges_by_connector(edge.dst, connector_to_look): + if ( + (not oedge.data.is_empty()) and (oedge.data.data not in descs) + and (oedge.data.data not in additional_descs) + ): + additional_descs[oedge.data.data] = sdfg.arrays[oedge.data.data] + + else: + # Case is ignored. + continue - scalar_args[edge.data.data] = desc + # Now processing the list of newly found data. + for aname, additional_desc in additional_descs.items(): + if isinstance(additional_desc, dt.Scalar): + scalar_args[aname] = additional_desc else: - data_args[edge.data.data] = desc + data_args[aname] = additional_desc # Loop over locally-used data descriptors for name, desc in descs.items(): diff --git a/tests/codegen/argumet_signature_test.py b/tests/codegen/argumet_signature_test.py new file mode 100644 index 0000000000..376724439f --- /dev/null +++ b/tests/codegen/argumet_signature_test.py @@ -0,0 +1,197 @@ +import dace +import copy + +def test_argument_signature_test(): + """Tests if the argument signature is computed correctly. + + The test is focused on if data dependencies are picked up if they are only + referenced indirectly. This effect is only directly visible for GPU. + The test also runs on GPU, but will only compile for GPU. + """ + + def make_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("Repr") + state = sdfg.add_state(is_start_block=True) + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + for name in "BC": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, 1), + transient=False, + ) + + # `A` uses a stride that is not used by any of the other arrays. + # However, the stride is used if we want to index array `A`. + second_stride_A = dace.symbol(sdfg.add_symbol("second_stride_A", dace.int32)) + sdfg.add_array( + name="A", + dtype=dace.float64, + shape=(N,), + strides=(second_stride_A,), + transient=False, + + ) + + # Also array `D` uses a stride that is not used by any other array. + second_stride_D = dace.symbol(sdfg.add_symbol("second_stride_D", dace.int32)) + sdfg.add_array( + name="D", + dtype=dace.float64, + shape=(N, N), + strides=(second_stride_D, 1), + transient=False, + + ) + + # Simplest way to generate a mapped Tasklet, we will later modify it. + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i1]"), + "__in1": dace.Memlet("B[__i0, __i1]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + external_edges=True, + ) + + # Instead of going from the MapEntry to the Tasklet we will go through + # an temporary AccessNode that is only used inside the map scope. + # Thus there is no direct reference to `A` inside the map scope, that would + # need `second_stride_A`. + sdfg.add_scalar("tmp_in", transient=True, dtype=dace.float64) + tmp_in = state.add_access("tmp_in") + for e in state.edges(): + if e.dst_conn == "__in0": + iedge = e + break + state.add_edge( + iedge.src, + iedge.src_conn, + tmp_in, + None, + # The important thing is that the Memlet, that connects the MapEntry with the + # AccessNode, does not refers to the memory outside (its source) but to the transient + # inside (its destination) + dace.Memlet(data="tmp_in", subset="0", other_subset="__i1"), # This does not work! + #dace.Memlet(data="A", subset="__i1", other_subset="0"), # This would work! + ) + state.add_edge( + tmp_in, + None, + iedge.dst, + iedge.dst_conn, + dace.Memlet(f"{tmp_in.data}[0]"), + ) + state.remove_edge(iedge) + + # Here we are doing something similar as for `A`, but this time for the output. + # The output of the Tasklet is stored inside a temporary scalar. + # From that scalar we then go to `C`, here the Memlet on the inside is still + # referring to `C`, thus it is referenced directly. + # We also add a second output that goes to `D` , but the inner Memlet does + # not refer to `D` but to the temporary. Thus there is no direct mention of + # `D` inside the map scope. + sdfg.add_scalar("tmp_out", transient=True, dtype=dace.float64) + tmp_out = state.add_access("tmp_out") + for e in state.edges(): + if e.src_conn == "__out": + oedge = e + assert oedge.data.data == "C" + break + + state.add_edge( + oedge.src, + oedge.src_conn, + tmp_out, + None, + dace.Memlet(data="tmp_out", subset="0"), + ) + state.add_edge( + tmp_out, + None, + oedge.dst, + oedge.dst_conn, + dace.Memlet(data="C", subset="__i0, __i1"), + ) + + # Now we create a new output that uses `tmp_out` but goes into `D`. + # The memlet on the inside will not use `D` but `tmp_out`. + state.add_edge( + tmp_out, + None, + oedge.dst, + "IN_D", + dace.Memlet(data=tmp_out.data, subset="0", other_subset="__i1, __i0"), + ) + state.add_edge( + oedge.dst, + "OUT_D", + state.add_access("D"), + None, + dace.Memlet(data="D", subset="__i0, __i1", other_subset="0"), + ) + oedge.dst.add_in_connector("IN_D", force=True) + oedge.dst.add_out_connector("OUT_D", force=True) + state.remove_edge(oedge) + + # Without this the test does not work properly + # It is related to [Issue#1703](https://github.com/spcl/dace/issues/1703) + sdfg.validate() + for edge in state.edges(): + edge.data.try_initialize(edge=edge, sdfg=sdfg, state=state) + + for array in sdfg.arrays.values(): + if isinstance(array, dace.data.Array): + array.storage = dace.StorageType.GPU_Global + else: + array.storage = dace.StorageType.Register + sdfg.apply_gpu_transformations(simplify=False) + sdfg.validate() + + return sdfg + + # Build the SDFG + sdfg = make_sdfg() + + map_entry = None + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.MapEntry): + map_entry = node + break + if map_entry is not None: + break + + # Now get the argument list of the map. + res_arglist = { k:v for k, v in state.scope_subgraph(map_entry).arglist().items()} + + ref_arglist = { + 'A': dace.data.Array, + 'B': dace.data.Array, + 'C': dace.data.Array, + 'D': dace.data.Array, + 'N': dace.data.Scalar, + 'second_stride_A': dace.data.Scalar, + 'second_stride_D': dace.data.Scalar, + } + + assert len(ref_arglist) == len(res_arglist), f"Expected {len(ref_arglist)} but got {len(res_arglist)}" + for aname in ref_arglist.keys(): + atype_ref = ref_arglist[aname] + atype_res = res_arglist[aname] + assert isinstance(atype_res, atype_ref), f"Expected '{aname}' to have type {atype_ref}, but it had {type(atype_res)}." + + # If we have cupy we will also compile it. + try: + import cupy as cp + except ImportError: + return + + csdfg = sdfg.compile() + +if __name__ == "__main__": + test_argument_signature_test()