Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: Add more debug info to DaCe #1384

Merged
merged 38 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
caebdff
Add more debug info to DaCe
kotsaloscv Nov 7, 2023
f14d091
Merge branch 'main' into more_debinfo
kotsaloscv Nov 27, 2023
bebb122
Add more debug info to DaCe
kotsaloscv Nov 27, 2023
7eeaddb
Add more debug info to DaCe : WIP
kotsaloscv Nov 27, 2023
7447807
Add more debug info to DaCe : WIP
kotsaloscv Nov 29, 2023
6d89149
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 29, 2023
16bc489
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Nov 30, 2023
0ed8090
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
3e38756
merge main
kotsaloscv Dec 4, 2023
0b1fe1a
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
55abb29
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 4, 2023
54dd79d
merge main
kotsaloscv Dec 5, 2023
774b2f5
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
1622866
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
7baedb8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 5, 2023
223de4e
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
6c636ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
93fcb14
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
d6da4ab
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
8460c67
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
5632def
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
b59fd83
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
68dde06
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 13, 2023
a1a91c4
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
1ed9764
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
371dc36
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
bb880dd
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
e0a254f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Dec 14, 2023
ea2f672
merge main
kotsaloscv Jan 4, 2024
50f96a8
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 4, 2024
9c9c8ae
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
6fb28a1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
3f4e9d1
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
b29bc5f
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
1a2e978
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
bf33827
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 12, 2024
616a833
merge main
kotsaloscv Jan 23, 2024
371b1da
Add more debug info to DaCe (pass SourceLocation from past/foast to i…
kotsaloscv Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 80 additions & 29 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def visit_FunctionDefinition(
id=node.id,
params=params,
expr=self.visit_BlockStmt(node.body, inner_expr=None),
location=node.location,
) # `expr` is a lifted stencil

def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.FunctionDefinition:
Expand All @@ -89,6 +90,7 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> itir.Funct
id=func_definition.id,
params=func_definition.params,
expr=new_body,
location=node.location,
)

def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.FunctionDefinition:
Expand All @@ -111,7 +113,9 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio
func_definition.params[0].id,
im.promote_to_const_iterator(func_definition.params[0].id),
)(im.deref(new_body))
definition = itir.Lambda(params=func_definition.params, expr=new_body)
definition = itir.Lambda(
params=func_definition.params, expr=new_body, location=node.location
)
body = im.call(im.call("scan")(definition, forward, init))(
*(param.id for param in definition.params[1:])
)
Expand All @@ -120,6 +124,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> itir.Functio
id=node.id,
params=definition.params[1:],
expr=body,
location=node.location,
)

def visit_Stmt(self, node: foast.Stmt, **kwargs):
Expand All @@ -128,21 +133,24 @@ def visit_Stmt(self, node: foast.Stmt, **kwargs):
def visit_Return(
self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
return self.visit(node.value, **kwargs)
return_ = self.visit(node.value, **kwargs)
return_.location = node.location
return return_

def visit_BlockStmt(
self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
for stmt in reversed(node.stmts):
inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs)
assert inner_expr
inner_expr.location = node.location
return inner_expr

def visit_IfStmt(
self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
# the lowered if call doesn't need to be lifted as the condition can only originate
# from a scalar value (and not a field)
# from a scalar value (and not a field)
assert (
isinstance(node.condition.type, ts.ScalarType)
and node.condition.type.kind == ts.ScalarKind.BOOL
Expand All @@ -167,9 +175,11 @@ def visit_IfStmt(
inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr)

# here we assume neither branch returns
return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))(
return_ = im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))(
kotsaloscv marked this conversation as resolved.
Show resolved Hide resolved
kotsaloscv marked this conversation as resolved.
Show resolved Hide resolved
inner_expr
)
return_.location = node.location
return return_
elif return_kind is StmtReturnKind.CONDITIONAL_RETURN:
common_syms = tuple(im.sym(sym) for sym in common_symbols.keys())
common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys())
Expand All @@ -183,9 +193,11 @@ def visit_IfStmt(
true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs)
false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs)

return im.let(inner_expr_name, inner_expr_evaluator)(
return_ = im.let(inner_expr_name, inner_expr_evaluator)(
im.if_(im.deref(cond), true_branch, false_branch)
)
return_.location = node.location
return return_

assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN

Expand All @@ -194,66 +206,81 @@ def visit_IfStmt(
true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs)
false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs)

return im.if_(im.deref(cond), true_branch, false_branch)
return_ = im.if_(im.deref(cond), true_branch, false_branch)
return_.location = node.location
return return_

def visit_Assign(
self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs
) -> itir.Expr:
return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))(
return_ = im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))(
inner_expr
)
return_.location = node.location
return return_
kotsaloscv marked this conversation as resolved.
Show resolved Hide resolved

def visit_Symbol(self, node: foast.Symbol, **kwargs) -> itir.Sym:
# TODO(tehrengruber): extend to more types
if isinstance(node.type, ts.FieldType):
kind = "Iterator"
dtype = node.type.dtype.kind.name.lower()
is_list = type_info.is_local_field(node.type)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list))
return im.sym(node.id)
return itir.Sym(id=node.id, kind=kind, dtype=(dtype, is_list), location=node.location)
return_ = im.sym(node.id)
return_.location = node.location
return return_

def visit_Name(self, node: foast.Name, **kwargs) -> itir.SymRef:
return im.ref(node.id)

def visit_Subscript(self, node: foast.Subscript, **kwargs) -> itir.Expr:
return im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))(
return_ = im.promote_to_lifted_stencil(lambda tuple_: im.tuple_get(node.index, tuple_))(
self.visit(node.value, **kwargs)
)
return_.location = node.location
return return_

def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs) -> itir.Expr:
return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
return_ = im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*[self.visit(el, **kwargs) for el in node.elts],
)
return_.location = node.location
return return_

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"{node.op} is only supported on `bool`s.")
return self._map("not_", node.operand)
return self._map("not_", node.operand, location=node.location)

return self._map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
location=node.location,
)

def visit_BinOp(self, node: foast.BinOp, **kwargs) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._map(node.op.value, node.left, node.right, location=node.location)

def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs) -> itir.FunCall:
return self._map("if_", node.condition, node.true_expr, node.false_expr)
return self._map(
"if_", node.condition, node.true_expr, node.false_expr, location=node.location
)

def visit_Compare(self, node: foast.Compare, **kwargs) -> itir.FunCall:
return self._map(node.op.value, node.left, node.right)
return self._map(node.op.value, node.left, node.right, location=node.location)

def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
match node.args[0]:
case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)):
shift_offset = im.shift(offset_name, offset_index)
case foast.Name(id=offset_name):
return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
return_ = im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs))
return_.location = node.location
return return_
case foast.Call(func=foast.Name(id="as_offset")):
func_args = node.args[0]
offset_dim = func_args.args[0]
Expand All @@ -263,9 +290,11 @@ def _visit_shift(self, node: foast.Call, **kwargs) -> itir.Expr:
)
case _:
raise FieldOperatorLoweringError("Unexpected shift arguments!")
return im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))(
return_ = im.lift(im.lambda_("it")(im.deref(shift_offset("it"))))(
self.visit(node.func, **kwargs)
)
return_.location = node.location
return return_

def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
if type_info.type_class(node.func.type) is ts.FieldType:
Expand Down Expand Up @@ -296,11 +325,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
)
call_args = [f"__arg{i}" for i in range(len(lowered_args))]
call_kwargs = [f"__kwarg_{name}" for name in lowered_kwargs.keys()]
return im.lift(
return_ = im.lift(
im.lambda_(*call_args, *call_kwargs)(
im.call(lowered_func)(*call_args, *call_kwargs)
)
)(*lowered_args, *lowered_kwargs.values())
return_.location = node.location
return return_
elif isinstance(node.func.type, ts.FunctionType):
# ITIR has no support for keyword arguments. Instead, we concatenate both positional
# and keyword arguments and use the unique order as given in the function signature.
Expand All @@ -310,7 +341,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
self.visit(node.kwargs, **kwargs),
use_signature_ordering=True,
)
return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values())
return_ = im.call(self.visit(node.func, **kwargs))(
*lowered_args, *lowered_kwargs.values()
)
return_.location = node.location
return return_

raise AssertionError(
f"Call to object of type {type(node.func.type).__name__} not understood."
Expand All @@ -319,18 +354,22 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:
def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, new_type = node.args[0], node.args[1].id
return self._process_elements(
return_ = self._process_elements(
lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs
)
return_.location = node.location
return return_

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
return self._map("if_", *node.args)
return self._map("if_", *node.args, location=node.location)

def _visit_broadcast(self, node: foast.Call, **kwargs) -> itir.FunCall:
return self.visit(node.args[0], **kwargs)
return_ = self.visit(node.args[0], **kwargs)
return_.location = node.location
return return_

def _visit_math_built_in(self, node: foast.Call, **kwargs) -> itir.FunCall:
return self._map(self.visit(node.func, **kwargs), *node.args)
return self._map(self.visit(node.func, **kwargs), *node.args, location=node.location)

def _make_reduction_expr(
self,
Expand All @@ -343,7 +382,9 @@ def _make_reduction_expr(
it = self.visit(node.args[0], **kwargs)
assert isinstance(node.kwargs["axis"].type, ts.DimensionType)
val = im.call(im.call("reduce")(op, im.deref(init_expr)))
return im.promote_to_lifted_stencil(val)(it)
return_ = im.promote_to_lifted_stencil(val)(it)
return_.location = node.location
return return_

def _visit_neighbor_sum(self, node: foast.Call, **kwargs) -> itir.FunCall:
dtype = type_info.extract_dtype(node.type)
Expand All @@ -367,10 +408,16 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr:
target_type = fbuiltins.BUILTINS[node_kind]
source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()]
if target_type is bool and source_type is not bool:
return im.promote_to_const_iterator(
return_ = im.promote_to_const_iterator(
im.literal(str(bool(source_type(node.args[0].value))), "bool")
)
return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind))
return_.location = node.location
return return_
return_ = im.promote_to_const_iterator(
im.literal(str(bool(source_type(node.args[0].value))), "bool")
)
return_.location = node.location
return return_
raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}")

def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
Expand All @@ -391,15 +438,19 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr:
raise ValueError(f"Unsupported literal type {type_}.")

def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr:
return self._make_literal(node.value, node.type)
return_ = self._make_literal(node.value, node.type)
return_.location = node.location
return return_

def _map(self, op, *args, **kwargs):
def _map(self, op, *args, location=None, **kwargs):
lowered_args = [self.visit(arg, **kwargs) for arg in args]
if any(type_info.contains_local_field(arg.type) for arg in args):
lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)]
op = im.call("map_")(op)

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)
return_ = im.promote_to_lifted_stencil(im.call(op))(*lowered_args)
return_.location = location
return return_

def _process_elements(
self,
Expand Down
Loading
Loading