Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Expression type info #590

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
self.meta: dict[str, str] = {}
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())

# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
# subclasses, Adding it here, this needs a review.
self.expr_type: str = ""

@property
def sym_tab(self) -> SymbolTable:
"""Get symbol table."""
Expand Down
210 changes: 114 additions & 96 deletions jaclang/compiler/passes/main/fuse_typeinfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

from __future__ import annotations

from typing import Callable, TypeVar
from typing import Callable, Optional, TypeVar

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.settings import settings
from jaclang.utils.helpers import pascal_to_snake
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack


import mypy.nodes as MypyNodes # noqa N812
import mypy.types as MypyTypes # noqa N812
from mypy.checkexpr import Type as MyType
Expand All @@ -28,23 +27,33 @@ class FuseTypeInfoPass(Pass):

node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}

# Override this to support enter expression.
def enter_node(self, node: ast.AstNode) -> None:
"""Run on entering node."""
if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"):
getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node)

# TODO: Make (AstSymbolNode::name_spec.sym_typ and Expr::expr_type) the same
# TODO: Introduce AstTypedNode to be a common parent for Expr and AstSymbolNode
if isinstance(node, ast.Expr):
self.enter_expr(node)

def __debug_print(self, *argv: object) -> None:
if settings.fuse_type_info_debug:
self.log_info("FuseTypeInfo::", *argv)

def __call_type_handler(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.ProperType
) -> None:
def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]:
mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__)
type_handler_name = f"get_type_from_{mypy_type_name}"
if hasattr(self, type_handler_name):
getattr(self, type_handler_name)(node, mypy_type)
else:
self.__debug_print(
f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)

def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None:
return getattr(self, type_handler_name)(mypy_type)
self.__debug_print(
f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)
return None

# TODO: Need to chsnge node type to be AstNode or a common parent
def __set_type_sym_table_link(self, node: ast.AstSymbolNode) -> None:
typ = node.sym_type.split(".")
typ_sym_table = self.ir.sym_tab

Expand Down Expand Up @@ -113,7 +122,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None:
# Jac node has only one mypy node linked to it
if len(node.gen.mypy_ast) == 1:
func(self, node)
self.__set_sym_table_link(node)
self.__set_type_sym_table_link(node)
self.__collect_python_dependencies(node)

# Jac node has multiple mypy nodes linked to it
Expand All @@ -137,7 +146,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None:
jac_node_str, "has duplicate mypy nodes associated to it"
)
func(self, node)
self.__set_sym_table_link(node)
self.__set_type_sym_table_link(node)
self.__collect_python_dependencies(node)

# Jac node doesn't have mypy nodes linked to it
Expand Down Expand Up @@ -172,7 +181,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
mypy_node = mypy_node.node

if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)

elif isinstance(mypy_node, MypyNodes.MypyFile):
node.name_spec.sym_type = "types.ModuleType"
Expand All @@ -181,7 +192,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
node.name_spec.sym_type = mypy_node.fullname

elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef):
self.__call_type_handler(node, mypy_node.items[0].func.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.items[0].func.type)
or node.name_spec.sym_type
)

elif mypy_node is None:
node.name_spec.sym_type = "None"
Expand All @@ -195,19 +209,66 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
else:
if isinstance(mypy_node, MypyNodes.ClassDef):
node.name_spec.sym_type = mypy_node.fullname
self.__set_sym_table_link(node)
self.__set_type_sym_table_link(node)
elif isinstance(mypy_node, MypyNodes.FuncDef):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Argument):
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Decorator):
self.__call_type_handler(node, mypy_node.func.type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.func.type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported',
type(mypy_node),
)

collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# NOTE (Thakee): Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
# and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
# valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
# otherwise.
# NOTE (GAMAL): This will be fixed through the AstTypedNode
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""Enter an expression node."""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = self.__call_type_handler(mytype) or ""

# Set they symbol type for collection expression.
#
# GenCompr is an instance of ListCompr but we don't handle it here.
# so the isinstace (node, <classes>) doesn't work, I'm going with type(...) == ...
if type(node) in self.collection_types_map:
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
collection_type = self.collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = (
self.__call_type_handler(mytype) or node.name_spec.sym_type
)

@__handle_node
def enter_name(self, node: ast.NameAtom) -> None:
"""Pass handler for name nodes."""
Expand Down Expand Up @@ -247,7 +308,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
def enter_ability(self, node: ast.Ability) -> None:
"""Pass handler for Ability nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an ability from mypy node other than Ability.",
Expand All @@ -258,7 +322,10 @@ def enter_ability(self, node: ast.Ability) -> None:
def enter_ability_def(self, node: ast.AbilityDef) -> None:
"""Pass handler for AbilityDef nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef.",
Expand All @@ -271,7 +338,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument):
mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0]
if mypy_node.variable.type:
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get parameter value from mypyNode other than Argument"
Expand All @@ -285,7 +355,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
if isinstance(mypy_node, MypyNodes.AssignmentStmt):
n = mypy_node.lvalues[0].node
if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, n.type)
node.name_spec.sym_type = (
self.__call_type_handler(n.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
"Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
Expand All @@ -310,54 +382,6 @@ def enter_f_string(self, node: ast.FString) -> None:
"""Pass handler for FString nodes."""
self.__debug_print("Getting type not supported in", type(node))

@__handle_node
def enter_list_val(self, node: ast.ListVal) -> None:
"""Pass handler for ListVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.list"

@__handle_node
def enter_set_val(self, node: ast.SetVal) -> None:
"""Pass handler for SetVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.set"

@__handle_node
def enter_tuple_val(self, node: ast.TupleVal) -> None:
"""Pass handler for TupleVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.tuple"

@__handle_node
def enter_dict_val(self, node: ast.DictVal) -> None:
"""Pass handler for DictVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.dict"

@__handle_node
def enter_list_compr(self, node: ast.ListCompr) -> None:
"""Pass handler for ListCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_dict_compr(self, node: ast.DictCompr) -> None:
"""Pass handler for DictCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_index_slice(self, node: ast.IndexSlice) -> None:
"""Pass handler for IndexSlice nodes."""
Expand All @@ -369,10 +393,12 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
if isinstance(node.gen.mypy_ast[0], MypyNodes.ClassDef):
mypy_node: MypyNodes.ClassDef = node.gen.mypy_ast[0]
node.name_spec.sym_type = mypy_node.fullname
self.__set_sym_table_link(node)
self.__set_type_sym_table_link(node)
elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0]
self.__call_type_handler(node, mypy_node2.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef",
Expand Down Expand Up @@ -424,42 +450,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
"""Pass handler for BuiltinType nodes."""
self.__collect_type_from_symbol(node)

def get_type_from_instance(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance
) -> None:
def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]:
"""Get type info from mypy type Instance."""
node.name_spec.sym_type = str(mypy_type)
return str(mypy_type)

def get_type_from_callable_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType
) -> None:
self, mypy_type: MypyTypes.CallableType
) -> Optional[str]:
"""Get type info from mypy type CallableType."""
node.name_spec.sym_type = str(mypy_type.ret_type)
return str(mypy_type.ret_type)

# TODO: Which overloaded function to get the return value from?
def get_type_from_overloaded(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded
) -> None:
self, mypy_type: MypyTypes.Overloaded
) -> Optional[str]:
"""Get type info from mypy type Overloaded."""
self.__call_type_handler(node, mypy_type.items[0])
return self.__call_type_handler(mypy_type.items[0])

def get_type_from_none_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType
) -> None:
def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "None"
return "None"

def get_type_from_any_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType
) -> None:
def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "Any"
return "Any"

def get_type_from_tuple_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType
) -> None:
def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]:
"""Get type info from mypy type TupleType."""
node.name_spec.sym_type = "builtins.tuple"
return "builtins.tuple"

def exit_assignment(self, node: ast.Assignment) -> None:
"""Add new symbols in the symbol table in case of self."""
Expand Down
2 changes: 1 addition & 1 deletion jaclang/compiler/passes/main/tests/test_type_check_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def test_type_coverage(self) -> None:
self.assertIn("HasVar - species - Type: builtins.str", out)
self.assertIn("myDog - Type: type_info.Dog", out)
self.assertIn("Body - Type: type_info.Dog.Body", out)
self.assertEqual(out.count("Type: builtins.str"), 28)
self.assertEqual(out.count("Type: builtins.str"), 39)
for i in lis:
self.assertNotIn(i, out)
Loading