Skip to content

Commit

Permalink
fix[lang]: use folded node for typechecking (#4365)
Browse files Browse the repository at this point in the history
This commit addresses several issues in the frontend where valid code
fails to compile because typechecking was performed on non-literal
AST nodes, specifically in `slice()` and `raw_log()` builtins. This is
fixed by using the folded node for typechecking instead.

Additionally, folding is applied for the argument to `convert()`, which
results in the typechecker being able to reject more invalid programs.
  • Loading branch information
tserg authored Nov 23, 2024
1 parent f38b61a commit bd876b1
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 6 deletions.
17 changes: 17 additions & 0 deletions tests/functional/codegen/features/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,23 @@ def foo():
assert log.topics == [event_id, topic1, topic2, topic3]


valid_list = [
# test constant folding inside raw_log
"""
topic: constant(bytes32) = 0x1212121212121210212801291212121212121210121212121212121212121212
@external
def foo():
raw_log([[topic]][0], b'')
"""
]


@pytest.mark.parametrize("code", valid_list)
def test_raw_log_pass(code):
assert compile_code(code) is not None


fail_list = [
(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ def foo():
def foo():
a: bytes32 = keccak256("ѓtest")
""",
# test constant folding inside of `convert()`
"""
BAR: constant(uint16) = 256
@external
def foo():
a: uint8 = convert(BAR, uint8)
""",
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def foo():
"""
a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699
""",
# test constant folding inside `convert()`
"""
BAR: constant(Bytes[5]) = b"vyper"
@external
def foo():
a: Bytes[4] = convert(BAR, Bytes[4])
""",
]


Expand Down
16 changes: 16 additions & 0 deletions tests/functional/syntax/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def foo(inp: Bytes[10]) -> Bytes[4]:
def foo() -> Bytes[10]:
return slice(b"badmintonzzz", 1, 10)
""",
# test constant folding for `slice()` `length` argument
"""
@external
def foo():
x: Bytes[32] = slice(msg.data, 0, 31 + 1)
""",
"""
@external
def foo(a: address):
x: Bytes[32] = slice(a.code, 0, 31 + 1)
""",
"""
@external
def foo(inp: Bytes[5], start: uint256) -> Bytes[3]:
return slice(inp, 0, 1 + 1)
""",
]


Expand Down
2 changes: 1 addition & 1 deletion vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def to_flag(expr, arg, out_typ):
def convert(expr, context):
assert len(expr.args) == 2, "bad typecheck: convert"

arg_ast = expr.args[0]
arg_ast = expr.args[0].reduced()
arg = Expr(arg_ast, context).ir_node
original_arg = arg

Expand Down
7 changes: 4 additions & 3 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def fetch_call_return(self, node):

arg = node.args[0]
start_expr = node.args[1]
length_expr = node.args[2]
length_expr = node.args[2].reduced()

# CMC 2022-03-22 NOTE slight code duplication with semantics/analysis/local
is_adhoc_slice = arg.get("attr") == "code" or (
Expand Down Expand Up @@ -1257,7 +1257,8 @@ def fetch_call_return(self, node):
def infer_arg_types(self, node, expected_return_typ=None):
self._validate_arg_types(node)

if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4:
arg = node.args[0].reduced()
if not isinstance(arg, vy_ast.List) or len(arg.elements) > 4:
raise InvalidType("Expecting a list of 0-4 topics as first argument", node.args[0])

# return a concrete type for `data`
Expand All @@ -1269,7 +1270,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
def build_IR(self, expr, args, kwargs, context):
context.check_is_not_constant(f"use {self._id}", expr)

topics_length = len(expr.args[0].elements)
topics_length = len(expr.args[0].reduced().elements)
topics = args[0].args
topics = [unwrap_location(topic) for topic in topics]

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> Non
parent = node.get_ancestor()
if isinstance(parent, vy_ast.Call):
ok_func = isinstance(parent.func, vy_ast.Name) and parent.func.id == "slice"
ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int)
ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int)
if ok_func and ok_args:
return

Expand All @@ -154,7 +154,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None:
"msg.data is only allowed inside of the slice, len or raw_call functions", node
)
if parent.get("func.id") == "slice":
ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int)
ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int)
if not ok_args:
raise StructureException(
"slice(msg.data) must use a compile-time constant for length argument", parent
Expand Down

0 comments on commit bd876b1

Please sign in to comment.