From fd2366fd0e1a73d568f59570baa6791dac4b7838 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 14:01:18 +0200 Subject: [PATCH] Applied Edoardo's comments. --- .../transformations/auto_opt.py | 1 + .../transformations/k_blocking.py | 4 +- .../transformations/map_promoter.py | 10 ++-- .../transformations/map_serial_fusion.py | 56 +++++++++---------- .../dace_fieldview/transformations/util.py | 55 +++++++++--------- 5 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index f13117f50e..6c249340f9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -190,6 +190,7 @@ def gt_auto_optimize( Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - Create a custom array elimination pass that honors rule 1. + - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py index 2227c82729..165a3acafd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -33,7 +33,7 @@ class KBlocking(transformation.SingleStateTransformation): dimension, that is commonly called "k", but identified with `block_dim`. All dimensions except `k` are unaffected by this transformation. In the outer - Map the will replace the `k` range, currently `k = 0:N`, with + Map will be replace the `k` range, currently `k = 0:N`, with `__coarse_k = 0:N:B`, where `N` is the original size of the range and `B` is the block size, passed as `blocking_size`. The transformation also handles the case if `N % B != 0`. @@ -231,7 +231,7 @@ def apply( # of the node in one go. relocated_nodes.add(edge_dst) - # In order to be useful we have to temporary store the data the + # In order to be useful we have to temporarily store the data the # independent node generates assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift if isinstance(edge_dst, nodes.AccessNode): diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 9331b983ed..2f7aff7d9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -36,7 +36,7 @@ class BaseMapPromoter(transformation.SingleStateTransformation): The transformation operates on two Maps, first the "source map". This map describes the Map that should be used as template. The second one is "map to promote". After the transformation the "map to promote" will have the same - map parameter than the "source map" has. + map parameter as the "source map" has. In order to properly work, the parameters of "source map" must be a strict superset of the ones of "map to promote". Furthermore, this transformation @@ -52,6 +52,8 @@ class BaseMapPromoter(transformation.SingleStateTransformation): promote_vertical: If `True` promote vertical dimensions; `True` by default. promote_local: If `True` promote local dimensions; `True` by default. promote_horizontal: If `True` promote horizontal dimensions; `False` by default. + promote_all: Do not impose any restriction on what to promote. The only + reasonable value is `True` or `None`. Note: This ignores tiling. @@ -311,9 +313,9 @@ class SerialMapPromoter(BaseMapPromoter): def expressions(cls) -> Any: """Get the match expressions. - The function generates two different match expression. The first match - describes the case where the top map must be promoted, while the second - case is the second/lower map must be promoted. + The function generates two match expressions. The first match describes + the case where the top map must be promoted, while the second case is + the second/lower map must be promoted. """ return [ dace.sdfg.utils.node_path_graph( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py index 66be18e6c0..a17bcf4bd1 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py @@ -35,12 +35,12 @@ class SerialMapFusion(map_fusion_helper.MapFusionHelper): Things that are improved, compared to the native DaCe implementation: - Nested Maps. - - Temporary arrays and the correct propagation of their Memelts. + - Temporary arrays and the correct propagation of their Memlets. - Top Maps that have multiple outputs. Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewriting the connections - appropriate. + and the entry of the lower or second map and then rewrites the connections + appropriately. This transformation assumes that an SDFG obeys the structure that is outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that @@ -70,9 +70,9 @@ def expressions(cls) -> Any: The transformation matches the exit node of the top Map that is connected to an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on these nodes, - but more or less anything that has an outgoing connection of the first Map, - and is connected to the second map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. """ return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] @@ -86,7 +86,7 @@ def can_be_applied( """Tests if the matched Maps can be merged. The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constrains. + - The `can_be_fused()` of the base succeed, which checks some basic constraints. - The decomposition exists and at least one of the intermediate sets is not empty. """ @@ -102,7 +102,8 @@ def can_be_applied( return False # Two maps can be serially fused if the node decomposition exists and - # there at least one of the intermediate output sets is not empty. + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. output_partition = self.partition_first_outputs( state=graph, sdfg=sdfg, @@ -111,7 +112,8 @@ def can_be_applied( ) if output_partition is None: return False - if not (output_partition[1] or output_partition[2]): + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): return False return True @@ -216,12 +218,12 @@ def handle_intermediate_set( state: The state in which the map is processed. sdfg: The SDFG that should be optimized. map_exit_1: The exit of the first/top map. - map_entry_1: The entry of the second map. + map_entry_2: The entry of the second map. map_exit_2: The exit of the second map. is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. Notes: - Before the transformation the `state` does not be to be valid and + Before the transformation the `state` does not have to be valid and after this function has run the state is (most likely) invalid. Todo: @@ -303,7 +305,7 @@ def handle_intermediate_set( ) new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) - # New we will reroute the output Memlet, thus it will no longer going + # New we will reroute the output Memlet, thus it will no longer pass # through the Map exit but through the newly created intermediate. # we will delete the previous edge later. pre_exit_memlet: dace.Memlet = pre_exit_edge.data @@ -314,11 +316,11 @@ def handle_intermediate_set( assert pre_exit_memlet.data == inter_name new_pre_exit_memlet.data = new_inter_name - # Now we have to fix the subset of the Memlet. - # Before the subset of the Memlet dependent on the Map variables, + # Now we have to modify the subset of the Memlet. + # Before the subset of the Memlet was dependent on the Map variables, # however, this is no longer the case, as we removed them. This change # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the bellow is correct. + # NOTE: Assert above ensures that the below is correct. new_pre_exit_memlet.replace(memlet_repl) if is_scalar: new_pre_exit_memlet.subset = "0" @@ -350,14 +352,14 @@ def handle_intermediate_set( producer_edge.data.replace(memlet_repl) if is_scalar: producer_edge.data.dst_subset = "0" - else: - if producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) + elif producer_edge.data.dst_subset is not None: + producer_edge.data.dst_subset.pop(squeezed_dims) # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the temporary - # in the second map. We do this by finding the input connectors on the - # map entry, such that we know where we have to reroute inside the Map. + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. # NOTE: Assumes that map (if connected is the direct neighbour). conn_names: set[str] = set() for inter_node_out_edge in state.out_edges(inter_node): @@ -386,7 +388,7 @@ def handle_intermediate_set( # Memlet and the correctness of the code below. new_inner_memlet = copy.deepcopy(inner_edge.data) new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not chenge the direction. + new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. # Now remove the old edge, that started the second map entry. # Also add the new edge that started at the new intermediate. @@ -402,9 +404,8 @@ def handle_intermediate_set( # Now we do subset modification to ensure that nothing failed. if is_scalar: new_inner_memlet.src_subset = "0" - else: - if new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.pop(squeezed_dims) # Now clean the Memlets of that tree to use the new intermediate node. for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): @@ -413,9 +414,8 @@ def handle_intermediate_set( consumer_edge.data.data = new_inter_name if is_scalar: consumer_edge.data.src_subset = "0" - else: - if consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) + elif consumer_edge.data.subset is not None: + consumer_edge.data.subset.pop(squeezed_dims) # The edge that leaves the second map entry was already deleted. # We will now delete the edges that brought the data. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 9e4f09c722..897aaeecab 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -14,17 +14,19 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Iterable +from typing import Iterable, Union import dace -from dace.sdfg import graph as dace_graph, nodes +from dace.sdfg import graph as dace_graph, nodes as dace_nodes -def is_nested_sdfg(sdfg: dace.SDFG) -> bool: - """Tests if `sdfg` is a neseted sdfg.""" +def is_nested_sdfg( + sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], +) -> bool: + """Tests if `sdfg` is a NestedSDFG.""" if isinstance(sdfg, dace.SDFGState): sdfg = sdfg.parent - if isinstance(sdfg, dace.nodes.NestedSDFG): + if isinstance(sdfg, dace_nodes.NestedSDFG): return True elif isinstance(sdfg, dace.SDFG): if sdfg.parent_nsdfg_node is not None: @@ -36,17 +38,17 @@ def is_nested_sdfg(sdfg: dace.SDFG) -> bool: def all_nodes_between( graph: dace.SDFG | dace.SDFGState, - begin: nodes.Node, - end: nodes.Node, + begin: dace_nodes.Node, + end: dace_nodes.Node, reverse: bool = False, -) -> set[nodes.Node] | None: +) -> set[dace_nodes.Node] | None: """Find all nodes that are reachable from `begin` but bound by `end`. - Essentially the function starts a DFS at `begin`, which is never part of the - returned set, if at a node an edge is found that lead to `end`, the function - will ignore this edge. However, it will find every node that is reachable - from `begin` that is reachable by a path that does not visit `end`. - In case `end` is never found the function will return `None`. + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end`, this edge is ignored. It will thus found any node that is reachable + from `begin` by a path that does not involve `end`. The returned set will + never contain `end` nor `begin`. In case `end` is never found the function + will return `None`. If `reverse` is set to `True` the function will start exploring at `end` and follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. @@ -58,12 +60,11 @@ def all_nodes_between( reverse: Perform a backward DFS. Notes: - - The returned set will never contain the node `begin`. - The returned set will also contain the nodes of path that starts at `begin` and ends at a node that is not `end`. """ - def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: if reverse: return (edge.src for edge in graph.in_edges(node)) return (edge.dst for edge in graph.out_edges(node)) @@ -71,12 +72,12 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: if reverse: begin, end = end, begin - to_visit: list[nodes.Node] = [begin] - seen: set[nodes.Node] = set() + to_visit: list[dace_nodes.Node] = [begin] + seen: set[dace_nodes.Node] = set() found_end: bool = False while len(to_visit) > 0: - n: nodes.Node = to_visit.pop() + n: dace_nodes.Node = to_visit.pop() if n == end: found_end = True continue @@ -94,10 +95,10 @@ def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: def find_downstream_consumers( state: dace.SDFGState, - begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, reverse: bool = False, -) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Find all downstream connectors of `begin`. A consumer, in for this function, is any node that is neither an entry nor @@ -123,17 +124,17 @@ def find_downstream_consumers( else: to_visit = list(state.out_edges(begin)) seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() while len(to_visit) != 0: curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: nodes.Node = curr_edge.src if reverse else curr_edge.dst + next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst if curr_edge in seen: continue seen.add(curr_edge) - if isinstance(next_node, (nodes.MapEntry, nodes.MapExit)): + if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): if reverse: target_conn = curr_edge.src_conn[4:] new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) @@ -141,7 +142,7 @@ def find_downstream_consumers( # In forward mode a Map entry could also mean the definition of a # dynamic map range. if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( - next_node, nodes.MapEntry + next_node, dace_nodes.MapEntry ): # This edge defines a dynamic map range, which is a consumer if not only_tasklets: @@ -152,7 +153,7 @@ def find_downstream_consumers( to_visit.extend(new_edges) del new_edges else: - if only_tasklets and (not isinstance(next_node, nodes.Tasklet)): + if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): continue found.add((next_node, curr_edge)) @@ -161,9 +162,9 @@ def find_downstream_consumers( def find_upstream_producers( state: dace.SDFGState, - begin: nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], only_tasklets: bool = False, -) -> set[tuple[nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" return find_downstream_consumers( state=state,