Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Expression type info #590

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
self.meta: dict[str, str] = {}
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())

# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
# subclasses, Adding it here, this needs a review.
self.expr_type: str = ""

@property
def sym_tab(self) -> SymbolTable:
"""Get symbol table."""
Expand Down
197 changes: 107 additions & 90 deletions jaclang/compiler/passes/main/fuse_typeinfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

from __future__ import annotations

from typing import Callable, TypeVar
from typing import Callable, Optional, TypeVar

import jaclang.compiler.absyntree as ast
from jaclang.compiler.passes import Pass
from jaclang.settings import settings
from jaclang.utils.helpers import pascal_to_snake
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack


import mypy.nodes as MypyNodes # noqa N812
import mypy.types as MypyTypes # noqa N812
from mypy.checkexpr import Type as MyType
Expand All @@ -28,21 +27,27 @@ class FuseTypeInfoPass(Pass):

node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}

# Override this to support enter expression.
def enter_node(self, node: ast.AstNode) -> None:
"""Run on entering node."""
if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"):
getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node)
elif isinstance(node, ast.Expr):
self.enter_expr(node)

def __debug_print(self, *argv: object) -> None:
if settings.fuse_type_info_debug:
self.log_info("FuseTypeInfo::", *argv)

def __call_type_handler(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.ProperType
) -> None:
def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]:
mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__)
type_handler_name = f"get_type_from_{mypy_type_name}"
if hasattr(self, type_handler_name):
getattr(self, type_handler_name)(node, mypy_type)
else:
self.__debug_print(
f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)
return getattr(self, type_handler_name)(mypy_type)
self.__debug_print(
f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet'
)
return None

def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None:
typ = node.sym_type.split(".")
Expand Down Expand Up @@ -172,7 +177,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
mypy_node = mypy_node.node

if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)

elif isinstance(mypy_node, MypyNodes.MypyFile):
node.name_spec.sym_type = "types.ModuleType"
Expand All @@ -181,7 +188,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
node.name_spec.sym_type = mypy_node.fullname

elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef):
self.__call_type_handler(node, mypy_node.items[0].func.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.items[0].func.type)
or node.name_spec.sym_type
)

elif mypy_node is None:
node.name_spec.sym_type = "None"
Expand All @@ -197,17 +207,67 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None:
node.name_spec.sym_type = mypy_node.fullname
self.__set_sym_table_link(node)
elif isinstance(mypy_node, MypyNodes.FuncDef):
self.__call_type_handler(node, mypy_node.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Argument):
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
elif isinstance(mypy_node, MypyNodes.Decorator):
self.__call_type_handler(node, mypy_node.func.type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.func.type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported',
type(mypy_node),
)

# NOTE: Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node
# and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as
# valid nodes that can have symbols. At this point I'm leaving this like this and lemme know
# otherwise.
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""
Enter an expression node.

This function is dynamically bound as a method on insntace of this class, since the
group of functions to handle expressions has a the exact same logic.
"""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = self.__call_type_handler(mytype) or ""

# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
# expression. Time and memory wasted here.
collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# Set they symbol type for collection expression.
if type(node) in tuple(collection_types_map.keys()):
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
collection_type = collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = (
self.__call_type_handler(mytype) or node.name_spec.sym_type
)

@__handle_node
def enter_name(self, node: ast.NameAtom) -> None:
"""Pass handler for name nodes."""
Expand Down Expand Up @@ -247,7 +307,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None:
def enter_ability(self, node: ast.Ability) -> None:
"""Pass handler for Ability nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an ability from mypy node other than Ability.",
Expand All @@ -258,7 +321,10 @@ def enter_ability(self, node: ast.Ability) -> None:
def enter_ability_def(self, node: ast.AbilityDef) -> None:
"""Pass handler for AbilityDef nodes."""
if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type)
node.name_spec.sym_type = (
self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef.",
Expand All @@ -271,7 +337,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None:
if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument):
mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0]
if mypy_node.variable.type:
self.__call_type_handler(node, mypy_node.variable.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node.variable.type)
or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get parameter value from mypyNode other than Argument"
Expand All @@ -285,7 +354,9 @@ def enter_has_var(self, node: ast.HasVar) -> None:
if isinstance(mypy_node, MypyNodes.AssignmentStmt):
n = mypy_node.lvalues[0].node
if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)):
self.__call_type_handler(node, n.type)
node.name_spec.sym_type = (
self.__call_type_handler(n.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
"Getting type of 'AssignmentStmt' is only supported with Var and FuncDef"
Expand All @@ -310,54 +381,6 @@ def enter_f_string(self, node: ast.FString) -> None:
"""Pass handler for FString nodes."""
self.__debug_print("Getting type not supported in", type(node))

@__handle_node
def enter_list_val(self, node: ast.ListVal) -> None:
"""Pass handler for ListVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.list"

@__handle_node
def enter_set_val(self, node: ast.SetVal) -> None:
"""Pass handler for SetVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.set"

@__handle_node
def enter_tuple_val(self, node: ast.TupleVal) -> None:
"""Pass handler for TupleVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.tuple"

@__handle_node
def enter_dict_val(self, node: ast.DictVal) -> None:
"""Pass handler for DictVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.dict"

@__handle_node
def enter_list_compr(self, node: ast.ListCompr) -> None:
"""Pass handler for ListCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_dict_compr(self, node: ast.DictCompr) -> None:
"""Pass handler for DictCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_index_slice(self, node: ast.IndexSlice) -> None:
"""Pass handler for IndexSlice nodes."""
Expand All @@ -372,7 +395,9 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None:
self.__set_sym_table_link(node)
elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef):
mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0]
self.__call_type_handler(node, mypy_node2.type)
node.name_spec.sym_type = (
self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type
)
else:
self.__debug_print(
f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef",
Expand Down Expand Up @@ -424,42 +449,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None:
"""Pass handler for BuiltinType nodes."""
self.__collect_type_from_symbol(node)

def get_type_from_instance(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance
) -> None:
def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]:
"""Get type info from mypy type Instance."""
node.name_spec.sym_type = str(mypy_type)
return str(mypy_type)

def get_type_from_callable_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType
) -> None:
self, mypy_type: MypyTypes.CallableType
) -> Optional[str]:
"""Get type info from mypy type CallableType."""
node.name_spec.sym_type = str(mypy_type.ret_type)
return str(mypy_type.ret_type)

# TODO: Which overloaded function to get the return value from?
def get_type_from_overloaded(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded
) -> None:
self, mypy_type: MypyTypes.Overloaded
) -> Optional[str]:
"""Get type info from mypy type Overloaded."""
self.__call_type_handler(node, mypy_type.items[0])
return self.__call_type_handler(mypy_type.items[0])

def get_type_from_none_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType
) -> None:
def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "None"
return "None"

def get_type_from_any_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType
) -> None:
def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]:
"""Get type info from mypy type NoneType."""
node.name_spec.sym_type = "Any"
return "Any"

def get_type_from_tuple_type(
self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType
) -> None:
def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]:
"""Get type info from mypy type TupleType."""
node.name_spec.sym_type = "builtins.tuple"
return "builtins.tuple"

def exit_assignment(self, node: ast.Assignment) -> None:
"""Add new symbols in the symbol table in case of self."""
Expand Down
2 changes: 1 addition & 1 deletion jaclang/compiler/passes/main/tests/test_type_check_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ def test_type_coverage(self) -> None:
self.assertIn("HasVar - species - Type: builtins.str", out)
self.assertIn("myDog - Type: type_info.Dog", out)
self.assertIn("Body - Type: type_info.Dog.Body", out)
self.assertEqual(out.count("Type: builtins.str"), 28)
self.assertEqual(out.count("Type: builtins.str"), 39)
for i in lis:
self.assertNotIn(i, out)
Loading
Loading