Skip to content

Commit

Permalink
feat: more frontend optimizations (#3785)
Browse files Browse the repository at this point in the history
- optimize `VyperNode.get_descendants()` and `get_children()`
- get rid of `sort_nodes()`, we can guarantee ordering the old fashioned
  way (topsort)
- optimize `VyperNode.__hash__()` and `VyperNode.__init__()`
- optimize `IntegerT.compare_type()`

optimizes front-end compilation time by another 25%
  • Loading branch information
charles-cooper authored Feb 19, 2024
1 parent 1fc819c commit 4177314
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 44 deletions.
83 changes: 44 additions & 39 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"node_id",
"_metadata",
"_original_node",
"_cache_descendants",
)
NODE_SRC_ATTRIBUTES = (
"col_offset",
Expand Down Expand Up @@ -211,15 +212,17 @@ def _node_filter(node, filters):
return True


def _sort_nodes(node_iterable):
# sorting function for VyperNode.get_children
def _apply_filters(node_iter, node_type, filters, reverse):
ret = node_iter
if node_type is not None:
ret = (i for i in ret if isinstance(i, node_type))
if filters is not None:
ret = (i for i in ret if _node_filter(i, filters))

def sortkey(key):
return float("inf") if key is None else key

return sorted(
node_iterable, key=lambda k: (sortkey(k.lineno), sortkey(k.col_offset), k.node_id)
)
ret = list(ret)
if reverse:
ret.reverse()
return ret


def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None:
Expand Down Expand Up @@ -257,10 +260,13 @@ class VyperNode:
"""

__slots__ = NODE_BASE_ATTRIBUTES + NODE_SRC_ATTRIBUTES

_public_slots = [i for i in __slots__ if not i.startswith("_")]
_only_empty_fields: tuple = ()
_translated_fields: dict = {}

def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
# this function is performance-sensitive
"""
AST node initializer method.
Expand All @@ -275,21 +281,19 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
Dictionary of fields to be included within the node.
"""
self.set_parent(parent)
self._children: set = set()
self._children: list = []
self._metadata: NodeMetadata = NodeMetadata()
self._original_node = None
self._cache_descendants = None

for field_name in NODE_SRC_ATTRIBUTES:
# when a source offset is not available, use the parent's source offset
value = kwargs.get(field_name)
if kwargs.get(field_name) is None:
value = kwargs.pop(field_name, None)
if value is None:
value = getattr(parent, field_name, None)
setattr(self, field_name, value)

for field_name, value in kwargs.items():
if field_name in NODE_SRC_ATTRIBUTES:
continue

if field_name in self._translated_fields:
field_name = self._translated_fields[field_name]

Expand All @@ -309,7 +313,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):

# add to children of parent last to ensure an accurate hash is generated
if parent is not None:
parent._children.add(self)
parent._children.append(self)

# set parent, can be useful when inserting copied nodes into the AST
def set_parent(self, parent: "VyperNode"):
Expand Down Expand Up @@ -338,7 +342,7 @@ def from_node(cls, node: "VyperNode", **kwargs) -> "VyperNode":
-------
Vyper node instance
"""
ast_struct = {i: getattr(node, i) for i in VyperNode.__slots__ if not i.startswith("_")}
ast_struct = {i: getattr(node, i) for i in VyperNode._public_slots}
ast_struct.update(ast_type=cls.__name__, **kwargs)
return cls(**ast_struct)

Expand All @@ -355,10 +359,11 @@ def get_fields(cls) -> set:
return set(i for i in slot_fields if not i.startswith("_"))

def __hash__(self):
values = [getattr(self, i, None) for i in VyperNode.__slots__ if not i.startswith("_")]
values = [getattr(self, i, None) for i in VyperNode._public_slots]
return hash(tuple(values))

def __deepcopy__(self, memo):
# default implementation of deepcopy is a hotspot
return pickle.loads(pickle.dumps(self))

def __eq__(self, other):
Expand Down Expand Up @@ -537,14 +542,7 @@ def get_children(
list
Child nodes matching the filter conditions.
"""
children = _sort_nodes(self._children)
if node_type is not None:
children = [i for i in children if isinstance(i, node_type)]
if reverse:
children.reverse()
if filters is None:
return children
return [i for i in children if _node_filter(i, filters)]
return _apply_filters(iter(self._children), node_type, filters, reverse)

def get_descendants(
self,
Expand All @@ -553,6 +551,7 @@ def get_descendants(
include_self: bool = False,
reverse: bool = False,
) -> list:
# this function is performance-sensitive
"""
Return a list of descendant nodes of this node which match the given filter(s).
Expand Down Expand Up @@ -589,19 +588,25 @@ def get_descendants(
list
Descendant nodes matching the filter conditions.
"""
children = self.get_children(node_type, filters)
for node in self.get_children():
children.extend(node.get_descendants(node_type, filters))
if (
include_self
and (not node_type or isinstance(self, node_type))
and _node_filter(self, filters)
):
children.append(self)
result = _sort_nodes(children)
if reverse:
result.reverse()
return result
ret = self._get_descendants(include_self)
return _apply_filters(ret, node_type, filters, reverse)

def _get_descendants(self, include_self=True):
# get descendants in topsort order
if self._cache_descendants is None:
ret = [self]
for node in self._children:
ret.extend(node._get_descendants())

self._cache_descendants = ret

ret = iter(self._cache_descendants)

if not include_self:
s = next(ret) # pop
assert s is self

return ret

def get(self, field_str: str) -> Any:
"""
Expand Down Expand Up @@ -669,7 +674,7 @@ def add_to_body(self, node: VyperNode) -> None:
self.body.append(node)
node._depth = self._depth + 1
node._parent = self
self._children.add(node)
self._children.append(node)

def remove_from_body(self, node: VyperNode) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo:


def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List:
# this function is a performance hotspot
"""
Return a list of common possible types between one or more nodes.
Expand Down
16 changes: 11 additions & 5 deletions vyper/semantics/types/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,17 @@ def abi_type(self) -> ABIType:
return ABI_GIntM(self.bits, self.is_signed)

def compare_type(self, other: VyperType) -> bool:
if not super().compare_type(other):
return False
assert isinstance(other, IntegerT) # mypy

return self.is_signed == other.is_signed and self.bits == other.bits
# this function is performance sensitive
# originally:
# if not super().compare_type(other):
# return False
# return self.is_signed == other.is_signed and self.bits == other.bits

return ( # noqa: E721
self.__class__ == other.__class__
and self.is_signed == other.is_signed # type: ignore
and self.bits == other.bits # type: ignore
)


# helper function for readability.
Expand Down

0 comments on commit 4177314

Please sign in to comment.