Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
expression typeinfo implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
ThakeeNathees committed Aug 24, 2024
1 parent 4fdfafd commit b99ee72
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 106 deletions.
4 changes: 4 additions & 0 deletions jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
self.meta: dict[str, str] = {}
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())

# NOTE: This is only applicable for Expr, However adding it there needs to call the constructor in all the
# subclasses, Adding it here, this needs a review.
self.expr_type: str = ""

@property
def sym_tab(self) -> SymbolTable:
"""Get symbol table."""
Expand Down
121 changes: 72 additions & 49 deletions jaclang/compiler/passes/main/fuse_typeinfo_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from types import MethodType
from typing import Callable, TypeVar

import jaclang.compiler.absyntree as ast
Expand All @@ -14,7 +15,6 @@
from jaclang.utils.helpers import pascal_to_snake
from jaclang.vendor.mypy.nodes import Node as VNode # bit of a hack


import mypy.nodes as MypyNodes # noqa N812
import mypy.types as MypyTypes # noqa N812
from mypy.checkexpr import Type as MyType
Expand All @@ -23,11 +23,82 @@
T = TypeVar("T", bound=ast.AstSymbolNode)


# List of expression nodes which we'll be extracting the type info from.
JAC_EXPR_NODES = (
ast.AwaitExpr,
ast.BinaryExpr,
ast.CompareExpr,
ast.BoolExpr,
ast.LambdaExpr,
ast.UnaryExpr,
ast.IfElseExpr,
ast.AtomTrailer,
ast.AtomUnit,
ast.YieldExpr,
ast.YieldExpr,
ast.FuncCall,
ast.EdgeRefTrailer,
ast.ListVal,
ast.SetVal,
ast.TupleVal,
ast.DictVal,
ast.ListCompr,
ast.DictCompr,
)


class FuseTypeInfoPass(Pass):
"""Python and bytecode file self.__debug_printing pass."""

node_type_hash: dict[MypyNodes.Node | VNode, MyType] = {}

@staticmethod
def enter_expr(self: FuseTypeInfoPass, node: ast.Expr) -> None:
"""
Enter an expression node.
This function is dynamically bound as a method on insntace of this class, since the
group of functions to handle expressions has a the exact same logic.
"""
if len(node.gen.mypy_ast) == 0:
return

# If the corrosponding mypy ast node type has stored here, get the values.
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
mytype: MyType = self.node_type_hash[mypy_node]
node.expr_type = str(mytype)

# TODO: Maybe move this out of the function otherwise it'll construct this dict every time it entered an
# expression. Time and memory wasted here.
collection_types_map = {
ast.ListVal: "builtins.list",
ast.SetVal: "builtins.set",
ast.TupleVal: "builtins.tuple",
ast.DictVal: "builtins.dict",
ast.ListCompr: None,
ast.DictCompr: None,
}

# Set they symbol type for collection expression.
if type(node) in tuple(collection_types_map.keys()):
assert isinstance(node, ast.AtomExpr) # To make mypy happy.
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(mytype)
collection_type = collection_types_map[type(node)]
if collection_type is not None:
node.name_spec.sym_type = collection_type

def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003
"""Initialize the FuseTpeInfoPass instance."""
for expr_node in JAC_EXPR_NODES:
method_name = "enter_" + pascal_to_snake(expr_node.__name__)
method = MethodType(
FuseTypeInfoPass.__handle_node(FuseTypeInfoPass.enter_expr), self
)
setattr(self, method_name, method)
super().__init__(*args, **kwargs)

def __debug_print(self, *argv: object) -> None:
if settings.fuse_type_info_debug:
self.log_info("FuseTypeInfo::", *argv)
Expand Down Expand Up @@ -310,54 +381,6 @@ def enter_f_string(self, node: ast.FString) -> None:
"""Pass handler for FString nodes."""
self.__debug_print("Getting type not supported in", type(node))

@__handle_node
def enter_list_val(self, node: ast.ListVal) -> None:
"""Pass handler for ListVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.list"

@__handle_node
def enter_set_val(self, node: ast.SetVal) -> None:
"""Pass handler for SetVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.set"

@__handle_node
def enter_tuple_val(self, node: ast.TupleVal) -> None:
"""Pass handler for TupleVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.tuple"

@__handle_node
def enter_dict_val(self, node: ast.DictVal) -> None:
"""Pass handler for DictVal nodes."""
mypy_node = node.gen.mypy_ast[0]
if mypy_node in self.node_type_hash:
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])
else:
node.name_spec.sym_type = "builtins.dict"

@__handle_node
def enter_list_compr(self, node: ast.ListCompr) -> None:
"""Pass handler for ListCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_dict_compr(self, node: ast.DictCompr) -> None:
"""Pass handler for DictCompr nodes."""
mypy_node = node.gen.mypy_ast[0]
node.name_spec.sym_type = str(self.node_type_hash[mypy_node])

@__handle_node
def enter_index_slice(self, node: ast.IndexSlice) -> None:
"""Pass handler for IndexSlice nodes."""
Expand Down
148 changes: 91 additions & 57 deletions jaclang/compiler/passes/utils/mypy_ast_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

import ast
import os
from types import MethodType

from jaclang.compiler.absyntree import AstNode
from jaclang.compiler.passes import Pass
from jaclang.compiler.passes.main.fuse_typeinfo_pass import (
FuseTypeInfoPass,
)
from jaclang.utils.helpers import pascal_to_snake

import mypy.build as myb
import mypy.checkexpr as mycke
import mypy.errors as mye
import mypy.fastparse as myfp
import mypy.nodes as mypy_nodes
from mypy.build import BuildSource
from mypy.build import BuildSourceSet
from mypy.build import FileSystemCache
Expand All @@ -29,6 +32,55 @@
from mypy.semanal_main import semantic_analysis_for_scc


# All the expression nodes of mypy.
EXPRESSION_NODES = (
mypy_nodes.AssertTypeExpr,
mypy_nodes.AssignmentExpr,
mypy_nodes.AwaitExpr,
mypy_nodes.BytesExpr,
mypy_nodes.CallExpr,
mypy_nodes.CastExpr,
mypy_nodes.ComparisonExpr,
mypy_nodes.ComplexExpr,
mypy_nodes.ConditionalExpr,
mypy_nodes.DictionaryComprehension,
mypy_nodes.DictExpr,
mypy_nodes.EllipsisExpr,
mypy_nodes.EnumCallExpr,
mypy_nodes.Expression,
mypy_nodes.FloatExpr,
mypy_nodes.GeneratorExpr,
mypy_nodes.IndexExpr,
mypy_nodes.IntExpr,
mypy_nodes.LambdaExpr,
mypy_nodes.ListComprehension,
mypy_nodes.ListExpr,
mypy_nodes.MemberExpr,
mypy_nodes.NamedTupleExpr,
mypy_nodes.NameExpr,
mypy_nodes.NewTypeExpr,
mypy_nodes.OpExpr,
mypy_nodes.ParamSpecExpr,
mypy_nodes.PromoteExpr,
mypy_nodes.RefExpr,
mypy_nodes.RevealExpr,
mypy_nodes.SetComprehension,
mypy_nodes.SetExpr,
mypy_nodes.SliceExpr,
mypy_nodes.StarExpr,
mypy_nodes.StrExpr,
mypy_nodes.SuperExpr,
mypy_nodes.TupleExpr,
mypy_nodes.TypeAliasExpr,
mypy_nodes.TypedDictExpr,
mypy_nodes.TypeVarExpr,
mypy_nodes.TypeVarTupleExpr,
mypy_nodes.UnaryExpr,
mypy_nodes.YieldExpr,
mypy_nodes.YieldFromExpr,
)


mypy_to_jac_node_map: dict[
tuple[int, int | None, int | None, int | None], list[AstNode]
] = {}
Expand Down Expand Up @@ -87,63 +139,45 @@ def __init__(
"""Override to mypy expression checker for direct AST pass through."""
super().__init__(tc, msg, plugin, per_line_checking_time_ns)

def visit_list_expr(self, e: mycke.ListExpr) -> mycke.Type:
"""Type check a list expression [...]."""
out = super().visit_list_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_set_expr(self, e: mycke.SetExpr) -> mycke.Type:
"""Type check a set expression {...}."""
out = super().visit_set_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_tuple_expr(self, e: myfp.TupleExpr) -> myb.Type:
"""Type check a tuple expression (...)."""
out = super().visit_tuple_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_dict_expr(self, e: myfp.DictExpr) -> myb.Type:
"""Type check a dictionary expression {...}."""
out = super().visit_dict_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_list_comprehension(self, e: myfp.ListComprehension) -> myb.Type:
"""Type check a list comprehension."""
out = super().visit_list_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_set_comprehension(self, e: myfp.SetComprehension) -> myb.Type:
"""Type check a set comprehension."""
out = super().visit_set_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_generator_expr(self, e: myfp.GeneratorExpr) -> myb.Type:
"""Type check a generator expression."""
out = super().visit_generator_expr(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_dictionary_comprehension(
self, e: myfp.DictionaryComprehension
) -> myb.Type:
"""Type check a dict comprehension."""
out = super().visit_dictionary_comprehension(e)
FuseTypeInfoPass.node_type_hash[e] = out
return out

def visit_member_expr(
self, e: myfp.MemberExpr, is_lvalue: bool = False
) -> myb.Type:
"""Type check a member expr."""
out = super().visit_member_expr(e, is_lvalue)
FuseTypeInfoPass.node_type_hash[e] = out
return out
# For every expression there, create attach a method on this instance (self) named "enter_expr()"
for expr_node in EXPRESSION_NODES:
method_name = "visit_" + pascal_to_snake(expr_node.__name__)

# We call the super() version of the method so ensure the parent class has the method or else continue.
if not hasattr(mycke.ExpressionChecker, method_name):
continue

# If the method already overriden then don't override it again here. Continue. Note that the method exists
# on the parent class and if it's also exists on this class and it's a different object that means it's
# overrident method.
if getattr(mycke.ExpressionChecker, method_name) != getattr(
ExpressionChecker, method_name
):
continue

# Since the "closure" function bellow captures the method name inside it, we cannot use it directly as the
# "method_name" variable is used inside a loop and by the time the closure close the "method_name" value,
# it'll be changed by the loop, so we need another method ("make_closure") to persist the value.
def make_closure(method_name: str): # noqa: ANN201
def closure(
self: ExpressionChecker,
e: mycke.Expression,
*args, # noqa: ANN002
**kwargs, # noqa: ANN003
) -> mycke.Type:
# Ignore B023 here since we bind loop variable properly but flake8 raise a false alarm
# (in some version of it), a bug in flake8 (https://github.com/PyCQA/flake8-bugbear/issues/269).
out = getattr(mycke.ExpressionChecker, method_name)( # noqa: B023
self, e, *args, **kwargs
)
FuseTypeInfoPass.node_type_hash[e] = out
return out

return closure

# Attach the new "visit_expr()" method to this instance.
method = make_closure(method_name)
setattr(self, method_name, MethodType(method, self))


class State(myb.State):
Expand Down
2 changes: 2 additions & 0 deletions jaclang/utils/treeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __node_repr_in_tree(node: AstNode) -> str:
)
out += f" SymbolPath: {symbol}"
return out
elif isinstance(node, ast.Expr):
return f"{node.__class__.__name__} - Type: {node.expr_type}"
elif isinstance(node, Token):
return f"{node.__class__.__name__} - {node.value}, {access}"
elif (
Expand Down

0 comments on commit b99ee72

Please sign in to comment.