From bf7b3462fccfbf2f64a04542f8165b4efe51be5e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 15 Jan 2024 09:38:16 -0500 Subject: [PATCH] move all check_modifiability checks to be after a validate_expected_type (which calls validate_call_args) --- vyper/builtins/_signatures.py | 19 ++++++++++++------- vyper/semantics/analysis/local.py | 7 ++++++- vyper/semantics/types/function.py | 17 ++--------------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d2aefb2fd4..0244f0ad19 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -88,7 +88,9 @@ class BuiltinFunctionT(VyperType): _is_terminus = False # helper function to deal with TYPE_DEFINITIONs - def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: + def _validate_single( + self, arg: vy_ast.VyperNode, expected_type: VyperType, modifiability: Modifiability + ) -> None: # TODO using "TYPE_DEFINITION" is a kludge in derived classes, # refactor me. if expected_type == "TYPE_DEFINITION": @@ -97,6 +99,9 @@ def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> N type_from_annotation(arg) else: validate_expected_type(arg, expected_type) + if not check_modifiability(arg, modifiability): + # CMC 2024-01-15 TODO: change to StateAccessViolation + raise TypeMismatch("Value must be literal", arg) def _validate_arg_types(self, node: vy_ast.Call) -> None: num_args = len(self._inputs) # the number of args the signature indicates @@ -109,15 +114,15 @@ def _validate_arg_types(self, node: vy_ast.Call) -> None: validate_call_args(node, expect_num_args, list(self._kwargs.keys())) for arg, (_, expected) in zip(node.args, self._inputs): - self._validate_single(arg, expected) + self._validate_single(arg, expected, Modifiability.MODIFIABLE) for kwarg in node.keywords: kwarg_settings = self._kwargs[kwarg.arg] - if kwarg_settings.require_literal and not check_modifiability( - kwarg.value, Modifiability.CONSTANT - ): - raise TypeMismatch("Value must be literal", kwarg.value) - self._validate_single(kwarg.value, kwarg_settings.typ) + + modifiability = Modifiability.MODIFIABLE + if kwarg_settings.require_literal: + modifiability = Modifiability.CONSTANT + self._validate_single(kwarg.value, kwarg_settings.typ, modifiability) # typecheck varargs. we don't have type info from the signature, # so ensure that the types of the args can be inferred exactly. diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index c4af5b1e3a..e136078a16 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -19,6 +19,7 @@ from vyper.semantics.analysis.base import Modifiability, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.utils import ( + check_modifiability, get_common_types, get_exact_type_from_node, get_expr_info, @@ -214,7 +215,11 @@ def analyze(self): # visit default args assert self.func.n_keyword_args == len(self.fn_node.args.defaults) for kwarg in self.func.keyword_args: - self.expr_visitor.visit(kwarg.default_value, kwarg.typ) + value = kwarg.default_value + self.expr_visitor.visit(value, kwarg.typ) + # CMC 2024-01-15 move these check_modifiability checks into expr visitor + if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): + raise StateAccessViolation("Value must be literal or environment variable", value) def visit(self, node): super().visit(node) diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7c77560e49..39210deaab 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -13,20 +13,10 @@ CompilerPanic, FunctionDeclarationException, InvalidType, - StateAccessViolation, StructureException, ) -from vyper.semantics.analysis.base import ( - FunctionVisibility, - Modifiability, - StateMutability, - StorageSlot, -) -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot +from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type from vyper.semantics.data_locations import DataLocation from vyper.semantics.types.base import KwargSettings, VyperType from vyper.semantics.types.primitives import BoolT @@ -701,9 +691,6 @@ def _parse_args( positional_args.append(PositionalArg(argname, type_, ast_source=arg)) else: value = funcdef.args.defaults[i - n_positional_args] - if not check_modifiability(value, Modifiability.RUNTIME_CONSTANT): - raise StateAccessViolation("Value must be literal or environment variable", value) - validate_expected_type(value, type_) keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg)) argnames.add(argname)