Skip to content

Commit

Permalink
move prefold to new file
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Oct 21, 2023
1 parent b5b2063 commit e04ecf5
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 85 deletions.
80 changes: 80 additions & 0 deletions vyper/ast/pre_typecheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from decimal import Decimal
from typing import Any

from vyper import ast as vy_ast
from vyper.exceptions import UnfoldableNode, VyperException
from vyper.semantics.namespace import get_namespace

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.namespace
begins an import cycle.


def prefold(node: vy_ast.VyperNode) -> Any:
if isinstance(node, vy_ast.Attribute):
val = prefold(node.value)
# constant struct members
if isinstance(val, dict):
return val[node.attr]
return None
elif isinstance(node, vy_ast.BinOp):
assert isinstance(node, vy_ast.BinOp)
left = prefold(node.left)
right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.BoolOp):
values = [prefold(i) for i in node.values]
if not all(isinstance(v, bool) for v in values):
return None
return node.op._op(values)
elif isinstance(node, vy_ast.Call):
# constant structs
if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict):
return prefold(node.args[0])

from vyper.builtins.functions import DISPATCH_TABLE

# builtins
if isinstance(node.func, vy_ast.Name):
call_type = DISPATCH_TABLE.get(node.func.id)
if call_type and hasattr(call_type, "evaluate"):
try:
return call_type.evaluate(node).value # type: ignore
except (UnfoldableNode, VyperException):
pass
elif isinstance(node, vy_ast.Compare):
left = prefold(node.left)

if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
if not isinstance(node.right, (vy_ast.List, vy_ast.Tuple)):
return None

right = [prefold(i) for i in node.right.elements]
if left is None or len(set([type(i) for i in right])) > 1:
return None
return node.op._op(left, right)

right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.Constant):
return node.value
elif isinstance(node, vy_ast.Dict):
values = [prefold(v) for v in node.values]
if any(v is None for v in values):
return None
return {k.id: v for (k, v) in zip(node.keys, values)}
elif isinstance(node, (vy_ast.List, vy_ast.Tuple)):
val = [prefold(e) for e in node.elements]
if None in val:
return None
return val
elif isinstance(node, vy_ast.Name):
ns = get_namespace()
return ns._constants.get(node.id, None)
elif isinstance(node, vy_ast.UnaryOp):
operand = prefold(node.operand)
if not isinstance(operand, int):
return None
return node.op._op(operand)

return None
3 changes: 2 additions & 1 deletion vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import functools
from typing import Dict

from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.ast.validation import validate_call_args
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch
from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import prefold, type_from_annotation
from vyper.semantics.types.utils import type_from_annotation


def process_arg(arg, expected_arg_type, context):
Expand Down
3 changes: 2 additions & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vyper import ast as vy_ast
from vyper.abi_types import ABI_Tuple
from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.ast.validation import validate_call_args
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
Expand Down Expand Up @@ -82,7 +83,7 @@
UINT8_T,
UINT256_T,
)
from vyper.semantics.types.utils import prefold, type_from_annotation
from vyper.semantics.types.utils import type_from_annotation
from vyper.utils import (
DECIMAL_DIVISOR,
EIP_170_LIMIT,
Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from vyper import ast as vy_ast
from vyper.ast.metadata import NodeMetadata
from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.ast.validation import validate_call_args
from vyper.exceptions import (
ExceptionList,
Expand Down Expand Up @@ -51,7 +52,7 @@
is_type_t,
)
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability
from vyper.semantics.types.utils import prefold, type_from_annotation
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import vyper.builtins.interfaces
from vyper import ast as vy_ast
from vyper.ast.pre_typecheck import prefold
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CallViolation,
Expand All @@ -28,7 +29,7 @@
from vyper.semantics.namespace import Namespace, get_namespace
from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.utils import prefold, type_from_annotation
from vyper.semantics.types.utils import type_from_annotation
from vyper.typing import InterfaceDict


Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Callable, List

from vyper import ast as vy_ast
from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.exceptions import (
CompilerPanic,
InvalidLiteral,
Expand All @@ -24,7 +25,6 @@
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
from vyper.semantics.types.utils import prefold
from vyper.utils import checksum_encode, int_to_fourbytes


Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

from vyper import ast as vy_ast
from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType
from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.exceptions import ArrayIndexException, InvalidType, StructureException
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.primitives import IntegerT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.utils import get_index_value, prefold, type_from_annotation
from vyper.semantics.types.utils import get_index_value, type_from_annotation


class _SubscriptableT(VyperType):
Expand Down
80 changes: 1 addition & 79 deletions vyper/semantics/types/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from decimal import Decimal
from typing import Any

from vyper import ast as vy_ast
from vyper.ast.pre_typecheck import prefold

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.ast.pre_typecheck
begins an import cycle.
from vyper.exceptions import (
ArrayIndexException,
InstantiationException,
InvalidType,
StructureException,
UnfoldableNode,
UnknownType,
VyperException,
)
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -135,80 +131,6 @@ def _failwith(type_name):
return typ_


def prefold(node: vy_ast.VyperNode) -> Any:
if isinstance(node, vy_ast.Attribute):
val = prefold(node.value)
# constant struct members
if isinstance(val, dict):
return val[node.attr]
return None
elif isinstance(node, vy_ast.BinOp):
assert isinstance(node, vy_ast.BinOp)
left = prefold(node.left)
right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.BoolOp):
values = [prefold(i) for i in node.values]
if not all(isinstance(v, bool) for v in values):
return None
return node.op._op(values)
elif isinstance(node, vy_ast.Call):
# constant structs
if len(node.args) == 1 and isinstance(node.args[0], vy_ast.Dict):
return prefold(node.args[0])

from vyper.builtins.functions import DISPATCH_TABLE

# builtins
if isinstance(node.func, vy_ast.Name):
call_type = DISPATCH_TABLE.get(node.func.id)
if call_type and hasattr(call_type, "evaluate"):
try:
return call_type.evaluate(node).value # type: ignore
except (UnfoldableNode, VyperException):
pass
elif isinstance(node, vy_ast.Compare):
left = prefold(node.left)

if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)):
if not isinstance(node.right, (vy_ast.List, vy_ast.Tuple)):
return None

right = [prefold(i) for i in node.right.elements]
if left is None or len(set([type(i) for i in right])) > 1:
return None
return node.op._op(left, right)

right = prefold(node.right)
if not (isinstance(left, type(right)) and isinstance(left, (int, Decimal))):
return None
return node.op._op(left, right)
elif isinstance(node, vy_ast.Constant):
return node.value
elif isinstance(node, vy_ast.Dict):
values = [prefold(v) for v in node.values]
if any(v is None for v in values):
return None
return {k.id: v for (k, v) in zip(node.keys, values)}
elif isinstance(node, (vy_ast.List, vy_ast.Tuple)):
val = [prefold(e) for e in node.elements]
if None in val:
return None
return val
elif isinstance(node, vy_ast.Name):
ns = get_namespace()
return ns._constants.get(node.id, None)
elif isinstance(node, vy_ast.UnaryOp):
operand = prefold(node.operand)
if not isinstance(operand, int):
return None
return node.op._op(operand)

return None


def get_index_value(node: vy_ast.Index, constants: dict) -> int:
"""
Return the literal value for a `Subscript` index.
Expand Down

0 comments on commit e04ecf5

Please sign in to comment.