From caebdff196bed04226ec6b00d55c044b9ae6d0e2 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 7 Nov 2023 08:17:27 +0100 Subject: [PATCH 01/32] Add more debug info to DaCe --- src/gt4py/next/ffront/foast_to_itir.py | 8 ++- src/gt4py/next/ffront/past_to_itir.py | 16 ++++-- src/gt4py/next/iterator/ir.py | 5 +- .../iterator/transforms/inline_fundefs.py | 1 + .../runners/dace_iterator/itir_to_sdfg.py | 18 +++++++ .../runners/dace_iterator/itir_to_tasklet.py | 49 ++++++++++++++----- 6 files changed, 78 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 1902d71b3c..1d0ed0fec0 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -78,6 +78,7 @@ def visit_FunctionDefinition( id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None), + location=node.location, ) # `expr` is a lifted stencil def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.FunctionDefinition: @@ -88,6 +89,7 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.Funct id=func_definition.id, params=func_definition.params, expr=new_body, + location=node.location, ) def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.FunctionDefinition: @@ -119,6 +121,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio id=node.id, params=definition.params[1:], expr=body, + location=node.location, ) def visit_Stmt(self, node: foast.Stmt, **kwargs): @@ -135,13 +138,14 @@ def visit_BlockStmt( for stmt in reversed(node.stmts): inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) assert inner_expr + inner_expr.location = node.location return inner_expr def visit_IfStmt( self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) + # from a scalar value (and not a field) assert ( isinstance(node.condition.type, ts.ScalarType) and node.condition.type.kind == ts.ScalarKind.BOOL @@ -208,7 +212,7 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: kind = "Iterator" dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) + return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) return im.sym(node.id) def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2c5dfc6e2f..857e63d3a7 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -125,6 +125,7 @@ def visit_Program( function_definitions=function_definitions, params=params, closures=closures, + location=node.location, ) def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: @@ -151,6 +152,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: stencil=itir.SymRef(id=node.func.id), inputs=[*lowered_args, *lowered_kwargs.values()], output=output, + location=node.location, ) def _visit_slice_bound( @@ -186,6 +188,7 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: return itir.FunCall( fun=itir.SymRef(id="make_tuple"), args=[self._construct_itir_out_arg(el) for el in node.elts], + location=node.location, ) else: raise ValueError( @@ -237,6 +240,7 @@ def _construct_itir_domain_arg( itir.FunCall( fun=itir.SymRef(id="named_range"), args=[itir.AxisLiteral(value=dim.value), lower, upper], + location=out_field.location, ) ) @@ -247,7 +251,7 @@ def _construct_itir_domain_arg( else: raise AssertionError() - return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args) + return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location) def _construct_itir_initialized_domain_arg( self, @@ -343,12 +347,12 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: f"Scalars of kind {node.type.kind} not supported currently." ) typename = node.type.kind.name.lower() - return itir.Literal(value=str(node.value), type=typename) + return itir.Literal(value=str(node.value), type=typename, location=node.location) raise NotImplementedError("Only scalar literals supported currently.") def visit_Name(self, node: past.Name, **kwargs) -> itir.SymRef: - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) def visit_Symbol(self, node: past.Symbol, **kwargs) -> itir.Sym: # TODO(tehrengruber): extend to more types @@ -356,13 +360,14 @@ def visit_Symbol(self, node: past.Symbol, **kwargs) -> itir.Sym: kind = "Iterator" dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) - return itir.Sym(id=node.id) + return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) + return itir.Sym(id=node.id, location=node.location) def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id=node.op.value), args=[self.visit(node.left, **kwargs), self.visit(node.right, **kwargs)], + location=node.location, ) def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: @@ -370,6 +375,7 @@ def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id=node.func.id), args=[self.visit(node.args[0]), self.visit(node.args[1])], + location=node.location, ) else: raise AssertionError( diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 535648cc47..4df8e8f981 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -17,12 +17,15 @@ import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @noninstantiable class Node(eve.Node): + location:Optional[SourceLocation] = None + def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat @@ -61,7 +64,7 @@ def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu @noninstantiable class Expr(Node): - ... + pass class Literal(Expr): diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 6bf2b60592..57e445a9bc 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -24,6 +24,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): return ir.Lambda( params=self.generic_visit(symbol.params, symtable=symtable), expr=self.generic_visit(symbol.expr, symtable=symtable), + location=node.location, ) return self.generic_visit(node) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 9e9cc4bf29..92b6e4070e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -134,6 +134,11 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) + program_sdfg.debuginfo = dace.dtypes.DebugInfo(start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename) last_state = program_sdfg.add_state("program_entry") self.node_types = itir_typing.infer_all(node) @@ -157,6 +162,11 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): closure_sdfg, input_names, output_names = self.visit( closure, array_table=program_sdfg.arrays ) + closure_sdfg.debuginfo = dace.dtypes.DebugInfo(start_line=closure.location.line, + start_column=closure.location.column, + end_line=closure.location.end_line, + end_column=closure.location.end_column, + filename=closure.location.filename) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) @@ -178,6 +188,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, + debuginfo=closure_sdfg.debuginfo, ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. @@ -201,6 +212,12 @@ def visit_StencilClosure( closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") + di = dace.dtypes.DebugInfo(start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename) + closure_sdfg.debuginfo = di program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) @@ -343,6 +360,7 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, + debuginfo=di, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 5d47cad909..238486e94c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -23,6 +23,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen +from gt4py.eve.concepts import SourceLocation from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider @@ -277,7 +278,7 @@ def builtin_if( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(expr_args, expr, type_, "if") + return transformer.add_expr_tasklet(expr_args, expr, type_, "if", location=node.location) def builtin_cast( @@ -291,7 +292,7 @@ def builtin_cast( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast") + return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast", location=node.location) def builtin_make_tuple( @@ -424,6 +425,7 @@ def visit_Lambda( results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lamda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + node.expr.location = node.location for expr in flatten_list(self.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() @@ -440,7 +442,7 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward", location=node.location)[0] self.context.body.arrays[result.value.data].transient = False results.append(result) @@ -492,7 +494,7 @@ def _visit_call(self, node: itir.FunCall): args = self.visit(node.args) args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] args = list(itertools.chain(*args)) - + node.fun.location = node.location func_context, func_inputs, results = self.visit(node.fun, args=args) nsdfg_inputs = {} @@ -519,12 +521,19 @@ def _visit_call(self, node: itir.FunCall): symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) + di = dace.dtypes.DebugInfo(start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename) + nsdfg_node = self.context.state.add_nested_sdfg( func_context.body, None, inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, + debuginfo=di, ) for name, value in func_inputs: @@ -625,7 +634,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref", location=node.location) def _split_shift_args( self, args: list[itir.Expr] @@ -691,7 +700,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "shift" + list(zip(args, internals)), expr, dace.dtypes.int64, "shift", location=node.location )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -822,7 +831,7 @@ def _visit_reduce(self, node: itir.FunCall): if not args[i]: args[i] = self.visit(node_arg)[0] - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location) lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) # clear context @@ -869,6 +878,9 @@ def _visit_reduce(self, node: itir.FunCall): def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] + for arg in node.args: + if hasattr(arg, 'location'): + arg.location = node.location args: list[SymbolExpr | ValueExpr] = list( itertools.chain(*[self.visit(arg) for arg in node.args]) ) @@ -882,7 +894,7 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return self.add_expr_tasklet(expr_args, expr, type_, "numeric") + return self.add_expr_tasklet(expr_args, expr, type_, "numeric", location=node.location) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -890,17 +902,26 @@ def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: return expr_func(self, node, node.args) def add_expr_tasklet( - self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str + self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str, location: SourceLocation = None ) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) result_access = self.context.state.add_access(result_name) + di = None + if location: + di = dace.dtypes.DebugInfo(start_line=location.line, + start_column=location.column, + end_line=location.end_line, + end_column=location.end_column, + filename=location.filename) + expr_tasklet = self.context.state.add_tasklet( name=name, inputs={internal for _, internal in args}, outputs={"__result"}, code=f"__result = {expr}", + debuginfo=di ) for arg, internal in args: @@ -946,7 +967,7 @@ def _visit_closure_callable( input_names: Sequence[str], ) -> Sequence[ValueExpr]: args = [itir.SymRef(id=name) for name in input_names] - fun_node = itir.FunCall(fun=node.stencil, args=args) + fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) return tlet_codegen.visit(fun_node) @@ -963,11 +984,17 @@ def closure_to_tasklet_sdfg( state = body.add_state("tasklet_toplevel_entry") symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} + di = dace.dtypes.DebugInfo(start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename) + idx_accesses = {} for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=dace.int64, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") + tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=di) access = state.add_access(name) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) From bebb122411eaacb2deea79c94af504040d9cc7ca Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Mon, 27 Nov 2023 15:53:51 +0100 Subject: [PATCH 02/32] Add more debug info to DaCe --- src/gt4py/next/ffront/foast_to_itir.py | 87 ++++++++++++++++++-------- 1 file changed, 62 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 8242e917af..f83b4bf67f 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -112,7 +112,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio func_definition.params[0].id, im.promote_to_const_iterator(func_definition.params[0].id), )(im.deref(new_body)) - definition = itir.Lambda(params=func_definition.params, expr=new_body) + definition = itir.Lambda(params=func_definition.params, expr=new_body, location=node.location) body = im.call(im.call("scan")(definition, forward, init))( *(param.id for param in definition.params[1:]) ) @@ -170,9 +170,11 @@ def visit_IfStmt( inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( + return_ = im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( inner_expr ) + return_.location = node.location + return return_ elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) @@ -186,9 +188,11 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - return im.let(inner_expr_name, inner_expr_evaluator)( + return_ = im.let(inner_expr_name, inner_expr_evaluator)( im.if_(im.deref(cond), true_branch, false_branch) ) + return_.location = node.location + return return_ assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN @@ -197,14 +201,18 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - return im.if_(im.deref(cond), true_branch, false_branch) + return_ = im.if_(im.deref(cond), true_branch, false_branch) + return_.location = node.location + return return_ def visit_Assign( self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( + return_ = im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( inner_expr ) + return_.location = node.location + return return_ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: # TODO(tehrengruber): extend to more types @@ -213,20 +221,26 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) - return im.sym(node.id) + return_ = im.sym(node.id) + return_.location = return_ + return return_ def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs) -> itir.Expr: - return im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( + return_ = im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( self.visit(node.value, **kwargs) ) + return_.location = node.location + return return_ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs) -> itir.Expr: - return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + return_ = im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( *[self.visit(el, **kwargs) for el in node.elts], ) + return_.location = node.location + return return_ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators @@ -234,29 +248,32 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"{node.op} is only supported on `bool`s.") - return self._map("not_", node.operand) + return self._map("not_", node.operand, location=node.location) return self._map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, + location=node.location, ) def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._map(node.op.value, node.left, node.right, location=node.location) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall: - return self._map("if_", node.condition, node.true_expr, node.false_expr) + return self._map("if_", node.condition, node.true_expr, node.false_expr, location=node.location) def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._map(node.op.value, node.left, node.right, location=node.location) def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: match node.args[0]: case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): shift_offset = im.shift(offset_name, offset_index) case foast.Name(id=offset_name): - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) + return_ = im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) + return_.location = node.location + return return_ case foast.Call(func=foast.Name(id="as_offset")): func_args = node.args[0] offset_dim = func_args.args[0] @@ -266,9 +283,11 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: ) case _: raise FieldOperatorLoweringError("Unexpected shift arguments!") - return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( + return_ = im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( self.visit(node.func, **kwargs) ) + return_.location = node.location + return return_ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: @@ -299,11 +318,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: ) call_args = [f"__arg{i}" for i in range(len(lowered_args))] call_kwargs = [f"__kwarg_{name}" for name in lowered_kwargs.keys()] - return im.lift( + return_ = im.lift( im.lambda_(*call_args, *call_kwargs)( im.call(lowered_func)(*call_args, *call_kwargs) ) )(*lowered_args, *lowered_kwargs.values()) + return_.location = node.location + return return_ elif isinstance(node.func.type, ts.FunctionType): # ITIR has no support for keyword arguments. Instead, we concatenate both positional # and keyword arguments and use the unique order as given in the function signature. @@ -313,7 +334,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: self.visit(node.kwargs, **kwargs), use_signature_ordering=True, ) - return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) + return_ = im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) + return_.location = node.location + return return_ raise AssertionError( f"Call to object of type {type(node.func.type).__name__} not understood." @@ -322,18 +345,20 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id - return self._process_elements( + return_ = self._process_elements( lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) + return_.location = node.location + return return_ def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map("if_", *node.args) + return self._map("if_", *node.args, location=node.location) def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: return self.visit(node.args[0], **kwargs) def _visit_math_built_in(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._map(self.visit(node.func, **kwargs), *node.args, location=node.location) def _make_reduction_expr( self, @@ -346,7 +371,9 @@ def _make_reduction_expr( it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) + return_ = im.promote_to_lifted_stencil(val)(it) + return_.location = node.location + return return_ def _visit_neighbor_sum(self, node: foast.Call, **kwargs) -> itir.FunCall: dtype = type_info.extract_dtype(node.type) @@ -370,10 +397,16 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: target_type = fbuiltins.BUILTINS[node_kind] source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] if target_type is bool and source_type is not bool: - return im.promote_to_const_iterator( + return_ = im.promote_to_const_iterator( im.literal(str(bool(source_type(node.args[0].value))), "bool") ) - return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) + return_.location = node.location + return return_ + return_ = im.promote_to_const_iterator( + im.literal(str(bool(source_type(node.args[0].value))), "bool") + ) + return_.location = node.location + return return_ raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: @@ -394,15 +427,19 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: raise ValueError(f"Unsupported literal type {type_}.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: - return self._make_literal(node.value, node.type) + return_ = self._make_literal(node.value, node.type) + return_.location = node.location + return return_ - def _map(self, op, *args, **kwargs): + def _map(self, op, *args, location=None, **kwargs): lowered_args = [self.visit(arg, **kwargs) for arg in args] if any(type_info.contains_local_field(arg.type) for arg in args): lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] op = im.call("map_")(op) - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + return_ = im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + return_.location = location + return return_ def _process_elements( self, From 7eeaddb5fa37d7a0c55a759100d2ff978f820198 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Mon, 27 Nov 2023 17:10:38 +0100 Subject: [PATCH 03/32] Add more debug info to DaCe : WIP --- src/gt4py/next/ffront/past_to_itir.py | 2 +- src/gt4py/next/iterator/ir.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 857e63d3a7..24b9e5c0a7 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -181,7 +181,7 @@ def _visit_slice_bound( def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): return self._construct_itir_out_arg(node.value) elif isinstance(node, past.TupleExpr): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 4df8e8f981..4e92cca57f 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -64,7 +64,7 @@ def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu @noninstantiable class Expr(Node): - pass + ... class Literal(Expr): From 744780758b4d65808ec8fc3c2ab978617736d9d2 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 29 Nov 2023 13:20:42 +0100 Subject: [PATCH 04/32] Add more debug info to DaCe : WIP --- .../next/iterator/transforms/collapse_list_get.py | 2 ++ .../next/iterator/transforms/collapse_tuple.py | 9 ++------- .../next/iterator/transforms/constant_folding.py | 1 + src/gt4py/next/iterator/transforms/cse.py | 1 + .../next/iterator/transforms/eta_reduction.py | 1 + src/gt4py/next/iterator/transforms/fuse_maps.py | 4 ++++ src/gt4py/next/iterator/transforms/global_tmps.py | 7 +++++++ .../next/iterator/transforms/inline_into_scan.py | 2 +- .../next/iterator/transforms/inline_lambdas.py | 2 ++ .../next/iterator/transforms/inline_lifts.py | 9 ++++++--- src/gt4py/next/iterator/transforms/merge_let.py | 1 + .../next/iterator/transforms/normalize_shifts.py | 1 + .../next/iterator/transforms/propagate_deref.py | 1 + .../iterator/transforms/prune_closure_inputs.py | 1 + .../next/iterator/transforms/remap_symbols.py | 5 +++-- .../iterator/transforms/scan_eta_reduction.py | 2 +- .../next/iterator/transforms/unroll_reduce.py | 3 ++- .../runners/dace_iterator/itir_to_sdfg.py | 8 +++++--- .../runners/dace_iterator/itir_to_tasklet.py | 15 ++++++++------- .../runners/dace_iterator/utility.py | 4 ++-- 20 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 08cbd7313e..4d35568b4d 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -49,8 +49,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: args=[it], ) ], + location=node.location, ) if node.args[1].fun == ir.SymRef(id="make_const_list"): + node.args[1].args[0].location = node.location return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7d710fc919..393f781276 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -85,13 +85,6 @@ def apply( node_types, ).visit(node) - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.collapse_make_tuple_tuple_get @@ -115,6 +108,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len( node.args ): + first_expr.location = node.location return first_expr if ( self.collapse_tuple_get_make_tuple @@ -130,5 +124,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: assert idx < len( make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" + node.args[1].args[idx].location = node.location return node.args[1].args[idx] return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index cda422f30d..94623f6e3a 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -46,4 +46,5 @@ def visit_FunCall(self, node: ir.FunCall): arg_values = [getattr(embedded, str(arg.type))(arg.value) for arg in new_node.args] # type: ignore[attr-defined] # arg type already established in if condition new_node = im.literal_from_value(fun(*arg_values)) + new_node.location = node.location return new_node diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 672e23c5e7..4d676a022b 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -379,6 +379,7 @@ def visit_FunCall(self, node: ir.FunCall): result = ir.FunCall( fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr), args=list(extracted.values()), + location=node.location, ) # if the node id is ignored (because its parent is eliminated), but it occurs diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 55b2141499..99043490ba 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -28,6 +28,7 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: for p, a in zip(node.params, node.expr.args) ) ): + node.expr.fun.location = node.location return self.visit(node.expr.fun) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index e9fbb0f81d..ea51adeca3 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: return ir.Lambda( params=params, expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]), + location=fun.location, ) def visit_FunCall(self, node: ir.FunCall, **kwargs): @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): ir.FunCall( fun=inner_op, args=[ir.SymRef(id=param.id) for param in inner_op.params], + location=node.location, ) ) ) @@ -123,10 +125,12 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args, + location=node.location, ) else: # _is_reduce(node) return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]), args=new_args, + location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index e1b697e0bc..059e8bd542 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -265,6 +265,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + location=closure.location, ) ) @@ -292,6 +293,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + location=closure.location, ) ) else: @@ -308,6 +310,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], + location=node.location, ) @@ -334,6 +337,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], + location=node.location, ) @@ -451,6 +455,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An stencil=closure.stencil, output=closure.output, inputs=closure.inputs, + location=closure.location, ) else: domain = closure.domain @@ -506,6 +511,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An ), params=node.params, tmps=node.tmps, + location=node.location, ) @@ -556,6 +562,7 @@ def convert_type(dtype): tmps=[ Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps ], + location=node.location, ) diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index fe1eae6e07..198cd02665 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -100,6 +100,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) + result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args], location=node.location) return result return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index fc268f85e3..3afaf04294 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,6 +97,7 @@ def new_name(name): new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) if all(eligible_params): + new_expr.location = node.location return new_expr else: return ir.FunCall( @@ -109,6 +110,7 @@ def new_name(name): expr=new_expr, ), args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], + location=node.location, ) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 8d62450e67..dfda6ea79d 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -154,6 +154,7 @@ def visit_FunCall( ir.FunCall( fun=self.generic_visit(node.fun, is_scan_pass_context=_is_scan(node), **kwargs), args=self.generic_visit(node.args, **kwargs), + location=node.location, ) if recurse else node @@ -167,7 +168,7 @@ def visit_FunCall( self.visit(ir.FunCall(fun=shift, args=[arg]), recurse=False, **kwargs) for arg in lift_call.args # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall ] - result = ir.FunCall(fun=lift_call.fun, args=new_args) # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall + result = ir.FunCall(fun=lift_call.fun, args=new_args, location=node.location) # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall return self.visit(result, recurse=False, **kwargs) elif self.flags & self.Flag.INLINE_DEREF_LIFT and node.fun == ir.SymRef(id="deref"): assert len(node.args) == 1 @@ -184,7 +185,7 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 f = node.args[0].fun.args[0] args = node.args[0].args - new_node = ir.FunCall(fun=f, args=args) + new_node = ir.FunCall(fun=f, args=args, location=node.location) if isinstance(f, ir.Lambda): new_node = inline_lambda(new_node, opcount_preserving=True) return self.visit(new_node, **kwargs) @@ -199,13 +200,14 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 args = node.args[0].args if len(args) == 0: - return ir.Literal(value="True", type="bool") + return ir.Literal(value="True", type="bool", location=node.location) res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) for arg in args[1:]: res = ir.FunCall( fun=ir.SymRef(id="and_"), args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], + location=node.location, ) return res elif ( @@ -253,6 +255,7 @@ def visit_FunCall( ) new_stencil = im.lambda_(*new_arg_exprs.keys())(inlined_call) + new_stencil.location = node.location return im.lift(new_stencil)(*new_arg_exprs.values()) return node diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 7426617ac8..b669b8d609 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -64,5 +64,6 @@ def visit_FunCall(self, node: itir.FunCall): params=outer_lambda.params + inner_lambda.params, expr=inner_lambda.expr ), args=outer_lambda_args + inner_lambda_args, + location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index efc9064612..6c63fe9c33 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -36,5 +36,6 @@ def visit_FunCall(self, node: ir.FunCall): fun=ir.SymRef(id="shift"), args=node.args[0].fun.args + node.fun.args ), args=node.args[0].args, + location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 54bdafcda8..9384a692e8 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -55,5 +55,6 @@ def visit_FunCall(self, node: ir.FunCall): expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]), ), args=lambda_args, + location=node.location, ) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 7fd3c50c6e..3f39c44183 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -37,6 +37,7 @@ def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: stencil=ir.Lambda(params=params, expr=expr), output=node.output, inputs=inputs, + location=node.location, ) def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index cdf3d76173..6b1eb3af41 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -30,6 +30,7 @@ def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): return ir.Lambda( params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map), + location=node.location, ) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] @@ -46,14 +47,14 @@ def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.Sym(id=name_map.get(node.id, node.id)) + return ir.Sym(id=name_map.get(node.id, node.id), location=node.location,) return node def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + return ir.SymRef(id=name_map.get(node.id, node.id), location=node.location,) return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 3266c25c4b..96fe0e0210 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -56,7 +56,7 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) result = ir.FunCall( - fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]] + fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]], location=node.location, ) return result diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index e3084eaba5..16cc2a2ffd 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -159,7 +159,8 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: for i in range(max_neighbors): expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun] + fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun], + location=node.location, ) return expr diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index f5f3344877..c84448a26d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -202,12 +202,14 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name) + for i, (inner_name, memlet) in enumerate(input_mapping.items()): + anode_loc = closure.inputs[i].location + access_node = last_state.add_access(inner_name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) + anode_loc = closure.output.location for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) program_sdfg.validate() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2d3018bd2a..006268d298 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -285,7 +285,7 @@ def builtin_can_deref( # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", location=node.location ) @@ -935,10 +935,10 @@ def add_expr_tasklet( di = None if location: di = dace.dtypes.DebugInfo(start_line=location.line, - start_column=location.column, - end_line=location.end_line, - end_column=location.end_column, - filename=location.filename) + start_column=location.column, + end_line=location.end_line, + end_column=location.end_column, + filename=location.filename) expr_tasklet = self.context.state.add_tasklet( name=name, @@ -1022,7 +1022,7 @@ def closure_to_tasklet_sdfg( access = state.add_access(name) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) - for name, ty in inputs: + for i, (name, ty) in enumerate(inputs): if isinstance(ty, ts.FieldType): ndim = len(ty.dims) shape = [ @@ -1034,7 +1034,8 @@ def closure_to_tasklet_sdfg( dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=stride, dtype=dtype) - field = state.add_access(name) + anode_loc = node.inputs[i].location + field = state.add_access(name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index c17a39ef2d..40be17c170 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -119,11 +119,11 @@ def add_mapped_nested_sdfg( if input_nodes is None: input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) for name, memlet in inputs.items() } if output_nodes is None: output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) for name, memlet in outputs.items() } if not inputs: state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) From 6d89149dd49839d51fcc3ff6af0312026c609cfb Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 29 Nov 2023 14:46:51 +0100 Subject: [PATCH 05/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- src/gt4py/next/ffront/foast_to_itir.py | 16 ++-- src/gt4py/next/ffront/past_to_itir.py | 6 +- src/gt4py/next/iterator/ir.py | 2 +- .../iterator/transforms/inline_into_scan.py | 6 +- .../next/iterator/transforms/remap_symbols.py | 10 ++- .../iterator/transforms/scan_eta_reduction.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 3 +- .../runners/dace_iterator/itir_to_sdfg.py | 64 ++++++++++----- .../runners/dace_iterator/itir_to_tasklet.py | 78 +++++++++++++------ .../runners/dace_iterator/utility.py | 6 +- 10 files changed, 138 insertions(+), 57 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index f83b4bf67f..a435d3dcb6 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -112,7 +112,9 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio func_definition.params[0].id, im.promote_to_const_iterator(func_definition.params[0].id), )(im.deref(new_body)) - definition = itir.Lambda(params=func_definition.params, expr=new_body, location=node.location) + definition = itir.Lambda( + params=func_definition.params, expr=new_body, location=node.location + ) body = im.call(im.call("scan")(definition, forward, init))( *(param.id for param in definition.params[1:]) ) @@ -222,7 +224,7 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: is_list = type_info.is_local_field(node.type) return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) return_ = im.sym(node.id) - return_.location = return_ + return_.location = node.location return return_ def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: @@ -261,7 +263,9 @@ def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall: return self._map(node.op.value, node.left, node.right, location=node.location) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall: - return self._map("if_", node.condition, node.true_expr, node.false_expr, location=node.location) + return self._map( + "if_", node.condition, node.true_expr, node.false_expr, location=node.location + ) def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall: return self._map(node.op.value, node.left, node.right, location=node.location) @@ -334,7 +338,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: self.visit(node.kwargs, **kwargs), use_signature_ordering=True, ) - return_ = im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) + return_ = im.call(self.visit(node.func, **kwargs))( + *lowered_args, *lowered_kwargs.values() + ) return_.location = node.location return return_ @@ -427,7 +433,7 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: raise ValueError(f"Unsupported literal type {type_}.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: - return_ = self._make_literal(node.value, node.type) + return_ = self._make_literal(node.value, node.type) return_.location = node.location return return_ diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 24b9e5c0a7..6cdfcec6ec 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -240,7 +240,7 @@ def _construct_itir_domain_arg( itir.FunCall( fun=itir.SymRef(id="named_range"), args=[itir.AxisLiteral(value=dim.value), lower, upper], - location=out_field.location, + location=out_field.location, ) ) @@ -251,7 +251,9 @@ def _construct_itir_domain_arg( else: raise AssertionError() - return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location) + return itir.FunCall( + fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location + ) def _construct_itir_initialized_domain_arg( self, diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 4e92cca57f..3411dfe6be 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -24,7 +24,7 @@ @noninstantiable class Node(eve.Node): - location:Optional[SourceLocation] = None + location: Optional[SourceLocation] = None def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index 198cd02665..b0f8d98bd6 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -100,6 +100,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args], location=node.location) + result = ir.FunCall( + fun=new_scan, + args=[ir.SymRef(id=ref) for ref in refs_in_args], + location=node.location, + ) return result return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 6b1eb3af41..84a57ee2e2 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -47,14 +47,20 @@ def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.Sym(id=name_map.get(node.id, node.id), location=node.location,) + return ir.Sym( + id=name_map.get(node.id, node.id), + location=node.location, + ) return node def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id), location=node.location,) + return ir.SymRef( + id=name_map.get(node.id, node.id), + location=node.location, + ) return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 96fe0e0210..466b2817be 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -56,7 +56,9 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) result = ir.FunCall( - fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]], location=node.location, + fun=ir.SymRef(id="scan"), + args=[new_scanpass, *node.expr.fun.args[1:]], + location=node.location, ) return result diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 16cc2a2ffd..143727708c 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -159,7 +159,8 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: for i in range(max_neighbors): expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun], + fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), + args=[step_fun], location=node.location, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index c84448a26d..0892246491 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -144,11 +144,14 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace.dtypes.DebugInfo(start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename) + if node.location: + program_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename, + ) last_state = program_sdfg.add_state("program_entry") self.node_types = itir_typing.infer_all(node) @@ -172,11 +175,14 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): closure_sdfg, input_names, output_names = self.visit( closure, array_table=program_sdfg.arrays ) - closure_sdfg.debuginfo = dace.dtypes.DebugInfo(start_line=closure.location.line, - start_column=closure.location.column, - end_line=closure.location.end_line, - end_column=closure.location.end_column, - filename=closure.location.filename) + if closure.location: + closure_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=closure.location.line, + start_column=closure.location.column, + end_line=closure.location.end_line, + end_column=closure.location.end_column, + filename=closure.location.filename, + ) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) @@ -204,12 +210,30 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. for i, (inner_name, memlet) in enumerate(input_mapping.items()): anode_loc = closure.inputs[i].location - access_node = last_state.add_access(inner_name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) + di = None + if anode_loc: + di = dace.dtypes.DebugInfo( + anode_loc.line, + anode_loc.column, + anode_loc.end_line, + anode_loc.end_column, + anode_loc.filename, + ) + access_node = last_state.add_access(inner_name, di) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) anode_loc = closure.output.location + di = None + if anode_loc: + di = dace.dtypes.DebugInfo( + anode_loc.line, + anode_loc.column, + anode_loc.end_line, + anode_loc.end_column, + anode_loc.filename, + ) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) + access_node = last_state.add_access(inner_name, di) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) program_sdfg.validate() @@ -224,12 +248,14 @@ def visit_StencilClosure( closure_sdfg = dace.SDFG(name="closure") closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - di = dace.dtypes.DebugInfo(start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename) - closure_sdfg.debuginfo = di + if node.location: + closure_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename, + ) program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) @@ -374,7 +400,7 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, - debuginfo=di, + debuginfo=closure_sdfg.debuginfo, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 006268d298..eaa3ff7c27 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -315,7 +315,9 @@ def builtin_cast( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast", location=node.location) + return transformer.add_expr_tasklet( + list(zip(args, internals)), expr, type_, "cast", location=node.location + ) def builtin_make_tuple( @@ -466,7 +468,9 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward", location=node.location)[0] + result = self.add_expr_tasklet( + [], expr.value, expr.dtype, "forward", location=node.location + )[0] self.context.body.arrays[result.value.data].transient = False results.append(result) @@ -545,11 +549,15 @@ def _visit_call(self, node: itir.FunCall): symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - di = dace.dtypes.DebugInfo(start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename) + di = None + if node.location: + di = dace.dtypes.DebugInfo( + start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename, + ) nsdfg_node = self.context.state.add_nested_sdfg( func_context.body, @@ -658,7 +666,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref", location=node.location) + return self.add_expr_tasklet( + list(zip(args, internals)), expr, iterator.dtype, "deref", location=node.location + ) def _split_shift_args( self, args: list[itir.Expr] @@ -855,7 +865,9 @@ def _visit_reduce(self, node: itir.FunCall): if not args[i]: args[i] = self.visit(node_arg)[0] - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location) + lambda_node = itir.Lambda( + expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location + ) lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) # clear context @@ -903,7 +915,7 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] for arg in node.args: - if hasattr(arg, 'location'): + if hasattr(arg, "location"): arg.location = node.location args: list[SymbolExpr | ValueExpr] = list( itertools.chain(*[self.visit(arg) for arg in node.args]) @@ -926,7 +938,12 @@ def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: return expr_func(self, node, node.args) def add_expr_tasklet( - self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str, location: SourceLocation = None + self, + args: list[tuple[ValueExpr, str]], + expr: str, + result_type: Any, + name: str, + location: Optional[SourceLocation] = None, ) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) @@ -934,18 +951,20 @@ def add_expr_tasklet( di = None if location: - di = dace.dtypes.DebugInfo(start_line=location.line, - start_column=location.column, - end_line=location.end_line, - end_column=location.end_column, - filename=location.filename) + di = dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column, + end_line=location.end_line, + end_column=location.end_column, + filename=location.filename, + ) expr_tasklet = self.context.state.add_tasklet( name=name, inputs={internal for _, internal in args}, outputs={"__result"}, code=f"__result = {expr}", - debuginfo=di + debuginfo=di, ) for arg, internal in args: @@ -1008,11 +1027,15 @@ def closure_to_tasklet_sdfg( state = body.add_state("tasklet_toplevel_entry") symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - di = dace.dtypes.DebugInfo(start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename) + di = None + if node.location: + di = dace.dtypes.DebugInfo( + start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename, + ) idx_accesses = {} for dim, idx in domain.items(): @@ -1035,7 +1058,16 @@ def closure_to_tasklet_sdfg( dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=stride, dtype=dtype) anode_loc = node.inputs[i].location - field = state.add_access(name, dace.dtypes.DebugInfo(anode_loc.line, anode_loc.column, anode_loc.end_line, anode_loc.end_column, anode_loc.filename)) + di = None + if anode_loc: + di = dace.dtypes.DebugInfo( + anode_loc.line, + anode_loc.column, + anode_loc.end_line, + anode_loc.end_column, + anode_loc.filename, + ) + field = state.add_access(name, di) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 40be17c170..c5d4c8538d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -119,11 +119,13 @@ def add_mapped_nested_sdfg( if input_nodes is None: input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) for name, memlet in inputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in inputs.items() } if output_nodes is None: output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) for name, memlet in outputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in outputs.items() } if not inputs: state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) From 16bc4891582fe83801f8b72b24cfd6a2a134f9e1 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 30 Nov 2023 11:17:32 +0100 Subject: [PATCH 06/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- src/gt4py/next/ffront/foast_to_itir.py | 8 ++++++-- src/gt4py/next/ffront/past_to_itir.py | 11 +++++++++-- src/gt4py/next/iterator/transforms/unroll_reduce.py | 13 +++++++++---- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index a435d3dcb6..6c6d1a6af8 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -132,7 +132,9 @@ def visit_Stmt(self, node: foast.Stmt, **kwargs): def visit_Return( self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - return self.visit(node.value, **kwargs) + return_ = self.visit(node.value, **kwargs) + return_.location = node.location + return return_ def visit_BlockStmt( self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs @@ -361,7 +363,9 @@ def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args, location=node.location) def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + return_ = self.visit(node.args[0], **kwargs) + return_.location = node.location + return return_ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args, location=node.location) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6cdfcec6ec..81dbfda743 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -177,13 +177,17 @@ def _visit_slice_bound( lowered_bound = self.visit(slice_bound, **kwargs) else: raise AssertionError("Expected `None` or `past.Constant`.") + if slice_bound: + lowered_bound.location = slice_bound.location return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): - return self._construct_itir_out_arg(node.value) + return_ = self._construct_itir_out_arg(node.value) + return_.location = node.location + return return_ elif isinstance(node, past.TupleExpr): return itir.FunCall( fun=itir.SymRef(id="make_tuple"), @@ -269,7 +273,10 @@ def _construct_itir_initialized_domain_arg( f"Expected {dim}, but got {keys_dims_types} " ) - return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] + return_ = [self.visit(bound) for bound in node_domain.values_[dim_i].elts] + for i, bound in enumerate(node_domain.values_[dim_i].elts): + return_[i].location = bound.location + return return_ @staticmethod def _compute_field_slice(node: past.Subscript): diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 143727708c..463b26165f 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -100,27 +100,32 @@ def _get_connectivity( def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), + args=[iterator], + location=iterator.location, ) def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator]) + return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="can_deref"), args=[iterator]) + return itir.FunCall( + fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location + ) def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], + location=cond.location, ) def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr]) + return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) @dataclasses.dataclass(frozen=True) From 0ed809009a980b9391aa061ee9f8708b2fb3d4a9 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Mon, 4 Dec 2023 09:43:47 +0100 Subject: [PATCH 07/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG) --- .../runners/dace_iterator/itir_to_sdfg.py | 108 +++++------ .../runners/dace_iterator/itir_to_tasklet.py | 179 ++++++++++-------- .../runners/dace_iterator/utility.py | 19 +- 3 files changed, 168 insertions(+), 138 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 0892246491..8b4946c6a0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -37,6 +37,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, get_sorted_dims, @@ -144,14 +145,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) - if node.location: - program_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename, - ) + program_sdfg.debuginfo = dace_debuginfo(node) last_state = program_sdfg.add_state("program_entry") self.node_types = itir_typing.infer_all(node) @@ -175,14 +169,6 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): closure_sdfg, input_names, output_names = self.visit( closure, array_table=program_sdfg.arrays ) - if closure.location: - closure_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=closure.location.line, - start_column=closure.location.column, - end_line=closure.location.end_line, - end_column=closure.location.end_column, - filename=closure.location.filename, - ) # Create a new state for the closure. last_state = program_sdfg.add_state_after(last_state) @@ -209,31 +195,15 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. for i, (inner_name, memlet) in enumerate(input_mapping.items()): - anode_loc = closure.inputs[i].location - di = None - if anode_loc: - di = dace.dtypes.DebugInfo( - anode_loc.line, - anode_loc.column, - anode_loc.end_line, - anode_loc.end_column, - anode_loc.filename, - ) - access_node = last_state.add_access(inner_name, di) + access_node = last_state.add_access( + inner_name, debuginfo=dace_debuginfo(closure.inputs[i]) + ) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - anode_loc = closure.output.location - di = None - if anode_loc: - di = dace.dtypes.DebugInfo( - anode_loc.line, - anode_loc.column, - anode_loc.end_line, - anode_loc.end_column, - anode_loc.filename, - ) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, di) + access_node = last_state.add_access( + inner_name, debuginfo=dace_debuginfo(closure.output) + ) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) program_sdfg.validate() @@ -246,16 +216,9 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") + closure_sdfg.debuginfo = dace_debuginfo(node) closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init") - if node.location: - closure_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename, - ) program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) @@ -284,8 +247,8 @@ def visit_StencilClosure( transient=True, ) closure_init_state.add_nedge( - closure_init_state.add_access(name), - closure_init_state.add_access(transient_name), + closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), + closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), create_memlet_full(name, closure_sdfg.arrays[name]), ) input_transients_mapping[name] = transient_name @@ -321,9 +284,15 @@ def visit_StencilClosure( out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", {}, {"__result"}, f"__result = {name}" + f"get_{name}", + {}, + {"__result"}, + f"__result = {name}", + debuginfo=closure_sdfg.debuginfo, + ) + access = closure_init_state.add_access( + out_name, debuginfo=closure_sdfg.debuginfo ) - access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) @@ -400,20 +369,20 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, - debuginfo=closure_sdfg.debuginfo, + debuginfo=nsdfg.debuginfo, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data if memlet.data not in output_connectors_mapping: continue - transient_access = closure_state.add_access(memlet.data) + transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) closure_state.add_edge( nsdfg_node, edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset), + dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo), ) inner_memlet = dace.Memlet.simple( memlet.data, output_subset, other_subset_str=memlet.subset @@ -462,6 +431,7 @@ def _visit_scan_stencil_closure( # the scan operator is implemented as an SDFG to be nested in the closure SDFG scan_sdfg = dace.SDFG(name="scan") + scan_sdfg.debuginfo = dace_debuginfo(node) # create a state machine for lambda call over the scan dimension start_state = scan_sdfg.add_state("start") @@ -525,6 +495,7 @@ def _visit_scan_stencil_closure( inputs=set(lambda_input_names) | set(connectivity_names), outputs={connector.value.label for connector in lambda_outputs}, symbol_mapping=symbol_mapping, + debuginfo=lambda_context.body.debuginfo, ) # the carry value of the scan operator exists in the scope of the scan sdfg @@ -533,9 +504,13 @@ def _visit_scan_stencil_closure( scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + "get_carry_init_value", + {}, + {"__result"}, + f"__result = {init_carry_value}", + debuginfo=scan_sdfg.debuginfo, ) - carry_node1 = start_state.add_access(scan_carry_name) + carry_node1 = start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo) start_state.add_edge( carry_init_tasklet, "__result", @@ -544,7 +519,9 @@ def _visit_scan_stencil_closure( dace.Memlet.simple(scan_carry_name, "0"), ) - carry_node2 = lambda_state.add_access(scan_carry_name) + carry_node2 = lambda_state.add_access( + scan_carry_name, debuginfo=lambda_context.body.debuginfo + ) lambda_state.add_memlet_path( carry_node2, scan_inner_node, @@ -560,7 +537,7 @@ def _visit_scan_stencil_closure( else: memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( - lambda_state.add_access(data_name), + lambda_state.add_access(data_name, debuginfo=lambda_context.body.debuginfo), scan_inner_node, memlet=memlet, src_conn=None, @@ -568,7 +545,9 @@ def _visit_scan_stencil_closure( ) for inner_name, memlet in connectivity_mapping.items(): - access_node = lambda_state.add_access(inner_name) + access_node = lambda_state.add_access( + inner_name, debuginfo=lambda_context.body.debuginfo + ) lambda_state.add_memlet_path( access_node, scan_inner_node, @@ -591,7 +570,7 @@ def _visit_scan_stencil_closure( ) lambda_state.add_memlet_path( scan_inner_node, - lambda_state.add_access(data_name), + lambda_state.add_access(data_name, debuginfo=lambda_context.body.debuginfo), memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, @@ -599,12 +578,17 @@ def _visit_scan_stencil_closure( # add state to scan SDFG to update the carry value at each loop iteration lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - result_node = lambda_update_state.add_access(output_names[0]) - carry_node3 = lambda_update_state.add_access(scan_carry_name) + result_node = lambda_update_state.add_access(output_names[0], debuginfo=scan_sdfg.debuginfo) + carry_node3 = lambda_update_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo) lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + memlet=dace.Memlet.simple( + output_names[0], + f"i_{scan_dim}", + other_subset_str="0", + debuginfo=scan_sdfg.debuginfo, + ), ) return scan_sdfg, map_ranges, scan_dim_index diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index eaa3ff7c27..a0e2d18f87 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -23,7 +23,6 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen -from gt4py.eve.concepts import SourceLocation from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider @@ -37,6 +36,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, @@ -177,6 +177,7 @@ def __init__( def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) offset_literal, data = node_args assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value @@ -198,7 +199,7 @@ def builtin_neighbors( result_name = unique_var_name() sdfg.add_array(result_name, dtype=iterator.dtype, shape=(table.max_neighbors,), transient=True) - result_access = state.add_access(result_name) + result_access = state.add_access(result_name, debuginfo=di) table_name = connectivity_identifier(offset_dim) @@ -207,23 +208,26 @@ def builtin_neighbors( me, mx = state.add_map( f"{offset_dim}_neighbors_map", ndrange={index_name: f"0:{table.max_neighbors}"}, + debuginfo=di, ) shift_tasklet = state.add_tasklet( "shift", code=f"__result = __table[__idx, {index_name}]", inputs={"__table", "__idx"}, outputs={"__result"}, + debuginfo=di, ) data_access_tasklet = state.add_tasklet( "data_access", code="__result = __field[__idx]", inputs={"__field", "__idx"}, outputs={"__result"}, + debuginfo=di, ) idx_name = unique_var_name() sdfg.add_scalar(idx_name, dace.int64, transient=True) state.add_memlet_path( - state.add_access(table_name), + state.add_access(table_name, debuginfo=di), me, shift_tasklet, memlet=create_memlet_full(table_name, sdfg.arrays[table_name]), @@ -233,7 +237,7 @@ def builtin_neighbors( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di), dst_conn="__idx", ) state.add_edge( @@ -259,7 +263,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, index_name), + memlet=dace.Memlet.simple(result_name, index_name, debuginfo=di), src_conn="__result", ) @@ -269,6 +273,7 @@ def builtin_neighbors( def builtin_can_deref( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) # first visit shift, to get set of indices for deref can_deref_callable = node_args[0] assert isinstance(can_deref_callable, itir.FunCall) @@ -285,13 +290,18 @@ def builtin_can_deref( # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", location=node.location + list(zip(args, internals)), + expr_code, + dace.dtypes.bool, + "can_deref", + dace_debuginfo=di, ) def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = [arg for li in transformer.visit(node_args) for arg in li] expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] internals = [ @@ -301,12 +311,19 @@ def builtin_if( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(expr_args, expr, type_, "if", location=node.location) + return transformer.add_expr_tasklet( + expr_args, + expr, + type_, + "if", + dace_debuginfo=di, + ) def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = [transformer.visit(node_args[0])[0]] internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] @@ -316,7 +333,11 @@ def builtin_cast( assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", location=node.location + list(zip(args, internals)), + expr, + type_, + "cast", + dace_debuginfo=di, ) @@ -391,17 +412,19 @@ def visit_Lambda( # Create the SDFG for the function's body prev_context = self.context context_sdfg = dace.SDFG(func_name) + di = dace_debuginfo(node) + context_sdfg.debuginfo = di context_state = context_sdfg.add_state(f"{func_name}_entry", True) symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} value: ValueExpr | IteratorExpr for param, arg in symbols.items(): if isinstance(arg, ValueExpr): - value = ValueExpr(context_state.add_access(param), arg.dtype) + value = ValueExpr(context_state.add_access(param, debuginfo=di), arg.dtype) else: assert isinstance(arg, IteratorExpr) - field = context_state.add_access(param) + field = context_state.add_access(param, debuginfo=di) indices = { - dim: context_state.add_access(f"__{param}_i_{dim}") + dim: context_state.add_access(f"__{param}_i_{dim}", debuginfo=di) for dim in arg.indices.keys() } value = IteratorExpr(field, indices, arg.dtype, arg.dimensions) @@ -456,7 +479,7 @@ def visit_Lambda( if isinstance(expr, ValueExpr): result_name = unique_var_name() self.context.body.add_scalar(result_name, expr.dtype, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) self.context.state.add_edge( expr.value, None, @@ -469,7 +492,11 @@ def visit_Lambda( else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors result = self.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", location=node.location + [], + expr.value, + expr.dtype, + "forward", + dace_debuginfo=di, )[0] self.context.body.arrays[result.value.data].transient = False results.append(result) @@ -484,7 +511,9 @@ def visit_Lambda( def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: if node.id not in self.context.symbol_map: - acc = self.context.state.add_access(node.id) + acc = self.context.state.add_access( + node.id, debuginfo=dace_debuginfo(node, self.context.body.debuginfo) + ) node_type = self.node_types[id(node)] assert isinstance(node_type, Val) self.context.symbol_map[node.id] = ValueExpr( @@ -501,6 +530,7 @@ def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))] def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: + node.fun.location = node.location if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): @@ -549,23 +579,13 @@ def _visit_call(self, node: itir.FunCall): symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - di = None - if node.location: - di = dace.dtypes.DebugInfo( - start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename, - ) - nsdfg_node = self.context.state.add_nested_sdfg( func_context.body, None, inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, - debuginfo=di, + debuginfo=dace_debuginfo(node, self.context.body.debuginfo), ) for name, value in func_inputs: @@ -585,14 +605,14 @@ def _visit_call(self, node: itir.FunCall): for conn, _ in neighbor_tables: var = connectivity_identifier(conn) memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var) + access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) result_exprs = [] for result in results: name = unique_var_name() self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name) + result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) result_exprs.append(ValueExpr(result_access, result.dtype)) memlet = create_memlet_full(name, self.context.body.arrays[name]) self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) @@ -600,6 +620,7 @@ def _visit_call(self, node: itir.FunCall): return result_exprs def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr @@ -616,13 +637,14 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: shape=(self.context.reduce_limit,), transient=True, ) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) # generate unique map index name to avoid conflict with other maps inside same state index_name = unique_name("__deref_idx") me, mx = self.context.state.add_map( "deref_map", ndrange={index_name: f"0:{self.context.reduce_limit}"}, + debuginfo=di, ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain @@ -640,6 +662,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: inputs=set(internals), outputs={"__result"}, code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", + debuginfo=di, ) for arg, internal in zip(args, internals): @@ -667,7 +690,11 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", location=node.location + list(zip(args, internals)), + expr, + iterator.dtype, + "deref", + dace_debuginfo=di, ) def _split_shift_args( @@ -684,6 +711,7 @@ def _make_shift_for_rest(self, rest, iterator): ) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: + di = dace_debuginfo(node, self.context.body.debuginfo) shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -699,7 +727,9 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] - connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + connectivity = self.context.state.add_access( + connectivity_identifier(offset_dim), debuginfo=di + ) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -734,7 +764,11 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "shift", location=node.location + list(zip(args, internals)), + expr, + dace.dtypes.int64, + "shift", + dace_debuginfo=di, )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -744,13 +778,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) - offset_node = self.context.state.add_access(offset_var) + offset_node = self.context.state.add_access(offset_var, debuginfo=di) tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}" + "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di ) self.context.state.add_edge( tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") @@ -758,8 +793,9 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): + di = dace_debuginfo(node, self.context.body.debuginfo) result_name = unique_var_name() - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) if len(node.args) == 1: assert ( @@ -787,6 +823,7 @@ def _visit_reduce(self, node: itir.FunCall): code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}", inputs={"__values"}, outputs={"__result"}, + debuginfo=di, ) self.context.state.add_edge( args[0].value, @@ -844,7 +881,7 @@ def _visit_reduce(self, node: itir.FunCall): init_value = get_reduce_identity_value(op_name.id, result_dtype) init_state = self.context.body.add_state_before(self.context.state, "init") init_tasklet = init_state.add_tasklet( - "init_reduce", {}, {"__out"}, f"__out = {init_value}" + "init_reduce", {}, {"__out"}, f"__out = {init_value}", debuginfo=di ) init_state.add_edge( init_tasklet, @@ -869,6 +906,7 @@ def _visit_reduce(self, node: itir.FunCall): expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location ) lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args) + lambda_context.body.debuginfo = di # clear context self.context.reduce_limit = 0 @@ -903,6 +941,7 @@ def _visit_reduce(self, node: itir.FunCall): symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, output_nodes={result_name: result_access}, + debuginfo=di, ) # we apply map fusion only to the nested-SDFG which is generated for the reduction operator @@ -930,7 +969,13 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return self.add_expr_tasklet(expr_args, expr, type_, "numeric", location=node.location) + return self.add_expr_tasklet( + expr_args, + expr, + type_, + "numeric", + dace_debuginfo=dace_debuginfo(node), + ) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -943,21 +988,12 @@ def add_expr_tasklet( expr: str, result_type: Any, name: str, - location: Optional[SourceLocation] = None, + dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, ) -> list[ValueExpr]: + di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name) - - di = None - if location: - di = dace.dtypes.DebugInfo( - start_line=location.line, - start_column=location.column, - end_line=location.end_line, - end_column=location.end_column, - filename=location.filename, - ) + result_access = self.context.state.add_access(result_name, debuginfo=dace_debuginfo) expr_tasklet = self.context.state.add_tasklet( name=name, @@ -982,7 +1018,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0") + memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] @@ -998,7 +1034,12 @@ def _visit_scan_closure_callable( ) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: stencil = cast(FunCall, node.stencil) assert isinstance(stencil.args[0], Lambda) - fun_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) + location_ = stencil.args[0].location + fun_node = itir.Lambda( + expr=stencil.args[0].expr, + params=stencil.args[0].params, + location=location_ if location_ else node.location, + ) args = list(itertools.chain(tlet_codegen.visit(node.output), *tlet_codegen.visit(node.inputs))) return tlet_codegen.visit(fun_node, args=args) @@ -1010,7 +1051,10 @@ def _visit_closure_callable( input_names: Sequence[str], ) -> Sequence[ValueExpr]: args = [itir.SymRef(id=name) for name in input_names] - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) + location_ = node.stencil.location + fun_node = itir.FunCall( + fun=node.stencil, args=args, location=location_ if location_ else node.location + ) return tlet_codegen.visit(fun_node) @@ -1024,25 +1068,18 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[tuple[str, ValueExpr]], Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") + body.debuginfo = dace_debuginfo(node) state = body.add_state("tasklet_toplevel_entry") symbol_map: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {} - di = None - if node.location: - di = dace.dtypes.DebugInfo( - start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename, - ) - idx_accesses = {} for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=dace.int64, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=di) - access = state.add_access(name) + tasklet = state.add_tasklet( + f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo + ) + access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for i, (name, ty) in enumerate(inputs): @@ -1057,24 +1094,16 @@ def closure_to_tasklet_sdfg( dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=stride, dtype=dtype) - anode_loc = node.inputs[i].location - di = None - if anode_loc: - di = dace.dtypes.DebugInfo( - anode_loc.line, - anode_loc.column, - anode_loc.end_line, - anode_loc.end_column, - anode_loc.filename, - ) - field = state.add_access(name, di) + field = state.add_access(name, debuginfo=dace_debuginfo(node.inputs[i])) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: assert isinstance(ty, ts.ScalarType) dtype = as_dace_type(ty) body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name), dtype) + symbol_map[name] = ValueExpr( + state.add_access(name, debuginfo=dace_debuginfo(node.inputs[i])), dtype + ) for arr, name in connectivities: shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)] stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index c5d4c8538d..fb677367b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,15 +12,32 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -from typing import Any, Sequence +from typing import Any, Optional, Sequence import dace from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts +def dace_debuginfo( + node: Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + if node.location: + di = dace.dtypes.DebugInfo( + start_line=node.location.line, + start_column=node.location.column, + end_line=node.location.end_line, + end_column=node.location.end_column, + filename=node.location.filename, + ) + else: + di = debuginfo + return di + + def as_dace_type(type_: ts.ScalarType): if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ From 0b1fe1a635353d8bfbcc78272e907f0dd702dcf9 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Mon, 4 Dec 2023 10:46:14 +0100 Subject: [PATCH 08/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG) --- .../runners/dace_iterator/itir_to_sdfg.py | 10 ++++++++-- .../runners/dace_iterator/itir_to_tasklet.py | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index ce2a1986ee..232f70f82f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -449,7 +449,11 @@ def _visit_scan_stencil_closure( # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}", debuginfo=scan_sdfg.debuginfo + "get_carry_init_value", + {}, + {"__result"}, + f"__result = {init_carry_value}", + debuginfo=scan_sdfg.debuginfo, ) start_state.add_edge( carry_init_tasklet, @@ -556,7 +560,9 @@ def _visit_scan_stencil_closure( lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") lambda_update_state.add_memlet_path( lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - lambda_update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), + lambda_update_state.add_access( + scan_carry_name, debuginfo=lambda_context.body.debuginfo + ), memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 6599ae418d..1075e9cb8f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -432,7 +432,9 @@ def _add_symbol(self, param, arg): # create storage in lambda sdfg self._sdfg.add_scalar(param, dtype=arg.dtype) # update table of lambda symbol - self._symbol_map[param] = ValueExpr(self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype) + self._symbol_map[param] = ValueExpr( + self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype + ) elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) @@ -444,7 +446,8 @@ def _add_symbol(self, param, arg): # update table of lambda symbol field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) for dim, index_arg in index_names.items() + dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) + for dim, index_arg in index_names.items() } self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) else: @@ -580,7 +583,9 @@ def visit_Lambda( if isinstance(expr, ValueExpr): result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access(result_name, debuginfo=lambda_sdfg.debuginfo) + result_access = lambda_state.add_access( + result_name, debuginfo=lambda_sdfg.debuginfo + ) lambda_state.add_nedge( expr.value, result_access, @@ -590,7 +595,9 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo)[0] + result = lambda_taskgen.add_expr_tasklet( + [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo + )[0] lambda_sdfg.arrays[result.value.data].transient = False results.append(result) @@ -1202,7 +1209,9 @@ def closure_to_tasklet_sdfg( for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo) + tasklet = state.add_tasklet( + f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo + ) access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) From 55abb29bb6452645a3386dbb11dfcb24db3497a0 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Mon, 4 Dec 2023 14:38:54 +0100 Subject: [PATCH 09/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- .../runners/dace_iterator/itir_to_sdfg.py | 10 +++------- .../runners/dace_iterator/itir_to_tasklet.py | 15 +++++++-------- .../runners/dace_iterator/utility.py | 13 +++++++------ 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 232f70f82f..51376fe445 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -200,16 +200,12 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for i, (inner_name, memlet) in enumerate(input_mapping.items()): - access_node = last_state.add_access( - inner_name, debuginfo=dace_debuginfo(closure.inputs[i]) - ) + for inner_name, memlet in input_mapping.items(): + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access( - inner_name, debuginfo=dace_debuginfo(closure.output) - ) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) program_sdfg.validate() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 1075e9cb8f..f690920964 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -803,7 +803,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name) + result_node = self.context.state.add_access(result_name, debuginfo=di) deref_connectors = ["_inp"] + [ f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices @@ -838,8 +838,8 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: for dim, size in zip(sorted_dims, field_array.shape) ) deref_access_state.add_nedge( - deref_access_state.add_access("_inp"), - deref_access_state.add_access("_out"), + deref_access_state.add_access("_inp", debuginfo=di), + deref_access_state.add_access("_out", debuginfo=di), dace.Memlet( data="_out", subset=subsets.Range.from_array(result_array), @@ -852,6 +852,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: self.context.body, inputs=set(deref_connectors), outputs={"_out"}, + debuginfo=di, ) for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): self.context.state.add_edge(node, None, deref_node, connector, memlet) @@ -1215,23 +1216,21 @@ def closure_to_tasklet_sdfg( access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) - for i, (name, ty) in enumerate(inputs): + for name, ty in inputs: if isinstance(ty, ts.FieldType): ndim = len(ty.dims) shape, strides = new_array_symbols(name, ndim) dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=dace_debuginfo(node.inputs[i])) + field = state.add_access(name, debuginfo=body.debuginfo) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: assert isinstance(ty, ts.ScalarType) dtype = as_dace_type(ty) body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr( - state.add_access(name, debuginfo=dace_debuginfo(node.inputs[i])), dtype - ) + symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) for arr, name in connectivities: shape, strides = new_array_symbols(name, ndim=2) body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d61971893f..121fb6850c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -25,13 +25,14 @@ def dace_debuginfo( node: Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None ) -> Optional[dace.dtypes.DebugInfo]: - if node.location: + location = node.location + if location: di = dace.dtypes.DebugInfo( - start_line=node.location.line, - start_column=node.location.column, - end_line=node.location.end_line, - end_column=node.location.end_column, - filename=node.location.filename, + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, ) else: di = debuginfo From 774b2f57b7c500099694d695412cc85ec7f63789 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 5 Dec 2023 13:57:35 +0100 Subject: [PATCH 10/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- src/gt4py/next/ffront/foast_to_itir.py | 5 +++-- src/gt4py/next/iterator/ir_utils/ir_makers.py | 16 ++++++++-------- .../runners/dace_iterator/__init__.py | 4 ++++ 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index a24957544d..6f0c81f193 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -49,6 +49,7 @@ class FieldOperatorLowering(NodeTranslator): Examples -------- >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser + >>> from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering >>> from gt4py.next import Field, Dimension, float64 >>> >>> IDim = Dimension("IDim") @@ -61,8 +62,8 @@ class FieldOperatorLowering(NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params - [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] + >>> lowered.params # doctest: +ELLIPSIS + [Sym(location=..., id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f7086ada0c..3950745c60 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -26,7 +26,7 @@ def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: Examples -------- >>> sym("a") - Sym(id=SymbolName('a'), kind=None, dtype=None) + Sym(location=None, id=SymbolName('a'), kind=None, dtype=None) >>> sym(itir.Sym(id="b")) Sym(id=SymbolName('b'), kind=None, dtype=None) @@ -43,7 +43,7 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef: Examples -------- >>> ref("a") - SymRef(id=SymbolRef('a')) + SymRef(location=None, id=SymbolRef('a')) >>> ref(itir.SymRef(id="b")) SymRef(id=SymbolRef('b')) @@ -60,7 +60,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti Examples -------- >>> ensure_expr("a") - SymRef(id=SymbolRef('a')) + SymRef(location=None, id=SymbolRef('a')) >>> ensure_expr(3) Literal(value='3', type='int32') @@ -83,7 +83,7 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of Examples -------- >>> ensure_offset("V2E") - OffsetLiteral(value='V2E') + OffsetLiteral(location=None, value='V2E') >>> ensure_offset(itir.OffsetLiteral(value="J")) OffsetLiteral(value='J') @@ -100,7 +100,7 @@ class lambda_: Examples -------- >>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS - Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))])) + Lambda(location=None, params=[Sym(location=None, id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('deref')), args=[SymRef(location=None, id=SymbolRef('a'))])) """ def __init__(self, *args): @@ -117,7 +117,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')]) + FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('plus')), args=[Literal(location=None, value='1', type='int32'), Literal(location=None, value='1', type='int32')]) """ def __init__(self, expr): @@ -264,7 +264,7 @@ def shift(offset, value=None): Examples -------- >>> shift("i", 0)("a") - FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='i'), OffsetLiteral(value=0)]), args=[SymRef(id=SymbolRef('a'))]) + FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='i'), OffsetLiteral(location=None, value=0)]), args=[SymRef(location=None, id=SymbolRef('a'))]) >>> shift("V2E")("b") FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='V2E')]), args=[SymRef(id=SymbolRef('b'))]) @@ -286,7 +286,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.) - Literal(value='1.0', type='float64') + Literal(location=None, value='1.0', type='float64') >>> literal_from_value(1) Literal(value='1', type='int32') >>> literal_from_value(2147483648) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 34ba2d2d95..41c354214f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -219,6 +219,10 @@ def build_sdfg_from_itir( program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) + # for nested_sdfg in sdfg.all_sdfgs_recursive(): + # if not nested_sdfg.debuginfo: + # warnings.warn(f"{nested_sdfg} does not have debuginfo. + # Consider adding them in the corresponding nested sdfg.") sdfg.simplify() # run DaCe auto-optimization heuristics From 16228660445a45d55db3d5d4e8be60c3022daba6 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 5 Dec 2023 14:22:58 +0100 Subject: [PATCH 11/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- src/gt4py/next/ffront/foast_to_itir.py | 118 +++++++++--------- src/gt4py/next/ffront/past_to_itir.py | 12 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 18 +-- .../runners/dace_iterator/__init__.py | 9 +- 4 files changed, 79 insertions(+), 78 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 6f0c81f193..f7bfd4a826 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -134,9 +134,9 @@ def visit_Stmt(self, node: foast.Stmt, **kwargs): def visit_Return( self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - return_ = self.visit(node.value, **kwargs) - return_.location = node.location - return return_ + itir_node = self.visit(node.value, **kwargs) + itir_node.location = node.location + return itir_node def visit_BlockStmt( self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs @@ -176,11 +176,11 @@ def visit_IfStmt( inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) # here we assume neither branch returns - return_ = im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - return_.location = node.location - return return_ + itir_node = im.let( + "__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch) + )(inner_expr) + itir_node.location = node.location + return itir_node elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) @@ -194,11 +194,11 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - return_ = im.let(inner_expr_name, inner_expr_evaluator)( + itir_node = im.let(inner_expr_name, inner_expr_evaluator)( im.if_(im.deref(cond), true_branch, false_branch) ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN @@ -207,18 +207,18 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - return_ = im.if_(im.deref(cond), true_branch, false_branch) - return_.location = node.location - return return_ + itir_node = im.if_(im.deref(cond), true_branch, false_branch) + itir_node.location = node.location + return itir_node def visit_Assign( self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - return_ = im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( + itir_node = im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( inner_expr ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: # TODO(tehrengruber): extend to more types @@ -227,26 +227,26 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) - return_ = im.sym(node.id) - return_.location = node.location - return return_ + itir_node = im.sym(node.id) + itir_node.location = node.location + return itir_node def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs) -> itir.Expr: - return_ = im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( + itir_node = im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( self.visit(node.value, **kwargs) ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs) -> itir.Expr: - return_ = im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + itir_node = im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( *[self.visit(el, **kwargs) for el in node.elts], ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators @@ -279,9 +279,9 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): shift_offset = im.shift(offset_name, offset_index) case foast.Name(id=offset_name): - return_ = im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - return_.location = node.location - return return_ + itir_node = im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) + itir_node.location = node.location + return itir_node case foast.Call(func=foast.Name(id="as_offset")): func_args = node.args[0] offset_dim = func_args.args[0] @@ -291,11 +291,11 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: ) case _: raise FieldOperatorLoweringError("Unexpected shift arguments!") - return_ = im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( + itir_node = im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( self.visit(node.func, **kwargs) ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: @@ -326,13 +326,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: ) call_args = [f"__arg{i}" for i in range(len(lowered_args))] call_kwargs = [f"__kwarg_{name}" for name in lowered_kwargs.keys()] - return_ = im.lift( + itir_node = im.lift( im.lambda_(*call_args, *call_kwargs)( im.call(lowered_func)(*call_args, *call_kwargs) ) )(*lowered_args, *lowered_kwargs.values()) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node elif isinstance(node.func.type, ts.FunctionType): # ITIR has no support for keyword arguments. Instead, we concatenate both positional # and keyword arguments and use the unique order as given in the function signature. @@ -342,11 +342,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: self.visit(node.kwargs, **kwargs), use_signature_ordering=True, ) - return_ = im.call(self.visit(node.func, **kwargs))( + itir_node = im.call(self.visit(node.func, **kwargs))( *lowered_args, *lowered_kwargs.values() ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node raise AssertionError( f"Call to object of type {type(node.func.type).__name__} not understood." @@ -355,19 +355,19 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id - return_ = self._process_elements( + itir_node = self._process_elements( lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map("if_", *node.args, location=node.location) def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: - return_ = self.visit(node.args[0], **kwargs) - return_.location = node.location - return return_ + itir_node = self.visit(node.args[0], **kwargs) + itir_node.location = node.location + return itir_node def _visit_math_built_in(self, node: foast.Call, **kwargs) -> itir.FunCall: return self._map(self.visit(node.func, **kwargs), *node.args, location=node.location) @@ -383,9 +383,9 @@ def _make_reduction_expr( it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return_ = im.promote_to_lifted_stencil(val)(it) - return_.location = node.location - return return_ + itir_node = im.promote_to_lifted_stencil(val)(it) + itir_node.location = node.location + return itir_node def _visit_neighbor_sum(self, node: foast.Call, **kwargs) -> itir.FunCall: dtype = type_info.extract_dtype(node.type) @@ -409,16 +409,16 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: target_type = fbuiltins.BUILTINS[node_kind] source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] if target_type is bool and source_type is not bool: - return_ = im.promote_to_const_iterator( + itir_node = im.promote_to_const_iterator( im.literal(str(bool(source_type(node.args[0].value))), "bool") ) - return_.location = node.location - return return_ - return_ = im.promote_to_const_iterator( + itir_node.location = node.location + return itir_node + itir_node = im.promote_to_const_iterator( im.literal(str(bool(source_type(node.args[0].value))), "bool") ) - return_.location = node.location - return return_ + itir_node.location = node.location + return itir_node raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: @@ -439,9 +439,9 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: raise ValueError(f"Unsupported literal type {type_}.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: - return_ = self._make_literal(node.value, node.type) - return_.location = node.location - return return_ + itir_node = self._make_literal(node.value, node.type) + itir_node.location = node.location + return itir_node def _map(self, op, *args, location=None, **kwargs): lowered_args = [self.visit(arg, **kwargs) for arg in args] @@ -449,9 +449,9 @@ def _map(self, op, *args, location=None, **kwargs): lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] op = im.call("map_")(op) - return_ = im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - return_.location = location - return return_ + itir_node = im.promote_to_lifted_stencil(im.call(op))(*lowered_args) + itir_node.location = location + return itir_node def _process_elements( self, diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 81dbfda743..6aad336198 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -185,9 +185,9 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): - return_ = self._construct_itir_out_arg(node.value) - return_.location = node.location - return return_ + itir_node = self._construct_itir_out_arg(node.value) + itir_node.location = node.location + return itir_node elif isinstance(node, past.TupleExpr): return itir.FunCall( fun=itir.SymRef(id="make_tuple"), @@ -273,10 +273,10 @@ def _construct_itir_initialized_domain_arg( f"Expected {dim}, but got {keys_dims_types} " ) - return_ = [self.visit(bound) for bound in node_domain.values_[dim_i].elts] + itir_node = [self.visit(bound) for bound in node_domain.values_[dim_i].elts] for i, bound in enumerate(node_domain.values_[dim_i].elts): - return_[i].location = bound.location - return return_ + itir_node[i].location = bound.location + return itir_node @staticmethod def _compute_field_slice(node: past.Subscript): diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 3950745c60..b7016ff662 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -29,7 +29,7 @@ def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: Sym(location=None, id=SymbolName('a'), kind=None, dtype=None) >>> sym(itir.Sym(id="b")) - Sym(id=SymbolName('b'), kind=None, dtype=None) + Sym(location=None, id=SymbolName('b'), kind=None, dtype=None) """ if isinstance(sym_or_name, itir.Sym): return sym_or_name @@ -46,7 +46,7 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef: SymRef(location=None, id=SymbolRef('a')) >>> ref(itir.SymRef(id="b")) - SymRef(id=SymbolRef('b')) + SymRef(location=None, id=SymbolRef('b')) """ if isinstance(ref_or_name, itir.SymRef): return ref_or_name @@ -63,10 +63,10 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(location=None, id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type='int32') + Literal(location=None, value='3', type='int32') >>> ensure_expr(itir.OffsetLiteral(value="i")) - OffsetLiteral(value='i') + OffsetLiteral(location=None, value='i') """ if isinstance(literal_or_expr, str): return ref(literal_or_expr) @@ -86,7 +86,7 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of OffsetLiteral(location=None, value='V2E') >>> ensure_offset(itir.OffsetLiteral(value="J")) - OffsetLiteral(value='J') + OffsetLiteral(location=None, value='J') """ if isinstance(str_or_offset, (str, int)): return itir.OffsetLiteral(value=str_or_offset) @@ -267,7 +267,7 @@ def shift(offset, value=None): FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='i'), OffsetLiteral(location=None, value=0)]), args=[SymRef(location=None, id=SymbolRef('a'))]) >>> shift("V2E")("b") - FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='V2E')]), args=[SymRef(id=SymbolRef('b'))]) + FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='V2E')]), args=[SymRef(location=None, id=SymbolRef('b'))]) """ offset = ensure_offset(offset) args = [offset] @@ -288,11 +288,11 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: >>> literal_from_value(1.) Literal(location=None, value='1.0', type='float64') >>> literal_from_value(1) - Literal(value='1', type='int32') + Literal(location=None, value='1', type='int32') >>> literal_from_value(2147483648) - Literal(value='2147483648', type='int64') + Literal(location=None, value='2147483648', type='int64') >>> literal_from_value(True) - Literal(value='True', type='bool') + Literal(location=None, value='True', type='bool') """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 raise ValueError(f"Value must be a scalar, but got {type(val).__name__}") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 41c354214f..f9f39f4bd7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -219,10 +219,11 @@ def build_sdfg_from_itir( program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) - # for nested_sdfg in sdfg.all_sdfgs_recursive(): - # if not nested_sdfg.debuginfo: - # warnings.warn(f"{nested_sdfg} does not have debuginfo. - # Consider adding them in the corresponding nested sdfg.") + for nested_sdfg in sdfg.all_sdfgs_recursive(): + if not nested_sdfg.debuginfo: + warnings.warn( + f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ) sdfg.simplify() # run DaCe auto-optimization heuristics From 7baedb8d4c99c3df80b803fca9a6e86af7037378 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 5 Dec 2023 19:55:32 +0100 Subject: [PATCH 12/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): WIP --- src/gt4py/next/ffront/past_to_itir.py | 1 - .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/inline_lifts.py | 8 +++- .../runners/dace_iterator/itir_to_tasklet.py | 39 ++++++++++++------- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6aad336198..8ec96fcec9 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -244,7 +244,6 @@ def _construct_itir_domain_arg( itir.FunCall( fun=itir.SymRef(id="named_range"), args=[itir.AxisLiteral(value=dim.value), lower, upper], - location=out_field.location, ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index c1b7a64ab9..a701a2ea31 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -143,9 +143,11 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): used_closure_params = collect_symbol_refs(node) assert not (set(used_closure_params) - set(closure_params)) - return im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( + itir_node = im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( *used_closure_params ) + itir_node.location = node.location + return itir_node return node diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 24e671b190..f7a3abf897 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -103,7 +103,9 @@ def _transform_and_extract_lift_args( extracted_args[new_symbol] = arg new_args.append(ir.SymRef(id=new_symbol.id)) - return (im.lift(inner_stencil)(*new_args), extracted_args) + itir_node = im.lift(inner_stencil)(*new_args) + itir_node.location = node.location + return (itir_node, extracted_args) # TODO(tehrengruber): This pass has many different options that should be written as dedicated @@ -257,6 +259,8 @@ def visit_FunCall( new_stencil = im.lambda_(*new_arg_exprs.keys())(inlined_call) new_stencil.location = node.location - return im.lift(new_stencil)(*new_arg_exprs.values()) + itir_node = im.lift(new_stencil)(*new_arg_exprs.values()) + itir_node.location = node.location + return itir_node return node diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 8cd2fc0e08..de71138181 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -303,13 +303,15 @@ def builtin_can_deref( # Returning a SymbolExpr would be preferable, but it requires update to type-checking. result_name = unique_var_name() transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name) + result_node = transformer.context.state.add_access(result_name, debuginfo=di) transformer.context.state.add_edge( - transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + transformer.context.state.add_tasklet( + "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di + ), "_out", result_node, None, - dace.Memlet.simple(result_name, "0"), + dace.Memlet.simple(result_name, "0", debuginfo=di), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -353,6 +355,7 @@ def builtin_if( def builtin_list_get( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node @@ -367,7 +370,9 @@ def builtin_list_get( arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args ] expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( @@ -511,7 +516,7 @@ def visit_SymRef(self, node: itir.SymRef): if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: node_type = self._node_types[id(node)] assert isinstance(node_type, Val) - access_node = self._state.add_access(param) + access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) self._symbol_map[param] = ValueExpr( access_node, dtype=itir_type_as_dace_type(node_type.dtype) ) @@ -778,6 +783,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset deref_sdfg = dace.SDFG("deref") + deref_sdfg.debuginfo = di deref_sdfg.add_array( "_inp", field_array.shape, iterator.dtype, strides=field_array.strides ) @@ -836,7 +842,9 @@ def _split_shift_args( def _make_shift_for_rest(self, rest, iterator): return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), + args=[iterator], + location=iterator.location, ) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: @@ -982,7 +990,9 @@ def _visit_reduce(self, node: itir.FunCall): reduce_input_name, nreduce_shape, reduce_dtype, transient=True ) - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_node = itir.Lambda( + expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location + ) lambda_context, inner_inputs, inner_outputs = self.visit( lambda_node, args=args, use_neighbor_tables=False ) @@ -1039,9 +1049,6 @@ def _visit_reduce(self, node: itir.FunCall): def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - for arg in node.args: - if hasattr(arg, "location"): - arg.location = node.location args: list[SymbolExpr | ValueExpr] = list( itertools.chain(*[self.visit(arg) for arg in node.args]) ) @@ -1060,7 +1067,7 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: expr, type_, "numeric", - dace_debuginfo=dace_debuginfo(node), + dace_debuginfo=dace_debuginfo(node, self.context.body.debuginfo), ) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: @@ -1079,7 +1086,7 @@ def add_expr_tasklet( di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=dace_debuginfo) + result_access = self.context.state.add_access(result_name, debuginfo=di) expr_tasklet = self.context.state.add_tasklet( name=name, @@ -1163,10 +1170,12 @@ def closure_to_tasklet_sdfg( if is_scan(node.stencil): stencil = cast(FunCall, node.stencil) assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - fun_node = itir.FunCall(fun=lambda_node, args=args) + lambda_node = itir.Lambda( + expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location + ) + fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) else: - fun_node = itir.FunCall(fun=node.stencil, args=args) + fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) results = translator.visit(fun_node) for r in results: From 223de4eeda4f5aa2472efb3cb44eb6248031b83f Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 11:47:15 +0100 Subject: [PATCH 13/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/eve/visitors.py | 8 ++ src/gt4py/next/ffront/foast_to_itir.py | 115 +++++------------- src/gt4py/next/ffront/past_to_itir.py | 14 +-- .../iterator/transforms/collapse_list_get.py | 5 +- .../iterator/transforms/collapse_tuple.py | 5 +- .../iterator/transforms/constant_folding.py | 4 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../next/iterator/transforms/eta_reduction.py | 3 +- .../next/iterator/transforms/global_tmps.py | 14 +-- .../iterator/transforms/inline_fundefs.py | 6 +- .../iterator/transforms/inline_into_scan.py | 10 +- .../iterator/transforms/inline_lambdas.py | 3 +- .../next/iterator/transforms/inline_lifts.py | 15 +-- .../next/iterator/transforms/merge_let.py | 4 +- .../iterator/transforms/normalize_shifts.py | 4 +- .../iterator/transforms/propagate_deref.py | 4 +- .../transforms/prune_closure_inputs.py | 4 +- .../next/iterator/transforms/remap_symbols.py | 16 +-- .../iterator/transforms/scan_eta_reduction.py | 10 +- .../iterator/transforms/symbol_ref_utils.py | 3 +- .../next/iterator/transforms/trace_shifts.py | 3 +- .../next/iterator/transforms/unroll_reduce.py | 7 +- 22 files changed, 98 insertions(+), 165 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index fe5f9e1474..769576ed80 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -196,3 +196,11 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return copy.deepcopy(node, memo=memo) + + +class PreserveLocation(NodeVisitor): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index f7bfd4a826..77acda2efa 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -17,6 +17,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.utils import UIDGenerator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.ffront import ( dialect_ast_enums, fbuiltins, @@ -39,7 +40,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocation, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -49,7 +50,6 @@ class FieldOperatorLowering(NodeTranslator): Examples -------- >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering >>> from gt4py.next import Field, Dimension, float64 >>> >>> IDim = Dimension("IDim") @@ -80,7 +80,6 @@ def visit_FunctionDefinition( id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None), - location=node.location, ) # `expr` is a lifted stencil def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.FunctionDefinition: @@ -91,7 +90,6 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.Funct id=func_definition.id, params=func_definition.params, expr=new_body, - location=node.location, ) def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.FunctionDefinition: @@ -114,9 +112,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio func_definition.params[0].id, im.promote_to_const_iterator(func_definition.params[0].id), )(im.deref(new_body)) - definition = itir.Lambda( - params=func_definition.params, expr=new_body, location=node.location - ) + definition = itir.Lambda(params=func_definition.params, expr=new_body) body = im.call(im.call("scan")(definition, forward, init))( *(param.id for param in definition.params[1:]) ) @@ -125,7 +121,6 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio id=node.id, params=definition.params[1:], expr=body, - location=node.location, ) def visit_Stmt(self, node: foast.Stmt, **kwargs): @@ -134,9 +129,7 @@ def visit_Stmt(self, node: foast.Stmt, **kwargs): def visit_Return( self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - itir_node = self.visit(node.value, **kwargs) - itir_node.location = node.location - return itir_node + return self.visit(node.value, **kwargs) def visit_BlockStmt( self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs @@ -144,7 +137,6 @@ def visit_BlockStmt( for stmt in reversed(node.stmts): inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) assert inner_expr - inner_expr.location = node.location return inner_expr def visit_IfStmt( @@ -176,11 +168,9 @@ def visit_IfStmt( inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) # here we assume neither branch returns - itir_node = im.let( - "__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch) - )(inner_expr) - itir_node.location = node.location - return itir_node + return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( + inner_expr + ) elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) @@ -194,11 +184,9 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - itir_node = im.let(inner_expr_name, inner_expr_evaluator)( + return im.let(inner_expr_name, inner_expr_evaluator)( im.if_(im.deref(cond), true_branch, false_branch) ) - itir_node.location = node.location - return itir_node assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN @@ -207,18 +195,14 @@ def visit_IfStmt( true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - itir_node = im.if_(im.deref(cond), true_branch, false_branch) - itir_node.location = node.location - return itir_node + return im.if_(im.deref(cond), true_branch, false_branch) def visit_Assign( self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: - itir_node = im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( + return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( inner_expr ) - itir_node.location = node.location - return itir_node def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: # TODO(tehrengruber): extend to more types @@ -226,27 +210,21 @@ def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym: kind = "Iterator" dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) - itir_node = im.sym(node.id) - itir_node.location = node.location - return itir_node + return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) + return im.sym(node.id) def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef: return im.ref(node.id) def visit_Subscript(self, node: foast.Subscript, **kwargs) -> itir.Expr: - itir_node = im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( + return im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))( self.visit(node.value, **kwargs) ) - itir_node.location = node.location - return itir_node def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs) -> itir.Expr: - itir_node = im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( + return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))( *[self.visit(el, **kwargs) for el in node.elts], ) - itir_node.location = node.location - return itir_node def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators @@ -254,34 +232,29 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"{node.op} is only supported on `bool`s.") - return self._map("not_", node.operand, location=node.location) + return self._map("not_", node.operand) return self._map( node.op.value, foast.Constant(value="0", type=dtype, location=node.location), node.operand, - location=node.location, ) def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right, location=node.location) + return self._map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall: - return self._map( - "if_", node.condition, node.true_expr, node.false_expr, location=node.location - ) + return self._map("if_", node.condition, node.true_expr, node.false_expr) def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right, location=node.location) + return self._map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: match node.args[0]: case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): shift_offset = im.shift(offset_name, offset_index) case foast.Name(id=offset_name): - itir_node = im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - itir_node.location = node.location - return itir_node + return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) case foast.Call(func=foast.Name(id="as_offset")): func_args = node.args[0] offset_dim = func_args.args[0] @@ -291,11 +264,9 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr: ) case _: raise FieldOperatorLoweringError("Unexpected shift arguments!") - itir_node = im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( + return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))( self.visit(node.func, **kwargs) ) - itir_node.location = node.location - return itir_node def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: if type_info.type_class(node.func.type) is ts.FieldType: @@ -326,13 +297,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: ) call_args = [f"__arg{i}" for i in range(len(lowered_args))] call_kwargs = [f"__kwarg_{name}" for name in lowered_kwargs.keys()] - itir_node = im.lift( + return im.lift( im.lambda_(*call_args, *call_kwargs)( im.call(lowered_func)(*call_args, *call_kwargs) ) )(*lowered_args, *lowered_kwargs.values()) - itir_node.location = node.location - return itir_node elif isinstance(node.func.type, ts.FunctionType): # ITIR has no support for keyword arguments. Instead, we concatenate both positional # and keyword arguments and use the unique order as given in the function signature. @@ -342,11 +311,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: self.visit(node.kwargs, **kwargs), use_signature_ordering=True, ) - itir_node = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - itir_node.location = node.location - return itir_node + return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) raise AssertionError( f"Call to object of type {type(node.func.type).__name__} not understood." @@ -355,22 +320,18 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = node.args[0], node.args[1].id - itir_node = self._process_elements( + return self._process_elements( lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs ) - itir_node.location = node.location - return itir_node def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map("if_", *node.args, location=node.location) + return self._map("if_", *node.args) def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall: - itir_node = self.visit(node.args[0], **kwargs) - itir_node.location = node.location - return itir_node + return self.visit(node.args[0], **kwargs) def _visit_math_built_in(self, node: foast.Call, **kwargs) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args, location=node.location) + return self._map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, @@ -383,9 +344,7 @@ def _make_reduction_expr( it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) val = im.call(im.call("reduce")(op, im.deref(init_expr))) - itir_node = im.promote_to_lifted_stencil(val)(it) - itir_node.location = node.location - return itir_node + return im.promote_to_lifted_stencil(val)(it) def _visit_neighbor_sum(self, node: foast.Call, **kwargs) -> itir.FunCall: dtype = type_info.extract_dtype(node.type) @@ -409,16 +368,10 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: target_type = fbuiltins.BUILTINS[node_kind] source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] if target_type is bool and source_type is not bool: - itir_node = im.promote_to_const_iterator( + return im.promote_to_const_iterator( im.literal(str(bool(source_type(node.args[0].value))), "bool") ) - itir_node.location = node.location - return itir_node - itir_node = im.promote_to_const_iterator( - im.literal(str(bool(source_type(node.args[0].value))), "bool") - ) - itir_node.location = node.location - return itir_node + return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: @@ -439,19 +392,15 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: raise ValueError(f"Unsupported literal type {type_}.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: - itir_node = self._make_literal(node.value, node.type) - itir_node.location = node.location - return itir_node + return self._make_literal(node.value, node.type) - def _map(self, op, *args, location=None, **kwargs): + def _map(self, op, *args, **kwargs): lowered_args = [self.visit(arg, **kwargs) for arg in args] if any(type_info.contains_local_field(arg.type) for arg in args): lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] op = im.call("map_")(op) - itir_node = im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - itir_node.location = location - return itir_node + return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) def _process_elements( self, diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 8ec96fcec9..996bd4f41e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -17,6 +17,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, concepts, traits +from gt4py.eve.visitors import PreserveLocation from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir @@ -40,7 +41,7 @@ def _flatten_tuple_expr( raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") -class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Lower Program AST (PAST) to Iterator IR (ITIR). @@ -125,7 +126,6 @@ def visit_Program( function_definitions=function_definitions, params=params, closures=closures, - location=node.location, ) def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: @@ -355,12 +355,12 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: f"Scalars of kind {node.type.kind} not supported currently." ) typename = node.type.kind.name.lower() - return itir.Literal(value=str(node.value), type=typename, location=node.location) + return itir.Literal(value=str(node.value), type=typename) raise NotImplementedError("Only scalar literals supported currently.") def visit_Name(self, node: past.Name, **kwargs) -> itir.SymRef: - return itir.SymRef(id=node.id, location=node.location) + return itir.SymRef(id=node.id) def visit_Symbol(self, node: past.Symbol, **kwargs) -> itir.Sym: # TODO(tehrengruber): extend to more types @@ -368,14 +368,13 @@ def visit_Symbol(self, node: past.Symbol, **kwargs) -> itir.Sym: kind = "Iterator" dtype = node.type.dtype.kind.name.lower() is_list = type_info.is_local_field(node.type) - return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location) - return itir.Sym(id=node.id, location=node.location) + return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list)) + return itir.Sym(id=node.id) def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id=node.op.value), args=[self.visit(node.left, **kwargs), self.visit(node.right, **kwargs)], - location=node.location, ) def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: @@ -383,7 +382,6 @@ def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id=node.func.id), args=[self.visit(node.args[0]), self.visit(node.args[1])], - location=node.location, ) else: raise AssertionError( diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 4d35568b4d..79d7cf6f5f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -13,10 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py import eve +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class CollapseListGet(eve.NodeTranslator): +class CollapseListGet(PreserveLocation, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples @@ -49,10 +50,8 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: args=[it], ) ], - location=node.location, ) if node.args[1].fun == ir.SymRef(id="make_const_list"): - node.args[1].args[0].location = node.location return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 393f781276..6cface6ac9 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -15,6 +15,7 @@ from typing import Optional from gt4py import eve +from gt4py.eve.visitors import PreserveLocation from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference @@ -45,7 +46,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(eve.NodeTranslator): +class CollapseTuple(PreserveLocation, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -108,7 +109,6 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if self.ignore_tuple_size or _get_tuple_size(first_expr, self._node_types) == len( node.args ): - first_expr.location = node.location return first_expr if ( self.collapse_tuple_get_make_tuple @@ -124,6 +124,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: assert idx < len( make_tuple_call.args ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" - node.args[1].args[idx].location = node.location return node.args[1].args[idx] return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 020f02475e..aa5fbf9eb3 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,11 +13,12 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im -class ConstantFolding(NodeTranslator): +class ConstantFolding(PreserveLocation, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) @@ -47,5 +48,4 @@ def visit_FunCall(self, node: ir.FunCall): arg_values = [getattr(embedded, str(arg.type))(arg.value) for arg in new_node.args] # type: ignore[attr-defined] # arg type already established in if condition new_node = im.literal_from_value(fun(*arg_values)) - new_node.location = node.location return new_node diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index c282673a92..c7000d86d0 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -19,12 +19,13 @@ from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @dataclasses.dataclass -class _NodeReplacer(NodeTranslator): +class _NodeReplacer(PreserveLocation, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) expr_map: dict[int, ir.SymRef] @@ -341,7 +342,7 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) -class CommonSubexpressionElimination(NodeTranslator): +class CommonSubexpressionElimination(PreserveLocation, NodeTranslator): """ Perform common subexpression elimination. @@ -379,7 +380,6 @@ def visit_FunCall(self, node: ir.FunCall): result = ir.FunCall( fun=ir.Lambda(params=list(extracted.keys()), expr=new_expr), args=list(extracted.values()), - location=node.location, ) # if the node id is ignored (because its parent is eliminated), but it occurs diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 99043490ba..0b47795961 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -13,10 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class EtaReduction(NodeTranslator): +class EtaReduction(PreserveLocation, NodeTranslator): """Eta reduction: simplifies `λ(args...) → f(args...)` to `f`.""" def visit_Lambda(self, node: ir.Lambda) -> ir.Node: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a701a2ea31..115c8313a8 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,6 +22,7 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -143,11 +144,9 @@ def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): used_closure_params = collect_symbol_refs(node) assert not (set(used_closure_params) - set(closure_params)) - itir_node = im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( + return im.lift(im.lambda_(*used_closure_params)(im.call(stencil)(*it_args)))( *used_closure_params ) - itir_node.location = node.location - return itir_node return node @@ -268,7 +267,6 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] - location=closure.location, ) ) @@ -296,7 +294,6 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], - location=closure.location, ) ) else: @@ -313,7 +310,6 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], - location=node.location, ) @@ -340,7 +336,6 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], - location=node.location, ) @@ -458,7 +453,6 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An stencil=closure.stencil, output=closure.output, inputs=closure.inputs, - location=closure.location, ) else: domain = closure.domain @@ -514,7 +508,6 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An ), params=node.params, tmps=node.tmps, - location=node.location, ) @@ -565,14 +558,13 @@ def convert_type(dtype): tmps=[ Temporary(id=tmp.id, domain=domains[tmp.id], dtype=types[tmp.id]) for tmp in node.tmps ], - location=node.location, ) # TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be # tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore # and hence also not extract as a temporary. -class CreateGlobalTmps(NodeTranslator): +class CreateGlobalTmps(PreserveLocation, NodeTranslator): """Main entry point for introducing global temporaries. Transforms an existing iterator IR fencil into a fencil with global temporaries. diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 57e445a9bc..e03d4b7ed9 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -15,16 +15,16 @@ from typing import Any, Dict, Set from gt4py.eve import NOTHING, NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class InlineFundefs(NodeTranslator): +class InlineFundefs(PreserveLocation, NodeTranslator): def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): return ir.Lambda( params=self.generic_visit(symbol.params, symtable=symtable), expr=self.generic_visit(symbol.expr, symtable=symtable), - location=node.location, ) return self.generic_visit(node) @@ -32,7 +32,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition): return self.generic_visit(node, symtable=node.annex.symtable) -class PruneUnreferencedFundefs(NodeTranslator): +class PruneUnreferencedFundefs(PreserveLocation, NodeTranslator): def visit_FunctionDefinition( self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index b0f8d98bd6..e3e22f32a0 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -16,6 +16,7 @@ from gt4py import eve from gt4py.eve import NodeTranslator, traits +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -53,7 +54,7 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Inline non-SymRef arguments into the scan. @@ -100,10 +101,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall( - fun=new_scan, - args=[ir.SymRef(id=ref) for ref in refs_in_args], - location=node.location, - ) - return result + return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index b82f4dfd0a..f5838764c9 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -16,6 +16,7 @@ from typing import Optional from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols @@ -122,7 +123,7 @@ def new_name(name): @dataclasses.dataclass -class InlineLambdas(NodeTranslator): +class InlineLambdas(PreserveLocation, NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" PRESERVED_ANNEX_ATTRS = ("type",) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f7a3abf897..1dfdd4c843 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,6 +19,7 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -112,7 +113,7 @@ def _transform_and_extract_lift_args( # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific @@ -171,7 +172,7 @@ def visit_FunCall( self.visit(ir.FunCall(fun=shift, args=[arg]), recurse=False, **kwargs) for arg in lift_call.args # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall ] - result = ir.FunCall(fun=lift_call.fun, args=new_args, location=node.location) # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall + result = ir.FunCall(fun=lift_call.fun, args=new_args) # type: ignore[attr-defined] # lift_call already asserted to be of type ir.FunCall return self.visit(result, recurse=False, **kwargs) elif self.flags & self.Flag.INLINE_DEREF_LIFT and node.fun == ir.SymRef(id="deref"): assert len(node.args) == 1 @@ -188,7 +189,7 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 f = node.args[0].fun.args[0] args = node.args[0].args - new_node = ir.FunCall(fun=f, args=args, location=node.location) + new_node = ir.FunCall(fun=f, args=args) if isinstance(f, ir.Lambda): new_node = inline_lambda(new_node, opcount_preserving=True) return self.visit(new_node, **kwargs) @@ -203,14 +204,13 @@ def visit_FunCall( assert len(node.args[0].fun.args) == 1 args = node.args[0].args if len(args) == 0: - return ir.Literal(value="True", type="bool", location=node.location) + return ir.Literal(value="True", type="bool") res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) for arg in args[1:]: res = ir.FunCall( fun=ir.SymRef(id="and_"), args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], - location=node.location, ) return res elif ( @@ -258,9 +258,6 @@ def visit_FunCall( ) new_stencil = im.lambda_(*new_arg_exprs.keys())(inlined_call) - new_stencil.location = node.location - itir_node = im.lift(new_stencil)(*new_arg_exprs.values()) - itir_node.location = node.location - return itir_node + return im.lift(new_stencil)(*new_arg_exprs.values()) return node diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index b669b8d609..140d6574fa 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -13,11 +13,12 @@ # SPDX-License-Identifier: GPL-3.0-or-later import gt4py.eve as eve +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(eve.NodeTranslator): +class MergeLet(PreserveLocation, eve.NodeTranslator): """ Merge let-like statements. @@ -64,6 +65,5 @@ def visit_FunCall(self, node: itir.FunCall): params=outer_lambda.params + inner_lambda.params, expr=inner_lambda.expr ), args=outer_lambda_args + inner_lambda_args, - location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index 6c63fe9c33..5545ad3231 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -13,10 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class NormalizeShifts(NodeTranslator): +class NormalizeShifts(PreserveLocation, NodeTranslator): def visit_FunCall(self, node: ir.FunCall): node = self.generic_visit(node) if ( @@ -36,6 +37,5 @@ def visit_FunCall(self, node: ir.FunCall): fun=ir.SymRef(id="shift"), args=node.args[0].fun.args + node.fun.args ), args=node.args[0].args, - location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 9384a692e8..21d3e333be 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -14,6 +14,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.pattern_matching import ObjectPattern as P +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir @@ -22,7 +23,7 @@ # `(λ(...) → plus(multiplies(...), ...))(...)`. -class PropagateDeref(NodeTranslator): +class PropagateDeref(PreserveLocation, NodeTranslator): @classmethod def apply(cls, node: ir.Node): """ @@ -55,6 +56,5 @@ def visit_FunCall(self, node: ir.FunCall): expr=ir.FunCall(fun=builtin, args=[lambda_fun.expr]), ), args=lambda_args, - location=node.location, ) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 3f39c44183..c67f7e4476 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -13,10 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class PruneClosureInputs(NodeTranslator): +class PruneClosureInputs(PreserveLocation, NodeTranslator): """Removes all unused input arguments from a stencil closure.""" def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: @@ -37,7 +38,6 @@ def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: stencil=ir.Lambda(params=params, expr=expr), output=node.output, inputs=inputs, - location=node.location, ) def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 84a57ee2e2..2a12d4dde0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -15,10 +15,11 @@ from typing import Any, Dict, Optional, Set from gt4py.eve import NodeTranslator, SymbolTableTrait +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir -class RemapSymbolRefs(NodeTranslator): +class RemapSymbolRefs(PreserveLocation, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): @@ -30,7 +31,6 @@ def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): return ir.Lambda( params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map), - location=node.location, ) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] @@ -40,27 +40,21 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] return super().generic_visit(node, **kwargs) -class RenameSymbols(NodeTranslator): +class RenameSymbols(PreserveLocation, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.Sym( - id=name_map.get(node.id, node.id), - location=node.location, - ) + return ir.Sym(id=name_map.get(node.id, node.id)) return node def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef( - id=name_map.get(node.id, node.id), - location=node.location, - ) + return ir.SymRef(id=name_map.get(node.id, node.id)) return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 466b2817be..5684a0afb4 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir @@ -24,7 +25,7 @@ def _is_scan(node: ir.Node): ) -class ScanEtaReduction(NodeTranslator): +class ScanEtaReduction(PreserveLocation, NodeTranslator): """Applies eta-reduction-like transformation involving scans. Simplifies `λ(x, y) → scan(λ(state, param_y, param_x) → ..., ...)(y, x)` to `scan(λ(state, param_x, param_y) → ..., ...)`. @@ -55,11 +56,8 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: original_scanpass.params[i + 1] for i in new_scanpass_params_idx ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) - result = ir.FunCall( - fun=ir.SymRef(id="scan"), - args=[new_scanpass, *node.expr.fun.args[1:]], - location=node.location, + return ir.FunCall( + fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]] ) - return result return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1c587fb9d6..679becd737 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -17,11 +17,12 @@ from typing import Iterable, Optional, Sequence import gt4py.eve as eve +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir as itir @dataclasses.dataclass -class CountSymbolRefs(eve.NodeVisitor): +class CountSymbolRefs(PreserveLocation, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 5c607e7df1..5c7729e3fb 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -17,6 +17,7 @@ from typing import Any, Final, Iterable, Literal from gt4py.eve import NodeTranslator +from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir @@ -235,7 +236,7 @@ def _tuple_get(index, tuple_val): @dataclasses.dataclass(frozen=True) -class TraceShifts(NodeTranslator): +class TraceShifts(PreserveLocation, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index f2762a9cb0..e025464f4c 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -18,6 +18,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.utils import UIDGenerator +from gt4py.eve.visitors import PreserveLocation from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -129,7 +130,7 @@ def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: @dataclasses.dataclass(frozen=True) -class UnrollReduce(NodeTranslator): +class UnrollReduce(PreserveLocation, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @@ -164,9 +165,7 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: for i in range(max_neighbors): expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), - args=[step_fun], - location=node.location, + fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun] ) return expr From 93fcb14243cef153bbe1df4e9c0e3d48734d21d3 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 14:30:28 +0100 Subject: [PATCH 14/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/eve/traits.py | 8 ++++++++ src/gt4py/next/ffront/past_to_itir.py | 3 +-- src/gt4py/next/iterator/transforms/inline_into_scan.py | 3 +-- src/gt4py/next/iterator/transforms/inline_lifts.py | 3 +-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index df556c9d7f..fb3efd6412 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result + + +class PreserveLocationWithSymbolTableTrait(VisitorWithSymbolTableTrait): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c2c9f17c4c..f8515364b5 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -17,7 +17,6 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.eve.visitors import PreserveLocation from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir @@ -41,7 +40,7 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Lower Program AST (PAST) to Iterator IR (ITIR). diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index e3e22f32a0..fc6eb767b0 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -16,7 +16,6 @@ from gt4py import eve from gt4py.eve import NodeTranslator, traits -from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -54,7 +53,7 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Inline non-SymRef arguments into the scan. diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 1dfdd4c843..65459a51f5 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -19,7 +19,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits -from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -113,7 +112,7 @@ def _transform_and_extract_lift_args( # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(PreserveLocation, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific From d6da4abe6afaae805609fc02b62b39f0ef6ca49e Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 15:01:59 +0100 Subject: [PATCH 15/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/eve/visitors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 769576ed80..64c8859499 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -24,6 +24,7 @@ from . import concepts, trees from .extended_typing import Any from .type_definitions import NOTHING +from gt4py.next.ffront.field_operator_ast import Name class NodeVisitor: @@ -201,6 +202,6 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: class PreserveLocation(NodeVisitor): def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location"): + if hasattr(node, "location") and hasattr(result, "location") and not isinstance(node, Name): result.location = node.location return result From 8460c67c3f4d5aa101d9a48e065de898f63d996a Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 15:19:29 +0100 Subject: [PATCH 16/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/eve/visitors.py | 3 ++- src/gt4py/next/ffront/past_to_itir.py | 4 +++- src/gt4py/next/iterator/transforms/inline_into_scan.py | 4 +++- src/gt4py/next/iterator/transforms/inline_lifts.py | 4 +++- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 64c8859499..9b1ef1695f 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -21,10 +21,11 @@ import copy from typing import ClassVar +from gt4py.next.ffront.field_operator_ast import Name + from . import concepts, trees from .extended_typing import Any from .type_definitions import NOTHING -from gt4py.next.ffront.field_operator_ast import Name class NodeVisitor: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index f8515364b5..d53383683a 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -40,7 +40,9 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering( + traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Lower Program AST (PAST) to Iterator IR (ITIR). diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index fc6eb767b0..d73a09c6e5 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan( + traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Inline non-SymRef arguments into the scan. diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 65459a51f5..821c543786 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -112,7 +112,9 @@ def _transform_and_extract_lift_args( # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts( + traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific From 5632def34ee1e5e70de7563df99e8e48df7fc7a2 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 15:27:54 +0100 Subject: [PATCH 17/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/eve/visitors.py | 4 +--- src/gt4py/next/ffront/foast_to_itir.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 9b1ef1695f..769576ed80 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -21,8 +21,6 @@ import copy from typing import ClassVar -from gt4py.next.ffront.field_operator_ast import Name - from . import concepts, trees from .extended_typing import Any from .type_definitions import NOTHING @@ -203,6 +201,6 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: class PreserveLocation(NodeVisitor): def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location") and not isinstance(node, Name): + if hasattr(node, "location") and hasattr(result, "location"): result.location = node.location return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 6a7dbe7360..9c06c57821 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,9 +15,8 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, concepts, extended_typing from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocation from gt4py.next.ffront import ( dialect_ast_enums, fbuiltins, @@ -40,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(PreserveLocation, NodeTranslator): +class FieldOperatorLowering(NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -72,6 +71,16 @@ class FieldOperatorLowering(PreserveLocation, NodeTranslator): def apply(cls, node: foast.LocatedNode) -> itir.Expr: return cls().visit(node) + def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any: + result = super().visit(node, **kwargs) + if ( + hasattr(node, "location") + and hasattr(result, "location") + and not isinstance(node, foast.Name) + ): + result.location = node.location + return result + def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs ) -> itir.FunctionDefinition: From b59fd8366e9fdcbafa847cb639a5c794ae76d2da Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 17:14:50 +0100 Subject: [PATCH 18/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/next/iterator/transforms/cse.py | 12 ++++++++++-- src/gt4py/next/iterator/transforms/eta_reduction.py | 1 - src/gt4py/next/iterator/transforms/fuse_maps.py | 6 +++--- src/gt4py/next/iterator/transforms/global_tmps.py | 6 ++++++ src/gt4py/next/iterator/transforms/inline_lifts.py | 1 - 5 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 1eba621db3..16c4af23c7 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,7 +17,13 @@ import operator import typing -from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait +from gt4py.eve import ( + NodeTranslator, + NodeVisitor, + SymbolTableTrait, + VisitorWithSymbolTableTrait, + traits, +) from gt4py.eve.utils import UIDGenerator from gt4py.eve.visitors import PreserveLocation from gt4py.next.iterator import ir @@ -73,7 +79,9 @@ def _is_collectable_expr(node: ir.Node) -> bool: @dataclasses.dataclass -class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): +class CollectSubexpressions( + traits.PreserveLocationWithSymbolTableTrait, VisitorWithSymbolTableTrait, NodeVisitor +): @dataclasses.dataclass class SubexpressionData: #: A list of node ids with equal hash and a set of collected child subexpression ids diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 0b47795961..23a55c27af 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -29,7 +29,6 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: for p, a in zip(node.params, node.expr.args) ) ): - node.expr.fun.location = node.location return self.visit(node.expr.fun) return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index ea51adeca3..e132bb5012 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -38,7 +38,9 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class FuseMaps( + traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Fuses nested `map_`s. @@ -125,12 +127,10 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args, - location=node.location, ) else: # _is_reduce(node) return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]), args=new_args, - location=node.location, ) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 115c8313a8..f0afbcb68d 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -267,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + location=current_closure.location, ) ) @@ -294,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + location=current_closure.location, ) ) else: @@ -307,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp + [ir.Sym(id=tmp.id) for tmp in tmps] + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), + location=node.location, ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], @@ -333,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari function_definitions=node.fencil.function_definitions, params=[p for p in node.fencil.params if p.id not in unused_tmps], closures=closures, + location=node.fencil.location, ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], @@ -453,6 +457,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An stencil=closure.stencil, output=closure.output, inputs=closure.inputs, + location=closure.location, ) else: domain = closure.domain @@ -505,6 +510,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An function_definitions=node.fencil.function_definitions, params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), + location=node.fencil.location, ), params=node.params, tmps=node.tmps, diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 821c543786..ac7aa4ed73 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -159,7 +159,6 @@ def visit_FunCall( ir.FunCall( fun=self.generic_visit(node.fun, is_scan_pass_context=_is_scan(node), **kwargs), args=self.generic_visit(node.args, **kwargs), - location=node.location, ) if recurse else node From 68dde06a7f108519490bb21cccf3637997204a5f Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Wed, 13 Dec 2023 17:49:21 +0100 Subject: [PATCH 19/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- .../program_processors/runners/dace_iterator/itir_to_tasklet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index b9d22d64d9..ff8b6e27d0 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -713,7 +713,7 @@ def _visit_call(self, node: itir.FunCall): inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, - debuginfo=dace_debuginfo(node, self.context.body.debuginfo), + debuginfo=dace_debuginfo(node, func_context.body.debuginfo), ) for name, value in func_inputs: From a1a91c4d6ecfe964ea81299c46f46bc95fdecf0e Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 14 Dec 2023 11:47:22 +0100 Subject: [PATCH 20/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- .../program_processors/runners/dace_iterator/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index cac91f2516..4a23b38d57 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,7 +31,12 @@ from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims +from .utility import ( + connectivity_identifier, + dace_debuginfo, + filter_neighbor_tables, + get_sorted_dims, +) try: @@ -224,6 +229,7 @@ def build_sdfg_from_itir( warnings.warn( f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." ) + nested_sdfg.debuginfo = dace_debuginfo(program) sdfg.simplify() # run DaCe auto-optimization heuristics From 1ed97643abe5dd3abd3d7f1cd544ed2cd471422d Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 14 Dec 2023 12:35:47 +0100 Subject: [PATCH 21/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- .../runners/dace_iterator/__init__.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 4a23b38d57..fa9524955f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,12 +31,7 @@ from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import ( - connectivity_identifier, - dace_debuginfo, - filter_neighbor_tables, - get_sorted_dims, -) +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims try: @@ -229,7 +224,13 @@ def build_sdfg_from_itir( warnings.warn( f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." ) - nested_sdfg.debuginfo = dace_debuginfo(program) + nested_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=0, + start_column=0, + end_line=-1, + end_column=0, + filename=None, + ) sdfg.simplify() # run DaCe auto-optimization heuristics From 371dc367728928e062a96edec569b84dedca26e3 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 14 Dec 2023 13:12:33 +0100 Subject: [PATCH 22/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- .../ffront_tests/test_foast_to_itir.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 2dd4b91c48..0f9dbe58da 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -59,6 +59,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None assert lowered.id == "copy_field" assert lowered.expr == im.ref("inp") @@ -70,6 +71,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], parsed = FieldOperatorParser.apply_to_function(scalar_arg) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("multiplies")( "alpha", "bar" @@ -84,6 +86,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] parsed = FieldOperatorParser.apply_to_function(multicopy) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") @@ -96,6 +99,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 parsed = FieldOperatorParser.apply_to_function(arithmetic) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") @@ -110,6 +114,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): parsed = FieldOperatorParser.apply_to_function(shift_by_one) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") @@ -124,6 +129,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): parsed = FieldOperatorParser.apply_to_function(shift_by_one) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") @@ -139,6 +145,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp" @@ -165,6 +172,7 @@ def unary(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(unary) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), @@ -194,6 +202,7 @@ def unpacking( parsed = FieldOperatorParser.apply_to_function(unpacking) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0") @@ -223,6 +232,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) @@ -247,6 +257,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: parsed = FieldOperatorParser.apply_to_function(call) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp") @@ -262,6 +273,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): parsed = FieldOperatorParser.apply_to_function(temp_tuple) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b") reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) @@ -275,6 +287,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(unary_not) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("not_")("cond") @@ -287,6 +300,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(plus) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("plus")("a", "b") @@ -299,6 +313,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6 parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("plus")( im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" @@ -314,6 +329,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3 parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.let( ssa.unique_name("tmp", 0), @@ -332,6 +348,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(mult) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("multiplies")("a", "b") @@ -344,6 +361,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(minus) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("minus")("a", "b") @@ -356,6 +374,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(division) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("divides")("a", "b") @@ -368,6 +387,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(bit_and) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("and_")("a", "b") @@ -380,6 +400,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: parsed = FieldOperatorParser.apply_to_function(scalar_and) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("and_")( "a", im.promote_to_const_iterator(im.literal("False", "bool")) @@ -394,6 +415,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(bit_or) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("or_")("a", "b") @@ -406,6 +428,7 @@ def comp_scalars() -> bool: parsed = FieldOperatorParser.apply_to_function(comp_scalars) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("greater")( im.promote_to_const_iterator(im.literal("3", "int32")), @@ -421,6 +444,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(comp_gt) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("greater")("a", "b") @@ -433,6 +457,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(comp_lt) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("less")("a", "b") @@ -445,6 +470,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): parsed = FieldOperatorParser.apply_to_function(comp_eq) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("eq")("a", "b") @@ -459,6 +485,7 @@ def compare_chain( parsed = FieldOperatorParser.apply_to_function(compare_chain) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("and_")( im.promote_to_lifted_stencil("greater")("a", "b"), @@ -474,6 +501,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): parsed = FieldOperatorParser.apply_to_function(reduction) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil( im.call( @@ -496,6 +524,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl parsed = FieldOperatorParser.apply_to_function(reduction) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( im.promote_to_lifted_stencil("make_const_list")( @@ -539,6 +568,7 @@ def int_constrs() -> ( parsed = FieldOperatorParser.apply_to_function(int_constrs) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("1", "int32")), @@ -575,6 +605,7 @@ def float_constrs() -> ( parsed = FieldOperatorParser.apply_to_function(float_constrs) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("0.1", "float64")), @@ -595,6 +626,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: parsed = FieldOperatorParser.apply_to_function(bool_constrs) lowered = FieldOperatorLowering.apply(parsed) + lowered.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal(str(True), "bool")), From bb880dd9de9de9aecfc8827f2c1ef0bdbc6558c0 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 14 Dec 2023 13:37:55 +0100 Subject: [PATCH 23/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- .../ffront_tests/test_foast_to_itir.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 0f9dbe58da..02f7321c6c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -59,7 +59,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None assert lowered.id == "copy_field" assert lowered.expr == im.ref("inp") @@ -71,7 +71,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], parsed = FieldOperatorParser.apply_to_function(scalar_arg) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("multiplies")( "alpha", "bar" @@ -86,7 +86,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] parsed = FieldOperatorParser.apply_to_function(multicopy) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") @@ -99,7 +99,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 parsed = FieldOperatorParser.apply_to_function(arithmetic) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") @@ -114,7 +114,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): parsed = FieldOperatorParser.apply_to_function(shift_by_one) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") @@ -129,7 +129,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): parsed = FieldOperatorParser.apply_to_function(shift_by_one) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") @@ -145,7 +145,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp" @@ -172,7 +172,7 @@ def unary(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(unary) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), @@ -202,7 +202,7 @@ def unpacking( parsed = FieldOperatorParser.apply_to_function(unpacking) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0") @@ -232,7 +232,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(copy_field) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) @@ -257,7 +257,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: parsed = FieldOperatorParser.apply_to_function(call) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp") @@ -273,7 +273,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): parsed = FieldOperatorParser.apply_to_function(temp_tuple) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b") reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) @@ -287,7 +287,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(unary_not) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("not_")("cond") @@ -300,7 +300,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(plus) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("plus")("a", "b") @@ -313,7 +313,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6 parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("plus")( im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" @@ -329,7 +329,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3 parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.let( ssa.unique_name("tmp", 0), @@ -348,7 +348,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(mult) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("multiplies")("a", "b") @@ -361,7 +361,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(minus) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("minus")("a", "b") @@ -374,7 +374,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(division) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("divides")("a", "b") @@ -387,7 +387,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(bit_and) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("and_")("a", "b") @@ -400,7 +400,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: parsed = FieldOperatorParser.apply_to_function(scalar_and) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("and_")( "a", im.promote_to_const_iterator(im.literal("False", "bool")) @@ -415,7 +415,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): parsed = FieldOperatorParser.apply_to_function(bit_or) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("or_")("a", "b") @@ -428,7 +428,7 @@ def comp_scalars() -> bool: parsed = FieldOperatorParser.apply_to_function(comp_scalars) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("greater")( im.promote_to_const_iterator(im.literal("3", "int32")), @@ -444,7 +444,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(comp_gt) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("greater")("a", "b") @@ -457,7 +457,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(comp_lt) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("less")("a", "b") @@ -470,7 +470,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): parsed = FieldOperatorParser.apply_to_function(comp_eq) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("eq")("a", "b") @@ -485,7 +485,7 @@ def compare_chain( parsed = FieldOperatorParser.apply_to_function(compare_chain) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("and_")( im.promote_to_lifted_stencil("greater")("a", "b"), @@ -501,7 +501,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): parsed = FieldOperatorParser.apply_to_function(reduction) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil( im.call( @@ -524,7 +524,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl parsed = FieldOperatorParser.apply_to_function(reduction) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( im.promote_to_lifted_stencil("make_const_list")( @@ -568,7 +568,7 @@ def int_constrs() -> ( parsed = FieldOperatorParser.apply_to_function(int_constrs) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("1", "int32")), @@ -605,7 +605,7 @@ def float_constrs() -> ( parsed = FieldOperatorParser.apply_to_function(float_constrs) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("0.1", "float64")), @@ -626,7 +626,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: parsed = FieldOperatorParser.apply_to_function(bool_constrs) lowered = FieldOperatorLowering.apply(parsed) - lowered.location = None + lowered.expr.location = None reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal(str(True), "bool")), From e0a254f65a1e7806ab6f0997933173fa34dc25a3 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 14 Dec 2023 14:48:21 +0100 Subject: [PATCH 24/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors [WIP] --- src/gt4py/next/ffront/foast_to_itir.py | 6 +- .../runners/dace_iterator/__init__.py | 13 +-- .../ffront_tests/test_foast_to_itir.py | 96 +++++++------------ 3 files changed, 43 insertions(+), 72 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 9c06c57821..23dd6ab306 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -66,10 +66,11 @@ class FieldOperatorLowering(NodeTranslator): """ uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) + preserve_location: bool = True @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) + def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr: + return cls(preserve_location=preserve_location).visit(node) def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any: result = super().visit(node, **kwargs) @@ -77,6 +78,7 @@ def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> exten hasattr(node, "location") and hasattr(result, "location") and not isinstance(node, foast.Name) + and self.preserve_location ): result.location = node.location return result diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fa9524955f..bea451afcc 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings +from inspect import currentframe, getframeinfo from typing import Any, Mapping, Optional, Sequence import dace @@ -221,15 +222,15 @@ def build_sdfg_from_itir( sdfg = sdfg_genenerator.visit(program) for nested_sdfg in sdfg.all_sdfgs_recursive(): if not nested_sdfg.debuginfo: - warnings.warn( + _, frameinfo = warnings.warn( f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), getframeinfo( + currentframe() # type: ignore ) nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=0, - start_column=0, - end_line=-1, - end_column=0, - filename=None, + start_line=frameinfo.lineno, + end_line=frameinfo.lineno, + filename=frameinfo.filename, ) sdfg.simplify() diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 02f7321c6c..d264dc37c6 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -58,8 +58,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return inp parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) assert lowered.id == "copy_field" assert lowered.expr == im.ref("inp") @@ -70,8 +69,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], return alpha * bar parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("multiplies")( "alpha", "bar" @@ -85,8 +83,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] return inp1, inp2 parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") @@ -98,8 +95,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 return inp1 + inp2 parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") @@ -113,8 +109,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): return inp(Ioff[1]) parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") @@ -128,8 +123,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): return inp(Ioff[-1]) parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") @@ -144,8 +138,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return tmp2 parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp" @@ -171,8 +164,7 @@ def unary(inp: gtx.Field[[TDim], float64]): return tmp parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), @@ -201,8 +193,7 @@ def unpacking( return tmp1 parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0") @@ -231,8 +222,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return tmp parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) @@ -256,8 +246,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: return identity(inp) parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp") @@ -272,8 +261,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): return tmp parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b") reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) @@ -286,8 +274,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]): return not cond parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("not_")("cond") @@ -299,8 +286,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a + b parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("plus")("a", "b") @@ -312,8 +298,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6 return 2.0 + a parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("plus")( im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" @@ -328,8 +313,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3 return a + tmp parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.let( ssa.unique_name("tmp", 0), @@ -347,8 +331,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a * b parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("multiplies")("a", "b") @@ -360,8 +343,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a - b parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("minus")("a", "b") @@ -373,8 +355,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a / b parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("divides")("a", "b") @@ -386,8 +367,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a & b parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("and_")("a", "b") @@ -399,8 +379,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: return a & False parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("and_")( "a", im.promote_to_const_iterator(im.literal("False", "bool")) @@ -414,8 +393,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a | b parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("or_")("a", "b") @@ -427,8 +405,7 @@ def comp_scalars() -> bool: return 3 > 4 parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("greater")( im.promote_to_const_iterator(im.literal("3", "int32")), @@ -443,8 +420,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a > b parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("greater")("a", "b") @@ -456,8 +432,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a < b parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("less")("a", "b") @@ -469,8 +444,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): return a == b parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("eq")("a", "b") @@ -484,8 +458,7 @@ def compare_chain( return a > b > c parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("and_")( im.promote_to_lifted_stencil("greater")("a", "b"), @@ -500,8 +473,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): return neighbor_sum(edge_f(V2E), axis=V2EDim) parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil( im.call( @@ -523,8 +495,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( im.promote_to_lifted_stencil("make_const_list")( @@ -567,8 +538,7 @@ def int_constrs() -> ( return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("1", "int32")), @@ -604,8 +574,7 @@ def float_constrs() -> ( ) parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("0.1", "float64")), @@ -625,8 +594,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - lowered.expr.location = None + lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal(str(True), "bool")), From 50f96a8d2f3cce90f1d6f9f52dc7010e4523f6ac Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 4 Jan 2024 17:04:39 +0100 Subject: [PATCH 25/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- .../program_processors/runners/dace_iterator/utility.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 5a5cc2297a..971c1bbdf2 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -27,16 +27,14 @@ def dace_debuginfo( ) -> Optional[dace.dtypes.DebugInfo]: location = node.location if location: - di = dace.dtypes.DebugInfo( + return dace.dtypes.DebugInfo( start_line=location.line, start_column=location.column if location.column else 0, end_line=location.end_line if location.end_line else -1, end_column=location.end_column if location.end_column else 0, filename=location.filename, ) - else: - di = debuginfo - return di + return debuginfo def as_dace_type(type_: ts.ScalarType): From 9c9c8aebc12804bbaa2b5cf32c57f7d954681e74 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 13:48:56 +0100 Subject: [PATCH 26/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/__init__.py | 3 ++- src/gt4py/eve/traits.py | 8 ------- src/gt4py/eve/visitors.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 7 +----- src/gt4py/next/ffront/past_to_itir.py | 15 +++++------- .../iterator/transforms/collapse_list_get.py | 4 ++-- .../iterator/transforms/collapse_tuple.py | 4 ++-- .../iterator/transforms/constant_folding.py | 4 ++-- src/gt4py/next/iterator/transforms/cse.py | 18 ++++---------- .../next/iterator/transforms/eta_reduction.py | 4 ++-- .../next/iterator/transforms/fuse_maps.py | 6 ++--- .../next/iterator/transforms/global_tmps.py | 4 ++-- .../iterator/transforms/inline_fundefs.py | 6 ++--- .../iterator/transforms/inline_into_scan.py | 6 ++--- .../iterator/transforms/inline_lambdas.py | 4 ++-- .../next/iterator/transforms/inline_lifts.py | 24 +++++++++---------- .../next/iterator/transforms/merge_let.py | 4 ++-- .../iterator/transforms/normalize_shifts.py | 4 ++-- .../iterator/transforms/propagate_deref.py | 4 ++-- .../transforms/prune_closure_inputs.py | 4 ++-- .../next/iterator/transforms/remap_symbols.py | 6 ++--- .../iterator/transforms/scan_eta_reduction.py | 4 ++-- .../iterator/transforms/symbol_ref_utils.py | 4 ++-- .../next/iterator/transforms/trace_shifts.py | 4 ++-- .../next/iterator/transforms/unroll_reduce.py | 4 ++-- 25 files changed, 64 insertions(+), 93 deletions(-) diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 617a889e28..2dd5183b74 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -70,7 +70,7 @@ walk_values, ) from .type_definitions import NOTHING, ConstrainedStr, Enum, IntEnum, NothingType, StrEnum -from .visitors import NodeTranslator, NodeVisitor +from .visitors import NodeTranslator, NodeVisitor, PreserveLocationVisitor __all__ = [ @@ -132,4 +132,5 @@ # visitors "NodeTranslator", "NodeVisitor", + "PreserveLocationVisitor", ] diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index fb3efd6412..df556c9d7f 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,11 +172,3 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result - - -class PreserveLocationWithSymbolTableTrait(VisitorWithSymbolTableTrait): - def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: - result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location"): - result.location = node.location - return result diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 769576ed80..c3b9f3abf3 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -198,7 +198,7 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: return copy.deepcopy(node, memo=memo) -class PreserveLocation(NodeVisitor): +class PreserveLocationVisitor(NodeVisitor): def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) if hasattr(node, "location") and hasattr(result, "location"): diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 23dd6ab306..ba771ed5b5 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -74,12 +74,7 @@ def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir. def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any: result = super().visit(node, **kwargs) - if ( - hasattr(node, "location") - and hasattr(result, "location") - and not isinstance(node, foast.Name) - and self.preserve_location - ): + if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: result.location = node.location return result diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index d53383683a..50ddb0401b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -16,7 +16,7 @@ from typing import Optional, cast -from gt4py.eve import NodeTranslator, concepts, traits +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, concepts, traits from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir @@ -40,9 +40,7 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering( - traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator -): +class ProgramLowering(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Lower Program AST (PAST) to Iterator IR (ITIR). @@ -256,7 +254,9 @@ def _construct_itir_domain_arg( raise AssertionError() return itir.FunCall( - fun=itir.SymRef(id=domain_builtin), args=domain_args, location=out_field.location + fun=itir.SymRef(id=domain_builtin), + args=domain_args, + location=(node_domain or out_field).location, ) def _construct_itir_initialized_domain_arg( @@ -273,10 +273,7 @@ def _construct_itir_initialized_domain_arg( f"expected '{dim}', got '{keys_dims_types}'." ) - itir_node = [self.visit(bound) for bound in node_domain.values_[dim_i].elts] - for i, bound in enumerate(node_domain.values_[dim_i].elts): - itir_node[i].location = bound.location - return itir_node + return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] @staticmethod def _compute_field_slice(node: past.Subscript): diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 79d7cf6f5f..58f047e9b0 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py import eve -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class CollapseListGet(PreserveLocation, eve.NodeTranslator): +class CollapseListGet(PreserveLocationVisitor, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 038b22a762..247b442853 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -15,7 +15,7 @@ from typing import Optional from gt4py import eve -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference @@ -49,7 +49,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(PreserveLocation, eve.NodeTranslator): +class CollapseTuple(PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index aa5fbf9eb3..6c70f7013e 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -13,12 +13,12 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im -class ConstantFolding(PreserveLocation, NodeTranslator): +class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 16c4af23c7..460f2cdbc3 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,21 +17,15 @@ import operator import typing -from gt4py.eve import ( - NodeTranslator, - NodeVisitor, - SymbolTableTrait, - VisitorWithSymbolTableTrait, - traits, -) +from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @dataclasses.dataclass -class _NodeReplacer(PreserveLocation, NodeTranslator): +class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) expr_map: dict[int, ir.SymRef] @@ -79,9 +73,7 @@ def _is_collectable_expr(node: ir.Node) -> bool: @dataclasses.dataclass -class CollectSubexpressions( - traits.PreserveLocationWithSymbolTableTrait, VisitorWithSymbolTableTrait, NodeVisitor -): +class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor): @dataclasses.dataclass class SubexpressionData: #: A list of node ids with equal hash and a set of collected child subexpression ids @@ -350,7 +342,7 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) -class CommonSubexpressionElimination(PreserveLocation, NodeTranslator): +class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): """ Perform common subexpression elimination. diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 23a55c27af..c146538554 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class EtaReduction(PreserveLocation, NodeTranslator): +class EtaReduction(PreserveLocationVisitor, NodeTranslator): """Eta reduction: simplifies `λ(args...) → f(args...)` to `f`.""" def visit_Lambda(self, node: ir.Lambda) -> ir.Node: diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index e132bb5012..2afa417743 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -15,7 +15,7 @@ import dataclasses from typing import TypeGuard -from gt4py.eve import NodeTranslator, traits +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import inline_lambdas @@ -38,9 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps( - traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator -): +class FuseMaps(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Fuses nested `map_`s. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index f0afbcb68d..2609e35735 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,7 +22,7 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -570,7 +570,7 @@ def convert_type(dtype): # TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be # tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore # and hence also not extract as a temporary. -class CreateGlobalTmps(PreserveLocation, NodeTranslator): +class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): """Main entry point for introducing global temporaries. Transforms an existing iterator IR fencil into a fencil with global temporaries. diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index e03d4b7ed9..c0176202ea 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -15,11 +15,11 @@ from typing import Any, Dict, Set from gt4py.eve import NOTHING, NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class InlineFundefs(PreserveLocation, NodeTranslator): +class InlineFundefs(PreserveLocationVisitor, NodeTranslator): def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): return ir.Lambda( @@ -32,7 +32,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition): return self.generic_visit(node, symtable=node.annex.symtable) -class PruneUnreferencedFundefs(PreserveLocation, NodeTranslator): +class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator): def visit_FunctionDefinition( self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index d73a09c6e5..6c9fb52d2a 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -15,7 +15,7 @@ from typing import Sequence, TypeGuard from gt4py import eve -from gt4py.eve import NodeTranslator, traits +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -53,9 +53,7 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan( - traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator -): +class InlineIntoScan(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Inline non-SymRef arguments into the scan. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index f5838764c9..a9f00afa5b 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -16,7 +16,7 @@ from typing import Optional from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols @@ -123,7 +123,7 @@ def new_name(name): @dataclasses.dataclass -class InlineLambdas(PreserveLocation, NodeTranslator): +class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" PRESERVED_ANNEX_ATTRS = ("type",) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index ac7aa4ed73..97d4c01244 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -18,7 +18,7 @@ from typing import Optional import gt4py.eve as eve -from gt4py.eve import NodeTranslator, traits +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -40,10 +40,10 @@ def _generate_unique_symbol( else: desired_name = f"__arg{arg_idx}" - new_symbol = ir.Sym(id=desired_name) + new_symbol = desired_name # make unique while new_symbol.id in occupied_names or new_symbol in occupied_symbols: - new_symbol = ir.Sym(id=new_symbol.id + "_") + new_symbol = new_symbol + "_" return new_symbol @@ -73,7 +73,7 @@ def _is_scan(node: ir.FunCall): def _transform_and_extract_lift_args( node: ir.FunCall, symtable: dict[eve.SymbolName, ir.Sym], - extracted_args: dict[ir.Sym, ir.Expr], + extracted_args: dict[eve.SymbolName, ir.Expr], ): """ Transform and extract non-symbol arguments of a lifted stencil call. @@ -89,8 +89,8 @@ def _transform_and_extract_lift_args( new_args = [] for i, arg in enumerate(node.args): if isinstance(arg, ir.SymRef): - sym = ir.Sym(id=arg.id) - assert sym not in extracted_args or extracted_args[sym] == arg + sym = arg.id + assert sym not in extracted_args or extracted_args[sym].id == arg.id extracted_args[sym] = arg new_args.append(arg) else: @@ -101,7 +101,7 @@ def _transform_and_extract_lift_args( ) assert new_symbol not in extracted_args extracted_args[new_symbol] = arg - new_args.append(ir.SymRef(id=new_symbol.id)) + new_args.append(ir.SymRef(id=new_symbol)) itir_node = im.lift(inner_stencil)(*new_args) itir_node.location = node.location @@ -112,9 +112,7 @@ def _transform_and_extract_lift_args( # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts( - traits.PreserveLocationWithSymbolTableTrait, traits.VisitorWithSymbolTableTrait, NodeTranslator -): +class InlineLifts(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific @@ -228,7 +226,7 @@ def visit_FunCall( # TODO(tehrengruber): we currently only inlining opcount preserving, but what we # actually want is to inline whenever the argument is not shifted. This is # currently beyond the capabilities of the inliner and the shift tracer. - new_arg_exprs: dict[ir.Sym, ir.Expr] = {} + new_arg_exprs: dict[eve.SymbolName, ir.Expr] = {} inlined_args = [] for i, (arg, eligible) in enumerate(zip(node.args, eligible_lifted_args)): if eligible: @@ -239,7 +237,7 @@ def visit_FunCall( inlined_args.append(inlined_arg) else: if isinstance(arg, ir.SymRef): - new_arg_sym = ir.Sym(id=arg.id) + new_arg_sym = arg.id else: new_arg_sym = _generate_unique_symbol( desired_name=(stencil, i), @@ -248,7 +246,7 @@ def visit_FunCall( ) new_arg_exprs[new_arg_sym] = arg - inlined_args.append(ir.SymRef(id=new_arg_sym.id)) + inlined_args.append(ir.SymRef(id=new_arg_sym)) inlined_call = self.visit( inline_lambda( diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 140d6574fa..5b96e9cbf5 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -13,12 +13,12 @@ # SPDX-License-Identifier: GPL-3.0-or-later import gt4py.eve as eve -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(PreserveLocation, eve.NodeTranslator): +class MergeLet(PreserveLocationVisitor, eve.NodeTranslator): """ Merge let-like statements. diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index 5545ad3231..d9bb96a81d 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class NormalizeShifts(PreserveLocation, NodeTranslator): +class NormalizeShifts(PreserveLocationVisitor, NodeTranslator): def visit_FunCall(self, node: ir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 21d3e333be..c3d338f926 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -14,7 +14,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir @@ -23,7 +23,7 @@ # `(λ(...) → plus(multiplies(...), ...))(...)`. -class PropagateDeref(PreserveLocation, NodeTranslator): +class PropagateDeref(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node): """ diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index c67f7e4476..0dd77768d5 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class PruneClosureInputs(PreserveLocation, NodeTranslator): +class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): """Removes all unused input arguments from a stencil closure.""" def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 2a12d4dde0..8d810cf624 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -15,11 +15,11 @@ from typing import Any, Dict, Optional, Set from gt4py.eve import NodeTranslator, SymbolTableTrait -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class RemapSymbolRefs(PreserveLocation, NodeTranslator): +class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): @@ -40,7 +40,7 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] return super().generic_visit(node, **kwargs) -class RenameSymbols(PreserveLocation, NodeTranslator): +class RenameSymbols(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_Sym( diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 5684a0afb4..93ab4b52db 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir @@ -25,7 +25,7 @@ def _is_scan(node: ir.Node): ) -class ScanEtaReduction(PreserveLocation, NodeTranslator): +class ScanEtaReduction(PreserveLocationVisitor, NodeTranslator): """Applies eta-reduction-like transformation involving scans. Simplifies `λ(x, y) → scan(λ(state, param_y, param_x) → ..., ...)(y, x)` to `scan(λ(state, param_x, param_y) → ..., ...)`. diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 679becd737..20b2650de5 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -17,12 +17,12 @@ from typing import Iterable, Optional, Sequence import gt4py.eve as eve -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir as itir @dataclasses.dataclass -class CountSymbolRefs(PreserveLocation, eve.NodeVisitor): +class CountSymbolRefs(PreserveLocationVisitor, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 5c7729e3fb..29428c3cf9 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -17,7 +17,7 @@ from typing import Any, Final, Iterable, Literal from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir @@ -236,7 +236,7 @@ def _tuple_get(index, tuple_val): @dataclasses.dataclass(frozen=True) -class TraceShifts(PreserveLocation, NodeTranslator): +class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 811c0179fb..4f66252298 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -18,7 +18,7 @@ from gt4py.eve import NodeTranslator from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocation +from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -130,7 +130,7 @@ def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: @dataclasses.dataclass(frozen=True) -class UnrollReduce(PreserveLocation, NodeTranslator): +class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) From 6fb28a104ad36138c4320b9957aef819266478b1 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 14:33:41 +0100 Subject: [PATCH 27/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/visitors.py | 8 +++++++- src/gt4py/next/ffront/foast_to_itir.py | 10 ++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index c3b9f3abf3..c0a0054f5a 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -199,8 +199,14 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: class PreserveLocationVisitor(NodeVisitor): + preserve_location: bool = True + + def __init__(self, preserve_location: bool = True) -> None: + super().__init__() + self.preserve_location = preserve_location + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location"): + if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: result.location = node.location return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index ba771ed5b5..4a88553532 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,7 +15,7 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator, concepts, extended_typing +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next.ffront import ( dialect_ast_enums, @@ -39,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -72,12 +72,6 @@ class FieldOperatorLowering(NodeTranslator): def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr: return cls(preserve_location=preserve_location).visit(node) - def visit(self, node: concepts.RootNode, **kwargs: extended_typing.Any) -> extended_typing.Any: - result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: - result.location = node.location - return result - def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs ) -> itir.FunctionDefinition: From 3f4e9d11fe4f167128c8b9acd201164510728fa0 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 16:11:21 +0100 Subject: [PATCH 28/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/next/iterator/ir.py | 2 +- .../next/iterator/transforms/inline_lifts.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index fb6199d7f1..d363151500 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -24,7 +24,7 @@ @noninstantiable class Node(eve.Node): - location: Optional[SourceLocation] = None + location: Optional[SourceLocation] = eve.field(default=None, repr=False) def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 97d4c01244..07d09f16cf 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -40,10 +40,10 @@ def _generate_unique_symbol( else: desired_name = f"__arg{arg_idx}" - new_symbol = desired_name + new_symbol = ir.Sym(id=desired_name) # make unique while new_symbol.id in occupied_names or new_symbol in occupied_symbols: - new_symbol = new_symbol + "_" + new_symbol = ir.Sym(id=new_symbol.id + "_") return new_symbol @@ -73,7 +73,7 @@ def _is_scan(node: ir.FunCall): def _transform_and_extract_lift_args( node: ir.FunCall, symtable: dict[eve.SymbolName, ir.Sym], - extracted_args: dict[eve.SymbolName, ir.Expr], + extracted_args: dict[ir.Sym, ir.Expr], ): """ Transform and extract non-symbol arguments of a lifted stencil call. @@ -89,8 +89,8 @@ def _transform_and_extract_lift_args( new_args = [] for i, arg in enumerate(node.args): if isinstance(arg, ir.SymRef): - sym = arg.id - assert sym not in extracted_args or extracted_args[sym].id == arg.id + sym = ir.Sym(id=arg.id) + assert sym not in extracted_args or extracted_args[sym] == arg extracted_args[sym] = arg new_args.append(arg) else: @@ -101,7 +101,7 @@ def _transform_and_extract_lift_args( ) assert new_symbol not in extracted_args extracted_args[new_symbol] = arg - new_args.append(ir.SymRef(id=new_symbol)) + new_args.append(ir.SymRef(id=new_symbol.id)) itir_node = im.lift(inner_stencil)(*new_args) itir_node.location = node.location @@ -226,7 +226,7 @@ def visit_FunCall( # TODO(tehrengruber): we currently only inlining opcount preserving, but what we # actually want is to inline whenever the argument is not shifted. This is # currently beyond the capabilities of the inliner and the shift tracer. - new_arg_exprs: dict[eve.SymbolName, ir.Expr] = {} + new_arg_exprs: dict[ir.Sym, ir.Expr] = {} inlined_args = [] for i, (arg, eligible) in enumerate(zip(node.args, eligible_lifted_args)): if eligible: @@ -237,7 +237,7 @@ def visit_FunCall( inlined_args.append(inlined_arg) else: if isinstance(arg, ir.SymRef): - new_arg_sym = arg.id + new_arg_sym = ir.Sym(id=arg.id) else: new_arg_sym = _generate_unique_symbol( desired_name=(stencil, i), @@ -246,7 +246,7 @@ def visit_FunCall( ) new_arg_exprs[new_arg_sym] = arg - inlined_args.append(ir.SymRef(id=new_arg_sym)) + inlined_args.append(ir.SymRef(id=new_arg_sym.id)) inlined_call = self.visit( inline_lambda( From b29bc5f628666065ef4a93564ec0f38a76a55a12 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 16:34:56 +0100 Subject: [PATCH 29/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/next/ffront/foast_to_itir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 4a88553532..fae944a101 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -62,7 +62,7 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): >>> lowered.id SymbolName('fieldop') >>> lowered.params # doctest: +ELLIPSIS - [Sym(location=..., id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] + [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) From 1a2e9781fd1b2530addb38c5e9a6a2c6a52a6f31 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 16:35:13 +0100 Subject: [PATCH 30/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 6237b7761a..94a2646422 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -26,10 +26,10 @@ def sym(sym_or_name: Union[str, itir.Sym]) -> itir.Sym: Examples -------- >>> sym("a") - Sym(location=None, id=SymbolName('a'), kind=None, dtype=None) + Sym(id=SymbolName('a'), kind=None, dtype=None) >>> sym(itir.Sym(id="b")) - Sym(location=None, id=SymbolName('b'), kind=None, dtype=None) + Sym(id=SymbolName('b'), kind=None, dtype=None) """ if isinstance(sym_or_name, itir.Sym): return sym_or_name @@ -43,10 +43,10 @@ def ref(ref_or_name: Union[str, itir.SymRef]) -> itir.SymRef: Examples -------- >>> ref("a") - SymRef(location=None, id=SymbolRef('a')) + SymRef(id=SymbolRef('a')) >>> ref(itir.SymRef(id="b")) - SymRef(location=None, id=SymbolRef('b')) + SymRef(id=SymbolRef('b')) """ if isinstance(ref_or_name, itir.SymRef): return ref_or_name @@ -60,13 +60,13 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti Examples -------- >>> ensure_expr("a") - SymRef(location=None, id=SymbolRef('a')) + SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(location=None, value='3', type='int32') + Literal(value='3', type='int32') >>> ensure_expr(itir.OffsetLiteral(value="i")) - OffsetLiteral(location=None, value='i') + OffsetLiteral(value='i') """ if isinstance(literal_or_expr, str): return ref(literal_or_expr) @@ -83,10 +83,10 @@ def ensure_offset(str_or_offset: Union[str, int, itir.OffsetLiteral]) -> itir.Of Examples -------- >>> ensure_offset("V2E") - OffsetLiteral(location=None, value='V2E') + OffsetLiteral(value='V2E') >>> ensure_offset(itir.OffsetLiteral(value="J")) - OffsetLiteral(location=None, value='J') + OffsetLiteral(value='J') """ if isinstance(str_or_offset, (str, int)): return itir.OffsetLiteral(value=str_or_offset) @@ -100,7 +100,7 @@ class lambda_: Examples -------- >>> lambda_("a")(deref("a")) # doctest: +ELLIPSIS - Lambda(location=None, params=[Sym(location=None, id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('deref')), args=[SymRef(location=None, id=SymbolRef('a'))])) + Lambda(params=[Sym(id=SymbolName('a'), kind=None, dtype=None)], expr=FunCall(fun=SymRef(id=SymbolRef('deref')), args=[SymRef(id=SymbolRef('a'))])) """ def __init__(self, *args): @@ -117,7 +117,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('plus')), args=[Literal(location=None, value='1', type='int32'), Literal(location=None, value='1', type='int32')]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type='int32'), Literal(value='1', type='int32')]) """ def __init__(self, expr): @@ -264,10 +264,10 @@ def shift(offset, value=None): Examples -------- >>> shift("i", 0)("a") - FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='i'), OffsetLiteral(location=None, value=0)]), args=[SymRef(location=None, id=SymbolRef('a'))]) + FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='i'), OffsetLiteral(value=0)]), args=[SymRef(id=SymbolRef('a'))]) >>> shift("V2E")("b") - FunCall(location=None, fun=FunCall(location=None, fun=SymRef(location=None, id=SymbolRef('shift')), args=[OffsetLiteral(location=None, value='V2E')]), args=[SymRef(location=None, id=SymbolRef('b'))]) + FunCall(fun=FunCall(fun=SymRef(id=SymbolRef('shift')), args=[OffsetLiteral(value='V2E')]), args=[SymRef(id=SymbolRef('b'))]) """ offset = ensure_offset(offset) args = [offset] @@ -286,13 +286,13 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.) - Literal(location=None, value='1.0', type='float64') + Literal(value='1.0', type='float64') >>> literal_from_value(1) - Literal(location=None, value='1', type='int32') + Literal(value='1', type='int32') >>> literal_from_value(2147483648) - Literal(location=None, value='2147483648', type='int64') + Literal(value='2147483648', type='int64') >>> literal_from_value(True) - Literal(location=None, value='True', type='bool') + Literal(value='True', type='bool') """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") From bf338274361f1ccb34285ba201c2157d6fd68de6 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Fri, 12 Jan 2024 17:06:37 +0100 Subject: [PATCH 31/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/next/iterator/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index d363151500..37abbec9e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -24,7 +24,7 @@ @noninstantiable class Node(eve.Node): - location: Optional[SourceLocation] = eve.field(default=None, repr=False) + location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat From 371b1da161950e185c24491c8ef2315f38e2b0f8 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 23 Jan 2024 17:38:40 +0100 Subject: [PATCH 32/32] Add more debug info to DaCe (pass SourceLocation from past/foast to itir, and from itir to the SDFG): Preserve Location through Visitors --- src/gt4py/eve/__init__.py | 11 +++- src/gt4py/eve/traits.py | 8 +++ src/gt4py/eve/visitors.py | 14 ---- src/gt4py/next/ffront/foast_to_itir.py | 5 +- src/gt4py/next/ffront/past_to_itir.py | 6 +- .../iterator/transforms/collapse_list_get.py | 3 +- .../iterator/transforms/collapse_tuple.py | 3 +- .../iterator/transforms/constant_folding.py | 3 +- src/gt4py/next/iterator/transforms/cse.py | 9 ++- .../next/iterator/transforms/eta_reduction.py | 3 +- .../next/iterator/transforms/fuse_maps.py | 4 +- .../next/iterator/transforms/global_tmps.py | 3 +- .../iterator/transforms/inline_fundefs.py | 3 +- .../iterator/transforms/inline_into_scan.py | 6 +- .../iterator/transforms/inline_lambdas.py | 3 +- .../next/iterator/transforms/inline_lifts.py | 6 +- .../next/iterator/transforms/merge_let.py | 3 +- .../iterator/transforms/normalize_shifts.py | 3 +- .../iterator/transforms/propagate_deref.py | 3 +- .../transforms/prune_closure_inputs.py | 3 +- .../next/iterator/transforms/remap_symbols.py | 3 +- .../iterator/transforms/scan_eta_reduction.py | 3 +- .../iterator/transforms/symbol_ref_utils.py | 3 +- .../next/iterator/transforms/trace_shifts.py | 3 +- .../next/iterator/transforms/unroll_reduce.py | 3 +- .../ffront_tests/test_foast_to_itir.py | 64 +++++++++---------- 26 files changed, 87 insertions(+), 94 deletions(-) diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 2dd5183b74..e726db1f1a 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -58,7 +58,12 @@ field, frozenmodel, ) -from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait +from .traits import ( + PreserveLocationVisitor, + SymbolTableTrait, + ValidatedSymbolTableTrait, + VisitorWithSymbolTableTrait, +) from .trees import ( bfs_walk_items, bfs_walk_values, @@ -70,7 +75,7 @@ walk_values, ) from .type_definitions import NOTHING, ConstrainedStr, Enum, IntEnum, NothingType, StrEnum -from .visitors import NodeTranslator, NodeVisitor, PreserveLocationVisitor +from .visitors import NodeTranslator, NodeVisitor __all__ = [ @@ -113,6 +118,7 @@ "SymbolTableTrait", "ValidatedSymbolTableTrait", "VisitorWithSymbolTableTrait", + "PreserveLocationVisitor", # trees "bfs_walk_items", "bfs_walk_values", @@ -132,5 +138,4 @@ # visitors "NodeTranslator", "NodeVisitor", - "PreserveLocationVisitor", ] diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index df556c9d7f..aacae804d8 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result + + +class PreserveLocationVisitor(visitors.NodeVisitor): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index c0a0054f5a..fe5f9e1474 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -196,17 +196,3 @@ def generic_visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return copy.deepcopy(node, memo=memo) - - -class PreserveLocationVisitor(NodeVisitor): - preserve_location: bool = True - - def __init__(self, preserve_location: bool = True) -> None: - super().__init__() - self.preserve_location = preserve_location - - def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: - result = super().visit(node, **kwargs) - if hasattr(node, "location") and hasattr(result, "location") and self.preserve_location: - result.location = node.location - return result diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index fae944a101..0c9ab4ab27 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -66,11 +66,10 @@ class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - preserve_location: bool = True @classmethod - def apply(cls, node: foast.LocatedNode, preserve_location: bool = True) -> itir.Expr: - return cls(preserve_location=preserve_location).visit(node) + def apply(cls, node: foast.LocatedNode) -> itir.Expr: + return cls().visit(node) def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 50ddb0401b..ed239e0436 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -16,7 +16,7 @@ from typing import Optional, cast -from gt4py.eve import NodeTranslator, PreserveLocationVisitor, concepts, traits +from gt4py.eve import NodeTranslator, concepts, traits from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir @@ -40,7 +40,9 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Lower Program AST (PAST) to Iterator IR (ITIR). diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 58f047e9b0..6acb8a79c4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -13,11 +13,10 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py import eve -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir -class CollapseListGet(PreserveLocationVisitor, eve.NodeTranslator): +class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 247b442853..42bbf28909 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -15,7 +15,6 @@ from typing import Optional from gt4py import eve -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next import type_inference from gt4py.next.iterator import ir, type_inference as it_type_inference @@ -49,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(PreserveLocationVisitor, eve.NodeTranslator): +class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 6c70f7013e..696a87a197 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 460f2cdbc3..f9cf272c45 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,9 +17,14 @@ import operator import typing -from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait +from gt4py.eve import ( + NodeTranslator, + NodeVisitor, + PreserveLocationVisitor, + SymbolTableTrait, + VisitorWithSymbolTableTrait, +) from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index c146538554..93702a6c96 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 2afa417743..694dcd6a61 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -15,7 +15,7 @@ import dataclasses from typing import TypeGuard -from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits +from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import inline_lambdas @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Fuses nested `map_`s. diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 7ad55d0a87..c423a3c277 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -19,10 +19,9 @@ import gt4py.eve as eve import gt4py.next as gtx -from gt4py.eve import Coerced, NodeTranslator +from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index c0176202ea..a53232745f 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -14,8 +14,7 @@ from typing import Any, Dict, Set -from gt4py.eve import NOTHING, NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index 6c9fb52d2a..a1c9a2eb5b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -15,7 +15,7 @@ from typing import Sequence, TypeGuard from gt4py import eve -from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits +from gt4py.eve import NodeTranslator, traits from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import symbol_ref_utils from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Inline non-SymRef arguments into the scan. diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a9f00afa5b..0b89fe6d98 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -15,8 +15,7 @@ import dataclasses from typing import Optional -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index 07d09f16cf..d6146d9fc8 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -18,7 +18,7 @@ from typing import Optional import gt4py.eve as eve -from gt4py.eve import NodeTranslator, PreserveLocationVisitor, traits +from gt4py.eve import NodeTranslator, traits from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @@ -112,7 +112,9 @@ def _transform_and_extract_lift_args( # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 5b96e9cbf5..bcfc6b2a17 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -13,12 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later import gt4py.eve as eve -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(PreserveLocationVisitor, eve.NodeTranslator): +class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Merge let-like statements. diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index d9bb96a81d..c70dc1ccd1 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index c3d338f926..783e54ede0 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -12,9 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 0dd77768d5..1e637a0bfb 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 8d810cf624..431dd6cd7a 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -14,8 +14,7 @@ from typing import Any, Dict, Optional, Set -from gt4py.eve import NodeTranslator, SymbolTableTrait -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 93ab4b52db..d93b4242ab 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -12,8 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 20b2650de5..05d137e8c4 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -17,12 +17,11 @@ from typing import Iterable, Optional, Sequence import gt4py.eve as eve -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next.iterator import ir as itir @dataclasses.dataclass -class CountSymbolRefs(PreserveLocationVisitor, eve.NodeVisitor): +class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 29428c3cf9..082987ac96 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -16,8 +16,7 @@ from collections.abc import Callable from typing import Any, Final, Iterable, Literal -from gt4py.eve import NodeTranslator -from gt4py.eve.visitors import PreserveLocationVisitor +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 4f66252298..3c878b2b00 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -16,9 +16,8 @@ from collections.abc import Iterable, Iterator from typing import TypeGuard -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator -from gt4py.eve.visitors import PreserveLocationVisitor from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index d264dc37c6..2dd4b91c48 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -58,7 +58,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return inp parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) assert lowered.id == "copy_field" assert lowered.expr == im.ref("inp") @@ -69,7 +69,7 @@ def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], return alpha * bar parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("multiplies")( "alpha", "bar" @@ -83,7 +83,7 @@ def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64] return inp1, inp2 parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") @@ -95,7 +95,7 @@ def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64 return inp1 + inp2 parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") @@ -109,7 +109,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): return inp(Ioff[1]) parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") @@ -123,7 +123,7 @@ def shift_by_one(inp: gtx.Field[[IDim], float64]): return inp(Ioff[-1]) parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") @@ -138,7 +138,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return tmp2 parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), "inp" @@ -164,7 +164,7 @@ def unary(inp: gtx.Field[[TDim], float64]): return tmp parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.let( itir.Sym(id=ssa.unique_name("tmp", 0), dtype=("float64", False), kind="Iterator"), @@ -193,7 +193,7 @@ def unpacking( return tmp1 parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) tuple_expr = im.promote_to_lifted_stencil("make_tuple")("inp1", "inp2") tuple_access_0 = im.promote_to_lifted_stencil(lambda x: im.tuple_get(0, x))("__tuple_tmp_0") @@ -222,7 +222,7 @@ def copy_field(inp: gtx.Field[[TDim], float64]): return tmp parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) @@ -246,7 +246,7 @@ def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: return identity(inp) parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.lift(im.lambda_("__arg0")(im.call("identity")("__arg0")))("inp") @@ -261,7 +261,7 @@ def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): return tmp parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) tuple_expr = im.promote_to_lifted_stencil("make_tuple")("a", "b") reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) @@ -274,7 +274,7 @@ def unary_not(cond: gtx.Field[[TDim], "bool"]): return not cond parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("not_")("cond") @@ -286,7 +286,7 @@ def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a + b parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("plus")("a", "b") @@ -298,7 +298,7 @@ def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float6 return 2.0 + a parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("plus")( im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" @@ -313,7 +313,7 @@ def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int3 return a + tmp parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.let( ssa.unique_name("tmp", 0), @@ -331,7 +331,7 @@ def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a * b parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("multiplies")("a", "b") @@ -343,7 +343,7 @@ def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a - b parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("minus")("a", "b") @@ -355,7 +355,7 @@ def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a / b parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("divides")("a", "b") @@ -367,7 +367,7 @@ def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a & b parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("and_")("a", "b") @@ -379,7 +379,7 @@ def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: return a & False parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("and_")( "a", im.promote_to_const_iterator(im.literal("False", "bool")) @@ -393,7 +393,7 @@ def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a | b parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("or_")("a", "b") @@ -405,7 +405,7 @@ def comp_scalars() -> bool: return 3 > 4 parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("greater")( im.promote_to_const_iterator(im.literal("3", "int32")), @@ -420,7 +420,7 @@ def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a > b parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("greater")("a", "b") @@ -432,7 +432,7 @@ def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): return a < b parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("less")("a", "b") @@ -444,7 +444,7 @@ def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): return a == b parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("eq")("a", "b") @@ -458,7 +458,7 @@ def compare_chain( return a > b > c parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("and_")( im.promote_to_lifted_stencil("greater")("a", "b"), @@ -473,7 +473,7 @@ def reduction(edge_f: gtx.Field[[Edge], float64]): return neighbor_sum(edge_f(V2E), axis=V2EDim) parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil( im.call( @@ -495,7 +495,7 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( im.promote_to_lifted_stencil("make_const_list")( @@ -538,7 +538,7 @@ def int_constrs() -> ( return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("1", "int32")), @@ -574,7 +574,7 @@ def float_constrs() -> ( ) parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal("0.1", "float64")), @@ -594,7 +594,7 @@ def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed, preserve_location=False) + lowered = FieldOperatorLowering.apply(parsed) reference = im.promote_to_lifted_stencil("make_tuple")( im.promote_to_const_iterator(im.literal(str(True), "bool")),