diff --git a/jaclang/compiler/absyntree.py b/jaclang/compiler/absyntree.py index b7524db5a..ea38eb14d 100644 --- a/jaclang/compiler/absyntree.py +++ b/jaclang/compiler/absyntree.py @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None: self.meta: dict[str, str] = {} self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range()) + # NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the + # subclasses, Adding it here, this needs a review. + self.expr_type: str = "" + @property def sym_tab(self) -> SymbolTable: """Get symbol table.""" diff --git a/jaclang/compiler/passes/main/fuse_typeinfo_pass.py b/jaclang/compiler/passes/main/fuse_typeinfo_pass.py index 8e44c5313..5c384b117 100644 --- a/jaclang/compiler/passes/main/fuse_typeinfo_pass.py +++ b/jaclang/compiler/passes/main/fuse_typeinfo_pass.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Callable, TypeVar +from typing import Callable, Optional, TypeVar import jaclang.compiler.absyntree as ast from jaclang.compiler.passes import Pass @@ -14,7 +14,6 @@ from jaclang.utils.helpers import pascal_to_snake from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack - import mypy.nodes as MypyNodes # noqa N812 import mypy.types as MypyTypes # noqa N812 from mypy.checkexpr import Type as MyType @@ -28,23 +27,33 @@ class FuseTypeInfoPass(Pass): node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {} + # Override this to support enter expression. + def enter_node(self, node: ast.AstNode) -> None: + """Run on entering node.""" + if hasattr(self, f"enter_{pascal_to_snake(type(node).__name__)}"): + getattr(self, f"enter_{pascal_to_snake(type(node).__name__)}")(node) + + # TODO: Make (AstSymbolNode::name_spec.sym_typ and Expr::expr_type) the same + # TODO: Introduce AstTypedNode to be a common parent for Expr and AstSymbolNode + if isinstance(node, ast.Expr): + self.enter_expr(node) + def __debug_print(self, *argv: object) -> None: if settings.fuse_type_info_debug: self.log_info("FuseTypeInfo::", *argv) - def __call_type_handler( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.ProperType - ) -> None: + def __call_type_handler(self, mypy_type: MypyTypes.Type) -> Optional[str]: mypy_type_name = pascal_to_snake(mypy_type.__class__.__name__) type_handler_name = f"get_type_from_{mypy_type_name}" if hasattr(self, type_handler_name): - getattr(self, type_handler_name)(node, mypy_type) - else: - self.__debug_print( - f'{node.loc}"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet' - ) - - def __set_sym_table_link(self, node: ast.AstSymbolNode) -> None: + return getattr(self, type_handler_name)(mypy_type) + self.__debug_print( + f'"MypyTypes::{mypy_type.__class__.__name__}" isn\'t supported yet' + ) + return None + + # TODO: Need to chsnge node type to be AstNode or a common parent + def __set_type_sym_table_link(self, node: ast.AstSymbolNode) -> None: typ = node.sym_type.split(".") typ_sym_table = self.ir.sym_tab @@ -113,7 +122,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None: # Jac node has only one mypy node linked to it if len(node.gen.mypy_ast) == 1: func(self, node) - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) self.__collect_python_dependencies(node) # Jac node has multiple mypy nodes linked to it @@ -137,7 +146,7 @@ def node_handler(self: FuseTypeInfoPass, node: T) -> None: jac_node_str, "has duplicate mypy nodes associated to it" ) func(self, node) - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) self.__collect_python_dependencies(node) # Jac node doesn't have mypy nodes linked to it @@ -172,7 +181,9 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: mypy_node = mypy_node.node if isinstance(mypy_node, (MypyNodes.Var, MypyNodes.FuncDef)): - self.__call_type_handler(node, mypy_node.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.MypyFile): node.name_spec.sym_type = "types.ModuleType" @@ -181,7 +192,10 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: node.name_spec.sym_type = mypy_node.fullname elif isinstance(mypy_node, MypyNodes.OverloadedFuncDef): - self.__call_type_handler(node, mypy_node.items[0].func.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.items[0].func.type) + or node.name_spec.sym_type + ) elif mypy_node is None: node.name_spec.sym_type = "None" @@ -195,19 +209,66 @@ def __collect_type_from_symbol(self, node: ast.AstSymbolNode) -> None: else: if isinstance(mypy_node, MypyNodes.ClassDef): node.name_spec.sym_type = mypy_node.fullname - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) elif isinstance(mypy_node, MypyNodes.FuncDef): - self.__call_type_handler(node, mypy_node.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.type) or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.Argument): - self.__call_type_handler(node, mypy_node.variable.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.variable.type) + or node.name_spec.sym_type + ) elif isinstance(mypy_node, MypyNodes.Decorator): - self.__call_type_handler(node, mypy_node.func.type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.func.type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f'"{node.loc}::{node.__class__.__name__}" mypy node isn\'t supported', type(mypy_node), ) + collection_types_map = { + ast.ListVal: "builtins.list", + ast.SetVal: "builtins.set", + ast.TupleVal: "builtins.tuple", + ast.DictVal: "builtins.dict", + ast.ListCompr: None, + ast.DictCompr: None, + } + + # NOTE (Thakee): Since expression nodes are not AstSymbolNodes, I'm not decorating this with __handle_node + # and IMO instead of checking if it's a symbol node or an expression, we somehow mark expressions as + # valid nodes that can have symbols. At this point I'm leaving this like this and lemme know + # otherwise. + # NOTE (GAMAL): This will be fixed through the AstTypedNode + def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None: + """Enter an expression node.""" + if len(node.gen.mypy_ast) == 0: + return + + # If the corrosponding mypy ast node type has stored here, get the values. + mypy_node = node.gen.mypy_ast[0] + if mypy_node in self.node_type_hash: + mytype: MyType = self.node_type_hash[mypy_node] + node.expr_type = self.__call_type_handler(mytype) or "" + + # Set they symbol type for collection expression. + # + # GenCompr is an instance of ListCompr but we don't handle it here. + # so the isinstace (node, ) doesn't work, I'm going with type(...) == ... + if type(node) in self.collection_types_map: + assert isinstance(node, ast.AtomExpr) # To make mypy happy. + collection_type = self.collection_types_map[type(node)] + if collection_type is not None: + node.name_spec.sym_type = collection_type + if mypy_node in self.node_type_hash: + node.name_spec.sym_type = ( + self.__call_type_handler(mytype) or node.name_spec.sym_type + ) + @__handle_node def enter_name(self, node: ast.NameAtom) -> None: """Pass handler for name nodes.""" @@ -247,7 +308,10 @@ def enter_enum_def(self, node: ast.EnumDef) -> None: def enter_ability(self, node: ast.Ability) -> None: """Pass handler for Ability nodes.""" if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): - self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get type of an ability from mypy node other than Ability.", @@ -258,7 +322,10 @@ def enter_ability(self, node: ast.Ability) -> None: def enter_ability_def(self, node: ast.AbilityDef) -> None: """Pass handler for AbilityDef nodes.""" if isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): - self.__call_type_handler(node, node.gen.mypy_ast[0].type.ret_type) + node.name_spec.sym_type = ( + self.__call_type_handler(node.gen.mypy_ast[0].type.ret_type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get type of an AbilityDef from mypy node other than FuncDef.", @@ -271,7 +338,10 @@ def enter_param_var(self, node: ast.ParamVar) -> None: if isinstance(node.gen.mypy_ast[0], MypyNodes.Argument): mypy_node: MypyNodes.Argument = node.gen.mypy_ast[0] if mypy_node.variable.type: - self.__call_type_handler(node, mypy_node.variable.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node.variable.type) + or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get parameter value from mypyNode other than Argument" @@ -285,7 +355,9 @@ def enter_has_var(self, node: ast.HasVar) -> None: if isinstance(mypy_node, MypyNodes.AssignmentStmt): n = mypy_node.lvalues[0].node if isinstance(n, (MypyNodes.Var, MypyNodes.FuncDef)): - self.__call_type_handler(node, n.type) + node.name_spec.sym_type = ( + self.__call_type_handler(n.type) or node.name_spec.sym_type + ) else: self.__debug_print( "Getting type of 'AssignmentStmt' is only supported with Var and FuncDef" @@ -310,54 +382,6 @@ def enter_f_string(self, node: ast.FString) -> None: """Pass handler for FString nodes.""" self.__debug_print("Getting type not supported in", type(node)) - @__handle_node - def enter_list_val(self, node: ast.ListVal) -> None: - """Pass handler for ListVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.list" - - @__handle_node - def enter_set_val(self, node: ast.SetVal) -> None: - """Pass handler for SetVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.set" - - @__handle_node - def enter_tuple_val(self, node: ast.TupleVal) -> None: - """Pass handler for TupleVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.tuple" - - @__handle_node - def enter_dict_val(self, node: ast.DictVal) -> None: - """Pass handler for DictVal nodes.""" - mypy_node = node.gen.mypy_ast[0] - if mypy_node in self.node_type_hash: - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - else: - node.name_spec.sym_type = "builtins.dict" - - @__handle_node - def enter_list_compr(self, node: ast.ListCompr) -> None: - """Pass handler for ListCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - - @__handle_node - def enter_dict_compr(self, node: ast.DictCompr) -> None: - """Pass handler for DictCompr nodes.""" - mypy_node = node.gen.mypy_ast[0] - node.name_spec.sym_type = str(self.node_type_hash[mypy_node]) - @__handle_node def enter_index_slice(self, node: ast.IndexSlice) -> None: """Pass handler for IndexSlice nodes.""" @@ -369,10 +393,12 @@ def enter_arch_ref(self, node: ast.ArchRef) -> None: if isinstance(node.gen.mypy_ast[0], MypyNodes.ClassDef): mypy_node: MypyNodes.ClassDef = node.gen.mypy_ast[0] node.name_spec.sym_type = mypy_node.fullname - self.__set_sym_table_link(node) + self.__set_type_sym_table_link(node) elif isinstance(node.gen.mypy_ast[0], MypyNodes.FuncDef): mypy_node2: MypyNodes.FuncDef = node.gen.mypy_ast[0] - self.__call_type_handler(node, mypy_node2.type) + node.name_spec.sym_type = ( + self.__call_type_handler(mypy_node2.type) or node.name_spec.sym_type + ) else: self.__debug_print( f"{node.loc}: Can't get ArchRef value from mypyNode other than ClassDef", @@ -424,42 +450,34 @@ def enter_builtin_type(self, node: ast.BuiltinType) -> None: """Pass handler for BuiltinType nodes.""" self.__collect_type_from_symbol(node) - def get_type_from_instance( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Instance - ) -> None: + def get_type_from_instance(self, mypy_type: MypyTypes.Instance) -> Optional[str]: """Get type info from mypy type Instance.""" - node.name_spec.sym_type = str(mypy_type) + return str(mypy_type) def get_type_from_callable_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.CallableType - ) -> None: + self, mypy_type: MypyTypes.CallableType + ) -> Optional[str]: """Get type info from mypy type CallableType.""" - node.name_spec.sym_type = str(mypy_type.ret_type) + return str(mypy_type.ret_type) # TODO: Which overloaded function to get the return value from? def get_type_from_overloaded( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.Overloaded - ) -> None: + self, mypy_type: MypyTypes.Overloaded + ) -> Optional[str]: """Get type info from mypy type Overloaded.""" - self.__call_type_handler(node, mypy_type.items[0]) + return self.__call_type_handler(mypy_type.items[0]) - def get_type_from_none_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.NoneType - ) -> None: + def get_type_from_none_type(self, mypy_type: MypyTypes.NoneType) -> Optional[str]: """Get type info from mypy type NoneType.""" - node.name_spec.sym_type = "None" + return "None" - def get_type_from_any_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.AnyType - ) -> None: + def get_type_from_any_type(self, mypy_type: MypyTypes.AnyType) -> Optional[str]: """Get type info from mypy type NoneType.""" - node.name_spec.sym_type = "Any" + return "Any" - def get_type_from_tuple_type( - self, node: ast.AstSymbolNode, mypy_type: MypyTypes.TupleType - ) -> None: + def get_type_from_tuple_type(self, mypy_type: MypyTypes.TupleType) -> Optional[str]: """Get type info from mypy type TupleType.""" - node.name_spec.sym_type = "builtins.tuple" + return "builtins.tuple" def exit_assignment(self, node: ast.Assignment) -> None: """Add new symbols in the symbol table in case of self.""" diff --git a/jaclang/compiler/passes/main/tests/test_type_check_pass.py b/jaclang/compiler/passes/main/tests/test_type_check_pass.py index 6f0c863b2..80d4b13e7 100644 --- a/jaclang/compiler/passes/main/tests/test_type_check_pass.py +++ b/jaclang/compiler/passes/main/tests/test_type_check_pass.py @@ -59,6 +59,6 @@ def test_type_coverage(self) -> None: self.assertIn("HasVar - species - Type: builtins.str", out) self.assertIn("myDog - Type: type_info.Dog", out) self.assertIn("Body - Type: type_info.Dog.Body", out) - self.assertEqual(out.count("Type: builtins.str"), 28) + self.assertEqual(out.count("Type: builtins.str"), 39) for i in lis: self.assertNotIn(i, out) diff --git a/jaclang/compiler/passes/utils/mypy_ast_build.py b/jaclang/compiler/passes/utils/mypy_ast_build.py index 89619c36b..904a0e0a4 100644 --- a/jaclang/compiler/passes/utils/mypy_ast_build.py +++ b/jaclang/compiler/passes/utils/mypy_ast_build.py @@ -4,17 +4,20 @@ import ast import os +from types import MethodType from jaclang.compiler.absyntree import AstNode from jaclang.compiler.passes import Pass from jaclang.compiler.passes.main.fuse_typeinfo_pass import ( FuseTypeInfoPass, ) +from jaclang.utils.helpers import pascal_to_snake import mypy.build as myb import mypy.checkexpr as mycke import mypy.errors as mye import mypy.fastparse as myfp +import mypy.nodes as mypy_nodes from mypy.build import BuildSource from mypy.build import BuildSourceSet from mypy.build import FileSystemCache @@ -29,6 +32,55 @@ from mypy.semanal_main import semantic_analysis_for_scc +# All the expression nodes of mypy. +EXPRESSION_NODES = ( + mypy_nodes.AssertTypeExpr, + mypy_nodes.AssignmentExpr, + mypy_nodes.AwaitExpr, + mypy_nodes.BytesExpr, + mypy_nodes.CallExpr, + mypy_nodes.CastExpr, + mypy_nodes.ComparisonExpr, + mypy_nodes.ComplexExpr, + mypy_nodes.ConditionalExpr, + mypy_nodes.DictionaryComprehension, + mypy_nodes.DictExpr, + mypy_nodes.EllipsisExpr, + mypy_nodes.EnumCallExpr, + mypy_nodes.Expression, + mypy_nodes.FloatExpr, + mypy_nodes.GeneratorExpr, + mypy_nodes.IndexExpr, + mypy_nodes.IntExpr, + mypy_nodes.LambdaExpr, + mypy_nodes.ListComprehension, + mypy_nodes.ListExpr, + mypy_nodes.MemberExpr, + mypy_nodes.NamedTupleExpr, + mypy_nodes.NameExpr, + mypy_nodes.NewTypeExpr, + mypy_nodes.OpExpr, + mypy_nodes.ParamSpecExpr, + mypy_nodes.PromoteExpr, + mypy_nodes.RefExpr, + mypy_nodes.RevealExpr, + mypy_nodes.SetComprehension, + mypy_nodes.SetExpr, + mypy_nodes.SliceExpr, + mypy_nodes.StarExpr, + mypy_nodes.StrExpr, + mypy_nodes.SuperExpr, + mypy_nodes.TupleExpr, + mypy_nodes.TypeAliasExpr, + mypy_nodes.TypedDictExpr, + mypy_nodes.TypeVarExpr, + mypy_nodes.TypeVarTupleExpr, + mypy_nodes.UnaryExpr, + mypy_nodes.YieldExpr, + mypy_nodes.YieldFromExpr, +) + + mypy_to_jac_node_map: dict[ tuple[int, int | None, int | None, int | None], list[AstNode] ] = {} @@ -87,63 +139,45 @@ def __init__( """Override to mypy expression checker for direct AST pass through.""" super().__init__(tc, msg, plugin, per_line_checking_time_ns) - def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type: - """Type check a list expression [...].""" - out = super().visit_list_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type: - """Type check a set expression {...}.""" - out = super().visit_set_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type: - """Type check a tuple expression (...).""" - out = super().visit_tuple_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type: - """Type check a dictionary expression {...}.""" - out = super().visit_dict_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type: - """Type check a list comprehension.""" - out = super().visit_list_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type: - """Type check a set comprehension.""" - out = super().visit_set_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type: - """Type check a generator expression.""" - out = super().visit_generator_expr(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_dictionary_comprehension( - self, e: myfp.DictionaryComprehension - ) -> myb.Type: - """Type check a dict comprehension.""" - out = super().visit_dictionary_comprehension(e) - FuseTypeInfoPass.node_type_hash[e] = out - return out - - def visit_member_expr( - self, e: myfp.MemberExpr, is_lvalue: bool = False - ) -> myb.Type: - """Type check a member expr.""" - out = super().visit_member_expr(e, is_lvalue) - FuseTypeInfoPass.node_type_hash[e] = out - return out + # For every expression there, create attach a method on this instance (self) named "enter_expr()" + for expr_node in EXPRESSION_NODES: + method_name = "visit_" + pascal_to_snake(expr_node.__name__) + + # We call the super() version of the method so ensure the parent class has the method or else continue. + if not hasattr(mycke.ExpressionChecker, method_name): + continue + + # If the method already overriden then don't override it again here. Continue. Note that the method exists + # on the parent class and if it's also exists on this class and it's a different object that means it's + # overrident method. + if getattr(mycke.ExpressionChecker, method_name) != getattr( + ExpressionChecker, method_name + ): + continue + + # Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the + # "method_name" variable is used inside a loop and by the time the closure close the "method_name" value, + # it'll be changed by the loop, so we need another method ("make_closure") to persist the value. + def make_closure(method_name: str): # noqa: ANN201 + def closure( + self: ExpressionChecker, + e: mycke.Expression, + *args, # noqa: ANN002 + **kwargs, # noqa: ANN003 + ) -> mycke.Type: + # Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm + # (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269). + out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023 + self, e, *args, **kwargs + ) + FuseTypeInfoPass.node_type_hash[e] = out + return out + + return closure + + # Attach the new "visit_expr()" method to this instance. + method = make_closure(method_name) + setattr(self, method_name, MethodType(method, self)) class State(myb.State): diff --git a/jaclang/utils/treeprinter.py b/jaclang/utils/treeprinter.py index 9c3c5f6ef..0ed875cae 100644 --- a/jaclang/utils/treeprinter.py +++ b/jaclang/utils/treeprinter.py @@ -135,6 +135,8 @@ def __node_repr_in_tree(node: AstNode) -> str: ) out += f" SymbolPath: {symbol}" return out + elif isinstance(node, ast.Expr): + return f"{node.__class__.__name__} - Type: {node.expr_type}" else: return f"{node.__class__.__name__}, {access}"