Skip to content

Commit

Permalink
refactor: reimplement AST folding (#3669)
Browse files Browse the repository at this point in the history
this commit reimplements AST folding. fundamentally, it changes AST
folding from a mutating pass to be an annotation pass. this brings
several benefits:
- typechecking is easier, because folding does not have to reason at all
  about types. type checking happens on both the folded and unfolded
  nodes, so intermediate values are type-checked.
- correctness in general is easier, because the AST is not mutated.
  there is also some incidental performance benefit, although that is
  not necessarily the focus here.
- the vyper frontend is now nearly mutation-free. only the getter AST
  expansion pass remains.

note that we cannot push folding past the typechecking stage entirely,
because some type checking operations depend on having folded values
(e.g., `range()` expressions, or type expressions with integer
parameters).

the approach taken in this commit is to change constant folding to be
annotating, rather than mutating. this way, type-checking can operate on
the original AST (and check for the folded values where needed).
intermediate values are also type-checked, so expressions like
`x: uint128 = 2**128 + 1 - 1` are caught by the typechecker.

summary of changes:
- `evaluate()` is renamed to `_try_fold()`. a new utility function
  called `get_folded_value()` caches folded values and is threaded
  through the codebase.
- `pre_typecheck` is added, which extracts `constant` variables and runs
  `get_folded_value()` on all nodes.
- a new `Modifiability` enum replaces the old (confusing) `is_constant`
  and `is_immutable` attributes on ExprInfo.
- `ExprInfo.is_transient` is removed, and handled by adding `TRANSIENT`
  to the `DataLocation` enum.
- the old `check_literal` and `check_kwargable` utility functions are
  replaced with a more general (and more correct) `check_modifiability`
  function
- several utility functions (ex. `_validate_numeric_bounds()`) related
  to ad-hoc type-checking (which would happen during constant folding)
  are removed.
- `CompilerData.vyper_module_folded` is renamed to
  `annotated_vyper_module`
- the AST output options are now `ast` and `annotated_ast`.
- `None` literals are now banned in AST validation instead of during
  analysis.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Dec 31, 2023
1 parent 87db3c1 commit 56c4c9d
Show file tree
Hide file tree
Showing 81 changed files with 1,464 additions and 1,230 deletions.
31 changes: 31 additions & 0 deletions tests/functional/builtins/codegen/test_keccak256.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from vyper.utils import hex_to_int


def test_hash_code(get_contract_with_gas_estimation, keccak):
hash_code = """
@external
Expand Down Expand Up @@ -80,3 +83,31 @@ def try32(inp: bytes32) -> bool:
assert c.tryy(b"\x35" * 33) is True

print("Passed KECCAK256 hash test")


def test_hash_constant_bytes32(get_contract_with_gas_estimation, keccak):
hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234"
code = f"""
FOO: constant(bytes32) = {hex_val}
BAR: constant(bytes32) = keccak256(FOO)
@external
def foo() -> bytes32:
x: bytes32 = BAR
return x
"""
c = get_contract_with_gas_estimation(code)
assert "0x" + c.foo().hex() == keccak(hex_to_int(hex_val).to_bytes(32, "big")).hex()


def test_hash_constant_string(get_contract_with_gas_estimation, keccak):
str_val = "0x1234567890123456789012345678901234567890123456789012345678901234"
code = f"""
FOO: constant(String[66]) = "{str_val}"
BAR: constant(bytes32) = keccak256(FOO)
@external
def foo() -> bytes32:
x: bytes32 = BAR
return x
"""
c = get_contract_with_gas_estimation(code)
assert "0x" + c.foo().hex() == keccak(str_val.encode()).hex()
30 changes: 30 additions & 0 deletions tests/functional/builtins/codegen/test_sha256.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from vyper.utils import hex_to_int

pytestmark = pytest.mark.usefixtures("memory_mocker")


Expand Down Expand Up @@ -77,3 +79,31 @@ def bar() -> bytes32:
c.set(test_val, transact={})
assert c.a() == test_val
assert c.bar() == hashlib.sha256(test_val).digest()


def test_sha256_constant_bytes32(get_contract_with_gas_estimation):
hex_val = "0x1234567890123456789012345678901234567890123456789012345678901234"
code = f"""
FOO: constant(bytes32) = {hex_val}
BAR: constant(bytes32) = sha256(FOO)
@external
def foo() -> bytes32:
x: bytes32 = BAR
return x
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() == hashlib.sha256(hex_to_int(hex_val).to_bytes(32, "big")).digest()


def test_sha256_constant_string(get_contract_with_gas_estimation):
str_val = "0x1234567890123456789012345678901234567890123456789012345678901234"
code = f"""
FOO: constant(String[66]) = "{str_val}"
BAR: constant(bytes32) = sha256(FOO)
@external
def foo() -> bytes32:
x: bytes32 = BAR
return x
"""
c = get_contract_with_gas_estimation(code)
assert c.foo() == hashlib.sha256(str_val.encode()).digest()
7 changes: 1 addition & 6 deletions tests/functional/builtins/codegen/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,11 @@ def bar() -> decimal:

def test_negation_int128(get_contract):
code = """
a: constant(int128) = -2**127
@external
def foo() -> int128:
return -2**127
a: constant(int128) = min_value(int128)
@external
def bar() -> int128:
return -(a+1)
"""
c = get_contract(code)
assert c.foo() == -(2**127)
assert c.bar() == 2**127 - 1
10 changes: 5 additions & 5 deletions tests/functional/builtins/folding/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from vyper.exceptions import OverflowException
from vyper.exceptions import InvalidType


@pytest.mark.fuzzing
Expand All @@ -21,7 +21,7 @@ def foo(a: int256) -> int256:

vyper_ast = vy_ast.parse_to_ast(f"abs({a})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["abs"].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node)

assert contract.foo(a) == new_node.value == abs(a)

Expand All @@ -35,7 +35,7 @@ def test_abs_upper_bound_folding(get_contract, a):
def foo(a: int256) -> int256:
return abs({a})
"""
with pytest.raises(OverflowException):
with pytest.raises(InvalidType):
get_contract(source)


Expand All @@ -55,7 +55,7 @@ def test_abs_lower_bound_folded(get_contract, tx_failed):
source = """
@external
def foo() -> int256:
return abs(-2**255)
return abs(min_value(int256))
"""
with pytest.raises(OverflowException):
with pytest.raises(InvalidType):
get_contract(source)
2 changes: 1 addition & 1 deletion tests/functional/builtins/folding/test_addmod_mulmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert contract.foo(a, b, c) == new_node.value
11 changes: 7 additions & 4 deletions tests/functional/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1)


# TODO: move this file to tests/unit/ast/nodes/test_fold_bitwise.py


@pytest.mark.fuzzing
@settings(max_examples=50)
@pytest.mark.parametrize("op", ["&", "|", "^"])
Expand All @@ -28,7 +31,7 @@ def foo(a: uint256, b: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.evaluate()
new_node = old_node.get_folded_value()

assert contract.foo(a, b) == new_node.value

Expand All @@ -49,7 +52,7 @@ def foo(a: uint256, b: uint256) -> uint256:
old_node = vyper_ast.body[0].value

try:
new_node = old_node.evaluate()
new_node = old_node.get_folded_value()
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
# more types are added).
Expand Down Expand Up @@ -79,7 +82,7 @@ def foo(a: int256, b: uint256) -> int256:
old_node = vyper_ast.body[0].value

try:
new_node = old_node.evaluate()
new_node = old_node.get_folded_value()
validate_expected_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
# compile-time will throw on OOB, runtime will wrap.
Expand All @@ -104,6 +107,6 @@ def foo(a: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"~{value}")
old_node = vyper_ast.body[0].value
new_node = old_node.evaluate()
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value
2 changes: 1 addition & 1 deletion tests/functional/builtins/folding/test_epsilon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def foo() -> {typ_name}:

vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["epsilon"].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node)

assert contract.foo() == new_node.value
2 changes: 1 addition & 1 deletion tests/functional/builtins/folding/test_floor_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def foo(a: decimal) -> int256:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert contract.foo(value) == new_node.value
4 changes: 2 additions & 2 deletions tests/functional/builtins/folding/test_fold_as_wei_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def foo(a: decimal) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.AsWeiValue().evaluate(old_node)
new_node = vy_fn.AsWeiValue()._try_fold(old_node)

assert contract.foo(value) == new_node.value

Expand All @@ -51,6 +51,6 @@ def foo(a: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.AsWeiValue().evaluate(old_node)
new_node = vy_fn.AsWeiValue()._try_fold(old_node)

assert contract.foo(value) == new_node.value
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_keccak_sha.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def foo(a: String[100]) -> bytes32:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert f"0x{contract.foo(value).hex()}" == new_node.value

Expand All @@ -41,7 +41,7 @@ def foo(a: Bytes[100]) -> bytes32:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert f"0x{contract.foo(value).hex()}" == new_node.value

Expand All @@ -62,6 +62,6 @@ def foo(a: Bytes[100]) -> bytes32:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert f"0x{contract.foo(value).hex()}" == new_node.value
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def foo(a: String[1024]) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"len('{value}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len().evaluate(old_node)
new_node = vy_fn.Len()._try_fold(old_node)

assert contract.foo(value) == new_node.value

Expand All @@ -35,7 +35,7 @@ def foo(a: Bytes[1024]) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len().evaluate(old_node)
new_node = vy_fn.Len()._try_fold(old_node)

assert contract.foo(value.encode()) == new_node.value

Expand All @@ -53,6 +53,6 @@ def foo(a: Bytes[1024]) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"len({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len().evaluate(old_node)
new_node = vy_fn.Len()._try_fold(old_node)

assert contract.foo(value) == new_node.value
6 changes: 3 additions & 3 deletions tests/functional/builtins/folding/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def foo(a: decimal, b: decimal) -> decimal:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert contract.foo(left, right) == new_node.value

Expand All @@ -50,7 +50,7 @@ def foo(a: int128, b: int128) -> int128:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert contract.foo(left, right) == new_node.value

Expand All @@ -69,6 +69,6 @@ def foo(a: uint256, b: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name].evaluate(old_node)
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)

assert contract.foo(left, right) == new_node.value
2 changes: 1 addition & 1 deletion tests/functional/builtins/folding/test_powmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def foo(a: uint256, b: uint256) -> uint256:

vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.PowMod256().evaluate(old_node)
new_node = vy_fn.PowMod256()._try_fold(old_node)

assert contract.foo(a, b) == new_node.value
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,48 @@ def foo(a: address = empty(address)):
def foo(a: int112 = min_value(int112)):
self.A = a
""",
"""
struct X:
x: int128
y: address
BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345})
@external
def out_literals(a: int128 = BAR.x + 1) -> X:
return BAR
""",
"""
struct X:
x: int128
y: address
struct Y:
x: X
y: uint256
BAR: constant(X) = X({x: 1, y: 0x0000000000000000000000000000000000012345})
FOO: constant(Y) = Y({x: BAR, y: 256})
@external
def out_literals(a: int128 = FOO.x.x + 1) -> Y:
return FOO
""",
"""
struct Bar:
a: bool
BAR: constant(Bar) = Bar({a: True})
@external
def foo(x: bool = True and not BAR.a):
pass
""",
"""
struct Bar:
a: uint256
BAR: constant(Bar) = Bar({ a: 123 })
@external
def foo(x: bool = BAR.a + 1 > 456):
pass
""",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_int128_too_long(get_contract, tx_failed):
contract_1 = """
@external
def foo() -> int256:
return (2**255)-1
return max_value(int256)
"""

c = get_contract(contract_1)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/codegen/test_call_graph_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def foo():
t = CompilerData(code)

# check the .called_functions data structure on foo() directly
foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0]
foo = t.annotated_vyper_module.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0]
foo_t = foo._metadata["func_type"]
assert [f.name for f in foo_t.called_functions] == func_names

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def ok() -> {typ}:
@external
def should_fail() -> int256:
return -2**255 # OOB for all int/uint types with less than 256 bits
return min_value(int256)
"""

code = f"""
Expand Down
Loading

0 comments on commit 56c4c9d

Please sign in to comment.