From 03f4b1ad310c8594c0ccf263221886121decfd9b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 31 Jul 2024 09:15:59 +0200 Subject: [PATCH] Inside a Map there can not be a library node for fusion. --- .../dace_fieldview/transformations/map_fusion_helper.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index bd6328c396..0ab9259292 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -477,7 +477,7 @@ def partition_first_outputs( # Certain nodes need more than one element as input. As explained # above, in this situation we assume that we can naturally decompose # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range. + # Furthermore, it can not be a dynamic map range or a library node. intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) for consumer_node, feed_edge in consumers: @@ -488,6 +488,9 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range. return None + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. + return None # Note that "remove" has a special meaning here, regardless of the # output of the check function, from within the second map we remove @@ -520,6 +523,9 @@ def partition_first_outputs( return None if consumer_node is map_entry_2: # Dynamic map range return None + if isinstance(consumer_node, nodes.LibraryNode): + # TODO(phimuell): Allow some library nodes. + return None else: # Ensure that there is no path that leads to the second map. after_intermdiate_node = util.all_nodes_between(