From b20ec2c273ace38e1407bbf2751ef2be9848e534 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 20 Oct 2023 21:31:45 +0200 Subject: [PATCH 1/4] [dace] Extend support for offset providers --- .../runners/dace_iterator/itir_to_sdfg.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 41 +++++++++++++++---- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 580486aa4a..1f9692356e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -196,7 +196,6 @@ def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) - assert ItirToSDFG._check_shift_offsets_are_literals(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index b28703feef..bd03783ea2 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -698,9 +698,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: offset = tail[0].value assert isinstance(offset, str) - assert isinstance(tail[1], itir.OffsetLiteral) - element = tail[1].value - assert isinstance(element, int) + if isinstance(tail[1], itir.OffsetLiteral): + element = tail[1].value + assert isinstance(element, int) + element_var = unique_var_name() + self.context.body.add_scalar(element_var, dace.dtypes.int64, transient=True) + element_node = self.context.state.add_access(element_var) + tlet_node = self.context.state.add_tasklet( + "get_element", {}, {"__out"}, f"__out = {element}" + ) + self.context.state.add_edge( + tlet_node, "__out", element_node, None, dace.Memlet.simple(element_var, "0") + ) + else: + element_node = self.visit(tail[1])[0] if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): table = self.offset_provider[offset] @@ -712,20 +723,32 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: args = [ ValueExpr(conn, table.table.dtype), ValueExpr(iterator.indices[shifted_dim], dace.int64), + ValueExpr(element_node, dace.int64), ] internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" - else: + expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" + elif isinstance(self.offset_provider[offset], StridedNeighborOffsetProvider): offset_provider = self.offset_provider[offset] - assert isinstance(offset_provider, StridedNeighborOffsetProvider) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value offset_value = iterator.indices[shifted_dim] - args = [ValueExpr(offset_value, dace.int64)] - internals = [f"{offset_value.data}_v"] - expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" + args = [ + ValueExpr(offset_value, dace.int64), + ValueExpr(element_node, dace.int64), + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" + else: + assert isinstance(self.offset_provider[offset], Dimension) + shifted_dim = target_dim = self.offset_provider[offset].value + args = [ + ValueExpr(iterator.indices[shifted_dim], dace.int64), + element_node, + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" From 289cd759cbb305fcf91fa48586de4c843e33f61e Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 25 Oct 2023 09:22:45 +0200 Subject: [PATCH 2/4] [dace] Cleanup shift visitor --- .../runners/dace_iterator/itir_to_tasklet.py | 99 ++++++------------- 1 file changed, 30 insertions(+), 69 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index bd03783ea2..2f32a07ce6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -478,17 +478,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): if node.fun.fun.id == "shift": - offset = node.fun.args[0] - assert isinstance(offset, itir.OffsetLiteral) - offset_name = offset.value - assert isinstance(offset_name, str) - if offset_name not in self.offset_provider: - raise ValueError(f"offset provider for `{offset_name}` is missing") - offset_provider = self.offset_provider[offset_name] - if isinstance(offset_provider, Dimension): - return self._visit_direct_addressing(node) - else: - return self._visit_indirect_addressing(node) + return self._visit_shift(node) elif node.fun.fun.id == "reduce": return self._visit_reduce(node) @@ -653,39 +643,7 @@ def _make_shift_for_rest(self, rest, iterator): fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] ) - def _visit_direct_addressing(self, node: itir.FunCall) -> IteratorExpr: - assert isinstance(node.fun, itir.FunCall) - shift = node.fun - assert isinstance(shift, itir.FunCall) - - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - - assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) - shifted_dim = self.offset_provider[offset].value - - assert isinstance(tail[1], itir.OffsetLiteral) - shift_amount = tail[1].value - assert isinstance(shift_amount, int) - - args = [ValueExpr(iterator.indices[shifted_dim], dace.int64)] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {shift_amount}" - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "dir_addr" - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - shifted_index[shifted_dim] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -699,33 +657,21 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: assert isinstance(offset, str) if isinstance(tail[1], itir.OffsetLiteral): - element = tail[1].value - assert isinstance(element, int) - element_var = unique_var_name() - self.context.body.add_scalar(element_var, dace.dtypes.int64, transient=True) - element_node = self.context.state.add_access(element_var) - tlet_node = self.context.state.add_tasklet( - "get_element", {}, {"__out"}, f"__out = {element}" - ) - self.context.state.add_edge( - tlet_node, "__out", element_node, None, dace.Memlet.simple(element_var, "0") - ) + offset_node = self.visit_OffsetLiteral(tail[1]) else: - element_node = self.visit(tail[1])[0] + offset_node = self.visit(tail[1])[0] if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): - table = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value - - conn = self.context.state.add_access(connectivity_identifier(offset)) + offset_provider = self.offset_provider[offset] + connectivity = self.context.state.add_access(connectivity_identifier(offset)) + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(conn, table.table.dtype), + ValueExpr(connectivity, offset_provider.table.dtype), ValueExpr(iterator.indices[shifted_dim], dace.int64), - ValueExpr(element_node, dace.int64), + offset_node, ] - internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" elif isinstance(self.offset_provider[offset], StridedNeighborOffsetProvider): @@ -733,19 +679,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value - offset_value = iterator.indices[shifted_dim] args = [ - ValueExpr(offset_value, dace.int64), - ValueExpr(element_node, dace.int64), + ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" else: assert isinstance(self.offset_provider[offset], Dimension) - shifted_dim = target_dim = self.offset_provider[offset].value + + shifted_dim = self.offset_provider[offset].value + target_dim = shifted_dim args = [ ValueExpr(iterator.indices[shifted_dim], dace.int64), - element_node, + offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" @@ -760,6 +707,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) + def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> ValueExpr: + offset = node.value + assert isinstance(offset, int) + offset_var = unique_var_name() + self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + offset_node = self.context.state.add_access(offset_var) + tasklet_node = self.context.state.add_tasklet( + "get_offset", {}, {"__out"}, f"__out = {offset}" + ) + self.context.state.add_edge( + tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") + ) + return ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype) + def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() result_access = self.context.state.add_access(result_name) From 61c538ede8aa0ad8daa43c7cb77d1296c10aaf5d Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 26 Oct 2023 06:40:18 +0200 Subject: [PATCH 3/4] [dace] Minor edit --- .../program_processors/runners/dace_iterator/itir_to_tasklet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2f32a07ce6..0949c80318 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -698,7 +698,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" + list(zip(args, internals)), expr, dace.dtypes.int64, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} From 77a2010283f21ee3641ead65c639befeb10e21a1 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Mon, 30 Oct 2023 13:48:44 +0100 Subject: [PATCH 4/4] [dace] Review comments --- .../runners/dace_iterator/itir_to_tasklet.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 0949c80318..1634596afa 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -653,17 +653,13 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: iterator = self.visit(node.args[0]) assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) + offset_dim = tail[0].value + assert isinstance(offset_dim, str) + offset_node = self.visit(tail[1])[0] - if isinstance(tail[1], itir.OffsetLiteral): - offset_node = self.visit_OffsetLiteral(tail[1]) - else: - offset_node = self.visit(tail[1])[0] - - if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): - offset_provider = self.offset_provider[offset] - connectivity = self.context.state.add_access(connectivity_identifier(offset)) + if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): + offset_provider = self.offset_provider[offset_dim] + connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -674,8 +670,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - elif isinstance(self.offset_provider[offset], StridedNeighborOffsetProvider): - offset_provider = self.offset_provider[offset] + elif isinstance(self.offset_provider[offset_dim], StridedNeighborOffsetProvider): + offset_provider = self.offset_provider[offset_dim] shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -686,9 +682,9 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" else: - assert isinstance(self.offset_provider[offset], Dimension) + assert isinstance(self.offset_provider[offset_dim], Dimension) - shifted_dim = self.offset_provider[offset].value + shifted_dim = self.offset_provider[offset_dim].value target_dim = shifted_dim args = [ ValueExpr(iterator.indices[shifted_dim], dace.int64), @@ -707,7 +703,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> ValueExpr: + def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() @@ -719,7 +715,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> ValueExpr: self.context.state.add_edge( tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") ) - return ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype) + return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name()