Skip to content

Commit

Permalink
round 1 of review with @charles-cooper
Browse files Browse the repository at this point in the history
  • Loading branch information
z80dev committed Oct 27, 2023
1 parent d0522ad commit bfb7986
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 105 deletions.
2 changes: 1 addition & 1 deletion tests/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def bar():
"""
ast.build_ast(src)
functiondef_node = ast.get_internal_function_nodes()[0]
fn_ast = AST.create_new_instance(functiondef_node)
fn_ast = AST.from_node(functiondef_node)
references = fn_ast.find_nodes_referencing_symbol("x")
assert len(references) == 1
assert (
Expand Down
13 changes: 9 additions & 4 deletions tests/test_navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@


@pytest.fixture
def doc():
doc = Document(uri="examples/Foo.vy")
def ast():
ast = AST()
return ast


@pytest.fixture
def doc(ast):
doc = Document(uri="examples/Foo.vy")
ast.build_ast(doc.source)
return doc


@pytest.fixture
def navigator():
return ASTNavigator()
def navigator(ast):
return ASTNavigator(ast)


def test_find_references_event_name(doc, navigator):
Expand Down
5 changes: 2 additions & 3 deletions vyper_lsp/analyzer/AstAnalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from vyper.compiler import CompilerData
from vyper.exceptions import VyperException
from vyper_lsp.analyzer.BaseAnalyzer import Analyzer
from vyper_lsp.ast import AST
from vyper_lsp.utils import (
get_expression_at_cursor,
get_word_at_cursor,
Expand Down Expand Up @@ -39,9 +38,9 @@


class AstAnalyzer(Analyzer):
def __init__(self, ast=None) -> None:
def __init__(self, ast) -> None:
super().__init__()
self.ast = ast or AST()
self.ast = ast
if get_installed_vyper_version() < min_vyper_version:
self.diagnostics_enabled = False
else:
Expand Down
140 changes: 52 additions & 88 deletions vyper_lsp/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from vyper.ast import VyperNode, nodes
from vyper.compiler import CompilerData

ast = None


class AST:
_instance = None
ast_data = None
ast_data_folded = None
ast_data_unfolded = None

custom_type_node_types = (nodes.StructDef, nodes.EnumDef, nodes.EventDef)

def __new__(cls):
if cls._instance is None:
cls._instance = super(AST, cls).__new__(cls)
cls._instance.ast_data = None
return cls._instance
@classmethod
def from_node(cls, node: VyperNode):
ast = cls()
ast.ast_data = node
ast.ast_data_unfolded = node
ast.ast_data_folded = node
return ast

def update_ast(self, document):
self.build_ast(document.source)
Expand All @@ -45,47 +44,38 @@ def build_ast(self, src: str):
print(f"Error generating folded AST, {e}")
pass

def get_descendants_from_best_ast(self, *args, **kwargs):
@property
def best_ast(self):
if self.ast_data_unfolded:
return self.ast_data_unfolded.get_descendants(*args, **kwargs)
return self.ast_data_unfolded
elif self.ast_data:
return self.ast_data.get_descendants(*args, **kwargs)
return self.ast_data
elif self.ast_data_folded:
return self.ast_data_folded.get_descendants(*args, **kwargs)
return self.ast_data_folded
else:
return None

def get_descendants(self, *args, **kwargs):
if self.best_ast is None:
return []
return self.best_ast.get_descendants(*args, **kwargs)

def get_children_from_best_ast(self, *args, **kwargs):
if self.ast_data_unfolded:
return self.ast_data_unfolded.get_children(*args, **kwargs)
elif self.ast_data:
return self.ast_data.get_children(*args, **kwargs)
elif self.ast_data_folded:
return self.ast_data_folded.get_children(*args, **kwargs)
else:
def get_top_level_nodes(self, *args, **kwargs):
if self.best_ast is None:
return []
return self.best_ast.get_children(*args, **kwargs)

def get_enums(self) -> List[str]:
return [node.name for node in self.get_descendants_from_best_ast(nodes.EnumDef)]
return [node.name for node in self.get_descendants(nodes.EnumDef)]

def get_structs(self) -> List[str]:
if self.ast_data_unfolded is None:
return []

return [
node.name for node in self.get_descendants_from_best_ast(nodes.StructDef)
]
return [node.name for node in self.get_descendants(nodes.StructDef)]

def get_events(self) -> List[str]:
return [
node.name for node in self.get_descendants_from_best_ast(nodes.EventDef)
]
return [node.name for node in self.get_descendants(nodes.EventDef)]

def get_user_defined_types(self):
return [
node.name
for node in self.get_descendants_from_best_ast(self.custom_type_node_types)
]
return [node.name for node in self.get_descendants(self.custom_type_node_types)]

def get_constants(self):
# NOTE: Constants should be fetched from self.ast_data, they are missing
Expand Down Expand Up @@ -123,38 +113,31 @@ def get_state_variables(self):
]

def get_internal_function_nodes(self):
function_nodes = self.get_descendants_from_best_ast(nodes.FunctionDef)
inernal_nodes = []
function_nodes = self.get_descendants(nodes.FunctionDef)
internal_nodes = []

for node in function_nodes:
for decorator in node.decorator_list:
if decorator.id == "internal":
inernal_nodes.append(node)
internal_nodes.append(node)

return inernal_nodes
return internal_nodes

def get_internal_functions(self):
return [node.name for node in self.get_internal_function_nodes()]

def find_nodes_referencing_internal_function(self, function: str):
return self.get_descendants_from_best_ast(
return self.get_descendants(
nodes.Call, {"func.attr": function, "func.value.id": "self"}
)

def find_nodes_referencing_state_variable(self, variable: str):
return self.get_descendants_from_best_ast(
return self.get_descendants(
nodes.Attribute, {"value.id": "self", "attr": variable}
)

def find_nodes_referencing_constant(self, constant: str):
# NOTE: Constants should be fetched from self.ast_data, they are missing
# from self.ast_data_unfolded and self.ast_data_folded
if self.ast_data_unfolded is None:
return []

name_nodes = self.ast_data_unfolded.get_descendants(
nodes.Name, {"id": constant}
)
name_nodes = self.get_descendants(nodes.Name, {"id": constant})
return [
node
for node in name_nodes
Expand All @@ -174,7 +157,7 @@ def get_attributes_for_symbol(self, symbol: str):
return []

def find_function_declaration_node_for_name(self, function: str):
for node in self.get_descendants_from_best_ast(nodes.FunctionDef):
for node in self.get_descendants(nodes.FunctionDef):
name_match = node.name == function
not_interface_declaration = not isinstance(
node.get_ancestor(), nodes.InterfaceDef
Expand All @@ -197,7 +180,7 @@ def find_state_variable_declaration_node_for_name(self, variable: str):
return None

def find_type_declaration_node_for_name(self, symbol: str):
for node in self.get_descendants_from_best_ast(self.custom_type_node_types):
for node in self.get_descendants(self.custom_type_node_types):
if node.name == symbol:
return node
if isinstance(node, nodes.EnumDef):
Expand All @@ -210,61 +193,54 @@ def find_type_declaration_node_for_name(self, symbol: str):
def find_nodes_referencing_enum(self, enum: str):
return_nodes = []

for node in self.get_descendants_from_best_ast(
nodes.AnnAssign, {"annotation.id": enum}
):
for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": enum}):
return_nodes.append(node)
for node in self.get_descendants_from_best_ast(
nodes.Attribute, {"value.id": enum}
):
for node in self.get_descendants(nodes.Attribute, {"value.id": enum}):
return_nodes.append(node)
for node in self.get_descendants_from_best_ast(
nodes.VariableDecl, {"annotation.id": enum}
):
for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": enum}):
return_nodes.append(node)
for node in self.get_descendants(nodes.FunctionDef, {"returns.id": enum}):
return_nodes.append(node)

return return_nodes

def find_nodes_referencing_enum_variant(self, enum: str, variant: str):
return self.get_descendants_from_best_ast(
return self.get_descendants(
nodes.Attribute, {"attr": variant, "value.id": enum}
)

def find_nodes_referencing_struct(self, struct: str):
return_nodes = []

for node in self.get_descendants_from_best_ast(
nodes.AnnAssign, {"annotation.id": struct}
):
for node in self.get_descendants(nodes.AnnAssign, {"annotation.id": struct}):
return_nodes.append(node)
for node in self.get_descendants_from_best_ast(nodes.Call, {"func.id": struct}):
for node in self.get_descendants(nodes.Call, {"func.id": struct}):
return_nodes.append(node)
for node in self.get_descendants_from_best_ast(
nodes.VariableDecl, {"annotation.id": struct}
):
for node in self.get_descendants(nodes.VariableDecl, {"annotation.id": struct}):
return_nodes.append(node)
for node in self.get_descendants_from_best_ast(
nodes.FunctionDef, {"returns.id": struct}
):
for node in self.get_descendants(nodes.FunctionDef, {"returns.id": struct}):
return_nodes.append(node)

return return_nodes

def find_top_level_node_at_pos(self, pos: Position) -> Optional[VyperNode]:
for node in self.get_children_from_best_ast():
if node.lineno <= pos.line and node.end_lineno >= pos.line:
for node in self.get_top_level_nodes():
if node.lineno <= pos.line and pos.line <= node.end_lineno:
return node

def find_nodes_referencing_symbol(self, symbol: str):
# this only runs on subtrees
return_nodes = []

for node in self.get_descendants_from_best_ast(nodes.Name, {"id": symbol}):
for node in self.get_descendants(nodes.Name, {"id": symbol}):
parent = node.get_ancestor()
if isinstance(parent, nodes.Dict):
# skip struct key names
if symbol not in [key.id for key in parent.keys]:
return_nodes.append(node)
elif isinstance(node.get_ancestor(), nodes.AnnAssign):
if node.id == node.get_ancestor().target.id:
elif isinstance(parent, nodes.AnnAssign):
if node.id == parent.target.id:
# lhs of variable declaration
continue
else:
return_nodes.append(node)
Expand All @@ -274,18 +250,6 @@ def find_nodes_referencing_symbol(self, symbol: str):
return return_nodes

def find_node_declaring_symbol(self, symbol: str):
for node in self.get_descendants_from_best_ast(
(nodes.AnnAssign, nodes.VariableDecl)
):
for node in self.get_descendants((nodes.AnnAssign, nodes.VariableDecl)):
if node.target.id == symbol:
return node

@classmethod
def create_new_instance(cls, ast):
# Create a new instance
new_instance = super(AST, cls).__new__(cls)
# Optionally, initialize the new instance
new_instance.ast_data = ast
new_instance.ast_data_unfolded = ast
new_instance.ast_data_folded = ast
return new_instance
7 changes: 4 additions & 3 deletions vyper_lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,19 @@

from .ast import AST

ast = AST()

server = LanguageServer("vyper", "v0.0.1")
navigator = ASTNavigator()
navigator = ASTNavigator(ast)

# AstAnalyzer is faster and better, but depends on the locally installed vyper version
# we should keep it around for now and use it when the contract version pragma is missing
# or if the version pragma matches the system version. its much faster so we can run it
# on every keystroke, with sourceanalyzer we should only run it on save
ast_analyzer = AstAnalyzer()
ast_analyzer = AstAnalyzer(ast)
completer = ast_analyzer
source_analyzer = SourceAnalyzer()

ast = AST()

debouncer = Debouncer(wait=0.5)

Expand Down
10 changes: 4 additions & 6 deletions vyper_lsp/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#
# the navigator should mainly return Ranges
class ASTNavigator:
def __init__(self, ast=None):
self.ast = ast or AST()
def __init__(self, ast):
self.ast = ast

def find_state_variable_declaration(self, word: str) -> Optional[Range]:
node = self.ast.find_state_variable_declaration_node_for_name(word)
Expand All @@ -30,7 +30,7 @@ def find_state_variable_declaration(self, word: str) -> Optional[Range]:
def find_variable_declaration_under_node(
self, node: VyperNode, symbol: str
) -> Optional[Range]:
decl_node = AST.create_new_instance(node).find_node_declaring_symbol(symbol)
decl_node = AST.from_node(node).find_node_declaring_symbol(symbol)
if decl_node:
range = Range(
start=Position(
Expand Down Expand Up @@ -120,9 +120,7 @@ def find_references(self, doc: Document, pos: Position) -> List[Range]:
)
references.append(range)
elif isinstance(top_level_node, FunctionDef):
refs = AST.create_new_instance(
top_level_node
).find_nodes_referencing_symbol(word)
refs = AST.from_node(top_level_node).find_nodes_referencing_symbol(word)
for ref in refs:
range = Range(
start=Position(line=ref.lineno - 1, character=ref.col_offset),
Expand Down

0 comments on commit bfb7986

Please sign in to comment.