From bd876b114bc34643a7d210b319f69642ce80f018 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Sun, 24 Nov 2024 04:34:57 +0800 Subject: [PATCH] fix[lang]: use folded node for typechecking (#4365) This commit addresses several issues in the frontend where valid code fails to compile because typechecking was performed on non-literal AST nodes, specifically in `slice()` and `raw_log()` builtins. This is fixed by using the folded node for typechecking instead. Additionally, folding is applied for the argument to `convert()`, which results in the typechecker being able to reject more invalid programs. --- .../functional/codegen/features/test_logging.py | 17 +++++++++++++++++ .../test_invalid_literal_exception.py | 8 ++++++++ .../exceptions/test_type_mismatch_exception.py | 8 ++++++++ tests/functional/syntax/test_slice.py | 16 ++++++++++++++++ vyper/builtins/_convert.py | 2 +- vyper/builtins/functions.py | 7 ++++--- vyper/semantics/analysis/local.py | 4 ++-- 7 files changed, 56 insertions(+), 6 deletions(-) diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 2bb646e6ef..87d848fae5 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -1254,6 +1254,23 @@ def foo(): assert log.topics == [event_id, topic1, topic2, topic3] +valid_list = [ + # test constant folding inside raw_log + """ +topic: constant(bytes32) = 0x1212121212121210212801291212121212121210121212121212121212121212 + +@external +def foo(): + raw_log([[topic]][0], b'') + """ +] + + +@pytest.mark.parametrize("code", valid_list) +def test_raw_log_pass(code): + assert compile_code(code) is not None + + fail_list = [ ( """ diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index a0cf10ad02..f3fd73fbfc 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -36,6 +36,14 @@ def foo(): def foo(): a: bytes32 = keccak256("ั“test") """, + # test constant folding inside of `convert()` + """ +BAR: constant(uint16) = 256 + +@external +def foo(): + a: uint8 = convert(BAR, uint8) + """, ] diff --git a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py index 76c5c481f0..63e0eb6d11 100644 --- a/tests/functional/syntax/exceptions/test_type_mismatch_exception.py +++ b/tests/functional/syntax/exceptions/test_type_mismatch_exception.py @@ -47,6 +47,14 @@ def foo(): """ a: constant(address) = 0x3cd751e6b0078be393132286c442345e5dc49699 """, + # test constant folding inside `convert()` + """ +BAR: constant(Bytes[5]) = b"vyper" + +@external +def foo(): + a: Bytes[4] = convert(BAR, Bytes[4]) + """, ] diff --git a/tests/functional/syntax/test_slice.py b/tests/functional/syntax/test_slice.py index 6bb666527e..6a091c9da3 100644 --- a/tests/functional/syntax/test_slice.py +++ b/tests/functional/syntax/test_slice.py @@ -53,6 +53,22 @@ def foo(inp: Bytes[10]) -> Bytes[4]: def foo() -> Bytes[10]: return slice(b"badmintonzzz", 1, 10) """, + # test constant folding for `slice()` `length` argument + """ +@external +def foo(): + x: Bytes[32] = slice(msg.data, 0, 31 + 1) + """, + """ +@external +def foo(a: address): + x: Bytes[32] = slice(a.code, 0, 31 + 1) + """, + """ +@external +def foo(inp: Bytes[5], start: uint256) -> Bytes[3]: + return slice(inp, 0, 1 + 1) + """, ] diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index aa53dee429..a494e4a344 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -463,7 +463,7 @@ def to_flag(expr, arg, out_typ): def convert(expr, context): assert len(expr.args) == 2, "bad typecheck: convert" - arg_ast = expr.args[0] + arg_ast = expr.args[0].reduced() arg = Expr(arg_ast, context).ir_node original_arg = arg diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 674efda7ce..9ed74b8cfe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -305,7 +305,7 @@ def fetch_call_return(self, node): arg = node.args[0] start_expr = node.args[1] - length_expr = node.args[2] + length_expr = node.args[2].reduced() # CMC 2022-03-22 NOTE slight code duplication with semantics/analysis/local is_adhoc_slice = arg.get("attr") == "code" or ( @@ -1257,7 +1257,8 @@ def fetch_call_return(self, node): def infer_arg_types(self, node, expected_return_typ=None): self._validate_arg_types(node) - if not isinstance(node.args[0], vy_ast.List) or len(node.args[0].elements) > 4: + arg = node.args[0].reduced() + if not isinstance(arg, vy_ast.List) or len(arg.elements) > 4: raise InvalidType("Expecting a list of 0-4 topics as first argument", node.args[0]) # return a concrete type for `data` @@ -1269,7 +1270,7 @@ def infer_arg_types(self, node, expected_return_typ=None): def build_IR(self, expr, args, kwargs, context): context.check_is_not_constant(f"use {self._id}", expr) - topics_length = len(expr.args[0].elements) + topics_length = len(expr.args[0].reduced().elements) topics = args[0].args topics = [unwrap_location(topic) for topic in topics] diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 809c6532c6..461326d72d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -136,7 +136,7 @@ def _validate_address_code(node: vy_ast.Attribute, value_type: VyperType) -> Non parent = node.get_ancestor() if isinstance(parent, vy_ast.Call): ok_func = isinstance(parent.func, vy_ast.Name) and parent.func.id == "slice" - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if ok_func and ok_args: return @@ -154,7 +154,7 @@ def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None: "msg.data is only allowed inside of the slice, len or raw_call functions", node ) if parent.get("func.id") == "slice": - ok_args = len(parent.args) == 3 and isinstance(parent.args[2], vy_ast.Int) + ok_args = len(parent.args) == 3 and isinstance(parent.args[2].reduced(), vy_ast.Int) if not ok_args: raise StructureException( "slice(msg.data) must use a compile-time constant for length argument", parent