From a59d575fb8adf6d9379422ab72417c26d806b754 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 19 Dec 2023 19:19:38 -0500 Subject: [PATCH 01/27] wip - module variables kill ModuleInfo, can just use VarInfo now --- vyper/codegen/expr.py | 47 +++++++++++++++--------------- vyper/compiler/input_bundle.py | 5 ++++ vyper/semantics/analysis/base.py | 20 +------------ vyper/semantics/analysis/module.py | 7 +++-- vyper/semantics/analysis/utils.py | 5 +--- vyper/semantics/types/module.py | 22 ++++++++++++-- vyper/semantics/types/user.py | 3 +- vyper/semantics/types/utils.py | 10 ++----- 8 files changed, 59 insertions(+), 60 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index d5ca5aceee..87e3dcd74d 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -242,6 +242,7 @@ def parse_Attribute(self): eval_code = ["extcodesize", addr] output_type = UINT256_T else: + assert self.expr.attr == "is_contract" eval_code = ["gt", ["extcodesize", addr], 0] output_type = BoolT() return IRnode.from_list(eval_code, typ=output_type) @@ -258,20 +259,6 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( @@ -327,17 +314,29 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) + # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: module-level variable + if isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": + varinfo = self.context.globals[self.expr.attr] + location = TRANSIENT if varinfo.is_transient else STORAGE + + ret = IRnode.from_list(varinfo.position.position, typ=varinfo.typ, location=location) + ret._referenced_variables = {varinfo} + + return ret + + sub = Expr(self.expr.value, self.context).ir_node + + # interface type + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 27170f0a56..8ec82bd918 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -70,6 +70,11 @@ def __init__(self, search_paths): # share the same lifetime as this input bundle. self._cache = lambda: None + # very strict equality! we don't want to accidentally compare + # two input bundles as being equal when they aren't the same. + def __eq__(self, other): + return self is other + def _normalize_path(self, path): raise NotImplementedError(f"not implemented! {self.__class__}._normalize_path()") diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 4d1b1cdbab..daebc4cef7 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -154,23 +154,9 @@ class AnalysisResult: pass -@dataclass -class ModuleInfo(AnalysisResult): - module_t: "ModuleT" - - @property - def module_node(self): - return self.module_t._module - - # duck type, conform to interface of VarInfo and ExprInfo - @property - def typ(self): - return self.module_t - - @dataclass class ImportInfo(AnalysisResult): - typ: Union[ModuleInfo, "InterfaceT"] + typ: Union["ModuleT", "InterfaceT"] alias: str # the name in the namespace qualified_module_name: str # for error messages # source_id: int @@ -245,10 +231,6 @@ def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": is_immutable=var_info.is_immutable, ) - @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) - def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 7aa661aec3..99d105f2cb 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -20,7 +20,7 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ImportInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import validate_functions @@ -277,6 +277,7 @@ def visit_VariableDecl(self, node): node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ + # TODO: maybe this code can be removed def _finalize(): # add the variable name to `self` namespace if the variable is either # 1. a public constant or immutable; or @@ -396,7 +397,7 @@ def _add_import( ) self.namespace[alias] = module_info - # load an InterfaceT or ModuleInfo from an import. + # load an InterfaceT or ModuleT from an import. # raises FileNotFoundError def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: # the directory this (currently being analyzed) module is in @@ -437,7 +438,7 @@ def _load_import_helper( is_interface=False, ) - return ModuleInfo(module_t) + return module_t except FileNotFoundError as e: # escape `e` from the block scope, it can make things diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 1785afd92d..5ecc89c612 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -71,9 +71,6 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, VarInfo): return ExprInfo.from_varinfo(info) - if isinstance(info, ModuleInfo): - return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) if isinstance(node, vy_ast.Attribute): diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 4622482951..7218f4feda 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,6 +1,7 @@ from functools import cached_property from typing import Optional +from vyper.semantics.data_locations import DataLocation from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args @@ -250,6 +251,16 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + + # disallow everything but storage + _invalid_locations = ( + DataLocation.UNSET, + DataLocation.CALLDATA, + DataLocation.CODE, + DataLocation.MEMORY, + ) + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -289,8 +300,15 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": - return self._helper.get_member(key, node) + def __repr__(self): + resolved_path = self._module.resolved_path + if self._id == resolved_path: + return f"module {self._id}" + else: + return f"module {self._id} (loaded from '{self._module.resolved_path}')" + + def get_type_member(self, attr: str, node: vy_ast.VyperNode) -> VyperType: + return self._helper.get_type_member(attr, node) # this is a property, because the function set changes after AST expansion @property diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index ef7e1d0eb4..92639e8377 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -355,7 +355,8 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": return cls(struct_name, members, ast_def=base_node) def __repr__(self): - return f"{self._id} declaration object" + arg_types = ",".join(repr(t) for t in self.members.values()) + return f"struct {self._id}({arg_types})" @property def size_in_bytes(self): diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index 8d68a9fa01..54e8c21976 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -123,18 +123,14 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: raise InvalidType(err_msg, node) try: - module_or_interface = namespace[node.value.id] # type: ignore + module_or_interface = namespace[node.value.id] except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface - - if not interface._attribute_in_annotation: + if not module_or_interface._attribute_in_annotation: raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + type_t = module_or_interface.get_type_member(node.attr, node) assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef From e353a15e3ec471aa049189eef15ef744e7804f34 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 09:08:32 -0500 Subject: [PATCH 02/27] add size_in_bytes to module --- vyper/semantics/types/module.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 7218f4feda..b38043da41 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -307,6 +307,10 @@ def __repr__(self): else: return f"module {self._id} (loaded from '{self._module.resolved_path}')" + @property + def size_in_bytes(self): + return sum(v.typ.size_in_bytes for v in self.variables.values()) + def get_type_member(self, attr: str, node: vy_ast.VyperNode) -> VyperType: return self._helper.get_type_member(attr, node) From 33dffafcc439a9b0c2635ee8bee75a53b3f4c351 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 15 Dec 2023 17:22:02 -0500 Subject: [PATCH 03/27] wip - storage allocator --- vyper/codegen/expr.py | 33 ++++++++---- vyper/semantics/analysis/base.py | 47 +++++++++++------ vyper/semantics/analysis/data_positions.py | 61 +++++++++++++++------- vyper/semantics/analysis/local.py | 6 +++ 4 files changed, 102 insertions(+), 45 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 87e3dcd74d..c44b924fb8 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -36,6 +36,8 @@ VyperException, tag_exceptions, ) +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.utils import get_expr_info from vyper.semantics.types import ( AddressT, BoolT, @@ -231,8 +233,9 @@ def parse_Attribute(self): else: seq = ["balance", addr] return IRnode.from_list(seq, typ=UINT256_T) + # x.codesize: codesize of address x - elif self.expr.attr == "codesize" or self.expr.attr == "is_contract": + if self.expr.attr == "codesize" or self.expr.attr == "is_contract": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): if self.expr.attr == "codesize": @@ -246,13 +249,15 @@ def parse_Attribute(self): eval_code = ["gt", ["extcodesize", addr], 0] output_type = BoolT() return IRnode.from_list(eval_code, typ=output_type) + # x.codehash: keccak of address x - elif self.expr.attr == "codehash": + if self.expr.attr == "codehash": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T) + # x.code: codecopy/extcodecopy of address x - elif self.expr.attr == "code": + if self.expr.attr == "code": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): # These adhoc nodes will be replaced with a valid node in `Slice.build_IR` @@ -261,7 +266,7 @@ def parse_Attribute(self): return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) # Reserved keywords - elif ( + if ( isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES ): key = f"{self.expr.value.id}.{self.expr.attr}" @@ -315,26 +320,32 @@ def parse_Attribute(self): ) return IRnode.from_list(["chainid"], typ=UINT256_T) - # Other variables - - # self.x: module-level variable - if isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] + # self.x: global storage variable or immutable + if (varinfo := self.expr._metadata.get("variable_access")) is not None: + assert isinstance(varinfo, VarInfo) + # TODO: handle immutables location = TRANSIENT if varinfo.is_transient else STORAGE - ret = IRnode.from_list(varinfo.position.position, typ=varinfo.typ, location=location) + module_t = self.context.module_ctx + ret = IRnode.from_list( + varinfo.position.position, + typ=varinfo.typ, + location=location, + annotation=self.expr.node_source_code, + ) ret._referenced_variables = {varinfo} return ret + # if we have gotten here, it's an instance of an interface or struct sub = Expr(self.expr.value, self.context).ir_node - # interface type if isinstance(sub.typ, InterfaceT): # MyInterface.address assert self.expr.attr == "address" sub.typ = typ return sub + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: return get_element_ptr(sub, self.expr.attr) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index daebc4cef7..c3ee493cc6 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass +from dataclasses import dataclass, asdict from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast @@ -128,7 +128,6 @@ def __repr__(self): class StorageSlot(DataPosition): - __slots__ = ("position",) _location = DataLocation.STORAGE def __init__(self, position): @@ -163,6 +162,9 @@ class ImportInfo(AnalysisResult): input_bundle: InputBundle node: vy_ast.VyperNode + def __eq__(self, other): + return self is other + @dataclass class VarInfo: @@ -189,19 +191,31 @@ def __hash__(self): return hash(id(self)) def __post_init__(self): - self._modification_count = 0 + self._reads = [] + self._writes = [] + self.position = None # the location provided by the allocator def set_position(self, position: DataPosition) -> None: - if hasattr(self, "position"): + if self.position is not None: raise CompilerPanic("Position was already assigned") if self.location != position._location: - if self.location == DataLocation.UNSET: - self.location = position._location - else: - raise CompilerPanic("Incompatible locations") + raise CompilerPanic(f"Incompatible locations: {self.location}, {position._location}") self.position = position +# class for imported variables. this is important for distinguishing +# how in the import graph variables were imported. +@dataclass(kw_only=True) +class ImportedVariable(VarInfo): + import_info: ImportInfo # a reference to how this variable was imported + + @classmethod + def from_varinfo(cls, var_info: Union["ImportedVariable",VarInfo], import_info: ImportInfo): + dict_fields = asdict(var_info) + + dict_fields["import_info"] = import_info + return cls(**dict_fields) + @dataclass class ExprInfo: """ @@ -209,26 +223,26 @@ class ExprInfo: """ typ: VyperType - var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET is_constant: bool = False is_immutable: bool = False + _var_info: Optional[VarInfo] = None def __post_init__(self): should_match = ("typ", "location", "is_constant", "is_immutable") - if self.var_info is not None: + if self._var_info is not None: for attr in should_match: - if getattr(self.var_info, attr) != getattr(self, attr): + if getattr(self._var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") @classmethod def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": return cls( var_info.typ, - var_info=var_info, location=var_info.location, is_constant=var_info.is_constant, is_immutable=var_info.is_immutable, + _var_info=var_info, ) def copy_with_type(self, typ: VyperType) -> "ExprInfo": @@ -264,12 +278,15 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.is_immutable: if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": raise ImmutableViolation("Immutable value cannot be written to", node) - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore + + if len(self._var_info._writes) > 0: raise ImmutableViolation( "Immutable value cannot be modified after assignment", node ) - self.var_info._modification_count += 1 # type: ignore + self._var_info._writes.append(node) + + if self.location == DataLocation.STORAGE: + self._var_info._writes.append(node) if isinstance(node, vy_ast.AugAssign): self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..b4d83ba02b 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -6,6 +6,7 @@ from vyper.semantics.analysis.base import CodeOffset, StorageSlot from vyper.typing import StorageLayout from vyper.utils import ceil32 +from vyper.semantics.analysis.base import ImportedVariable def set_data_positions( @@ -111,12 +112,11 @@ def set_storage_slots_with_overrides( ) # Iterate through variables - for node in vyper_module.get_children(vy_ast.VariableDecl): - # Ignore immutable parameters - if node.get("annotation.func.id") == "immutable": + for varinfo in vyper_module._metadata["type"].variables.values(): + if varinfo.is_immutable: continue - varinfo = node.target._metadata["varinfo"] + assert isinstance((node := varinfo.decl_node), vy_ast.VariableDecl) # Expect to find this variable within the storage layout overrides if node.target.id in storage_layout_overrides: @@ -127,6 +127,7 @@ def set_storage_slots_with_overrides( reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) varinfo.set_position(StorageSlot(var_slot)) + # TODO: FIXME! ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} else: raise StorageLayoutException( @@ -164,52 +165,58 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: ret: Dict[str, Dict] = {} - for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] + for funcdef in vyper_module.get_children(vy_ast.FunctionDef): + type_ = funcdef._metadata["func_type"] if type_.nonreentrant is None: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + keyname = f"nonreentrant.{type_.nonreentrant}" # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] + if keyname in ret: + _slot = ret[keyname]["slot"] type_.set_reentrancy_key_position(StorageSlot(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + slot = allocator.allocate_slot(1, keyname) type_.set_reentrancy_key_position(StorageSlot(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} + ret[keyname] = {"type": "nonreentrant lock", "slot": slot} - for node in vyper_module.get_children(vy_ast.VariableDecl): + for varinfo in vyper_module._metadata["type"].variables.values(): # skip non-storage variables - if node.is_constant or node.is_immutable: + if varinfo.is_constant or varinfo.is_immutable: continue - varinfo = node.target._metadata["varinfo"] type_ = varinfo.typ + assert isinstance((vardecl := varinfo.decl_node), vy_ast.VariableDecl) + + varname = vardecl.target.id + # CMC 2021-07-23 note that HashMaps get assigned a slot here. # I'm not sure if it's safe to avoid allocating that slot # for HashMaps because downstream code might use the slot # ID as a salt. n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) + slot = allocator.allocate_slot(n_slots, varname) varinfo.set_position(StorageSlot(slot)) + if isinstance(varinfo, ImportedVariable): + varname = varinfo.import_info.qualified_module_name + "." + varname + assert varname not in ret # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} + ret[varname] = {"type": str(type_), "slot": slot} return ret @@ -226,8 +233,10 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: ret = {} offset = 0 - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] + for varinfo in vyper_module._metadata["type"].variables.values(): + if not varinfo.is_immutable: + continue + type_ = varinfo.typ varinfo.set_position(CodeOffset(offset)) @@ -235,7 +244,21 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} + output_dict = {"type": str(type_), "offset": offset, "length": len_} + + # put it into the storage layout + + # sanity check. there are ways to construct varinfo with no + # decl_node but they shouldn't make it here + assert isinstance(varinfo.decl_node, vy_ast.VariableDecl) + name = varinfo.decl_node.target.id + + # XXX: FIXME + if isinstance(varinfo, ImportedVariable): + name = varinfo.import_info.qualified_module_name + "." + name + + assert name not in ret + ret[name] = output_dict offset += len_ diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 974c14f261..a11a58a99a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -608,6 +608,12 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + # tag variable accesses + info = get_expr_info(node) + if (var_info := info._var_info) is not None: + node._metadata["variable_access"] = var_info + var_info._reads.append(node) + def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) From 654256bccd98d93535f485c89062e848c063fc8a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 09:08:54 -0500 Subject: [PATCH 04/27] remove ImportedVariable thing --- vyper/semantics/analysis/base.py | 13 ------------- vyper/semantics/analysis/data_positions.py | 7 ------- 2 files changed, 20 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index c3ee493cc6..b5e10f0a95 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -203,19 +203,6 @@ def set_position(self, position: DataPosition) -> None: self.position = position -# class for imported variables. this is important for distinguishing -# how in the import graph variables were imported. -@dataclass(kw_only=True) -class ImportedVariable(VarInfo): - import_info: ImportInfo # a reference to how this variable was imported - - @classmethod - def from_varinfo(cls, var_info: Union["ImportedVariable",VarInfo], import_info: ImportInfo): - dict_fields = asdict(var_info) - - dict_fields["import_info"] = import_info - return cls(**dict_fields) - @dataclass class ExprInfo: """ diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index b4d83ba02b..60f2bc90d2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -6,7 +6,6 @@ from vyper.semantics.analysis.base import CodeOffset, StorageSlot from vyper.typing import StorageLayout from vyper.utils import ceil32 -from vyper.semantics.analysis.base import ImportedVariable def set_data_positions( @@ -211,8 +210,6 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: varinfo.set_position(StorageSlot(slot)) - if isinstance(varinfo, ImportedVariable): - varname = varinfo.import_info.qualified_module_name + "." + varname assert varname not in ret # this could have better typing but leave it untyped until # we understand the use case better @@ -253,10 +250,6 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: assert isinstance(varinfo.decl_node, vy_ast.VariableDecl) name = varinfo.decl_node.target.id - # XXX: FIXME - if isinstance(varinfo, ImportedVariable): - name = varinfo.import_info.qualified_module_name + "." + name - assert name not in ret ret[name] = output_dict From bcc03d6112e61779957d63105e26744d4aa72650 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 18:40:23 -0500 Subject: [PATCH 05/27] wip - add get_element_ptr for module, fix some logic in Expr.parse_Attribute --- vyper/codegen/context.py | 1 + vyper/codegen/core.py | 33 +++++++++++++++++++++++++++++++ vyper/codegen/expr.py | 33 +++++++++++++++++-------------- vyper/exceptions.py | 4 ++-- vyper/semantics/types/__init__.py | 2 +- 5 files changed, 55 insertions(+), 18 deletions(-) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index dea30faabc..5d1ab0b8c1 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -76,6 +76,7 @@ def __init__( self.in_range_expr = False # store module context + # note the module_ctx is the type of the current compilation target! self.module_ctx = module_ctx # full function type diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index e1d3ea12b4..4f1eb5eaf6 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -12,6 +12,7 @@ BoolT, BytesM_T, BytesT, + ModuleT, DArrayT, DecimalT, HashMapT, @@ -441,6 +442,35 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp=True): annotation=f"{parent}{ofst}", ) +# get a variable out of a module +def _get_element_ptr_module(parent, key): + # note that this implementation is substantially similar to + # the StructT pathway through get_element_ptr_tuplelike and + # has potential to be refactored. + module_t = parent.typ + assert isinstance(module_t, ModuleT) + + assert isinstance(key, str) + typ = module_t.variables[key].typ + attrs = list(module_t.variables.keys()) + index = attrs.index(key) + annotation = key + + ofst = 0 # offset from parent start + + assert parent.location == STORAGE, parent.location + + for i in range(index): + ofst += module_t.variables[attrs[i]].typ.storage_size_in_words + + return IRnode.from_list( + add_ofst(parent, ofst), + typ=typ, + location=parent.location, + encoding=parent.encoding, + annotation=annotation, + ) + # TODO simplify this code, especially the ABI decoding def _get_element_ptr_tuplelike(parent, key): @@ -590,6 +620,9 @@ def get_element_ptr(parent, key, array_bounds_check=True): if is_tuple_like(typ): ret = _get_element_ptr_tuplelike(parent, key) + elif isinstance(typ, ModuleT): + ret = _get_element_ptr_module(parent, key) + elif isinstance(typ, HashMapT): ret = _get_element_ptr_mapping(parent, key) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index c44b924fb8..2a9296b90c 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -171,7 +171,8 @@ def parse_NameConstant(self): # Variable names def parse_Name(self): if self.expr.id == "self": - return IRnode.from_list(["address"], typ=AddressT()) + # TODO: have `self` return a module type + return IRnode.from_list(["self"], typ=AddressT()) elif self.expr.id in self.context.vars: var = self.context.vars[self.expr.id] ret = IRnode.from_list( @@ -221,7 +222,7 @@ def parse_Attribute(self): return IRnode.from_list(value, typ=typ) # x.balance: balance of address x - if self.expr.attr == "balance": + elif self.expr.attr == "balance": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): if ( @@ -235,7 +236,7 @@ def parse_Attribute(self): return IRnode.from_list(seq, typ=UINT256_T) # x.codesize: codesize of address x - if self.expr.attr == "codesize" or self.expr.attr == "is_contract": + elif self.expr.attr == "codesize" or self.expr.attr == "is_contract": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): if self.expr.attr == "codesize": @@ -251,13 +252,13 @@ def parse_Attribute(self): return IRnode.from_list(eval_code, typ=output_type) # x.codehash: keccak of address x - if self.expr.attr == "codehash": + elif self.expr.attr == "codehash": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T) # x.code: codecopy/extcodecopy of address x - if self.expr.attr == "code": + elif self.expr.attr == "code": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): # These adhoc nodes will be replaced with a valid node in `Slice.build_IR` @@ -266,7 +267,8 @@ def parse_Attribute(self): return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) # Reserved keywords - if ( + elif ( + # TODO: use type information here isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES ): key = f"{self.expr.value.id}.{self.expr.attr}" @@ -321,20 +323,21 @@ def parse_Attribute(self): return IRnode.from_list(["chainid"], typ=UINT256_T) # self.x: global storage variable or immutable - if (varinfo := self.expr._metadata.get("variable_access")) is not None: + elif (varinfo := self.expr._metadata.get("variable_access")) is not None: assert isinstance(varinfo, VarInfo) + # TODO: handle immutables location = TRANSIENT if varinfo.is_transient else STORAGE - module_t = self.context.module_ctx - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation=self.expr.node_source_code, - ) - ret._referenced_variables = {varinfo} + module_ptr = Expr(self.expr.value, self.context).ir_node + + global_t = self.context.module_ctx + if module_ptr.value == "self": + # TODO: self.context.self_ptr + module_ptr = IRnode.from_list(0, typ=global_t, location=STORAGE) + ret = get_element_ptr(module_ptr, self.expr.attr) + ret._referenced_variables = {varinfo} return ret # if we have gotten here, it's an instance of an interface or struct diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 4846b1c3b1..2cd2c6d167 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -366,7 +366,7 @@ def tag_exceptions( yield except _BaseVyperException as e: if not e.annotations and not e.lineno: - raise e.with_annotation(node) from None - raise e from None + raise e.with_annotation(node) + raise e except Exception as e: raise fallback_exception_type(fallback_message, node) from e diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 1fef6a706e..2ae4dd8454 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -2,7 +2,7 @@ from .base import TYPE_T, KwargSettings, VyperType, is_type_t from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT -from .module import InterfaceT +from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EnumT, EventT, StructT From e9b867aab03ca4c4582d3f1537081d0447d5575e Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 21:07:05 -0500 Subject: [PATCH 06/27] call set_data_positions recursively --- vyper/codegen/core.py | 3 ++- vyper/codegen/expr.py | 6 +++--- vyper/compiler/phases.py | 27 +++++++++------------------ vyper/semantics/analysis/base.py | 2 +- vyper/semantics/analysis/module.py | 18 ++++++++++++++++-- vyper/semantics/types/module.py | 2 +- 6 files changed, 32 insertions(+), 26 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 4f1eb5eaf6..2e0fc0192c 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -12,12 +12,12 @@ BoolT, BytesM_T, BytesT, - ModuleT, DArrayT, DecimalT, HashMapT, IntegerT, InterfaceT, + ModuleT, StructT, TupleT, _BytestringT, @@ -442,6 +442,7 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp=True): annotation=f"{parent}{ofst}", ) + # get a variable out of a module def _get_element_ptr_module(parent, key): # note that this implementation is substantially similar to diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 2a9296b90c..6d798d995e 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -37,7 +37,6 @@ tag_exceptions, ) from vyper.semantics.analysis.base import VarInfo -from vyper.semantics.analysis.utils import get_expr_info from vyper.semantics.types import ( AddressT, BoolT, @@ -269,7 +268,8 @@ def parse_Attribute(self): # Reserved keywords elif ( # TODO: use type information here - isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES + isinstance(self.expr.value, vy_ast.Name) + and self.expr.value.id in ENVIRONMENT_VARIABLES ): key = f"{self.expr.value.id}.{self.expr.attr}" if key == "msg.sender": @@ -334,7 +334,7 @@ def parse_Attribute(self): global_t = self.context.module_ctx if module_ptr.value == "self": # TODO: self.context.self_ptr - module_ptr = IRnode.from_list(0, typ=global_t, location=STORAGE) + module_ptr = IRnode.from_list(0, typ=global_t, location=location) ret = get_element_ptr(module_ptr, self.expr.attr) ret._referenced_variables = {varinfo} diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index edffa9a85e..ab4e910d3e 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property from pathlib import Path, PurePath -from typing import Optional, Tuple +from typing import Optional from vyper import ast as vy_ast from vyper.codegen import module @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer -from vyper.semantics import set_data_positions, validate_semantics +from vyper.semantics import validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -161,21 +161,16 @@ def vyper_module_unfolded(self) -> vy_ast.Module: return generate_unfolded_ast(self.vyper_module, self.input_bundle) - @cached_property - def _folded_module(self): + @property + def vyper_module_folded(self) -> vy_ast.Module: return generate_folded_ast( self.vyper_module, self.input_bundle, self.storage_layout_override ) - @property - def vyper_module_folded(self) -> vy_ast.Module: - module, storage_layout = self._folded_module - return module - @property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._folded_module - return storage_layout + module = self.vyper_module_folded + return module._metadata["variables_layout"] @property def global_ctx(self) -> ModuleT: @@ -264,7 +259,7 @@ def generate_folded_ast( vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, -) -> Tuple[vy_ast.Module, StorageLayout]: +) -> vy_ast.Module: """ Perform constant folding operations on the Vyper AST. @@ -277,8 +272,6 @@ def generate_folded_ast( ------- vy_ast.Module Folded Vyper AST - StorageLayout - Layout of variables in storage """ vy_ast.validation.validate_literal_nodes(vyper_module) @@ -287,11 +280,9 @@ def generate_folded_ast( vy_ast.folding.fold(vyper_module_folded) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - validate_semantics(vyper_module_folded, input_bundle) - - symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides) + validate_semantics(vyper_module_folded, input_bundle, storage_layout_overrides) - return vyper_module_folded, symbol_tables + return vyper_module_folded def generate_ir_nodes(global_ctx: ModuleT, optimize: OptimizationLevel) -> tuple[IRnode, IRnode]: diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index b5e10f0a95..9f3414f4be 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,5 +1,5 @@ import enum -from dataclasses import dataclass, asdict +from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 99d105f2cb..c87a049f47 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -22,6 +22,7 @@ ) from vyper.semantics.analysis.base import ImportInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.data_positions import set_data_positions from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import validate_functions from vyper.semantics.analysis.utils import ( @@ -37,8 +38,17 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: - return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) +# TODO: rename to `analyze_vyper` +def validate_semantics( + module_ast, input_bundle, storage_layout_overrides=None, is_interface=False +) -> ModuleT: + return validate_semantics_r( + module_ast, + input_bundle, + ImportGraph(), + is_interface, + storage_layout_overrides=storage_layout_overrides, + ) def validate_semantics_r( @@ -46,6 +56,7 @@ def validate_semantics_r( input_bundle: InputBundle, import_graph: ImportGraph, is_interface: bool, + storage_layout_overrides: Any = None, ) -> ModuleT: """ Analyze a Vyper module AST node, add all module-level objects to the @@ -65,6 +76,9 @@ def validate_semantics_r( if not is_interface: validate_functions(module_ast) + layout = set_data_positions(module_ast, storage_layout_overrides) + module_ast._metadata["variables_layout"] = layout + return ret diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b38043da41..7e841fde7a 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -1,13 +1,13 @@ from functools import cached_property from typing import Optional -from vyper.semantics.data_locations import DataLocation from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT From 037d5a66ad12cdace387bad79564e40c4972b71a Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 20 Dec 2023 21:07:50 -0500 Subject: [PATCH 07/27] add a sanity check --- vyper/codegen/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 2e0fc0192c..6a57b7ea9b 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -464,6 +464,9 @@ def _get_element_ptr_module(parent, key): for i in range(index): ofst += module_t.variables[attrs[i]].typ.storage_size_in_words + # calculated the same way both ways + assert ofst == module_t.variables[key].position.position + return IRnode.from_list( add_ofst(parent, ofst), typ=typ, From 3512e3ff3829b21ed5f936009630a45da32538e4 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 10:54:38 -0500 Subject: [PATCH 08/27] rename some size calculators and add immutable_bytes_required to VyperType storage_size_in_words -> storage_slots_required immutable_section_bytes -> immutable_bytes_required this establishes a convention to find the size of a type: `typ.{location}_{slots_or_bytes}_required` this maybe indicates that a better API will be `size_in(location: DataLocation)` but we will see how the design evolves --- .../codegen/types/test_node_types.py | 15 ++++++----- .../semantics/types/test_size_in_bytes.py | 10 +++---- vyper/codegen/core.py | 6 ++--- vyper/codegen/module.py | 2 +- vyper/semantics/analysis/data_positions.py | 8 +++--- vyper/semantics/types/base.py | 26 ++++++++++++++++--- vyper/semantics/types/bytestrings.py | 2 +- vyper/semantics/types/module.py | 16 +++++++----- vyper/semantics/types/subscriptable.py | 14 +++++----- vyper/semantics/types/user.py | 4 +-- 10 files changed, 62 insertions(+), 41 deletions(-) diff --git a/tests/functional/codegen/types/test_node_types.py b/tests/functional/codegen/types/test_node_types.py index b6561ae8eb..96f5cf8650 100644 --- a/tests/functional/codegen/types/test_node_types.py +++ b/tests/functional/codegen/types/test_node_types.py @@ -12,6 +12,7 @@ ) # TODO: this module should be merged in with other tests/functional/semantics/types/ tests. +# and moved to tests/unit/! def test_bytearray_node_type(): @@ -51,17 +52,17 @@ def test_canonicalize_type(): def test_type_storage_sizes(): - assert IntegerT(True, 128).storage_size_in_words == 1 - assert BytesT(12).storage_size_in_words == 2 - assert BytesT(33).storage_size_in_words == 3 - assert SArrayT(IntegerT(True, 128), 10).storage_size_in_words == 10 + assert IntegerT(True, 128).storage_slots_required == 1 + assert BytesT(12).storage_slots_required == 2 + assert BytesT(33).storage_slots_required == 3 + assert SArrayT(IntegerT(True, 128), 10).storage_slots_required == 10 tuple_ = TupleT([IntegerT(True, 128), DecimalT()]) - assert tuple_.storage_size_in_words == 2 + assert tuple_.storage_slots_required == 2 struct_ = StructT("Foo", {"a": IntegerT(True, 128), "b": DecimalT()}) - assert struct_.storage_size_in_words == 2 + assert struct_.storage_slots_required == 2 # Don't allow unknown types. with raises(Exception): - _ = int.storage_size_in_words + _ = int.storage_slots_required diff --git a/tests/unit/semantics/types/test_size_in_bytes.py b/tests/unit/semantics/types/test_size_in_bytes.py index 69250fdfdf..244c52b3e0 100644 --- a/tests/unit/semantics/types/test_size_in_bytes.py +++ b/tests/unit/semantics/types/test_size_in_bytes.py @@ -11,7 +11,7 @@ def test_base_types(build_node, type_str): node = build_node(type_str) type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == 32 + assert type_definition._size_in_bytes == 32 @pytest.mark.parametrize("type_str", BYTESTRING_TYPES) @@ -20,7 +20,7 @@ def test_array_value_types(build_node, type_str, length, size): node = build_node(f"{type_str}[{length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == size + assert type_definition._size_in_bytes == size @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -29,7 +29,7 @@ def test_dynamic_array_lengths(build_node, type_str, length): node = build_node(f"DynArray[{type_str}, {length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == 32 + length * 32 + assert type_definition._size_in_bytes == 32 + length * 32 @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -38,7 +38,7 @@ def test_base_types_as_arrays(build_node, type_str, length): node = build_node(f"{type_str}[{length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == length * 32 + assert type_definition._size_in_bytes == length * 32 @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -49,4 +49,4 @@ def test_base_types_as_multidimensional_arrays(build_node, type_str, first, seco type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == first * second * 32 + assert type_definition._size_in_bytes == first * second * 32 diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 6a57b7ea9b..bc582e0bc4 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -462,7 +462,7 @@ def _get_element_ptr_module(parent, key): assert parent.location == STORAGE, parent.location for i in range(index): - ofst += module_t.variables[attrs[i]].typ.storage_size_in_words + ofst += module_t.variables[attrs[i]].typ.storage_slots_required # calculated the same way both ways assert ofst == module_t.variables[key].position.position @@ -519,7 +519,7 @@ def _get_element_ptr_tuplelike(parent, key): if parent.location.word_addressable: for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words + ofst += typ.member_types[attrs[i]].storage_slots_required elif parent.location.byte_addressable: for i in range(index): ofst += typ.member_types[attrs[i]].memory_bytes_required @@ -586,7 +586,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) if parent.location.word_addressable: - element_size = subtype.storage_size_in_words + element_size = subtype.storage_slots_required elif parent.location.byte_addressable: element_size = subtype.memory_bytes_required else: diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index ef861e3953..91bd393c6c 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -472,7 +472,7 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = module_ctx.immutable_section_bytes + immutables_len = module_ctx.immutable_bytes_required if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 60f2bc90d2..9ae632ff11 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -120,7 +120,7 @@ def set_storage_slots_with_overrides( # Expect to find this variable within the storage layout overrides if node.target.id in storage_layout_overrides: var_slot = storage_layout_overrides[node.target.id]["slot"] - storage_length = varinfo.typ.storage_size_in_words + storage_length = varinfo.typ.storage_slots_required # Ensure that all required storage slots are reserved, and prevents other variables # from using these slots reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) @@ -205,7 +205,7 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: # I'm not sure if it's safe to avoid allocating that slot # for HashMaps because downstream code might use the slot # ID as a salt. - n_slots = type_.storage_size_in_words + n_slots = type_.storage_slots_required slot = allocator.allocate_slot(n_slots, varname) varinfo.set_position(StorageSlot(slot)) @@ -234,11 +234,11 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: if not varinfo.is_immutable: continue + len_ = ceil32(type_.immutable_bytes_required) + type_ = varinfo.typ varinfo.set_position(CodeOffset(offset)) - len_ = ceil32(type_.size_in_bytes) - # this could have better typing but leave it untyped until # we understand the use case better output_dict = {"type": str(type_), "offset": offset, "length": len_} diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index d22d9bfff9..2c1caa0351 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -68,7 +69,7 @@ class VyperType: _supports_external_calls: bool = False _attribute_in_annotation: bool = False - size_in_bytes = 32 # default; override for larger types + _size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: self.members: Dict = {} @@ -120,22 +121,39 @@ def abi_type(self) -> ABIType: @property def memory_bytes_required(self) -> int: + if DataLocation.MEMORY in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in memory!") # alias for API compatibility with codegen - return self.size_in_bytes + return self._size_in_bytes @property - def storage_size_in_words(self) -> int: + def storage_slots_required(self) -> int: # consider renaming if other word-addressable address spaces are # added to EVM or exist in other arches """ Returns the number of words required to allocate in storage for this type """ - r = self.memory_bytes_required + if DataLocation.STORAGE in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in storage!") + + r = self._size_in_bytes if r % 32 != 0: raise CompilerPanic("Memory bytes must be multiple of 32") return r // 32 + @property + def immutable_bytes_required(self) -> int: + """ + Returns the number of bytes required when instantiating this type + in the immutables section + """ + # sanity check the type can actually be instantiated as an immutable + if DataLocation.CODE in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be an immutable!") + + return self._size_in_bytes + @property def canonical_abi_type(self) -> str: """ diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index e3c381ac69..dee7a8e9be 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -67,7 +67,7 @@ def validate_literal(self, node: vy_ast.Constant) -> None: raise CompilerPanic("unreachable") @property - def size_in_bytes(self): + def _size_in_bytes(self): # the first slot (32 bytes) stores the actual length, and then we reserve # enough additional slots to store the data if it uses the max available length # because this data type is single-bytes, we make it so it takes the max 32 byte diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 7e841fde7a..7172e26314 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -307,10 +307,6 @@ def __repr__(self): else: return f"module {self._id} (loaded from '{self._module.resolved_path}')" - @property - def size_in_bytes(self): - return sum(v.typ.size_in_bytes for v in self.variables.values()) - def get_type_member(self, attr: str, node: vy_ast.VyperNode) -> VyperType: return self._helper.get_type_member(attr, node) @@ -341,13 +337,21 @@ def variables(self): # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def storage_slots_required(self): + return sum(v.typ.storage_slots_required for v in self.variables.values()) + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + def immutable_bytes_required(self): + # note: super().immutable_bytes_required checks that + # `DataLocations.CODE not in self._invalid_locations`; this is ok because + # ModuleT is a bit of a hybrid - it can't be declared as an immutable, but + # it can have immutable members. + return sum(imm.typ.immutable_bytes_required for imm in self.immutables) @cached_property def interface(self): diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 46dffbdec4..e8129eaa2f 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -174,10 +174,9 @@ def to_abi_arg(self, name: str = "") -> Dict[str, Any]: ret["type"] += f"[{self.length}]" return _set_first_key(ret, "name", name) - # TODO rename to `memory_bytes_required` @property - def size_in_bytes(self): - return self.value_type.size_in_bytes * self.length + def _size_in_bytes(self): + return self.value_type._size_in_bytes * self.length @property def subtype(self): @@ -257,11 +256,10 @@ def to_abi_arg(self, name: str = "") -> Dict[str, Any]: ret["type"] += "[]" return _set_first_key(ret, "name", name) - # TODO rename me to memory_bytes_required @property - def size_in_bytes(self): + def _size_in_bytes(self): # one length word + size of the array items - return 32 + self.value_type.size_in_bytes * self.length + return 32 + self.value_type._size_in_bytes * self.length def compare_type(self, other): # TODO allow static array to be assigned to dyn array? @@ -355,8 +353,8 @@ def to_abi_arg(self, name: str = "") -> dict: return {"name": name, "type": "tuple", "components": components} @property - def size_in_bytes(self): - return sum(i.size_in_bytes for i in self.member_types) + def _size_in_bytes(self): + return sum(i._size_in_bytes for i in self.member_types) def validate_index_type(self, node): if not isinstance(node, vy_ast.Int): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 92639e8377..b4ba417167 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -359,8 +359,8 @@ def __repr__(self): return f"struct {self._id}({arg_types})" @property - def size_in_bytes(self): - return sum(i.size_in_bytes for i in self.member_types.values()) + def _size_in_bytes(self): + return sum(i._size_in_bytes for i in self.member_types.values()) @property def abi_type(self) -> ABIType: From 86c299a275e396294eafe2d897e45b15a7e3b918 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 12:04:06 -0500 Subject: [PATCH 09/27] add a comment --- vyper/semantics/types/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 2c1caa0351..4b8bbc3094 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -69,6 +69,10 @@ class VyperType: _supports_external_calls: bool = False _attribute_in_annotation: bool = False + # _size_in_bytes is an internal property that is used + # to calculate sizes required in various locations. it can + # be used by subclasses, but does not have to be. it should + # *not* by used by external consumers of VyperType! _size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: From 9e689b9bf7aa88acfce39550b7d18252699a32a3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 12:55:14 -0500 Subject: [PATCH 10/27] add Context.self_ptr helper - rename module_ctx to compilation_target --- vyper/codegen/context.py | 29 ++++++++++---- vyper/codegen/expr.py | 5 +-- vyper/codegen/function_definitions/common.py | 4 +- vyper/codegen/module.py | 40 ++++++++++---------- 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 5d1ab0b8c1..290a5c7e01 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -6,7 +6,7 @@ from vyper.codegen.ir_node import Encoding from vyper.evm.address_space import MEMORY, AddrSpace from vyper.exceptions import CompilerPanic, StateAccessViolation -from vyper.semantics.types import VyperType +from vyper.semantics.types import VyperType, ModuleT class Constancy(enum.Enum): @@ -48,7 +48,7 @@ def __repr__(self): class Context: def __init__( self, - module_ctx, + compilation_target, memory_allocator, vars_=None, forvars=None, @@ -59,9 +59,6 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,9 +72,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store module context - # note the module_ctx is the type of the current compilation target! - self.module_ctx = module_ctx + # the type information for the current compilation target + self.compilation_target: ModuleT = compilation_target # full function type self.func_t = func_t @@ -95,6 +91,23 @@ def __init__( # either the constructor, or called from the constructor self.is_ctor_context = is_ctor_context + def self_ptr(self): + func_module = self.func_t.ast_def._parent + assert isinstance(func_module, vy_ast.Module) + + module_t = func_module._metadata["type"] + module_is_compilation_target = (module_t == self.compilation_target) + + if module_is_compilation_target: + # return 0 for the special case where compilation target is self + return IRnode.from_list(0, typ=module_t) + + # otherwise, the function compilation context takes a `self_ptr` + # argument in the calling convention + # TODO: probably need to track immutables and storage variables + # separately + return IRnode.from_list("self_ptr", typ=module_t) + def is_constant(self): return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 6d798d995e..858c9d687a 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -330,11 +330,8 @@ def parse_Attribute(self): location = TRANSIENT if varinfo.is_transient else STORAGE module_ptr = Expr(self.expr.value, self.context).ir_node - - global_t = self.context.module_ctx if module_ptr.value == "self": - # TODO: self.context.self_ptr - module_ptr = IRnode.from_list(0, typ=global_t, location=location) + module_ptr = self.context.self_ptr ret = get_element_ptr(module_ptr, self.expr.attr) ret._referenced_variables = {varinfo} diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 454ba9c8cd..f8283c8539 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -101,7 +101,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False + code: vy_ast.FunctionDef, compilation_target: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -133,7 +133,7 @@ def generate_ir_for_function( context = Context( vars_=None, - module_ctx=module_ctx, + compilation_target=compilation_target, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 91bd393c6c..15f4f8f1cc 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -15,6 +15,8 @@ def _topsort(functions): # single pass to get a global topological sort of functions (so that each # function comes after each of its callees). + # note this function can return functions from other modules if + # they are reachable! ret = OrderedSet() for func_ast in functions: fn_t = func_ast._metadata["func_type"] @@ -104,12 +106,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, module_ctx): +def _generate_external_entry_points(external_functions, compilation_target): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_function(code, compilation_target) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -131,13 +133,13 @@ def _generate_external_entry_points(external_functions, module_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, module_ctx): +def _selector_section_dense(external_functions, compilation_target): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -282,13 +284,13 @@ def _selector_section_dense(external_functions, module_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, module_ctx): +def _selector_section_sparse(external_functions, compilation_target): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -385,14 +387,14 @@ def _selector_section_sparse(external_functions, module_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, module_ctx): +def _selector_section_linear(external_functions, compilation_target): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) dispatcher = ["seq"] @@ -421,10 +423,10 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR -def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: +def generate_ir_for_module(compilation_target: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(module_ctx.function_defs) - reachable = _globally_reachable_functions(module_ctx.function_defs) + function_defs = _topsort(compilation_target.function_defs) + reachable = _globally_reachable_functions(compilation_target.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] init_function = next((f for f in function_defs if _is_constructor(f)), None) @@ -442,7 +444,7 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: for func_ast in internal_functions: # compile it so that _ir_info is populated (whether or not it makes # it into the final IR artifact) - func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + func_ir = _ir_for_internal_function(func_ast, compilation_target, False) # only include it in the IR if it is reachable from an external # function. @@ -450,16 +452,16 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, module_ctx) + selector_section = _selector_section_linear(external_functions, compilation_target) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, module_ctx) + selector_section = _selector_section_dense(external_functions, compilation_target) else: - selector_section = _selector_section_sparse(external_functions, module_ctx) + selector_section = _selector_section_sparse(external_functions, compilation_target) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, compilation_target) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -472,7 +474,7 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = module_ctx.immutable_bytes_required + immutables_len = compilation_target.immutable_bytes_required if init_function: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] @@ -484,13 +486,13 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # unreachable code, delete it continue - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + func_ir = _ir_for_internal_function(f, compilation_target, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, compilation_target, is_ctor_context=True) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables From 8becda2aab92ffe671f030ecdce3ffa5483e76ef Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 12:58:15 -0500 Subject: [PATCH 11/27] add a note --- vyper/semantics/analysis/local.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index a11a58a99a..d1804eee6a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -172,6 +172,7 @@ def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: def _validate_self_reference(node: vy_ast.Name) -> None: # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through + # TODO: this is now wrong, we can have things like `self.module.foo` if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): raise StateAccessViolation("not allowed to query self in pure functions", node) From 9ac072b988d667ece28a9e31bc9e8a0cf91d7768 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 15:46:16 -0500 Subject: [PATCH 12/27] improve Context.self_ptr --- vyper/codegen/context.py | 17 ++++++++++------- vyper/codegen/expr.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 290a5c7e01..9bf9b3d1cc 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -3,8 +3,10 @@ from dataclasses import dataclass from typing import Any, Optional -from vyper.codegen.ir_node import Encoding -from vyper.evm.address_space import MEMORY, AddrSpace +import vyper.ast as vy_ast + +from vyper.codegen.ir_node import Encoding, IRnode +from vyper.evm.address_space import MEMORY, AddrSpace, STORAGE, IMMUTABLES from vyper.exceptions import CompilerPanic, StateAccessViolation from vyper.semantics.types import VyperType, ModuleT @@ -91,7 +93,7 @@ def __init__( # either the constructor, or called from the constructor self.is_ctor_context = is_ctor_context - def self_ptr(self): + def self_ptr(self, location): func_module = self.func_t.ast_def._parent assert isinstance(func_module, vy_ast.Module) @@ -100,13 +102,14 @@ def self_ptr(self): if module_is_compilation_target: # return 0 for the special case where compilation target is self - return IRnode.from_list(0, typ=module_t) + return IRnode.from_list(0, typ=module_t, location=location) # otherwise, the function compilation context takes a `self_ptr` # argument in the calling convention - # TODO: probably need to track immutables and storage variables - # separately - return IRnode.from_list("self_ptr", typ=module_t) + if location == STORAGE: + return IRnode.from_list("self_ptr_storage", typ=module_t, location=location) + if location == IMMUTABLES: + return IRnode.from_list("self_ptr_code", typ=module_t, location=location) def is_constant(self): return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 858c9d687a..2c8162f491 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -331,7 +331,7 @@ def parse_Attribute(self): module_ptr = Expr(self.expr.value, self.context).ir_node if module_ptr.value == "self": - module_ptr = self.context.self_ptr + module_ptr = self.context.self_ptr(location) ret = get_element_ptr(module_ptr, self.expr.attr) ret._referenced_variables = {varinfo} From 4e102bf0c6cb6e1f34a333e2be2ec618973d98d5 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 16:12:35 -0500 Subject: [PATCH 13/27] add function variable read/writes analysis --- vyper/codegen/self_call.py | 3 ++- vyper/semantics/analysis/base.py | 13 +++++++++---- vyper/semantics/analysis/local.py | 13 +++++++------ vyper/semantics/data_locations.py | 1 - vyper/semantics/types/function.py | 24 ++++++++++++++++++++++++ 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f53e4a81b4..dffe48184b 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -93,7 +93,8 @@ def ir_for_self_call(stmt_expr, context): goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)] # pass return buffer to subroutine if return_buffer is not None: - goto_op += [return_buffer] + goto_op.append(return_buffer) + # pass return label to subroutine goto_op.append(["symbol", return_label]) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 9f3414f4be..37e6ce1462 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -262,18 +262,23 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil raise ImmutableViolation("Cannot write to calldata", node) if self.is_constant: raise ImmutableViolation("Constant value cannot be written to", node) + + func_node = node.get_ancestor(vy_ast.FunctionDef) + func_t = func_node._metadata["func_type"] + if self.is_immutable: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": + if func_node.name != "__init__": raise ImmutableViolation("Immutable value cannot be written to", node) if len(self._var_info._writes) > 0: raise ImmutableViolation( "Immutable value cannot be modified after assignment", node ) - self._var_info._writes.append(node) - if self.location == DataLocation.STORAGE: - self._var_info._writes.append(node) + # tag it in the metadata + node._metadata["variable_write"] = self._var_info + self._var_info._writes.append(node) + func_t._variable_writes.append(node) if isinstance(node, vy_ast.AugAssign): self.typ.validate_numeric_op(node) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d1804eee6a..d649058c3d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -264,17 +264,17 @@ def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): + target_info = get_expr_info(node.target) + if isinstance(target_info.typ, HashMapT): raise StructureException( "Left-hand side of assignment cannot be a HashMap without a key", node ) - validate_expected_type(node.value, target.typ) - target.validate_modification(node, self.func.mutability) + validate_expected_type(node.value, target_info.typ) + target_info.validate_modification(node, self.func.mutability) - self.expr_visitor.visit(node.value, target.typ) - self.expr_visitor.visit(node.target, target.typ) + self.expr_visitor.visit(node.value, target_info.typ) + self.expr_visitor.visit(node.target, target_info.typ) def visit_Assign(self, node): self._assign_helper(node) @@ -614,6 +614,7 @@ def visit(self, node, typ): if (var_info := info._var_info) is not None: node._metadata["variable_access"] = var_info var_info._reads.append(node) + self.func._variable_reads.append(node) def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 2f259b1766..07e8435686 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -7,5 +7,4 @@ class DataLocation(enum.Enum): STORAGE = 2 CALLDATA = 3 CODE = 4 - # XXX: needed for separate transient storage allocator # TRANSIENT = 5 diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 34206546fd..969cfc4ac3 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -107,10 +107,34 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # list of variables read in this function + self._variable_reads: list[vy_ast.VyperNode] = [] + # list of variables written in this function + self._variable_writes: list[vy_ast.VyperNode] = [] + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + def touches_location(self, location): + for r in self._variable_reads: + if r._metadata["variable_access"].location == location: + return True + for w in self._variable_writes: + if w._metadata["variable_write"].location == location: + return True + return False + + @property + def touched_locations(self): + # return the DataLocations of touched module variables + ret = [] + possible_locations = [DataLocation.STORAGE, DataLocation.CODE] + for location in possible_locations: + if self.touches_location(location): + ret.append(location) + return ret + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site From 9f100d98e2c20b719624a9ad9611b498b2e33844 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 20:14:37 -0500 Subject: [PATCH 14/27] calculate pointer things --- vyper/codegen/context.py | 7 +- vyper/codegen/core.py | 12 ++++ vyper/codegen/expr.py | 3 +- .../function_definitions/internal_function.py | 5 ++ vyper/codegen/module.py | 4 +- vyper/codegen/self_call.py | 65 ++++++++++++++++--- vyper/ir/compile_ir.py | 9 ++- vyper/semantics/analysis/data_positions.py | 3 +- 8 files changed, 90 insertions(+), 18 deletions(-) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 9bf9b3d1cc..334d988dd4 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -4,11 +4,10 @@ from typing import Any, Optional import vyper.ast as vy_ast - from vyper.codegen.ir_node import Encoding, IRnode -from vyper.evm.address_space import MEMORY, AddrSpace, STORAGE, IMMUTABLES +from vyper.evm.address_space import IMMUTABLES, MEMORY, STORAGE, AddrSpace from vyper.exceptions import CompilerPanic, StateAccessViolation -from vyper.semantics.types import VyperType, ModuleT +from vyper.semantics.types import ModuleT, VyperType class Constancy(enum.Enum): @@ -98,7 +97,7 @@ def self_ptr(self, location): assert isinstance(func_module, vy_ast.Module) module_t = func_module._metadata["type"] - module_is_compilation_target = (module_t == self.compilation_target) + module_is_compilation_target = module_t == self.compilation_target if module_is_compilation_target: # return 0 for the special case where compilation target is self diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 798a84dd6f..d61c6641c2 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -7,6 +7,7 @@ from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -65,6 +66,17 @@ def is_array_like(typ): return ret +def data_location_to_addr_space(s: DataLocation): + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.CODE: + return IMMUTABLES + + raise CompilerPanic("unreachable") # pragma: nocover + + def get_type_for_exact_size(n_bytes): """Create a type which will take up exactly n_bytes. Used for allocating internal buffers. diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 2c8162f491..f30ce50e0a 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -323,7 +323,7 @@ def parse_Attribute(self): return IRnode.from_list(["chainid"], typ=UINT256_T) # self.x: global storage variable or immutable - elif (varinfo := self.expr._metadata.get("variable_access")) is not None: + if (varinfo := self.expr._metadata.get("variable_access")) is not None: assert isinstance(varinfo, VarInfo) # TODO: handle immutables @@ -334,6 +334,7 @@ def parse_Attribute(self): module_ptr = self.context.self_ptr(location) ret = get_element_ptr(module_ptr, self.expr.attr) + # TODO: take referenced variables info from analysis ret._referenced_variables = {varinfo} return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..043b01f4ed 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -50,6 +50,11 @@ def generate_ir_for_internal_function( cleanup_label = func_t._ir_info.exit_sequence_label stack_args = ["var_list"] + + for location in func_t.touched_locations: + location_name = location.name.lower() + stack_args.append(f"self_ptr_{location_name}") + if func_t.return_type: stack_args += ["return_buffer"] stack_args += ["return_pc"] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index f2bbffcfe0..8ce8a262f9 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -494,7 +494,9 @@ def generate_ir_for_module(compilation_target: ModuleT) -> tuple[IRnode, IRnode] # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, compilation_target, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor( + init_function, compilation_target, is_ctor_context=True + ) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index dffe48184b..5e06588d30 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -1,7 +1,14 @@ -from vyper.codegen.core import _freshname, eval_once_check, make_setter +from vyper import ast as vy_ast +from vyper.codegen.core import ( + _freshname, + data_location_to_addr_space, + eval_once_check, + get_element_ptr, + make_setter, +) from vyper.codegen.ir_node import IRnode from vyper.evm.address_space import MEMORY -from vyper.exceptions import StateAccessViolation +from vyper.exceptions import CompilerPanic, StateAccessViolation from vyper.semantics.types.subscriptable import TupleT @@ -20,7 +27,42 @@ def _align_kwargs(func_t, args_ir): return [i.default_value for i in unprovided_kwargs] -def ir_for_self_call(stmt_expr, context): +def _get_self_ptr_for_location(node: vy_ast.Attribute, context, location): + # resolve something like self.x.y.z to a pointer + if isinstance(node.value, vy_ast.Name): + # base case + if node.value.id == "self": + ptr = context.self_ptr(location) + else: # pragma: nocover + raise CompilerPanic("unreachable!", node.value) + else: + # recurse + ptr = _get_self_ptr_for_location(node.value, context, location) + + return get_element_ptr(ptr, node.attr) + + +def _calculate_self_ptr_requirements(call_expr, func_t, context): + ret = [] + + module_t = func_t.ast_def._parent._metadata["type"] + if module_t == context.compilation_target: + # we don't need to pass a pointer + return ret + + func_expr = call_expr.func + assert isinstance(func_expr, vy_ast.Attribute) + pointer_expr = func_expr.value + for location in func_t.touched_locations: + codegen_location = data_location_to_addr_space(location) + + # self.foo.bar.baz() => pointer_expr == `self.foo.bar` + ret.append(_get_self_ptr_for_location(pointer_expr, context, codegen_location)) + + return ret + + +def ir_for_self_call(call_expr, context): from vyper.codegen.expr import Expr # TODO rethink this circular import # ** Internal Call ** @@ -30,10 +72,10 @@ def ir_for_self_call(stmt_expr, context): # - push jumpdest (callback ptr) and return buffer location # - jump to label # - (private function will fill return buffer and jump back) - method_name = stmt_expr.func.attr - func_t = stmt_expr.func._metadata["type"] + method_name = call_expr.func.attr + func_t = call_expr.func._metadata["type"] - pos_args_ir = [Expr(x, context).ir_node for x in stmt_expr.args] + pos_args_ir = [Expr(x, context).ir_node for x in call_expr.args] default_vals = _align_kwargs(func_t, pos_args_ir) default_vals_ir = [Expr(x, context).ir_node for x in default_vals] @@ -49,7 +91,7 @@ def ir_for_self_call(stmt_expr, context): raise StateAccessViolation( f"May not call state modifying function " f"'{method_name}' within {context.pp_constancy()}.", - stmt_expr, + call_expr, ) # note: internal_function_label asserts `func_t.is_internal`. @@ -91,6 +133,11 @@ def ir_for_self_call(stmt_expr, context): copy_args = make_setter(args_dst, args_as_tuple) goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)] + + # if needed, pass pointers to the callee + for self_ptr in _calculate_self_ptr_requirements(call_expr, func_t, context): + goto_op.append(self_ptr) + # pass return buffer to subroutine if return_buffer is not None: goto_op.append(return_buffer) @@ -99,7 +146,7 @@ def ir_for_self_call(stmt_expr, context): goto_op.append(["symbol", return_label]) call_sequence = ["seq"] - call_sequence.append(eval_once_check(_freshname(stmt_expr.node_source_code))) + call_sequence.append(eval_once_check(_freshname(call_expr.node_source_code))) call_sequence.extend([copy_args, goto_op, ["label", return_label, ["var_list"], "pass"]]) if return_buffer is not None: # push return buffer location to stack @@ -109,7 +156,7 @@ def ir_for_self_call(stmt_expr, context): call_sequence, typ=func_t.return_type, location=MEMORY, - annotation=stmt_expr.get("node_source_code"), + annotation=call_expr.get("node_source_code"), add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 8ce8c887f1..7e94d57b55 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -112,6 +112,9 @@ def calc_mem_ofst_size(ctor_mem_size): def _rewrite_return_sequences(ir_node, label_params=None): args = ir_node.args + # special values which should be popped at the end of function execution + POPPABLE_PARAMS = ("return_buffer", "self_ptr_storage", "self_ptr_immutables") + if ir_node.value == "return": if args[0].value == "ret_ofst" and args[1].value == "ret_len": ir_node.args[0].value = "pass" @@ -126,8 +129,10 @@ def _rewrite_return_sequences(ir_node, label_params=None): ir_node.value = "seq" _t = ["seq"] - if "return_buffer" in label_params: - _t.append(["pop", "pass"]) + + for s in POPPABLE_PARAMS: + if s in label_params: + _t.append(["pop", "pass"]) dest = args[0].value # works for both internal and external exit_to diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 9ae632ff11..291d9b7af2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -234,9 +234,10 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: if not varinfo.is_immutable: continue + type_ = varinfo.typ + len_ = ceil32(type_.immutable_bytes_required) - type_ = varinfo.typ varinfo.set_position(CodeOffset(offset)) # this could have better typing but leave it untyped until From 2d056996ab81c9157e23c308fc5b97b547175e70 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 21:07:17 -0500 Subject: [PATCH 15/27] quash mypy --- vyper/codegen/self_call.py | 9 ++++----- vyper/semantics/analysis/base.py | 3 +++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 5e06588d30..2fb413db67 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -30,12 +30,11 @@ def _align_kwargs(func_t, args_ir): def _get_self_ptr_for_location(node: vy_ast.Attribute, context, location): # resolve something like self.x.y.z to a pointer if isinstance(node.value, vy_ast.Name): - # base case - if node.value.id == "self": - ptr = context.self_ptr(location) - else: # pragma: nocover - raise CompilerPanic("unreachable!", node.value) + # base case - we should always end up at self. sanity check this! + assert node.value.id == "self" + ptr = context.self_ptr(location) else: + assert isinstance(node.value, vy_ast.Attribute) # mypy hint # recurse ptr = _get_self_ptr_for_location(node.value, context, location) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 37e6ce1462..4def7e6d2e 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -266,6 +266,9 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil func_node = node.get_ancestor(vy_ast.FunctionDef) func_t = func_node._metadata["func_type"] + assert isinstance(func_node, vy_ast.FunctionDef) # mypy hint + assert self._var_info is not None # mypy hint + if self.is_immutable: if func_node.name != "__init__": raise ImmutableViolation("Immutable value cannot be written to", node) From 1585bdcd890f0c6af55fc028900ca7ad5dbbc542 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 21 Dec 2023 21:16:16 -0500 Subject: [PATCH 16/27] wip - handle immutables --- vyper/codegen/expr.py | 18 ++++++++++-------- vyper/codegen/self_call.py | 2 +- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index f30ce50e0a..6858c83193 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -185,13 +185,9 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: - varinfo = self.context.globals[self.expr.id] + elif (varinfo := self.expr._metadata.get("variable_access")) is not None: assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset - if self.context.is_ctor_context: mutable = True location = IMMUTABLES @@ -199,10 +195,16 @@ def parse_Name(self): mutable = False location = DATA - ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable - ) + module_ptr = self.context.self_ptr(location) + + ret = get_element_ptr(module_ptr, self.expr.id) + + assert ret.typ == varinfo.typ + assert ret.location == location + + ret.mutable = mutable ret._referenced_variables = {varinfo} + return ret # x.y or x[5] diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index 2fb413db67..6eb60ded6c 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -8,7 +8,7 @@ ) from vyper.codegen.ir_node import IRnode from vyper.evm.address_space import MEMORY -from vyper.exceptions import CompilerPanic, StateAccessViolation +from vyper.exceptions import StateAccessViolation from vyper.semantics.types.subscriptable import TupleT From bf6e99c990d571aa60f91a2e39ac2b5b41314b61 Mon Sep 17 00:00:00 2001 From: Alberto Date: Sat, 23 Dec 2023 17:09:50 +0100 Subject: [PATCH 17/27] feat: replace `enum` with `flag` keyword (#3697) per title, replace `enum` with `flag` as it more closely models https://docs.python.org/3/library/enum.html#enum.IntFlag than regular enums. allow `enum` for now (for backwards compatibility) but convert to `flag` internally and issue a warning --- docs/types.rst | 22 ++++---- .../builtins/codegen/test_convert.py | 8 +-- .../codegen/features/test_assignment.py | 4 +- .../codegen/features/test_clampers.py | 8 +-- .../codegen/types/test_dynamic_array.py | 10 ++-- .../types/{test_enum.py => test_flag.py} | 28 +++++----- tests/functional/syntax/test_dynamic_array.py | 4 +- .../syntax/{test_enum.py => test_flag.py} | 52 +++++++++---------- tests/functional/syntax/test_public.py | 4 +- vyper/ast/folding.py | 2 +- vyper/ast/grammar.lark | 10 +++- vyper/ast/identifiers.py | 1 + vyper/ast/nodes.py | 21 +++++++- vyper/ast/nodes.pyi | 2 +- vyper/ast/pre_parser.py | 5 +- vyper/builtins/_convert.py | 10 ++-- vyper/codegen/core.py | 10 ++-- vyper/codegen/expr.py | 12 ++--- vyper/exceptions.py | 4 +- vyper/semantics/analysis/local.py | 4 +- vyper/semantics/analysis/module.py | 6 +-- vyper/semantics/analysis/utils.py | 4 +- vyper/semantics/types/__init__.py | 2 +- vyper/semantics/types/user.py | 14 ++--- 24 files changed, 137 insertions(+), 110 deletions(-) rename tests/functional/codegen/types/{test_enum.py => test_flag.py} (93%) rename tests/functional/syntax/{test_enum.py => test_flag.py} (74%) diff --git a/docs/types.rst b/docs/types.rst index d669e6946d..0ad13967e9 100644 --- a/docs/types.rst +++ b/docs/types.rst @@ -376,22 +376,22 @@ On the ABI level the Fixed-size bytes array is annotated as ``string``. example_str: String[100] = "Test String" -Enums +Flags ----- -**Keyword:** ``enum`` +**Keyword:** ``flag`` -Enums are custom defined types. An enum must have at least one member, and can hold up to a maximum of 256 members. +Flags are custom defined types. A flag must have at least one member, and can hold up to a maximum of 256 members. The members are represented by ``uint256`` values in the form of 2\ :sup:`n` where ``n`` is the index of the member in the range ``0 <= n <= 255``. .. code-block:: python - # Defining an enum with two members - enum Roles: + # Defining a flag with two members + flag Roles: ADMIN USER - # Declaring an enum variable + # Declaring a flag variable role: Roles = Roles.ADMIN # Returning a member @@ -426,13 +426,13 @@ Operator Description ``~x`` Bitwise not ============= ====================== -Enum members can be combined using the above bitwise operators. While enum members have values that are power of two, enum member combinations may not. +Flag members can be combined using the above bitwise operators. While flag members have values that are power of two, flag member combinations may not. -The ``in`` and ``not in`` operators can be used in conjunction with enum member combinations to check for membership. +The ``in`` and ``not in`` operators can be used in conjunction with flag member combinations to check for membership. .. code-block:: python - enum Roles: + flag Roles: MANAGER ADMIN USER @@ -447,7 +447,7 @@ The ``in`` and ``not in`` operators can be used in conjunction with enum member def bar(a: Roles) -> bool: return a not in (Roles.MANAGER | Roles.USER) -Note that ``in`` is not the same as strict equality (``==``). ``in`` checks that *any* of the flags on two enum objects are simultaneously set, while ``==`` checks that two enum objects are bit-for-bit equal. +Note that ``in`` is not the same as strict equality (``==``). ``in`` checks that *any* of the flags on two flag objects are simultaneously set, while ``==`` checks that two flag objects are bit-for-bit equal. The following code uses bitwise operations to add and revoke permissions from a given ``Roles`` object. @@ -488,7 +488,7 @@ Fixed-size Lists Fixed-size lists hold a finite number of elements which belong to a specified type. -Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and enums. +Lists can be declared with ``_name: _ValueType[_Integer]``, except ``Bytes[N]``, ``String[N]`` and flags. .. code-block:: python diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index b5ce613235..99dae4a932 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -486,10 +486,10 @@ def test_memory_variable_convert(x: {i_typ}) -> {o_typ}: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 2**128, 2**256 - 1, 2**256 - 2]) -def test_enum_conversion(get_contract_with_gas_estimation, assert_compile_failed, val, typ): +def test_flag_conversion(get_contract_with_gas_estimation, assert_compile_failed, val, typ): roles = "\n ".join([f"ROLE_{i}" for i in range(256)]) contract = f""" -enum Roles: +flag Roles: {roles} @external @@ -510,11 +510,11 @@ def bar(a: uint256) -> Roles: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2]) -def test_enum_conversion_2( +def test_flag_conversion_2( get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ ): contract = f""" -enum Status: +flag Status: STARTED PAUSED STOPPED diff --git a/tests/functional/codegen/features/test_assignment.py b/tests/functional/codegen/features/test_assignment.py index cd26659a5c..9af7058250 100644 --- a/tests/functional/codegen/features/test_assignment.py +++ b/tests/functional/codegen/features/test_assignment.py @@ -66,7 +66,7 @@ def bar(x: {typ}) -> {typ}: def test_internal_assign_struct(get_contract_with_gas_estimation): code = """ -enum Bar: +flag Bar: BAD BAK BAZ @@ -92,7 +92,7 @@ def bar(x: Foo) -> Foo: def test_internal_assign_struct_member(get_contract_with_gas_estimation): code = """ -enum Bar: +flag Bar: BAD BAK BAZ diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 08ad349c09..263f10a89c 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -187,9 +187,9 @@ def foo(s: bool) -> bool: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [0] + [2**i for i in range(5)]) -def test_enum_clamper_passing(w3, get_contract, value, evm_version): +def test_flag_clamper_passing(w3, get_contract, value, evm_version): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -207,9 +207,9 @@ def foo(s: Roles) -> Roles: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**i for i in range(5, 256)]) -def test_enum_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_flag_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 9231d1979f..d793a56d6e 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -102,7 +102,7 @@ def foo6() -> DynArray[DynArray[String[32], 2], 2]: def test_list_output_tester_code(get_contract_with_gas_estimation): list_output_tester_code = """ -enum Foobar: +flag Foobar: FOO BAR @@ -1247,13 +1247,13 @@ def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check """ code = struct_def + "\n" + code elif subtype == "DynArray[Foobar, 3]": - enum_def = """ -enum Foobar: + flag_def = """ +flag Foobar: FOO BAR BAZ """ - code = enum_def + "\n" + code + code = flag_def + "\n" + code test_data = [2 ** (i - 1) for i in test_data] c = get_contract(code) @@ -1292,7 +1292,7 @@ def foo() -> (uint256, DynArray[uint256, 3], DynArray[uint256, 2]): def test_list_of_structs_arg(get_contract): code = """ -enum Foobar: +flag Foobar: FOO BAR diff --git a/tests/functional/codegen/types/test_enum.py b/tests/functional/codegen/types/test_flag.py similarity index 93% rename from tests/functional/codegen/types/test_enum.py rename to tests/functional/codegen/types/test_flag.py index c66efff566..03c22134ed 100644 --- a/tests/functional/codegen/types/test_enum.py +++ b/tests/functional/codegen/types/test_flag.py @@ -1,6 +1,6 @@ def test_values_should_be_increasing_ints(get_contract): code = """ -enum Action: +flag Action: BUY SELL CANCEL @@ -26,9 +26,9 @@ def cancel() -> Action: assert c.cancel() == 4 -def test_enum_storage(get_contract): +def test_flag_storage(get_contract): code = """ -enum Actions: +flag Actions: BUY SELL CANCEL @@ -49,7 +49,7 @@ def set_and_get(a: Actions) -> Actions: def test_eq_neq(get_contract): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -76,7 +76,7 @@ def is_not_boss(a: Roles) -> bool: def test_bitwise(get_contract, assert_tx_failed): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -147,7 +147,7 @@ def binv_arg(a: Roles) -> Roles: def test_augassign_storage(get_contract, w3, assert_tx_failed): code = """ -enum Roles: +flag Roles: ADMIN MINTER @@ -214,9 +214,9 @@ def checkMinter(minter: address): assert_tx_failed(lambda: c.checkMinter(admin_address)) -def test_in_enum(get_contract_with_gas_estimation): +def test_in_flag(get_contract_with_gas_estimation): code = """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -259,9 +259,9 @@ def baz(a: Roles) -> bool: assert c.baz(0b01000) is False # Roles.MANAGER should fail -def test_struct_with_enum(get_contract_with_gas_estimation): +def test_struct_with_flag(get_contract_with_gas_estimation): code = """ -enum Foobar: +flag Foobar: FOO BAR @@ -270,17 +270,17 @@ def test_struct_with_enum(get_contract_with_gas_estimation): b: Foobar @external -def get_enum_from_struct() -> Foobar: +def get_flag_from_struct() -> Foobar: f: Foo = Foo({a: 1, b: Foobar.BAR}) return f.b """ c = get_contract_with_gas_estimation(code) - assert c.get_enum_from_struct() == 2 + assert c.get_flag_from_struct() == 2 -def test_mapping_with_enum(get_contract_with_gas_estimation): +def test_mapping_with_flag(get_contract_with_gas_estimation): code = """ -enum Foobar: +flag Foobar: FOO BAR diff --git a/tests/functional/syntax/test_dynamic_array.py b/tests/functional/syntax/test_dynamic_array.py index 0c23bf67da..99a01a17c8 100644 --- a/tests/functional/syntax/test_dynamic_array.py +++ b/tests/functional/syntax/test_dynamic_array.py @@ -34,12 +34,12 @@ def test_block_fail(assert_compile_failed, get_contract, bad_code, exc): valid_list = [ """ -enum Foo: +flag Foo: FE FI bar: DynArray[Foo, 10] - """, # dynamic arrays of enums are allowed, but not static arrays + """, # dynamic arrays of flags are allowed, but not static arrays """ bar: DynArray[Bytes[30], 10] """, # dynamic arrays of bytestrings are allowed, but not static arrays diff --git a/tests/functional/syntax/test_enum.py b/tests/functional/syntax/test_flag.py similarity index 74% rename from tests/functional/syntax/test_enum.py rename to tests/functional/syntax/test_flag.py index 9bb74fb675..22309502b7 100644 --- a/tests/functional/syntax/test_enum.py +++ b/tests/functional/syntax/test_flag.py @@ -2,7 +2,7 @@ from vyper import compiler from vyper.exceptions import ( - EnumDeclarationException, + FlagDeclarationException, InvalidOperation, NamespaceCollision, StructureException, @@ -16,7 +16,7 @@ event Action: pass -enum Action: +flag Action: BUY SELL """, @@ -24,23 +24,23 @@ ), ( """ -enum Action: +flag Action: pass """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Action: +flag Action: BUY BUY """, - EnumDeclarationException, + FlagDeclarationException, ), - ("enum Foo:\n" + "\n".join([f" member{i}" for i in range(257)]), EnumDeclarationException), + ("flag Foo:\n" + "\n".join([f" member{i}" for i in range(257)]), FlagDeclarationException), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -53,20 +53,20 @@ def foo(x: Roles) -> bool: ), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @external def foo(x: Roles) -> Roles: - return x.USER # can't dereference on enum instance + return x.USER # can't dereference on flag instance """, StructureException, ), ( """ -enum Roles: +flag Roles: USER STAFF ADMIN @@ -79,28 +79,28 @@ def foo(x: Roles) -> bool: ), ( """ -enum Functions: +flag Functions: def foo():nonpayable """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Numbers: +flag Numbers: a:constant(uint256) = a """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Numbers: +flag Numbers: 12 """, - EnumDeclarationException, + FlagDeclarationException, ), ( """ -enum Roles: +flag Roles: ADMIN USER @@ -112,9 +112,9 @@ def foo() -> Roles: ), ( """ -enum A: +flag A: a -enum B: +flag B: a b @@ -135,12 +135,12 @@ def test_fail_cases(bad_code): valid_list = [ """ -enum Action: +flag Action: BUY SELL """, """ -enum Action: +flag Action: BUY SELL @external @@ -148,7 +148,7 @@ def run() -> Action: return Action.BUY """, """ -enum Action: +flag Action: BUY SELL @@ -163,16 +163,16 @@ def run() -> Order: amount: 10**18 }) """, - "enum Foo:\n" + "\n".join([f" member{i}" for i in range(256)]), + "flag Foo:\n" + "\n".join([f" member{i}" for i in range(256)]), """ a: constant(uint256) = 1 -enum A: +flag A: a """, ] @pytest.mark.parametrize("good_code", valid_list) -def test_enum_success(good_code): +def test_flag_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/tests/functional/syntax/test_public.py b/tests/functional/syntax/test_public.py index 68575ebd41..71bff753f4 100644 --- a/tests/functional/syntax/test_public.py +++ b/tests/functional/syntax/test_public.py @@ -30,9 +30,9 @@ def foo() -> int128: x: public(HashMap[uint256, Foo]) """, - # expansion of public user-defined enum + # expansion of public user-defined flag """ -enum Foo: +flag Foo: BAR x: public(HashMap[uint256, Foo]) diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 38d58f6fd0..087708a356 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -246,7 +246,7 @@ def replace_constant( continue # do not replace enum members - if node.get_ancestor(vy_ast.EnumDef): + if node.get_ancestor(vy_ast.FlagDef): continue try: diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 15367ce94a..7889473b19 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -10,7 +10,8 @@ module: ( DOCSTRING | interface_def | constant_def | variable_def - | enum_def + | enum_def // TODO deprecate at some point in favor of flag + | flag_def | event_def | function_def | immutable_def @@ -76,12 +77,19 @@ indexed_event_arg: NAME ":" "indexed" "(" type ")" event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT event_def: _EVENT_DECL NAME ":" ( event_body | _PASS ) +// TODO deprecate in favor of flag // Enums _ENUM_DECL: "enum" enum_member: NAME enum_body: _NEWLINE _INDENT (enum_member _NEWLINE)+ _DEDENT enum_def: _ENUM_DECL NAME ":" enum_body +// Flags +_FLAG_DECL: "flag" +flag_member: NAME +flag_body: _NEWLINE _INDENT (flag_member _NEWLINE)+ _DEDENT +flag_def: _FLAG_DECL NAME ":" flag_body + // Types array_def: (NAME | array_def | dyn_array_def) "[" _expr "]" dyn_array_def: "DynArray" "[" (NAME | array_def | dyn_array_def) "," _expr "]" diff --git a/vyper/ast/identifiers.py b/vyper/ast/identifiers.py index 985b04e5cd..7d42727066 100644 --- a/vyper/ast/identifiers.py +++ b/vyper/ast/identifiers.py @@ -69,6 +69,7 @@ def validate_identifier(attr, ast_node=None): "struct", "event", "enum", + "flag" # EVM operations "unreachable", # special functions (no name mangling) diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 3bccc5f141..dba9f2a22d 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -4,6 +4,7 @@ import decimal import operator import sys +import warnings from typing import Any, Optional, Union from vyper.ast.metadata import NodeMetadata @@ -18,6 +19,7 @@ SyntaxException, TypeMismatch, UnfoldableNode, + VyperException, ZeroDivisionException, ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code @@ -78,6 +80,11 @@ def get_node( else: ast_struct["ast_type"] = "VariableDecl" + enum_warn = False + if ast_struct["ast_type"] == "EnumDef": + enum_warn = True + ast_struct["ast_type"] = "FlagDef" + vy_class = getattr(sys.modules[__name__], ast_struct["ast_type"], None) if not vy_class: if ast_struct["ast_type"] == "Delete": @@ -92,7 +99,17 @@ def get_node( ast_struct, ) - return vy_class(parent=parent, **ast_struct) + node = vy_class(parent=parent, **ast_struct) + + # TODO: Putting this after node creation to pretty print, remove after enum deprecation + if enum_warn: + # TODO: hack to pretty print, logic should be factored out of exception + pretty_printed_node = str(VyperException("", node)) + warnings.warn( + f"enum will be deprecated in a future release, use flag instead. {pretty_printed_node}", + stacklevel=2, + ) + return node def compare_nodes(left_node: "VyperNode", right_node: "VyperNode") -> bool: @@ -725,7 +742,7 @@ class Log(Stmt): __slots__ = ("value",) -class EnumDef(TopLevel): +class FlagDef(TopLevel): __slots__ = ("name", "body") diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 05784aed0f..47856b6021 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -81,7 +81,7 @@ class Return(VyperNode): ... class Log(VyperNode): value: VyperNode = ... -class EnumDef(VyperNode): +class FlagDef(VyperNode): body: list = ... name: str = ... diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index 9d96efea5e..b949a242bb 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -44,7 +44,8 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: # compound statements that are replaced with `class` -VYPER_CLASS_TYPES = {"enum", "event", "interface", "struct"} +# TODO remove enum in favor of flag +VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} # simple statements or expressions that are replaced with `yield` VYPER_EXPRESSION_TYPES = {"log"} @@ -55,7 +56,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Re-formats a vyper source string into a python source string and performs some validation. More specifically, - * Translates "interface", "struct", "enum, and "event" keywords into python "class" keyword + * Translates "interface", "struct", "flag", and "event" keywords into python "class" keyword * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py index e09f5f3174..998cbbc9f6 100644 --- a/vyper/builtins/_convert.py +++ b/vyper/builtins/_convert.py @@ -14,7 +14,7 @@ int_clamp, is_bytes_m_type, is_decimal_type, - is_enum_type, + is_flag_type, is_integer_type, sar, shl, @@ -35,7 +35,7 @@ BytesM_T, BytesT, DecimalT, - EnumT, + FlagT, IntegerT, StringT, ) @@ -277,7 +277,7 @@ def to_bool(expr, arg, out_typ): return IRnode.from_list(["iszero", ["iszero", arg]], typ=out_typ) -@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, EnumT, BytesT) +@_input_types(IntegerT, DecimalT, BytesM_T, AddressT, BoolT, FlagT, BytesT) def to_int(expr, arg, out_typ): return _to_int(expr, arg, out_typ) @@ -305,7 +305,7 @@ def _to_int(expr, arg, out_typ): elif is_decimal_type(arg.typ): arg = _fixed_to_int(arg, out_typ) - elif is_enum_type(arg.typ): + elif is_flag_type(arg.typ): if out_typ != UINT256_T: _FAIL(arg.typ, out_typ, expr) # pretend enum is uint256 @@ -468,7 +468,7 @@ def convert(expr, context): ret = to_bool(arg_ast, arg, out_typ) elif out_typ == AddressT(): ret = to_address(arg_ast, arg, out_typ) - elif is_enum_type(out_typ): + elif is_flag_type(out_typ): ret = to_enum(arg_ast, arg, out_typ) elif is_integer_type(out_typ): ret = to_int(arg_ast, arg, out_typ) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index d61c6641c2..f705e9deed 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -25,7 +25,7 @@ ) from vyper.semantics.types.shortcuts import BYTES32_T, INT256_T, UINT256_T from vyper.semantics.types.subscriptable import SArrayT -from vyper.semantics.types.user import EnumT +from vyper.semantics.types.user import FlagT from vyper.utils import GAS_COPY_WORD, GAS_IDENTITY, GAS_IDENTITYWORD, ceil32 DYNAMIC_ARRAY_OVERHEAD = 1 @@ -47,8 +47,8 @@ def is_decimal_type(typ): return isinstance(typ, DecimalT) -def is_enum_type(typ): - return isinstance(typ, EnumT) +def is_flag_type(typ): + return isinstance(typ, FlagT) def is_tuple_like(typ): @@ -878,7 +878,7 @@ def needs_clamp(t, encoding): raise CompilerPanic("unreachable") # pragma: notest if isinstance(t, (_BytestringT, DArrayT)): return True - if isinstance(t, EnumT): + if isinstance(t, FlagT): return len(t._enum_members) < 256 if isinstance(t, SArrayT): return needs_clamp(t.value_type, encoding) @@ -1187,7 +1187,7 @@ def clamp_basetype(ir_node): # copy of the input ir_node = unwrap_location(ir_node) - if isinstance(t, EnumT): + if isinstance(t, FlagT): bits = len(t._enum_members) # assert x >> bits == 0 ret = int_clamp(ir_node, bits, signed=False) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 6858c83193..c46f8cec1b 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -12,7 +12,7 @@ getpos, is_array_like, is_bytes_m_type, - is_enum_type, + is_flag_type, is_numeric_type, is_tuple_like, pop_dyn_array, @@ -43,7 +43,7 @@ BytesT, DArrayT, DecimalT, - EnumT, + FlagT, HashMapT, InterfaceT, SArrayT, @@ -213,7 +213,7 @@ def parse_Attribute(self): # MyEnum.foo if ( - isinstance(typ, EnumT) + isinstance(typ, FlagT) and isinstance(self.expr.value, vy_ast.Name) and typ.name == self.expr.value.id ): @@ -397,7 +397,7 @@ def parse_BinOp(self): # This should be unreachable due to the type check pass if left.typ != right.typ: raise TypeCheckFailure(f"unreachable, {left.typ} != {right.typ}", self.expr) - assert is_numeric_type(left.typ) or is_enum_type(left.typ) + assert is_numeric_type(left.typ) or is_flag_type(left.typ) out_typ = left.typ @@ -529,7 +529,7 @@ def parse_Compare(self): if is_array_like(right.typ): return self.build_in_comparator() else: - assert isinstance(right.typ, EnumT), right.typ + assert isinstance(right.typ, FlagT), right.typ intersection = ["and", left, right] if isinstance(self.expr.op, vy_ast.In): return IRnode.from_list(["iszero", ["iszero", intersection]], typ=BoolT()) @@ -646,7 +646,7 @@ def parse_UnaryOp(self): return IRnode.from_list(["iszero", operand], typ=BoolT()) if isinstance(self.expr.op, vy_ast.Invert): - if isinstance(operand.typ, EnumT): + if isinstance(operand.typ, FlagT): n_members = len(operand.typ._enum_members) # use (xor 0b11..1 operand) to flip all the bits in # `operand`. `mask` could be a very large constant and diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 2cd2c6d167..0c549ec10f 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -176,8 +176,8 @@ class FunctionDeclarationException(VyperException): """Invalid function declaration.""" -class EnumDeclarationException(VyperException): - """Invalid enum declaration.""" +class FlagDeclarationException(VyperException): + """Invalid flag declaration.""" class EventDeclarationException(VyperException): diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index d649058c3d..8f6103a217 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -38,8 +38,8 @@ AddressT, BoolT, DArrayT, - EnumT, EventT, + FlagT, HashMapT, IntegerT, SArrayT, @@ -708,7 +708,7 @@ def visit_Compare(self, node: vy_ast.Compare, typ: VyperType) -> None: validate_expected_type(node.right, rtyp) else: rtyp = get_exact_type_from_node(node.right) - if isinstance(rtyp, EnumT): + if isinstance(rtyp, FlagT): # enum membership - `some_enum in other_enum` ltyp = rtyp else: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index c87a049f47..a9bd3a3c6c 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -32,7 +32,7 @@ ) from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace -from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT +from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.semantics.types.utils import type_from_annotation @@ -341,8 +341,8 @@ def _validate_self_namespace(): return _finalize() - def visit_EnumDef(self, node): - obj = EnumT.from_EnumDef(node) + def visit_FlagDef(self, node): + obj = FlagT.from_FlagDef(node) self.namespace[node.name] = obj def visit_EventDef(self, node): diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 5ecc89c612..7a83b44ca6 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -245,13 +245,13 @@ def types_from_Compare(self, node): # comparisons, e.g. `x < y` # TODO fixme circular import - from vyper.semantics.types.user import EnumT + from vyper.semantics.types.user import FlagT if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): # x in y left = self.get_possible_types_from_node(node.left) right = self.get_possible_types_from_node(node.right) - if any(isinstance(t, EnumT) for t in left): + if any(isinstance(t, FlagT) for t in left): types_list = get_common_types(node.left, node.right) _validate_op(node, types_list, "validate_comparator") return [BoolT()] diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 2ae4dd8454..a8586a62dd 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -5,7 +5,7 @@ from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT -from .user import EnumT, EventT, StructT +from .user import EventT, FlagT, StructT def _get_primitive_types(): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index b4ba417167..45fe8bbcc6 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -5,8 +5,8 @@ from vyper.abi_types import ABI_GIntM, ABI_Tuple, ABIType from vyper.ast.validation import validate_call_args from vyper.exceptions import ( - EnumDeclarationException, EventDeclarationException, + FlagDeclarationException, InvalidAttribute, NamespaceCollision, StructureException, @@ -43,7 +43,7 @@ def __hash__(self): # note: enum behaves a lot like uint256, or uints in general. -class EnumT(_UserType): +class FlagT(_UserType): # this is a carveout because currently we allow dynamic arrays of # enums, but not static arrays of enums _as_darray = True @@ -52,7 +52,7 @@ class EnumT(_UserType): def __init__(self, name: str, members: dict) -> None: if len(members.keys()) > 256: - raise EnumDeclarationException("Enums are limited to 256 members!") + raise FlagDeclarationException("Enums are limited to 256 members!") super().__init__(members=None) @@ -103,7 +103,7 @@ def validate_comparator(self, node): # return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments)})" @classmethod - def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": + def from_FlagDef(cls, base_node: vy_ast.FlagDef) -> "FlagT": """ Generate an `Enum` object from a Vyper ast node. @@ -118,15 +118,15 @@ def from_EnumDef(cls, base_node: vy_ast.EnumDef) -> "EnumT": members: dict = {} if len(base_node.body) == 1 and isinstance(base_node.body[0], vy_ast.Pass): - raise EnumDeclarationException("Enum must have members", base_node) + raise FlagDeclarationException("Enum must have members", base_node) for i, node in enumerate(base_node.body): if not isinstance(node, vy_ast.Expr) or not isinstance(node.value, vy_ast.Name): - raise EnumDeclarationException("Invalid syntax for enum member", node) + raise FlagDeclarationException("Invalid syntax for enum member", node) member_name = node.value.id if member_name in members: - raise EnumDeclarationException( + raise FlagDeclarationException( f"Enum member '{member_name}' has already been declared", node.value ) From 1824321b70b0602223e47607bc32a2f12f89c015 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Sat, 23 Dec 2023 18:42:19 +0100 Subject: [PATCH 18/27] refactor: make `assert_tx_failed` a contextmanager (#3706) rename `assert_tx_failed` to `tx_failed` and change it into a context manager which has a similar API to `pytest.raises()`. --------- Co-authored-by: Charles Cooper --- docs/testing-contracts-ethtester.rst | 4 +- tests/conftest.py | 28 +-- .../builtins/codegen/test_abi_decode.py | 25 +- .../builtins/codegen/test_addmod.py | 5 +- .../builtins/codegen/test_as_wei_value.py | 16 +- .../builtins/codegen/test_convert.py | 13 +- .../builtins/codegen/test_create_functions.py | 59 ++--- .../builtins/codegen/test_extract32.py | 19 +- .../builtins/codegen/test_minmax.py | 2 +- .../builtins/codegen/test_mulmod.py | 5 +- .../builtins/codegen/test_raw_call.py | 17 +- .../functional/builtins/codegen/test_send.py | 8 +- .../functional/builtins/codegen/test_slice.py | 13 +- .../functional/builtins/codegen/test_unary.py | 5 +- tests/functional/builtins/folding/test_abs.py | 7 +- .../test_default_function.py | 19 +- .../test_default_parameters.py | 5 +- .../calling_convention/test_erc20_abi.py | 18 +- .../test_external_contract_calls.py | 150 +++++++----- ...test_modifiable_external_contract_calls.py | 21 +- .../calling_convention/test_return_tuple.py | 7 +- .../environment_variables/test_blockhash.py | 10 +- .../features/decorators/test_nonreentrant.py | 18 +- .../features/decorators/test_payable.py | 23 +- .../features/decorators/test_private.py | 2 +- .../features/iteration/test_for_range.py | 10 +- .../features/iteration/test_range_in.py | 10 +- .../codegen/features/test_assert.py | 38 +-- .../features/test_assert_unreachable.py | 24 +- .../codegen/features/test_clampers.py | 92 ++++---- .../functional/codegen/features/test_init.py | 5 +- .../codegen/features/test_logging.py | 81 ++++--- .../codegen/features/test_reverting.py | 27 +-- .../codegen/integration/test_escrow.py | 10 +- tests/functional/codegen/test_interfaces.py | 32 ++- .../functional/codegen/test_selector_table.py | 16 +- .../codegen/test_stateless_modules.py | 5 +- .../codegen/types/numbers/test_constants.py | 15 +- .../codegen/types/numbers/test_decimals.py | 33 ++- .../codegen/types/numbers/test_exponents.py | 31 ++- .../codegen/types/numbers/test_modulo.py | 5 +- .../codegen/types/numbers/test_signed_ints.py | 68 ++++-- .../types/numbers/test_unsigned_ints.py | 28 ++- tests/functional/codegen/types/test_bytes.py | 5 +- .../codegen/types/test_dynamic_array.py | 51 ++-- tests/functional/codegen/types/test_flag.py | 34 ++- tests/functional/codegen/types/test_lists.py | 33 ++- tests/functional/codegen/types/test_string.py | 14 +- .../examples/auctions/test_blind_auction.py | 38 ++- .../auctions/test_simple_open_auction.py | 19 +- .../examples/company/test_company.py | 49 ++-- .../crowdfund/test_crowdfund_example.py | 8 +- .../test_on_chain_market_maker.py | 25 +- .../name_registry/test_name_registry.py | 5 +- .../test_safe_remote_purchase.py | 34 ++- .../examples/storage/test_advanced_storage.py | 22 +- .../examples/tokens/test_erc1155.py | 223 +++++++++--------- .../functional/examples/tokens/test_erc20.py | 108 ++++++--- .../functional/examples/tokens/test_erc721.py | 103 ++++---- .../functional/examples/voting/test_ballot.py | 29 ++- .../functional/examples/wallet/test_wallet.py | 22 +- .../ast/nodes/test_evaluate_binop_decimal.py | 10 +- .../unit/ast/nodes/test_evaluate_binop_int.py | 15 +- 63 files changed, 1051 insertions(+), 825 deletions(-) diff --git a/docs/testing-contracts-ethtester.rst b/docs/testing-contracts-ethtester.rst index 992cdc312a..1b7e9e3263 100644 --- a/docs/testing-contracts-ethtester.rst +++ b/docs/testing-contracts-ethtester.rst @@ -55,9 +55,9 @@ To test events and failed transactions we expand our simple storage contract to Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions. -.. literalinclude:: ../tests/base_conftest.py +.. literalinclude:: ../tests/conftest.py :language: python - :pyobject: assert_tx_failed + :pyobject: tx_failed The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction. diff --git a/tests/conftest.py b/tests/conftest.py index 925a025a4a..51b4b4459a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import json import logging +from contextlib import contextmanager from functools import wraps import hypothesis @@ -411,23 +412,6 @@ def assert_compile_failed(function_to_test, exception=Exception): return assert_compile_failed -# TODO this should not be a fixture -@pytest.fixture -def search_for_sublist(): - def search_for_sublist(ir, sublist): - _list = ir.to_list() if hasattr(ir, "to_list") else ir - if _list == sublist: - return True - if isinstance(_list, list): - for i in _list: - ret = search_for_sublist(i, sublist) - if ret is True: - return ret - return False - - return search_for_sublist - - @pytest.fixture def create2_address_of(keccak): def _f(_addr, _salt, _initcode): @@ -484,16 +468,16 @@ def get_logs(tx_hash, c, event_name): return get_logs -# TODO replace me with function like `with anchor_state()` @pytest.fixture(scope="module") -def assert_tx_failed(tester): - def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None): +def tx_failed(tester): + @contextmanager + def fn(exception=TransactionFailed, exc_text=None): snapshot_id = tester.take_snapshot() with pytest.raises(exception) as excinfo: - function_to_test() + yield excinfo tester.revert_to_snapshot(snapshot_id) if exc_text: # TODO test equality assert exc_text in str(excinfo.value), (exc_text, excinfo.value) - return assert_tx_failed + return fn diff --git a/tests/functional/builtins/codegen/test_abi_decode.py b/tests/functional/builtins/codegen/test_abi_decode.py index 242841e1cf..69bfef63ea 100644 --- a/tests/functional/builtins/codegen/test_abi_decode.py +++ b/tests/functional/builtins/codegen/test_abi_decode.py @@ -331,7 +331,7 @@ def abi_decode(x: Bytes[32]) -> uint256: b"\x01" * 96, # Length of byte array is beyond size bound of output type ], ) -def test_clamper(get_contract, assert_tx_failed, input_): +def test_clamper(get_contract, tx_failed, input_): contract = """ @external def abi_decode(x: Bytes[96]) -> (uint256, uint256): @@ -341,10 +341,11 @@ def abi_decode(x: Bytes[96]) -> (uint256, uint256): return a, b """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) -def test_clamper_nested_uint8(get_contract, assert_tx_failed): +def test_clamper_nested_uint8(get_contract, tx_failed): # check that _abi_decode clamps on word-types even when it is in a nested expression # decode -> validate uint8 -> revert if input >= 256 -> cast back to uint256 contract = """ @@ -355,10 +356,11 @@ def abi_decode(x: uint256) -> uint256: """ c = get_contract(contract) assert c.abi_decode(255) == 255 - assert_tx_failed(lambda: c.abi_decode(256)) + with tx_failed(): + c.abi_decode(256) -def test_clamper_nested_bytes(get_contract, assert_tx_failed): +def test_clamper_nested_bytes(get_contract, tx_failed): # check that _abi_decode clamps dynamic even when it is in a nested expression # decode -> validate Bytes[20] -> revert if len(input) > 20 -> convert back to -> add 1 contract = """ @@ -369,7 +371,8 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]: """ c = get_contract(contract) assert c.abi_decode(abi.encode("(bytes)", (b"bc",))) == b"abc" - assert_tx_failed(lambda: c.abi_decode(abi.encode("(bytes)", (b"a" * 22,)))) + with tx_failed(): + c.abi_decode(abi.encode("(bytes)", (b"a" * 22,))) @pytest.mark.parametrize( @@ -381,7 +384,7 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]: ("Bytes[5]", b"\x01" * 192), ], ) -def test_clamper_dynamic(get_contract, assert_tx_failed, output_typ, input_): +def test_clamper_dynamic(get_contract, tx_failed, output_typ, input_): contract = f""" @external def abi_decode(x: Bytes[192]) -> {output_typ}: @@ -390,7 +393,8 @@ def abi_decode(x: Bytes[192]) -> {output_typ}: return a """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) @pytest.mark.parametrize( @@ -422,7 +426,7 @@ def abi_decode(x: Bytes[160]) -> uint256: ("Bytes[5]", "address", b"\x01" * 128), ], ) -def test_clamper_dynamic_tuple(get_contract, assert_tx_failed, output_typ1, output_typ2, input_): +def test_clamper_dynamic_tuple(get_contract, tx_failed, output_typ1, output_typ2, input_): contract = f""" @external def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}): @@ -432,7 +436,8 @@ def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}): return a, b """ c = get_contract(contract) - assert_tx_failed(lambda: c.abi_decode(input_)) + with tx_failed(): + c.abi_decode(input_) FAIL_LIST = [ diff --git a/tests/functional/builtins/codegen/test_addmod.py b/tests/functional/builtins/codegen/test_addmod.py index b3135660bb..00745c0cdb 100644 --- a/tests/functional/builtins/codegen/test_addmod.py +++ b/tests/functional/builtins/codegen/test_addmod.py @@ -1,4 +1,4 @@ -def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation): +def test_uint256_addmod(tx_failed, get_contract_with_gas_estimation): uint256_code = """ @external def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: @@ -11,7 +11,8 @@ def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256: assert c._uint256_addmod(32, 2, 32) == 2 assert c._uint256_addmod((2**256) - 1, 0, 2) == 1 assert c._uint256_addmod(2**255, 2**255, 6) == 4 - assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0)) + with tx_failed(): + c._uint256_addmod(1, 2, 0) def test_uint256_addmod_ext_call( diff --git a/tests/functional/builtins/codegen/test_as_wei_value.py b/tests/functional/builtins/codegen/test_as_wei_value.py index cc27507e7c..522684fa05 100644 --- a/tests/functional/builtins/codegen/test_as_wei_value.py +++ b/tests/functional/builtins/codegen/test_as_wei_value.py @@ -23,7 +23,7 @@ @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_uint256(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_uint256(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: uint256) -> uint256: @@ -36,11 +36,12 @@ def foo(a: uint256) -> uint256: assert c.foo(value) == value * (10**multiplier) value = (2**256 - 1) // (10 ** (multiplier - 1)) - assert_tx_failed(lambda: c.foo(value)) + with tx_failed(): + c.foo(value) @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_int128(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_int128(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: int128) -> uint256: @@ -54,7 +55,7 @@ def foo(a: int128) -> uint256: @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) -def test_wei_decimal(get_contract, assert_tx_failed, denom, multiplier): +def test_wei_decimal(get_contract, tx_failed, denom, multiplier): code = f""" @external def foo(a: decimal) -> uint256: @@ -69,7 +70,7 @@ def foo(a: decimal) -> uint256: @pytest.mark.parametrize("value", (-1, -(2**127))) @pytest.mark.parametrize("data_type", ["decimal", "int128"]) -def test_negative_value_reverts(get_contract, assert_tx_failed, value, data_type): +def test_negative_value_reverts(get_contract, tx_failed, value, data_type): code = f""" @external def foo(a: {data_type}) -> uint256: @@ -77,12 +78,13 @@ def foo(a: {data_type}) -> uint256: """ c = get_contract(code) - assert_tx_failed(lambda: c.foo(value)) + with tx_failed(): + c.foo(value) @pytest.mark.parametrize("denom,multiplier", wei_denoms.items()) @pytest.mark.parametrize("data_type", ["decimal", "int128", "uint256"]) -def test_zero_value(get_contract, assert_tx_failed, denom, multiplier, data_type): +def test_zero_value(get_contract, tx_failed, denom, multiplier, data_type): code = f""" @external def foo(a: {data_type}) -> uint256: diff --git a/tests/functional/builtins/codegen/test_convert.py b/tests/functional/builtins/codegen/test_convert.py index 99dae4a932..559e1448ef 100644 --- a/tests/functional/builtins/codegen/test_convert.py +++ b/tests/functional/builtins/codegen/test_convert.py @@ -511,7 +511,7 @@ def bar(a: uint256) -> Roles: @pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"]) @pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2]) def test_flag_conversion_2( - get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ + get_contract_with_gas_estimation, assert_compile_failed, tx_failed, val, typ ): contract = f""" flag Status: @@ -529,7 +529,8 @@ def foo(a: {typ}) -> Status: if lo <= val <= hi: assert c.foo(val) == val else: - assert_tx_failed(lambda: c.foo(val)) + with tx_failed(): + c.foo(val) else: assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch) @@ -608,7 +609,7 @@ def foo() -> {t_bytes}: @pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases()) @pytest.mark.fuzzing def test_conversion_failures( - get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, i_typ, o_typ, val + get_contract_with_gas_estimation, assert_compile_failed, tx_failed, i_typ, o_typ, val ): """ Test multiple contracts and check for a specific exception. @@ -650,7 +651,8 @@ def foo(): """ c2 = get_contract_with_gas_estimation(contract_2) - assert_tx_failed(lambda: c2.foo()) + with tx_failed(): + c2.foo() contract_3 = f""" @external @@ -659,4 +661,5 @@ def foo(bar: {i_typ}) -> {o_typ}: """ c3 = get_contract_with_gas_estimation(contract_3) - assert_tx_failed(lambda: c3.foo(val)) + with tx_failed(): + c3.foo(val) diff --git a/tests/functional/builtins/codegen/test_create_functions.py b/tests/functional/builtins/codegen/test_create_functions.py index fa7729d98e..afa729ac8a 100644 --- a/tests/functional/builtins/codegen/test_create_functions.py +++ b/tests/functional/builtins/codegen/test_create_functions.py @@ -77,7 +77,7 @@ def test2() -> Bytes[100]: assert c.test2() == b"hello world!" -def test_minimal_proxy_exception(w3, get_contract, assert_tx_failed): +def test_minimal_proxy_exception(w3, get_contract, tx_failed): code = """ interface SubContract: @@ -111,7 +111,8 @@ def test2(a: uint256) -> Bytes[100]: c.test(transact={}) assert c.test2(1) == b"hello world!" - assert_tx_failed(lambda: c.test2(0)) + with tx_failed(): + c.test2(0) GAS_SENT = 30000 tx_hash = c.test2(0, transact={"gas": GAS_SENT}) @@ -122,9 +123,7 @@ def test2(a: uint256) -> Bytes[100]: assert receipt["gasUsed"] < GAS_SENT -def test_create_minimal_proxy_to_create2( - get_contract, create2_address_of, keccak, assert_tx_failed -): +def test_create_minimal_proxy_to_create2(get_contract, create2_address_of, keccak, tx_failed): code = """ main: address @@ -143,20 +142,15 @@ def test(_salt: bytes32) -> address: c.test(salt, transact={}) # revert on collision - assert_tx_failed(lambda: c.test(salt, transact={})) + with tx_failed(): + c.test(salt, transact={}) # test blueprints with various prefixes - 0xfe would block calls to the blueprint # contract, and 0xfe7100 is ERC5202 magic @pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"]) def test_create_from_blueprint( - get_contract, - deploy_blueprint_for, - w3, - keccak, - create2_address_of, - assert_tx_failed, - blueprint_prefix, + get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed, blueprint_prefix ): code = """ @external @@ -193,7 +187,8 @@ def test2(target: address, salt: bytes32): # extcodesize check zero_address = "0x" + "00" * 20 - assert_tx_failed(lambda: d.test(zero_address)) + with tx_failed(): + d.test(zero_address) # now same thing but with create2 salt = keccak(b"vyper") @@ -209,11 +204,12 @@ def test2(target: address, salt: bytes32): assert HexBytes(test.address) == create2_address_of(d.address, salt, initcode) # can't collide addresses - assert_tx_failed(lambda: d.test2(f.address, salt)) + with tx_failed(): + d.test2(f.address, salt) def test_create_from_blueprint_bad_code_offset( - get_contract, get_contract_from_ir, deploy_blueprint_for, w3, assert_tx_failed + get_contract, get_contract_from_ir, deploy_blueprint_for, w3, tx_failed ): deployer_code = """ BLUEPRINT: immutable(address) @@ -254,15 +250,17 @@ def test(code_ofst: uint256) -> address: d.test(initcode_len - 1) # code_offset=len(blueprint) NOT fine! would EXTCODECOPY empty initcode - assert_tx_failed(lambda: d.test(initcode_len)) + with tx_failed(): + d.test(initcode_len) # code_offset=EIP_170_LIMIT definitely not fine! - assert_tx_failed(lambda: d.test(EIP_170_LIMIT)) + with tx_failed(): + d.test(EIP_170_LIMIT) # test create_from_blueprint with args def test_create_from_blueprint_args( - get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, assert_tx_failed + get_contract, deploy_blueprint_for, w3, keccak, create2_address_of, tx_failed ): code = """ struct Bar: @@ -332,7 +330,8 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): assert test.bar() == BAR # extcodesize check - assert_tx_failed(lambda: d.test("0x" + "00" * 20, FOO, BAR)) + with tx_failed(): + d.test("0x" + "00" * 20, FOO, BAR) # now same thing but with create2 salt = keccak(b"vyper") @@ -359,9 +358,11 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): assert test.bar() == BAR # can't collide addresses - assert_tx_failed(lambda: d.test2(f.address, FOO, BAR, salt)) + with tx_failed(): + d.test2(f.address, FOO, BAR, salt) # ditto - with raw_args - assert_tx_failed(lambda: d.test4(f.address, encoded_args, salt)) + with tx_failed(): + d.test4(f.address, encoded_args, salt) # but creating a contract with different args is ok FOO = "bar" @@ -375,10 +376,11 @@ def should_fail(target: address, arg1: String[129], arg2: Bar): BAR = ("",) sig = keccak("should_fail(address,string,(string))".encode()).hex()[:10] encoded = abi.encode("(address,string,(string))", (f.address, FOO, BAR)).hex() - assert_tx_failed(lambda: w3.eth.send_transaction({"to": d.address, "data": f"{sig}{encoded}"})) + with tx_failed(): + w3.eth.send_transaction({"to": d.address, "data": f"{sig}{encoded}"}) -def test_create_copy_of(get_contract, w3, keccak, create2_address_of, assert_tx_failed): +def test_create_copy_of(get_contract, w3, keccak, create2_address_of, tx_failed): code = """ created_address: public(address) @internal @@ -412,7 +414,8 @@ def test2(target: address, salt: bytes32) -> address: assert w3.eth.get_code(test1) == bytecode # extcodesize check - assert_tx_failed(lambda: c.test("0x" + "00" * 20)) + with tx_failed(): + c.test("0x" + "00" * 20) # test1 = c.test(b"\x01") # assert w3.eth.get_code(test1) == b"\x01" @@ -425,12 +428,14 @@ def test2(target: address, salt: bytes32) -> address: assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(bytecode)) # can't create2 where contract already exists - assert_tx_failed(lambda: c.test2(c.address, salt, transact={})) + with tx_failed(): + c.test2(c.address, salt, transact={}) # test single byte contract # test2 = c.test2(b"\x01", salt) # assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(b"\x01")) - # assert_tx_failed(lambda: c.test2(bytecode, salt)) + # with tx_failed(): + # c.test2(bytecode, salt) # XXX: these various tests to check the msize allocator for diff --git a/tests/functional/builtins/codegen/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py index 6e4ee09abc..a95b57b5ab 100644 --- a/tests/functional/builtins/codegen/test_extract32.py +++ b/tests/functional/builtins/codegen/test_extract32.py @@ -1,4 +1,4 @@ -def test_extract32_extraction(assert_tx_failed, get_contract_with_gas_estimation): +def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation): extract32_code = """ y: Bytes[100] @external @@ -34,18 +34,19 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32: ) for S, i in test_cases: - expected_result = S[i : i + 32] if 0 <= i <= len(S) - 32 else None - if expected_result is None: - assert_tx_failed(lambda p=(S, i): c.extrakt32(*p)) - else: + if 0 <= i <= len(S) - 32: + expected_result = S[i : i + 32] assert c.extrakt32(S, i) == expected_result assert c.extrakt32_mem(S, i) == expected_result assert c.extrakt32_storage(i, S) == expected_result + else: + with tx_failed(): + c.extrakt32(S, i) print("Passed bytes32 extraction test") -def test_extract32_code(assert_tx_failed, get_contract_with_gas_estimation): +def test_extract32_code(tx_failed, get_contract_with_gas_estimation): extract32_code = """ @external def foo(inp: Bytes[32]) -> int128: @@ -72,7 +73,8 @@ def foq(inp: Bytes[32]) -> address: assert c.foo(b"\x00" * 30 + b"\x01\x01") == 257 assert c.bar(b"\x00" * 30 + b"\x01\x01") == 257 - assert_tx_failed(lambda: c.foo(b"\x80" + b"\x00" * 30)) + with tx_failed(): + c.foo(b"\x80" + b"\x00" * 30) assert c.bar(b"\x80" + b"\x00" * 31) == 2**255 @@ -80,6 +82,7 @@ def foq(inp: Bytes[32]) -> address: assert c.fop(b"crow" * 8) == b"crow" * 8 assert c.foq(b"\x00" * 12 + b"3" * 20) == "0x" + "3" * 40 - assert_tx_failed(lambda: c.foq(b"crow" * 8)) + with tx_failed(): + c.foq(b"crow" * 8) print("Passed extract32 test") diff --git a/tests/functional/builtins/codegen/test_minmax.py b/tests/functional/builtins/codegen/test_minmax.py index da939d605a..f86504522f 100644 --- a/tests/functional/builtins/codegen/test_minmax.py +++ b/tests/functional/builtins/codegen/test_minmax.py @@ -198,7 +198,7 @@ def foo() -> uint256: def test_minmax_var_uint256_negative_int128( - get_contract_with_gas_estimation, assert_tx_failed, assert_compile_failed + get_contract_with_gas_estimation, tx_failed, assert_compile_failed ): from vyper.exceptions import TypeMismatch diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index 96477897b9..ba82ebd5b8 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -1,4 +1,4 @@ -def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation): +def test_uint256_mulmod(tx_failed, get_contract_with_gas_estimation): uint256_code = """ @external def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: @@ -11,7 +11,8 @@ def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256: assert c._uint256_mulmod(200, 3, 601) == 600 assert c._uint256_mulmod(2**255, 1, 3) == 2 assert c._uint256_mulmod(2**255, 2, 6) == 4 - assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0)) + with tx_failed(): + c._uint256_mulmod(2, 2, 0) def test_uint256_mulmod_complex(get_contract_with_gas_estimation): diff --git a/tests/functional/builtins/codegen/test_raw_call.py b/tests/functional/builtins/codegen/test_raw_call.py index 5bb23447e4..4d37176cf8 100644 --- a/tests/functional/builtins/codegen/test_raw_call.py +++ b/tests/functional/builtins/codegen/test_raw_call.py @@ -91,7 +91,7 @@ def create_and_return_proxy(inp: address) -> address: # print(f'Gas consumed: {(chain.head_state.receipts[-1].gas_used - chain.head_state.receipts[-2].gas_used - chain.last_tx.intrinsic_gas_used)}') # noqa: E501 -def test_multiple_levels2(assert_tx_failed, get_contract_with_gas_estimation): +def test_multiple_levels2(tx_failed, get_contract_with_gas_estimation): inner_code = """ @external def returnten() -> int128: @@ -114,7 +114,8 @@ def create_and_return_proxy(inp: address) -> address: c2 = get_contract_with_gas_estimation(outer_code) - assert_tx_failed(lambda: c2.create_and_call_returnten(c.address)) + with tx_failed(): + c2.create_and_call_returnten(c.address) print("Passed minimal proxy exception test") @@ -171,7 +172,7 @@ def set(i: int128, owner: address): assert outer_contract.owners(1) == a1 -def test_gas(get_contract, assert_tx_failed): +def test_gas(get_contract, tx_failed): inner_code = """ bar: bytes32 @@ -202,7 +203,8 @@ def foo_call(_addr: address): # manually specifying an insufficient amount should fail outer_contract = get_contract(outer_code.format(", gas=15000")) - assert_tx_failed(lambda: outer_contract.foo_call(inner_contract.address)) + with tx_failed(): + outer_contract.foo_call(inner_contract.address) def test_static_call(get_contract): @@ -323,7 +325,7 @@ def foo(_addr: address) -> bool: assert caller.foo(target.address) is True -def test_static_call_fails_nonpayable(get_contract, assert_tx_failed): +def test_static_call_fails_nonpayable(get_contract, tx_failed): target_source = """ baz: int128 @@ -349,10 +351,11 @@ def foo(_addr: address) -> int128: target = get_contract(target_source) caller = get_contract(caller_source) - assert_tx_failed(lambda: caller.foo(target.address)) + with tx_failed(): + caller.foo(target.address) -def test_checkable_raw_call(get_contract, assert_tx_failed): +def test_checkable_raw_call(get_contract, tx_failed): target_source = """ baz: int128 @external diff --git a/tests/functional/builtins/codegen/test_send.py b/tests/functional/builtins/codegen/test_send.py index 199f708cb4..36f8979556 100644 --- a/tests/functional/builtins/codegen/test_send.py +++ b/tests/functional/builtins/codegen/test_send.py @@ -1,4 +1,4 @@ -def test_send(assert_tx_failed, get_contract): +def test_send(tx_failed, get_contract): send_test = """ @external def foo(): @@ -9,9 +9,11 @@ def fop(): send(msg.sender, 10) """ c = get_contract(send_test, value=10) - assert_tx_failed(lambda: c.foo(transact={})) + with tx_failed(): + c.foo(transact={}) c.fop(transact={}) - assert_tx_failed(lambda: c.fop(transact={})) + with tx_failed(): + c.fop(transact={}) def test_default_gas(get_contract, w3): diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 53e092019f..a15a3eeb35 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -41,7 +41,7 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: def test_slice_immutable( get_contract, assert_compile_failed, - assert_tx_failed, + tx_failed, opt_level, bytesdata, start, @@ -79,7 +79,8 @@ def _get_contract(): assert_compile_failed(lambda: _get_contract(), ArgumentException) elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): # deploy fail - assert_tx_failed(lambda: _get_contract()) + with tx_failed(): + _get_contract() else: c = _get_contract() assert c.do_splice() == bytesdata[start : start + length] @@ -95,7 +96,7 @@ def _get_contract(): def test_slice_bytes_fuzz( get_contract, assert_compile_failed, - assert_tx_failed, + tx_failed, opt_level, location, bytesdata, @@ -175,10 +176,12 @@ def _get_contract(): assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) elif location == "code" and len(bytesdata) > length_bound: # deploy fail - assert_tx_failed(lambda: _get_contract()) + with tx_failed(): + _get_contract() elif end > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() - assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) + with tx_failed(): + c.do_slice(bytesdata, start, length) else: c = _get_contract() assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code diff --git a/tests/functional/builtins/codegen/test_unary.py b/tests/functional/builtins/codegen/test_unary.py index da3823edfe..33f79be233 100644 --- a/tests/functional/builtins/codegen/test_unary.py +++ b/tests/functional/builtins/codegen/test_unary.py @@ -13,14 +13,15 @@ def negate(a: uint256) -> uint256: assert_compile_failed(lambda: get_contract(code), exception=InvalidOperation) -def test_unary_sub_int128_fail(get_contract, assert_tx_failed): +def test_unary_sub_int128_fail(get_contract, tx_failed): code = """@external def negate(a: int128) -> int128: return -(a) """ c = get_contract(code) # This test should revert on overflow condition - assert_tx_failed(lambda: c.negate(-(2**127))) + with tx_failed(): + c.negate(-(2**127)) @pytest.mark.parametrize("val", [-(2**127) + 1, 0, 2**127 - 1]) diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 1c919d7826..a91a4f1ad3 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -39,7 +39,7 @@ def foo(a: int256) -> int256: get_contract(source) -def test_abs_lower_bound(get_contract, assert_tx_failed): +def test_abs_lower_bound(get_contract, tx_failed): source = """ @external def foo(a: int256) -> int256: @@ -47,10 +47,11 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - assert_tx_failed(lambda: contract.foo(-(2**255))) + with tx_failed(): + contract.foo(-(2**255)) -def test_abs_lower_bound_folded(get_contract, assert_tx_failed): +def test_abs_lower_bound_folded(get_contract, tx_failed): source = """ @external def foo() -> int256: diff --git a/tests/functional/codegen/calling_convention/test_default_function.py b/tests/functional/codegen/calling_convention/test_default_function.py index f7eef21af7..cf55607877 100644 --- a/tests/functional/codegen/calling_convention/test_default_function.py +++ b/tests/functional/codegen/calling_convention/test_default_function.py @@ -1,4 +1,4 @@ -def test_throw_on_sending(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_throw_on_sending(w3, tx_failed, get_contract_with_gas_estimation): code = """ x: public(int128) @@ -10,9 +10,8 @@ def __init__(): assert c.x() == 123 assert w3.eth.get_balance(c.address) == 0 - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "value": w3.to_wei(0.1, "ether")}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": w3.to_wei(0.1, "ether")}) assert w3.eth.get_balance(c.address) == 0 @@ -56,7 +55,7 @@ def __default__(): assert w3.eth.get_balance(c.address) == w3.to_wei(0.1, "ether") -def test_basic_default_not_payable(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_basic_default_not_payable(w3, tx_failed, get_contract_with_gas_estimation): code = """ event Sent: sender: indexed(address) @@ -67,7 +66,8 @@ def __default__(): """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c.address, "value": 10**17})) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 10**17}) def test_multi_arg_default(assert_compile_failed, get_contract_with_gas_estimation): @@ -100,7 +100,7 @@ def __default__(): assert_compile_failed(lambda: get_contract_with_gas_estimation(code)) -def test_zero_method_id(w3, get_logs, get_contract, assert_tx_failed): +def test_zero_method_id(w3, get_logs, get_contract, tx_failed): # test a method with 0x00000000 selector, # expects at least 36 bytes of calldata. code = """ @@ -143,10 +143,11 @@ def _call_with_bytes(hexstr): for i in range(4, 36): # match the full 4 selector bytes, but revert due to malformed (short) calldata - assert_tx_failed(lambda p="0x" + "00" * i: _call_with_bytes(p)) + with tx_failed(): + _call_with_bytes(f"0x{'00' * i}") -def test_another_zero_method_id(w3, get_logs, get_contract, assert_tx_failed): +def test_another_zero_method_id(w3, get_logs, get_contract, tx_failed): # test another zero method id but which only expects 4 bytes of calldata code = """ event Sent: diff --git a/tests/functional/codegen/calling_convention/test_default_parameters.py b/tests/functional/codegen/calling_convention/test_default_parameters.py index a90f5e6624..03f5d9fca2 100644 --- a/tests/functional/codegen/calling_convention/test_default_parameters.py +++ b/tests/functional/codegen/calling_convention/test_default_parameters.py @@ -150,7 +150,7 @@ def foo(a: int128[3] = [1, 2, 3]) -> int128[3]: assert c.foo() == [1, 2, 3] -def test_default_param_clamp(get_contract, monkeypatch, assert_tx_failed): +def test_default_param_clamp(get_contract, monkeypatch, tx_failed): code = """ @external def bar(a: int128, b: int128 = -1) -> (int128, int128): # noqa: E501 @@ -168,7 +168,8 @@ def validate_value(cls, value): monkeypatch.setattr("eth_abi.encoding.NumberEncoder.validate_value", validate_value) assert c.bar(200, 2**127 - 1) == [200, 2**127 - 1] - assert_tx_failed(lambda: c.bar(200, 2**127)) + with tx_failed(): + c.bar(200, 2**127) def test_default_param_private(get_contract): diff --git a/tests/functional/codegen/calling_convention/test_erc20_abi.py b/tests/functional/codegen/calling_convention/test_erc20_abi.py index 4a09ce68fa..b9dc5c663f 100644 --- a/tests/functional/codegen/calling_convention/test_erc20_abi.py +++ b/tests/functional/codegen/calling_convention/test_erc20_abi.py @@ -81,7 +81,7 @@ def test_initial_state(w3, erc20_caller): assert erc20_caller.decimals() == TOKEN_DECIMALS -def test_call_transfer(w3, erc20, erc20_caller, assert_tx_failed): +def test_call_transfer(w3, erc20, erc20_caller, tx_failed): # Basic transfer. erc20.transfer(erc20_caller.address, 10, transact={}) assert erc20.balanceOf(erc20_caller.address) == 10 @@ -90,13 +90,12 @@ def test_call_transfer(w3, erc20, erc20_caller, assert_tx_failed): assert erc20.balanceOf(w3.eth.accounts[1]) == 10 # more than allowed - assert_tx_failed(lambda: erc20_caller.transfer(w3.eth.accounts[1], TOKEN_TOTAL_SUPPLY)) + with tx_failed(): + erc20_caller.transfer(w3.eth.accounts[1], TOKEN_TOTAL_SUPPLY) # Negative transfer value. - assert_tx_failed( - function_to_test=lambda: erc20_caller.transfer(w3.eth.accounts[1], -1), - exception=ValidationError, - ) + with tx_failed(ValidationError): + erc20_caller.transfer(w3.eth.accounts[1], -1) def test_caller_approve_allowance(w3, erc20, erc20_caller): @@ -105,11 +104,10 @@ def test_caller_approve_allowance(w3, erc20, erc20_caller): assert erc20_caller.allowance(w3.eth.accounts[0], erc20_caller.address) == 10 -def test_caller_tranfer_from(w3, erc20, erc20_caller, assert_tx_failed): +def test_caller_tranfer_from(w3, erc20, erc20_caller, tx_failed): # Cannot transfer tokens that are unavailable - assert_tx_failed( - lambda: erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 10) - ) + with tx_failed(): + erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 10) assert erc20.balanceOf(erc20_caller.address) == 0 assert erc20.approve(erc20_caller.address, 10, transact={}) erc20_caller.transferFrom(w3.eth.accounts[0], erc20_caller.address, 5, transact={}) diff --git a/tests/functional/codegen/calling_convention/test_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_external_contract_calls.py index 12fcde2f4f..0360396f03 100644 --- a/tests/functional/codegen/calling_convention/test_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_external_contract_calls.py @@ -3,6 +3,7 @@ import pytest from eth.codecs import abi +from vyper import compile_code from vyper.exceptions import ( ArgumentException, InvalidType, @@ -94,7 +95,7 @@ def get_array(arg1: address) -> Bytes[3]: assert c2.get_array(c.address) == b"dog" -def test_bytes_too_long(get_contract, assert_tx_failed): +def test_bytes_too_long(get_contract, tx_failed): contract_1 = """ @external def array() -> Bytes[4]: @@ -113,13 +114,14 @@ def get_array(arg1: address) -> Bytes[3]: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) @pytest.mark.parametrize( "revert_string", ["Mayday, mayday!", "A very long revert string" + "." * 512] ) -def test_revert_propagation(get_contract, assert_tx_failed, revert_string): +def test_revert_propagation(get_contract, tx_failed, revert_string): raiser = f""" @external def run(): @@ -135,7 +137,8 @@ def run(raiser: address): """ c1 = get_contract(raiser) c2 = get_contract(caller) - assert_tx_failed(lambda: c2.run(c1.address), exc_text=revert_string) + with tx_failed(exc_text=revert_string): + c2.run(c1.address) @pytest.mark.parametrize("a,b", [(3, 3), (4, 3), (3, 4), (32, 32), (33, 33), (64, 64)]) @@ -169,7 +172,7 @@ def get_array(arg1: address) -> (Bytes[{a}], int128, Bytes[{b}]): @pytest.mark.parametrize("a,b", [(18, 7), (18, 18), (19, 6), (64, 6), (7, 19)]) @pytest.mark.parametrize("c,d", [(19, 7), (64, 64)]) -def test_tuple_with_bytes_too_long(get_contract, assert_tx_failed, a, c, b, d): +def test_tuple_with_bytes_too_long(get_contract, tx_failed, a, c, b, d): contract_1 = f""" @external def array() -> (Bytes[{c}], int128, Bytes[{d}]): @@ -193,10 +196,11 @@ def get_array(arg1: address) -> (Bytes[{a}], int128, Bytes[{b}]): c2 = get_contract(contract_2) assert c.array() == [b"nineteen characters", 255, b"seven!!"] - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) -def test_tuple_with_bytes_too_long_two(get_contract, assert_tx_failed): +def test_tuple_with_bytes_too_long_two(get_contract, tx_failed): contract_1 = """ @external def array() -> (Bytes[30], int128, Bytes[30]): @@ -220,7 +224,8 @@ def get_array(arg1: address) -> (Bytes[30], int128, Bytes[3]): c2 = get_contract(contract_2) assert c.array() == [b"nineteen characters", 255, b"seven!!"] - assert_tx_failed(lambda: c2.get_array(c.address)) + with tx_failed(): + c2.get_array(c.address) @pytest.mark.parametrize("length", [8, 256]) @@ -246,7 +251,7 @@ def bar(arg1: address) -> uint8: assert c2.bar(c.address) == 255 -def test_uint8_too_long(get_contract, assert_tx_failed): +def test_uint8_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -265,7 +270,8 @@ def bar(arg1: address) -> uint8: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(8, 8), (8, 256), (256, 8), (256, 256)]) @@ -298,7 +304,7 @@ def bar(arg1: address) -> (uint{a}, Bytes[3], uint{b}): @pytest.mark.parametrize("a,b", [(8, 256), (256, 8), (256, 256)]) -def test_tuple_with_uint8_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_uint8_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{a}, Bytes[3], uint{b}): @@ -322,11 +328,12 @@ def bar(arg1: address) -> (uint8, Bytes[3], uint8): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**a)-1}"), b"dog", int(f"{(2**b)-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(8, 256), (256, 8)]) -def test_tuple_with_uint8_too_long_two(get_contract, assert_tx_failed, a, b): +def test_tuple_with_uint8_too_long_two(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{b}, Bytes[3], uint{a}): @@ -350,7 +357,8 @@ def bar(arg1: address) -> (uint{a}, Bytes[3], uint{b}): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**b)-1}"), b"dog", int(f"{(2**a)-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("length", [128, 256]) @@ -376,7 +384,7 @@ def bar(arg1: address) -> int128: assert c2.bar(c.address) == 1 -def test_int128_too_long(get_contract, assert_tx_failed): +def test_int128_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> int256: @@ -395,7 +403,8 @@ def bar(arg1: address) -> int128: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(128, 128), (128, 256), (256, 128), (256, 256)]) @@ -428,7 +437,7 @@ def bar(arg1: address) -> (int{a}, Bytes[3], int{b}): @pytest.mark.parametrize("a,b", [(128, 256), (256, 128), (256, 256)]) -def test_tuple_with_int128_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_int128_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (int{a}, Bytes[3], int{b}): @@ -452,11 +461,12 @@ def bar(arg1: address) -> (int128, Bytes[3], int128): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**(a-1))-1}"), b"dog", int(f"{(2**(b-1))-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a,b", [(128, 256), (256, 128)]) -def test_tuple_with_int128_too_long_two(get_contract, assert_tx_failed, a, b): +def test_tuple_with_int128_too_long_two(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (int{b}, Bytes[3], int{a}): @@ -480,7 +490,8 @@ def bar(arg1: address) -> (int{a}, Bytes[3], int{b}): c2 = get_contract(contract_2) assert c.foo() == [int(f"{(2**(b-1))-1}"), b"dog", int(f"{(2**(a-1))-1}")] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "uint256", "int128", "int256"]) @@ -506,7 +517,7 @@ def bar(arg1: address) -> decimal: assert c2.bar(c.address) == Decimal("1e-10") -def test_decimal_too_long(get_contract, assert_tx_failed): +def test_decimal_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -525,7 +536,8 @@ def bar(arg1: address) -> decimal: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @@ -559,7 +571,7 @@ def bar(arg1: address) -> (decimal, Bytes[3], decimal): @pytest.mark.parametrize("a,b", [(8, 256), (256, 8), (256, 256)]) -def test_tuple_with_decimal_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_decimal_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> (uint{a}, Bytes[3], uint{b}): @@ -583,7 +595,8 @@ def bar(arg1: address) -> (decimal, Bytes[3], decimal): c2 = get_contract(contract_2) assert c.foo() == [2 ** (a - 1), b"dog", 2 ** (b - 1)] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "uint256", "int128", "int256"]) @@ -609,7 +622,7 @@ def bar(arg1: address) -> bool: assert c2.bar(c.address) is True -def test_bool_too_long(get_contract, assert_tx_failed): +def test_bool_too_long(get_contract, tx_failed): contract_1 = """ @external def foo() -> uint256: @@ -628,7 +641,8 @@ def bar(arg1: address) -> bool: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @@ -662,7 +676,7 @@ def bar(arg1: address) -> (bool, Bytes[3], bool): @pytest.mark.parametrize("a", ["uint8", "uint256", "int128", "int256"]) @pytest.mark.parametrize("b", ["uint8", "uint256", "int128", "int256"]) -def test_tuple_with_bool_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_bool_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> ({a}, Bytes[3], {b}): @@ -686,7 +700,8 @@ def bar(arg1: address) -> (bool, Bytes[3], bool): c2 = get_contract(contract_2) assert c.foo() == [1, b"dog", 2] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("type", ["uint8", "int128", "uint256", "int256"]) @@ -736,7 +751,7 @@ def bar(arg1: address) -> address: @pytest.mark.parametrize("type", ["uint256", "int256"]) -def test_address_too_long(get_contract, assert_tx_failed, type): +def test_address_too_long(get_contract, tx_failed, type): contract_1 = f""" @external def foo() -> {type}: @@ -755,7 +770,8 @@ def bar(arg1: address) -> address: """ c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) @pytest.mark.parametrize("a", ["uint8", "int128", "uint256", "int256"]) @@ -826,7 +842,7 @@ def bar(arg1: address) -> (address, Bytes[3], address): @pytest.mark.parametrize("a", ["uint256", "int256"]) @pytest.mark.parametrize("b", ["uint256", "int256"]) -def test_tuple_with_address_too_long(get_contract, assert_tx_failed, a, b): +def test_tuple_with_address_too_long(get_contract, tx_failed, a, b): contract_1 = f""" @external def foo() -> ({a}, Bytes[3], {b}): @@ -850,7 +866,8 @@ def bar(arg1: address) -> (address, Bytes[3], address): c2 = get_contract(contract_2) assert c.foo() == [(2**160) - 1, b"dog", 2**160] - assert_tx_failed(lambda: c2.bar(c.address)) + with tx_failed(): + c2.bar(c.address) def test_external_contract_call_state_change(get_contract): @@ -1095,7 +1112,7 @@ def _expr(x: address) -> int128: assert c2._expr(c2.address) == 1 -def test_invalid_nonexistent_contract_call(w3, assert_tx_failed, get_contract): +def test_invalid_nonexistent_contract_call(w3, tx_failed, get_contract): contract_1 = """ @external def bar() -> int128: @@ -1115,11 +1132,13 @@ def foo(x: address) -> int128: c2 = get_contract(contract_2) assert c2.foo(c1.address) == 1 - assert_tx_failed(lambda: c2.foo(w3.eth.accounts[0])) - assert_tx_failed(lambda: c2.foo(w3.eth.accounts[3])) + with tx_failed(): + c2.foo(w3.eth.accounts[0]) + with tx_failed(): + c2.foo(w3.eth.accounts[3]) -def test_invalid_contract_reference_declaration(assert_tx_failed, get_contract): +def test_invalid_contract_reference_declaration(tx_failed, get_contract): contract = """ interface Bar: get_magic_number: 1 @@ -1130,19 +1149,21 @@ def test_invalid_contract_reference_declaration(assert_tx_failed, get_contract): def __init__(): pass """ - assert_tx_failed(lambda: get_contract(contract), exception=StructureException) + with tx_failed(exception=StructureException): + get_contract(contract) -def test_invalid_contract_reference_call(assert_tx_failed, get_contract): +def test_invalid_contract_reference_call(tx_failed, get_contract): contract = """ @external def bar(arg1: address, arg2: int128) -> int128: return Foo(arg1).foo(arg2) """ - assert_tx_failed(lambda: get_contract(contract), exception=UndeclaredDefinition) + with pytest.raises(UndeclaredDefinition): + compile_code(contract) -def test_invalid_contract_reference_return_type(assert_tx_failed, get_contract): +def test_invalid_contract_reference_return_type(tx_failed, get_contract): contract = """ interface Foo: def foo(arg2: int128) -> invalid: view @@ -1151,7 +1172,8 @@ def foo(arg2: int128) -> invalid: view def bar(arg1: address, arg2: int128) -> int128: return Foo(arg1).foo(arg2) """ - assert_tx_failed(lambda: get_contract(contract), exception=UnknownType) + with pytest.raises(UnknownType): + compile_code(contract) def test_external_contract_call_declaration_expr(get_contract): @@ -1378,7 +1400,7 @@ def get_lucky(amount_to_send: uint256) -> int128: assert w3.eth.get_balance(c2.address) == 250 -def test_external_call_with_gas(assert_tx_failed, get_contract_with_gas_estimation): +def test_external_call_with_gas(tx_failed, get_contract_with_gas_estimation): contract_1 = """ @external def get_lucky() -> int128: @@ -1406,7 +1428,8 @@ def get_lucky(gas_amount: uint256) -> int128: c2.set_contract(c1.address, transact={}) assert c2.get_lucky(1000) == 656598 - assert_tx_failed(lambda: c2.get_lucky(50)) # too little gas. + with tx_failed(): + c2.get_lucky(50) # too little gas. def test_skip_contract_check(get_contract_with_gas_estimation): @@ -2240,7 +2263,7 @@ def get_array(arg1: address) -> int128[3]: assert c2.get_array(c.address) == [0, 0, 0] -def test_returndatasize_too_short(get_contract, assert_tx_failed): +def test_returndatasize_too_short(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128) -> int128: @@ -2256,10 +2279,11 @@ def foo(_addr: address): """ c1 = get_contract(contract_1) c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.foo(c1.address)) + with tx_failed(): + c2.foo(c1.address) -def test_returndatasize_empty(get_contract, assert_tx_failed): +def test_returndatasize_empty(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128): @@ -2275,7 +2299,8 @@ def foo(_addr: address) -> int128: """ c1 = get_contract(contract_1) c2 = get_contract(contract_2) - assert_tx_failed(lambda: c2.foo(c1.address)) + with tx_failed(): + c2.foo(c1.address) def test_returndatasize_too_long(get_contract): @@ -2299,7 +2324,7 @@ def foo(_addr: address) -> int128: assert c2.foo(c1.address) == 456 -def test_no_returndata(get_contract, assert_tx_failed): +def test_no_returndata(get_contract, tx_failed): contract_1 = """ @external def bar(a: int128) -> int128: @@ -2321,10 +2346,11 @@ def foo(_addr: address, _addr2: address) -> int128: c2 = get_contract(contract_2) assert c2.foo(c1.address, c1.address) == 123 - assert_tx_failed(lambda: c2.foo(c1.address, "0x1234567890123456789012345678901234567890")) + with tx_failed(): + c2.foo(c1.address, "0x1234567890123456789012345678901234567890") -def test_default_override(get_contract, assert_tx_failed): +def test_default_override(get_contract, tx_failed): bad_erc20_code = """ @external def transfer(receiver: address, amount: uint256): @@ -2358,17 +2384,20 @@ def transferBorked(erc20: ERC20, receiver: address, amount: uint256): c = get_contract(code) # demonstrate transfer failing - assert_tx_failed(lambda: c.transferBorked(bad_erc20.address, c.address, 0)) + with tx_failed(): + c.transferBorked(bad_erc20.address, c.address, 0) # would fail without default_return_value assert c.safeTransfer(bad_erc20.address, c.address, 0) == 7 # check that `default_return_value` does not stomp valid returndata. negative_contract = get_contract(negative_transfer_code) - assert_tx_failed(lambda: c.safeTransfer(negative_contract.address, c.address, 0)) + with tx_failed(): + c.safeTransfer(negative_contract.address, c.address, 0) # default_return_value should fail on EOAs (addresses with no code) random_address = "0x0000000000000000000000000000000000001234" - assert_tx_failed(lambda: c.safeTransfer(random_address, c.address, 1)) + with tx_failed(): + c.safeTransfer(random_address, c.address, 1) # in this case, the extcodesize check runs after the token contract # selfdestructs. however, extcodesize still returns nonzero until @@ -2378,7 +2407,7 @@ def transferBorked(erc20: ERC20, receiver: address, amount: uint256): assert c.safeTransfer(self_destructing_contract.address, c.address, 0) == 7 -def test_default_override2(get_contract, assert_tx_failed): +def test_default_override2(get_contract, tx_failed): bad_code_1 = """ @external def return_64_bytes() -> bool: @@ -2407,7 +2436,8 @@ def bar(foo: Foo): c = get_contract(code) # fails due to returndatasize being nonzero but also lt 64 - assert_tx_failed(lambda: c.bar(bad_1.address)) + with tx_failed(): + c.bar(bad_1.address) c.bar(bad_2.address) @@ -2456,7 +2486,7 @@ def do_stuff(f: Foo) -> uint256: @pytest.mark.parametrize("typ,val", [("address", TEST_ADDR)]) -def test_calldata_clamp(w3, get_contract, assert_tx_failed, keccak, typ, val): +def test_calldata_clamp(w3, get_contract, tx_failed, keccak, typ, val): code = f""" @external def foo(a: {typ}): @@ -2469,7 +2499,8 @@ def foo(a: {typ}): # Static size is short by 1 byte malformed = data[:-2] - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c1.address, "data": malformed})) + with tx_failed(): + w3.eth.send_transaction({"to": c1.address, "data": malformed}) # Static size is exact w3.eth.send_transaction({"to": c1.address, "data": data}) @@ -2479,7 +2510,7 @@ def foo(a: {typ}): @pytest.mark.parametrize("typ,val", [("address", ([TEST_ADDR] * 3, "vyper"))]) -def test_dynamic_calldata_clamp(w3, get_contract, assert_tx_failed, keccak, typ, val): +def test_dynamic_calldata_clamp(w3, get_contract, tx_failed, keccak, typ, val): code = f""" @external def foo(a: DynArray[{typ}, 3], b: String[5]): @@ -2493,7 +2524,8 @@ def foo(a: DynArray[{typ}, 3], b: String[5]): # Dynamic size is short by 1 byte malformed = data[:264] - assert_tx_failed(lambda: w3.eth.send_transaction({"to": c1.address, "data": malformed})) + with tx_failed(): + w3.eth.send_transaction({"to": c1.address, "data": malformed}) # Dynamic size is at least minimum (132 bytes * 2 + 2 (for 0x) = 266) valid = data[:266] diff --git a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py index 4c321442f4..e6b2402016 100644 --- a/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py +++ b/tests/functional/codegen/calling_convention/test_modifiable_external_contract_calls.py @@ -1,7 +1,7 @@ from vyper.exceptions import StructureException, SyntaxException, UnknownType -def test_external_contract_call_declaration_expr(get_contract, assert_tx_failed): +def test_external_contract_call_declaration_expr(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -39,11 +39,12 @@ def static_set_lucky(_lucky: int128): c2.modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 # Fails attempting a state change after a call to a static address - assert_tx_failed(lambda: c2.static_set_lucky(5, transact={})) + with tx_failed(): + c2.static_set_lucky(5, transact={}) assert c1.lucky() == 7 -def test_external_contract_call_declaration_stmt(get_contract, assert_tx_failed): +def test_external_contract_call_declaration_stmt(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -83,11 +84,12 @@ def static_set_lucky(_lucky: int128): c2.modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 # Fails attempting a state change after a call to a static address - assert_tx_failed(lambda: c2.static_set_lucky(5, transact={})) + with tx_failed(): + c2.static_set_lucky(5, transact={}) assert c1.lucky() == 7 -def test_multiple_contract_state_changes(get_contract, assert_tx_failed): +def test_multiple_contract_state_changes(get_contract, tx_failed): contract_1 = """ lucky: public(int128) @@ -161,9 +163,12 @@ def static_modifiable_set_lucky(_lucky: int128): assert c1.lucky() == 0 c3.modifiable_modifiable_set_lucky(7, transact={}) assert c1.lucky() == 7 - assert_tx_failed(lambda: c3.modifiable_static_set_lucky(6, transact={})) - assert_tx_failed(lambda: c3.static_modifiable_set_lucky(6, transact={})) - assert_tx_failed(lambda: c3.static_static_set_lucky(6, transact={})) + with tx_failed(): + c3.modifiable_static_set_lucky(6, transact={}) + with tx_failed(): + c3.static_modifiable_set_lucky(6, transact={}) + with tx_failed(): + c3.static_static_set_lucky(6, transact={}) assert c1.lucky() == 7 diff --git a/tests/functional/codegen/calling_convention/test_return_tuple.py b/tests/functional/codegen/calling_convention/test_return_tuple.py index b375839147..266555ead6 100644 --- a/tests/functional/codegen/calling_convention/test_return_tuple.py +++ b/tests/functional/codegen/calling_convention/test_return_tuple.py @@ -1,5 +1,6 @@ import pytest +from vyper import compile_code from vyper.exceptions import TypeMismatch pytestmark = pytest.mark.usefixtures("memory_mocker") @@ -152,11 +153,11 @@ def test3() -> (address, int128): assert c.test3() == [c.out_literals()[2], 1] -def test_tuple_return_typecheck(assert_tx_failed, get_contract_with_gas_estimation): +def test_tuple_return_typecheck(tx_failed, get_contract_with_gas_estimation): code = """ @external def getTimeAndBalance() -> (bool, address): return block.timestamp, self.balance """ - - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with pytest.raises(TypeMismatch): + compile_code(code) diff --git a/tests/functional/codegen/environment_variables/test_blockhash.py b/tests/functional/codegen/environment_variables/test_blockhash.py index b92c17a561..68db053b12 100644 --- a/tests/functional/codegen/environment_variables/test_blockhash.py +++ b/tests/functional/codegen/environment_variables/test_blockhash.py @@ -23,7 +23,7 @@ def foo() -> bytes32: assert_compile_failed(lambda: get_contract_with_gas_estimation(code)) -def test_too_old_blockhash(assert_tx_failed, get_contract_with_gas_estimation, w3): +def test_too_old_blockhash(tx_failed, get_contract_with_gas_estimation, w3): w3.testing.mine(257) code = """ @external @@ -31,14 +31,16 @@ def get_50_blockhash() -> bytes32: return blockhash(block.number - 257) """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.get_50_blockhash()) + with tx_failed(): + c.get_50_blockhash() -def test_non_existing_blockhash(assert_tx_failed, get_contract_with_gas_estimation): +def test_non_existing_blockhash(tx_failed, get_contract_with_gas_estimation): code = """ @external def get_future_blockhash() -> bytes32: return blockhash(block.number + 1) """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.get_future_blockhash()) + with tx_failed(): + c.get_future_blockhash() diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9e74019250..9329605678 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -5,7 +5,7 @@ # TODO test functions in this module across all evm versions # once we have cancun support. -def test_nonreentrant_decorator(get_contract, assert_tx_failed): +def test_nonreentrant_decorator(get_contract, tx_failed): calling_contract_code = """ interface SpecialContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable @@ -98,20 +98,23 @@ def unprotected_function(val: String[100], do_callback: bool): assert reentrant_contract.special_value() == "some value" assert reentrant_contract.protected_view_fn() == "some value" - assert_tx_failed(lambda: reentrant_contract.protected_function("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function("zzz value", True, transact={}) reentrant_contract.protected_function2("another value", False, transact={}) assert reentrant_contract.special_value() == "another value" - assert_tx_failed(lambda: reentrant_contract.protected_function2("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function2("zzz value", True, transact={}) reentrant_contract.protected_function3("another value", False, transact={}) assert reentrant_contract.special_value() == "another value" - assert_tx_failed(lambda: reentrant_contract.protected_function3("zzz value", True, transact={})) + with tx_failed(): + reentrant_contract.protected_function3("zzz value", True, transact={}) -def test_nonreentrant_decorator_for_default(w3, get_contract, assert_tx_failed): +def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): calling_contract_code = """ @external def send_funds(_amount: uint256): @@ -196,9 +199,8 @@ def __default__(): assert w3.eth.get_balance(calling_contract.address) == 2000 # Test protected function with callback to default. - assert_tx_failed( - lambda: reentrant_contract.protected_function("zzz value", True, transact={"value": 1000}) - ) + with tx_failed(): + reentrant_contract.protected_function("zzz value", True, transact={"value": 1000}) def test_disallow_on_init_function(get_contract): diff --git a/tests/functional/codegen/features/decorators/test_payable.py b/tests/functional/codegen/features/decorators/test_payable.py index 4858a7df0d..ced58e1af0 100644 --- a/tests/functional/codegen/features/decorators/test_payable.py +++ b/tests/functional/codegen/features/decorators/test_payable.py @@ -177,14 +177,13 @@ def baz() -> bool: @pytest.mark.parametrize("code", nonpayable_code) -def test_nonpayable_runtime_assertion(w3, keccak, assert_tx_failed, get_contract, code): +def test_nonpayable_runtime_assertion(w3, keccak, tx_failed, get_contract, code): c = get_contract(code) c.foo(transact={"value": 0}) sig = keccak("foo()".encode()).hex()[:10] - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "data": sig, "value": 10**18}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "data": sig, "value": 10**18}) payable_code = [ @@ -355,7 +354,7 @@ def __default__(): w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) -def test_nonpayable_default_func_invalid_calldata(get_contract, w3, assert_tx_failed): +def test_nonpayable_default_func_invalid_calldata(get_contract, w3, tx_failed): code = """ @external @payable @@ -369,12 +368,11 @@ def __default__(): c = get_contract(code) w3.eth.send_transaction({"to": c.address, "value": 0, "data": "0x12345678"}) - assert_tx_failed( - lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"}) -def test_batch_nonpayable(get_contract, w3, assert_tx_failed): +def test_batch_nonpayable(get_contract, w3, tx_failed): code = """ @external def foo() -> bool: @@ -390,8 +388,5 @@ def __default__(): data = bytes([1, 2, 3, 4]) for i in range(5): calldata = "0x" + data[:i].hex() - assert_tx_failed( - lambda data=calldata: w3.eth.send_transaction( - {"to": c.address, "value": 100, "data": data} - ) - ) + with tx_failed(): + w3.eth.send_transaction({"to": c.address, "value": 100, "data": calldata}) diff --git a/tests/functional/codegen/features/decorators/test_private.py b/tests/functional/codegen/features/decorators/test_private.py index 51e6d90ee1..39ea1bb9ae 100644 --- a/tests/functional/codegen/features/decorators/test_private.py +++ b/tests/functional/codegen/features/decorators/test_private.py @@ -449,7 +449,7 @@ def whoami() -> address: assert logged_addr == addr, "oh no" -def test_nested_static_params_only(get_contract, assert_tx_failed): +def test_nested_static_params_only(get_contract, tx_failed): code1 = """ @internal @view diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index ed6235d992..96b83ae691 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -14,7 +14,7 @@ def repeat(z: int128) -> int128: assert c.repeat(9) == 54 -def test_range_bound(get_contract, assert_tx_failed): +def test_range_bound(get_contract, tx_failed): code = """ @external def repeat(n: uint256) -> uint256: @@ -28,7 +28,8 @@ def repeat(n: uint256) -> uint256: assert c.repeat(n) == sum(i + 1 for i in range(n)) # check codegen inserts assertion for n greater than bound - assert_tx_failed(lambda: c.repeat(7)) + with tx_failed(): + c.repeat(7) def test_digit_reverser(get_contract_with_gas_estimation): @@ -172,7 +173,7 @@ def test(): @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) -def test_for_range_oob_check(get_contract, assert_tx_failed, typ): +def test_for_range_oob_check(get_contract, tx_failed, typ): code = f""" @external def test(): @@ -181,7 +182,8 @@ def test(): pass """ c = get_contract(code) - assert_tx_failed(lambda: c.test()) + with tx_failed(): + c.test() @pytest.mark.parametrize("typ", ["int128", "uint256"]) diff --git a/tests/functional/codegen/features/iteration/test_range_in.py b/tests/functional/codegen/features/iteration/test_range_in.py index 062cd389a0..7540049778 100644 --- a/tests/functional/codegen/features/iteration/test_range_in.py +++ b/tests/functional/codegen/features/iteration/test_range_in.py @@ -110,7 +110,7 @@ def testin() -> bool: assert_compile_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) -def test_ownership(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_ownership(w3, tx_failed, get_contract_with_gas_estimation): code = """ owners: address[2] @@ -135,7 +135,8 @@ def is_owner() -> bool: assert c.is_owner(call={"from": a1}) is False # no one else is. # only an owner may set another owner. - assert_tx_failed(lambda: c.set_owner(1, a1, call={"from": a1})) + with tx_failed(): + c.set_owner(1, a1, call={"from": a1}) c.set_owner(1, a1, transact={}) assert c.is_owner(call={"from": a1}) is True @@ -145,7 +146,7 @@ def is_owner() -> bool: assert c.is_owner() is False -def test_in_fails_when_types_dont_match(get_contract_with_gas_estimation, assert_tx_failed): +def test_in_fails_when_types_dont_match(get_contract_with_gas_estimation, tx_failed): code = """ @external def testin(x: address) -> bool: @@ -154,4 +155,5 @@ def testin(x: address) -> bool: return True return False """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(code) diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index 842b32d815..af189e6dca 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -3,12 +3,12 @@ # web3 returns f"execution reverted: {err_str}" -# TODO move exception string parsing logic into assert_tx_failed +# TODO move exception string parsing logic into tx_failed def _fixup_err_str(s): return s.replace("execution reverted: ", "") -def test_assert_refund(w3, get_contract_with_gas_estimation, assert_tx_failed): +def test_assert_refund(w3, get_contract_with_gas_estimation, tx_failed): code = """ @external def foo(): @@ -26,7 +26,7 @@ def foo(): assert tx_receipt["gasUsed"] < gas_sent -def test_assert_reason(w3, get_contract_with_gas_estimation, assert_tx_failed, memory_mocker): +def test_assert_reason(w3, get_contract_with_gas_estimation, tx_failed, memory_mocker): code = """ @external def test(a: int128) -> int128: @@ -132,7 +132,7 @@ def test_valid_assertions(get_contract, code): get_contract(code) -def test_assert_staticcall(get_contract, assert_tx_failed, memory_mocker): +def test_assert_staticcall(get_contract, tx_failed, memory_mocker): foreign_code = """ state: uint256 @external @@ -151,10 +151,11 @@ def test(): c1 = get_contract(foreign_code) c2 = get_contract(code, *[c1.address]) # static call prohibits state change - assert_tx_failed(lambda: c2.test()) + with tx_failed(): + c2.test() -def test_assert_in_for_loop(get_contract, assert_tx_failed, memory_mocker): +def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: @@ -166,12 +167,15 @@ def test(x: uint256[3]) -> bool: c = get_contract(code) c.test([1, 2, 3]) - assert_tx_failed(lambda: c.test([5, 1, 3])) - assert_tx_failed(lambda: c.test([1, 5, 3])) - assert_tx_failed(lambda: c.test([1, 3, 5])) + with tx_failed(): + c.test([5, 1, 3]) + with tx_failed(): + c.test([1, 5, 3]) + with tx_failed(): + c.test([1, 3, 5]) -def test_assert_with_reason_in_for_loop(get_contract, assert_tx_failed, memory_mocker): +def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: @@ -183,12 +187,15 @@ def test(x: uint256[3]) -> bool: c = get_contract(code) c.test([1, 2, 3]) - assert_tx_failed(lambda: c.test([5, 1, 3])) - assert_tx_failed(lambda: c.test([1, 5, 3])) - assert_tx_failed(lambda: c.test([1, 3, 5])) + with tx_failed(): + c.test([5, 1, 3]) + with tx_failed(): + c.test([1, 5, 3]) + with tx_failed(): + c.test([1, 3, 5]) -def test_assert_reason_revert_length(w3, get_contract, assert_tx_failed, memory_mocker): +def test_assert_reason_revert_length(w3, get_contract, tx_failed, memory_mocker): code = """ @external def test() -> int128: @@ -196,4 +203,5 @@ def test() -> int128: return 1 """ c = get_contract(code) - assert_tx_failed(lambda: c.test(), exc_text="oops") + with tx_failed(exc_text="oops"): + c.test() diff --git a/tests/functional/codegen/features/test_assert_unreachable.py b/tests/functional/codegen/features/test_assert_unreachable.py index 90ed31a22e..4db00bce7c 100644 --- a/tests/functional/codegen/features/test_assert_unreachable.py +++ b/tests/functional/codegen/features/test_assert_unreachable.py @@ -15,7 +15,7 @@ def foo(): assert tx_receipt["gasUsed"] == gas_sent # Drains all gains sent -def test_basic_unreachable(w3, get_contract, assert_tx_failed): +def test_basic_unreachable(w3, get_contract, tx_failed): code = """ @external def foo(val: int128) -> bool: @@ -28,12 +28,15 @@ def foo(val: int128) -> bool: assert c.foo(2) is True - assert_tx_failed(lambda: c.foo(1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-2), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-2) -def test_basic_call_unreachable(w3, get_contract, assert_tx_failed): +def test_basic_call_unreachable(w3, get_contract, tx_failed): code = """ @view @@ -51,11 +54,13 @@ def foo(val: int128) -> int128: assert c.foo(33) == -123 - assert_tx_failed(lambda: c.foo(1), exc_text="Invalid opcode 0xfe") - assert_tx_failed(lambda: c.foo(-1), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(1) + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo(-1) -def test_raise_unreachable(w3, get_contract, assert_tx_failed): +def test_raise_unreachable(w3, get_contract, tx_failed): code = """ @external def foo(): @@ -64,4 +69,5 @@ def foo(): c = get_contract(code) - assert_tx_failed(lambda: c.foo(), exc_text="Invalid opcode 0xfe") + with tx_failed(exc_text="Invalid opcode 0xfe"): + c.foo() diff --git a/tests/functional/codegen/features/test_clampers.py b/tests/functional/codegen/features/test_clampers.py index 263f10a89c..6db8570fc7 100644 --- a/tests/functional/codegen/features/test_clampers.py +++ b/tests/functional/codegen/features/test_clampers.py @@ -33,7 +33,7 @@ def _make_invalid_dynarray_tx(w3, address, signature, data): w3.eth.send_transaction({"to": address, "data": f"0x{sig}{data}"}) -def test_bytes_clamper(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ @external def foo(s: Bytes[3]) -> Bytes[3]: @@ -43,10 +43,11 @@ def foo(s: Bytes[3]) -> Bytes[3]: c = get_contract_with_gas_estimation(clamper_test_code) assert c.foo(b"ca") == b"ca" assert c.foo(b"cat") == b"cat" - assert_tx_failed(lambda: c.foo(b"cate")) + with tx_failed(): + c.foo(b"cate") -def test_bytes_clamper_multiple_slots(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper_multiple_slots(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ @external def foo(s: Bytes[40]) -> Bytes[40]: @@ -58,10 +59,11 @@ def foo(s: Bytes[40]) -> Bytes[40]: assert c.foo(data[:30]) == data[:30] assert c.foo(data) == data - assert_tx_failed(lambda: c.foo(data + b"!")) + with tx_failed(): + c.foo(data + b"!") -def test_bytes_clamper_on_init(assert_tx_failed, get_contract_with_gas_estimation): +def test_bytes_clamper_on_init(tx_failed, get_contract_with_gas_estimation): clamper_test_code = """ foo: Bytes[3] @@ -77,7 +79,8 @@ def get_foo() -> Bytes[3]: c = get_contract_with_gas_estimation(clamper_test_code, *[b"cat"]) assert c.get_foo() == b"cat" - assert_tx_failed(lambda: get_contract_with_gas_estimation(clamper_test_code, *[b"cats"])) + with tx_failed(): + get_contract_with_gas_estimation(clamper_test_code, *[b"cats"]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -99,7 +102,7 @@ def foo(s: bytes{n}) -> bytes{n}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(1, 32))) # bytes32 always passes -def test_bytes_m_clamper_failing(w3, get_contract, assert_tx_failed, n, evm_version): +def test_bytes_m_clamper_failing(w3, get_contract, tx_failed, n, evm_version): values = [] values.append(b"\x00" * n + b"\x80") # just one bit set values.append(b"\xff" * n + b"\x80") # n*8 + 1 bits set @@ -118,11 +121,9 @@ def foo(s: bytes{n}) -> bytes{n}: c = get_contract(code, evm_version=evm_version) for v in values: # munge for `_make_tx` - assert_tx_failed( - lambda val=int.from_bytes(v, byteorder="big"): _make_tx( - w3, c.address, f"foo(bytes{n})", [val] - ) - ) + with tx_failed(): + int_value = int.from_bytes(v, byteorder="big") + _make_tx(w3, c.address, f"foo(bytes{n})", [int_value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -144,7 +145,7 @@ def foo(s: int{bits}) -> int{bits}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(31))) # int256 does not clamp -def test_sint_clamper_failing(w3, assert_tx_failed, get_contract, n, evm_version): +def test_sint_clamper_failing(w3, tx_failed, get_contract, n, evm_version): bits = 8 * (n + 1) lo, hi = int_bounds(True, bits) values = [-(2**255), 2**255 - 1, lo - 1, hi + 1] @@ -156,7 +157,8 @@ def foo(s: int{bits}) -> int{bits}: c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(int{bits})", [val])) + with tx_failed(): + _make_tx(w3, c.address, f"foo(int{bits})", [v]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -174,7 +176,7 @@ def foo(s: bool) -> bool: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2, 3, 4, 8, 16, 2**256 - 1]) -def test_bool_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_bool_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: bool) -> bool: @@ -182,7 +184,8 @@ def foo(s: bool) -> bool: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(bool)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(bool)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -207,7 +210,7 @@ def foo(s: Roles) -> Roles: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**i for i in range(5, 256)]) -def test_flag_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_flag_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ flag Roles: USER @@ -222,7 +225,8 @@ def foo(s: Roles) -> Roles: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(uint256)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(uint256)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -243,7 +247,7 @@ def foo(s: uint{bits}) -> uint{bits}: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("n", list(range(31))) # uint256 has no failing cases -def test_uint_clamper_failing(w3, assert_tx_failed, get_contract, evm_version, n): +def test_uint_clamper_failing(w3, tx_failed, get_contract, evm_version, n): bits = 8 * (n + 1) values = [-1, -(2**255), 2**bits] code = f""" @@ -253,7 +257,8 @@ def foo(s: uint{bits}) -> uint{bits}: """ c = get_contract(code, evm_version=evm_version) for v in values: - assert_tx_failed(lambda val=v: _make_tx(w3, c.address, f"foo(uint{bits})", [val])) + with tx_failed(): + _make_tx(w3, c.address, f"foo(uint{bits})", [v]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -284,7 +289,7 @@ def foo(s: address) -> address: @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @pytest.mark.parametrize("value", [2**160, 2**256 - 1]) -def test_address_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_address_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: address) -> address: @@ -292,7 +297,8 @@ def foo(s: address) -> address: """ c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(address)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(address)", [value]) @pytest.mark.parametrize("evm_version", list(EVM_VERSIONS)) @@ -337,7 +343,7 @@ def foo(s: decimal) -> decimal: -187072209578355573530071658587684226515959365500929, # - (2 ** 127 - 1e-10) ], ) -def test_decimal_clamper_failing(w3, assert_tx_failed, get_contract, value, evm_version): +def test_decimal_clamper_failing(w3, tx_failed, get_contract, value, evm_version): code = """ @external def foo(s: decimal) -> decimal: @@ -346,7 +352,8 @@ def foo(s: decimal) -> decimal: c = get_contract(code, evm_version=evm_version) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(fixed168x10)", [value])) + with tx_failed(): + _make_tx(w3, c.address, "foo(fixed168x10)", [value]) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -366,7 +373,7 @@ def foo(a: uint256, b: int128[5], c: uint256) -> int128[5]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(5)) -def test_int128_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_array_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -378,7 +385,8 @@ def foo(b: int128[5]) -> int128[5]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[5])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[5])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -397,7 +405,7 @@ def foo(a: uint256, b: int128[10], c: uint256) -> int128[10]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_int128_array_looped_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_array_looped_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: int128[10]) -> int128[10]: @@ -408,7 +416,8 @@ def foo(b: int128[10]) -> int128[10]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[10])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[10])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -427,7 +436,7 @@ def foo(a: uint256, b: int128[6][3][1][8], c: uint256) -> int128[6][3][1][8]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(12)) -def test_multidimension_array_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_multidimension_array_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: int128[6][1][2]) -> int128[6][1][2]: @@ -438,7 +447,8 @@ def foo(b: int128[6][1][2]) -> int128[6][1][2]: values[idx] = bad_value c = get_contract(code) - assert_tx_failed(lambda: _make_tx(w3, c.address, "foo(int128[6][1][2]])", values)) + with tx_failed(): + _make_tx(w3, c.address, "foo(int128[6][1][2]])", values) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -458,7 +468,7 @@ def foo(a: uint256, b: DynArray[int128, 5], c: uint256) -> DynArray[int128, 5]: @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(5)) -def test_int128_dynarray_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_dynarray_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -473,7 +483,8 @@ def foo(b: int128[5]) -> int128[5]: c = get_contract(code) data = _make_dynarray_data(32, 5, values) - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -491,7 +502,7 @@ def foo(a: uint256, b: DynArray[int128, 10], c: uint256) -> DynArray[int128, 10] @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_int128_dynarray_looped_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_int128_dynarray_looped_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: @@ -505,7 +516,8 @@ def foo(b: DynArray[int128, 10]) -> DynArray[int128, 10]: data = _make_dynarray_data(32, 10, values) signature = "foo(int128[])" - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -527,9 +539,7 @@ def foo( @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(4)) -def test_multidimension_dynarray_clamper_failing( - w3, assert_tx_failed, get_contract, bad_value, idx -): +def test_multidimension_dynarray_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): code = """ @external def foo(b: DynArray[DynArray[int128, 2], 2]) -> DynArray[DynArray[int128, 2], 2]: @@ -549,7 +559,8 @@ def foo(b: DynArray[DynArray[int128, 2], 2]) -> DynArray[DynArray[int128, 2], 2] signature = "foo(int128[][])" c = get_contract(code) - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) @pytest.mark.parametrize("value", [0, 1, -1, 2**127 - 1, -(2**127)]) @@ -570,7 +581,7 @@ def foo( @pytest.mark.parametrize("bad_value", [2**127, -(2**127) - 1, 2**255 - 1, -(2**255)]) @pytest.mark.parametrize("idx", range(10)) -def test_dynarray_list_clamper_failing(w3, assert_tx_failed, get_contract, bad_value, idx): +def test_dynarray_list_clamper_failing(w3, tx_failed, get_contract, bad_value, idx): # ensure the invalid value is detected at all locations in the array code = """ @external @@ -588,4 +599,5 @@ def foo(b: DynArray[int128[5], 2]) -> DynArray[int128[5], 2]: c = get_contract(code) signature = "foo(int128[5][])" - assert_tx_failed(lambda: _make_invalid_dynarray_tx(w3, c.address, signature, data)) + with tx_failed(): + _make_invalid_dynarray_tx(w3, c.address, signature, data) diff --git a/tests/functional/codegen/features/test_init.py b/tests/functional/codegen/features/test_init.py index 29a466e869..fc765f8ab3 100644 --- a/tests/functional/codegen/features/test_init.py +++ b/tests/functional/codegen/features/test_init.py @@ -24,7 +24,7 @@ def __init__(a: uint256): assert "CALLDATALOAD" not in assembly[:ir_return_idx_start] + assembly[ir_return_idx_end:] -def test_init_calls_internal(get_contract, assert_compile_failed, assert_tx_failed): +def test_init_calls_internal(get_contract, assert_compile_failed, tx_failed): code = """ foo: public(uint8) @internal @@ -46,7 +46,8 @@ def baz() -> uint8: n = 6 c = get_contract(code, n) assert c.foo() == n * 7 - assert_tx_failed(lambda: c.baz()) + with tx_failed(): + c.baz() n = 255 assert_compile_failed(lambda: get_contract(code, n)) diff --git a/tests/functional/codegen/features/test_logging.py b/tests/functional/codegen/features/test_logging.py index 84311c41f5..ba09be1991 100644 --- a/tests/functional/codegen/features/test_logging.py +++ b/tests/functional/codegen/features/test_logging.py @@ -3,6 +3,7 @@ import pytest from eth.codecs import abi +from vyper import compile_code from vyper.exceptions import ( ArgumentException, EventDeclarationException, @@ -193,7 +194,7 @@ def bar(): def test_event_logging_cannot_have_more_than_three_topics( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -203,9 +204,8 @@ def test_event_logging_cannot_have_more_than_three_topics( arg4: indexed(int128) """ - assert_tx_failed( - lambda: get_contract_with_gas_estimation(loggy_code), EventDeclarationException - ) + with pytest.raises(EventDeclarationException): + compile_code(loggy_code) def test_event_logging_with_data(w3, tester, keccak, get_logs, get_contract_with_gas_estimation): @@ -555,7 +555,7 @@ def foo(): assert args.arg2 == {"x": 1, "y": b"abc", "z": {"t": "house", "w": Decimal("13.5")}} -def test_fails_when_input_is_the_wrong_type(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_input_is_the_wrong_type(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -565,10 +565,11 @@ def foo_(): log MyLog(b'yo') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_topic_is_the_wrong_size(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_topic_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(Bytes[3]) @@ -579,12 +580,11 @@ def foo(): log MyLog(b'bars') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_input_topic_is_the_wrong_size( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_fails_when_input_topic_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(Bytes[3]) @@ -594,10 +594,11 @@ def foo(arg1: Bytes[4]): log MyLog(arg1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_data_is_the_wrong_size(assert_tx_failed, get_contract_with_gas_estimation): +def test_fails_when_data_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -607,12 +608,11 @@ def foo(): log MyLog(b'bars') """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), InvalidType) + with tx_failed(InvalidType): + get_contract_with_gas_estimation(loggy_code) -def test_fails_when_input_data_is_the_wrong_size( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_fails_when_input_data_is_the_wrong_size(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -622,7 +622,8 @@ def foo(arg1: Bytes[4]): log MyLog(arg1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) def test_topic_over_32_bytes(get_contract_with_gas_estimation): @@ -637,7 +638,7 @@ def foo(): get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_over_three_topics(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_over_three_topics(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -650,12 +651,11 @@ def __init__(): log MyLog(1, 2, 3, 4) """ - assert_tx_failed( - lambda: get_contract_with_gas_estimation(loggy_code), EventDeclarationException - ) + with tx_failed(EventDeclarationException): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_duplicate_log_names(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_duplicate_log_names(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: pass event MyLog: pass @@ -665,12 +665,11 @@ def foo(): log MyLog() """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), NamespaceCollision) + with tx_failed(NamespaceCollision): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_when_log_is_undeclared( - assert_tx_failed, get_contract_with_gas_estimation -): +def test_logging_fails_with_when_log_is_undeclared(tx_failed, get_contract_with_gas_estimation): loggy_code = """ @external @@ -678,10 +677,11 @@ def foo(): log MyLog() """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), UndeclaredDefinition) + with tx_failed(UndeclaredDefinition): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_topic_type_mismatch(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_topic_type_mismatch(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: indexed(int128) @@ -691,10 +691,11 @@ def foo(): log MyLog(self) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) -def test_logging_fails_with_data_type_mismatch(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_with_data_type_mismatch(tx_failed, get_contract_with_gas_estimation): loggy_code = """ event MyLog: arg1: Bytes[3] @@ -704,11 +705,12 @@ def foo(): log MyLog(self) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(loggy_code) def test_logging_fails_when_number_of_arguments_is_greater_than_declaration( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -718,11 +720,12 @@ def test_logging_fails_when_number_of_arguments_is_greater_than_declaration( def foo(): log MyLog(1, 2) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), ArgumentException) + with tx_failed(ArgumentException): + get_contract_with_gas_estimation(loggy_code) def test_logging_fails_when_number_of_arguments_is_less_than_declaration( - assert_tx_failed, get_contract_with_gas_estimation + tx_failed, get_contract_with_gas_estimation ): loggy_code = """ event MyLog: @@ -733,7 +736,8 @@ def test_logging_fails_when_number_of_arguments_is_less_than_declaration( def foo(): log MyLog(1) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(loggy_code), ArgumentException) + with tx_failed(ArgumentException): + get_contract_with_gas_estimation(loggy_code) def test_loggy_code(w3, tester, get_contract_with_gas_estimation): @@ -962,7 +966,7 @@ def set_list(): ] -def test_logging_fails_when_input_is_too_big(assert_tx_failed, get_contract_with_gas_estimation): +def test_logging_fails_when_input_is_too_big(tx_failed, get_contract_with_gas_estimation): code = """ event Bar: _value: indexed(Bytes[32]) @@ -971,7 +975,8 @@ def test_logging_fails_when_input_is_too_big(assert_tx_failed, get_contract_with def foo(inp: Bytes[33]): log Bar(inp) """ - assert_tx_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch) + with tx_failed(TypeMismatch): + get_contract_with_gas_estimation(code) def test_2nd_var_list_packing(get_logs, get_contract_with_gas_estimation): diff --git a/tests/functional/codegen/features/test_reverting.py b/tests/functional/codegen/features/test_reverting.py index 2cdc727015..f24886ce96 100644 --- a/tests/functional/codegen/features/test_reverting.py +++ b/tests/functional/codegen/features/test_reverting.py @@ -7,7 +7,7 @@ pytestmark = pytest.mark.usefixtures("memory_mocker") -def test_revert_reason(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -17,14 +17,11 @@ def foo(): revert_bytes = method_id("NoFives()") - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) -def test_revert_reason_typed(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason_typed(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -35,14 +32,11 @@ def foo(): revert_bytes = method_id("NoFives(uint256)") + abi.encode("(uint256)", (5,)) - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) -def test_revert_reason_typed_no_variable(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_revert_reason_typed_no_variable(w3, tx_failed, get_contract_with_gas_estimation): reverty_code = """ @external def foo(): @@ -52,8 +46,5 @@ def foo(): revert_bytes = method_id("NoFives(uint256)") + abi.encode("(uint256)", (5,)) - assert_tx_failed( - lambda: get_contract_with_gas_estimation(reverty_code).foo(transact={}), - TransactionFailed, - exc_text=f"execution reverted: {revert_bytes}", - ) + with tx_failed(TransactionFailed, exc_text=f"execution reverted: {revert_bytes}"): + get_contract_with_gas_estimation(reverty_code).foo(transact={}) diff --git a/tests/functional/codegen/integration/test_escrow.py b/tests/functional/codegen/integration/test_escrow.py index 1578f5a418..70e7cb4594 100644 --- a/tests/functional/codegen/integration/test_escrow.py +++ b/tests/functional/codegen/integration/test_escrow.py @@ -1,7 +1,7 @@ # from ethereum.tools import tester -def test_arbitration_code(w3, get_contract_with_gas_estimation, assert_tx_failed): +def test_arbitration_code(w3, get_contract_with_gas_estimation, tx_failed): arbitration_code = """ buyer: address seller: address @@ -28,13 +28,14 @@ def refund(): a0, a1, a2 = w3.eth.accounts[:3] c = get_contract_with_gas_estimation(arbitration_code, value=1) c.setup(a1, a2, transact={}) - assert_tx_failed(lambda: c.finalize(transact={"from": a1})) + with tx_failed(): + c.finalize(transact={"from": a1}) c.finalize(transact={}) print("Passed escrow test") -def test_arbitration_code_with_init(w3, assert_tx_failed, get_contract_with_gas_estimation): +def test_arbitration_code_with_init(w3, tx_failed, get_contract_with_gas_estimation): arbitration_code_with_init = """ buyer: address seller: address @@ -60,7 +61,8 @@ def refund(): """ a0, a1, a2 = w3.eth.accounts[:3] c = get_contract_with_gas_estimation(arbitration_code_with_init, *[a1, a2], value=1) - assert_tx_failed(lambda: c.finalize(transact={"from": a1})) + with tx_failed(): + c.finalize(transact={"from": a1}) c.finalize(transact={"from": a0}) print("Passed escrow test with initializer") diff --git a/tests/functional/codegen/test_interfaces.py b/tests/functional/codegen/test_interfaces.py index 3544f4a965..65d2df9038 100644 --- a/tests/functional/codegen/test_interfaces.py +++ b/tests/functional/codegen/test_interfaces.py @@ -427,7 +427,7 @@ def test(addr: address): # test data returned from external interface gets clamped @pytest.mark.parametrize("typ", ("int128", "uint8")) -def test_external_interface_int_clampers(get_contract, assert_tx_failed, typ): +def test_external_interface_int_clampers(get_contract, tx_failed, typ): external_contract = f""" @external def ok() -> {typ}: @@ -474,13 +474,16 @@ def test_fail3() -> int256: assert bad_c.should_fail() == -(2**255) assert c.test_ok() == 1 - assert_tx_failed(lambda: c.test_fail()) - assert_tx_failed(lambda: c.test_fail2()) - assert_tx_failed(lambda: c.test_fail3()) + with tx_failed(): + c.test_fail() + with tx_failed(): + c.test_fail2() + with tx_failed(): + c.test_fail3() # test data returned from external interface gets clamped -def test_external_interface_bytes_clampers(get_contract, assert_tx_failed): +def test_external_interface_bytes_clampers(get_contract, tx_failed): external_contract = """ @external def ok() -> Bytes[2]: @@ -522,14 +525,14 @@ def test_fail2() -> Bytes[3]: assert bad_c.should_fail() == b"123" assert c.test_ok() == b"12" - assert_tx_failed(lambda: c.test_fail1()) - assert_tx_failed(lambda: c.test_fail2()) + with tx_failed(): + c.test_fail1() + with tx_failed(): + c.test_fail2() # test data returned from external interface gets clamped -def test_json_abi_bytes_clampers( - get_contract, assert_tx_failed, assert_compile_failed, make_input_bundle -): +def test_json_abi_bytes_clampers(get_contract, tx_failed, assert_compile_failed, make_input_bundle): external_contract = """ @external def returns_Bytes3() -> Bytes[3]: @@ -584,9 +587,12 @@ def test_fail3() -> Bytes[3]: c = get_contract(code, bad_c.address, input_bundle=input_bundle) assert bad_c.returns_Bytes3() == b"123" - assert_tx_failed(lambda: c.test_fail1()) - assert_tx_failed(lambda: c.test_fail2()) - assert_tx_failed(lambda: c.test_fail3()) + with tx_failed(): + c.test_fail1() + with tx_failed(): + c.test_fail2() + with tx_failed(): + c.test_fail3() def test_units_interface(w3, get_contract, make_input_bundle): diff --git a/tests/functional/codegen/test_selector_table.py b/tests/functional/codegen/test_selector_table.py index abea81ced4..94233977c9 100644 --- a/tests/functional/codegen/test_selector_table.py +++ b/tests/functional/codegen/test_selector_table.py @@ -512,9 +512,7 @@ def generate_methods(draw, max_calldata_bytes): # dense selector table packing boundaries at 256 and 65336 @pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) @pytest.mark.fuzzing -def test_selector_table_fuzz( - max_calldata_bytes, opt_level, w3, get_contract, assert_tx_failed, get_logs -): +def test_selector_table_fuzz(max_calldata_bytes, opt_level, w3, get_contract, tx_failed, get_logs): def abi_sig(func_id, calldata_words, n_default_args): params = [] if not calldata_words else [f"uint256[{calldata_words}]"] params.extend(["uint256"] * n_default_args) @@ -600,7 +598,8 @@ def __default__(): else: hexstr = (method_id + argsdata).hex() txdata = {"to": c.address, "data": hexstr, "value": 1} - assert_tx_failed(lambda d=txdata: w3.eth.send_transaction(d)) + with tx_failed(): + w3.eth.send_transaction(txdata) # now do calldatasize check # strip some bytes @@ -610,7 +609,8 @@ def __default__(): if n_calldata_words == 0 and j == 0: # no args, hit default function if default_fn_mutability == "": - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) elif default_fn_mutability == "@payable": # we should be able to send eth to it tx_params["value"] = 1 @@ -628,8 +628,10 @@ def __default__(): # check default function reverts tx_params["value"] = 1 - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) else: - assert_tx_failed(lambda p=tx_params: w3.eth.send_transaction(p)) + with tx_failed(): + w3.eth.send_transaction(tx_params) _test() diff --git a/tests/functional/codegen/test_stateless_modules.py b/tests/functional/codegen/test_stateless_modules.py index 8e634e5868..2abc164689 100644 --- a/tests/functional/codegen/test_stateless_modules.py +++ b/tests/functional/codegen/test_stateless_modules.py @@ -186,7 +186,7 @@ def qux() -> library.SomeStruct: # test calls to library functions in statement position -def test_library_statement_calls(get_contract, make_input_bundle, assert_tx_failed): +def test_library_statement_calls(get_contract, make_input_bundle, tx_failed): library_source = """ from vyper.interfaces import ERC20 @internal @@ -211,7 +211,8 @@ def foo(x: uint256): assert c.counter() == 7 - assert_tx_failed(lambda: c.foo(8)) + with tx_failed(): + c.foo(8) def test_library_is_typechecked(make_input_bundle): diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index 25617651ec..8244bc5487 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -8,6 +8,13 @@ from vyper.utils import MemoryPositions +def search_for_sublist(ir, sublist): + _list = ir.to_list() if hasattr(ir, "to_list") else ir + if _list == sublist: + return True + return isinstance(_list, list) and any(search_for_sublist(i, sublist) for i in _list) + + def test_builtin_constants(get_contract_with_gas_estimation): code = """ @external @@ -192,7 +199,7 @@ def test() -> Bytes[100]: assert c.test() == test_str -def test_constant_folds(search_for_sublist): +def test_constant_folds(): some_prime = 10013677 code = f""" SOME_CONSTANT: constant(uint256) = 11 + 1 @@ -205,11 +212,9 @@ def test() -> uint256: ret: uint256 = 2**SOME_CONSTANT * SOME_PRIME return ret """ - ir = compile_code(code, output_formats=["ir"])["ir"] - assert search_for_sublist( - ir, ["mstore", [MemoryPositions.RESERVED_MEMORY], [2**12 * some_prime]] - ) + search = ["mstore", [MemoryPositions.RESERVED_MEMORY], [2**12 * some_prime]] + assert search_for_sublist(ir, search) def test_constant_lists(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index 1418eab063..25dc1f1a1e 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -156,7 +156,7 @@ def iarg() -> uint256: print("Passed fractional multiplication test") -def test_mul_overflow(assert_tx_failed, get_contract_with_gas_estimation): +def test_mul_overflow(tx_failed, get_contract_with_gas_estimation): mul_code = """ @external @@ -170,12 +170,14 @@ def _num_mul(x: decimal, y: decimal) -> decimal: x = Decimal("85070591730234615865843651857942052864") y = Decimal("136112946768375385385349842973") - assert_tx_failed(lambda: c._num_mul(x, y)) + with tx_failed(): + c._num_mul(x, y) x = SizeLimits.MAX_AST_DECIMAL y = 1 + DECIMAL_EPSILON - assert_tx_failed(lambda: c._num_mul(x, y)) + with tx_failed(): + c._num_mul(x, y) assert c._num_mul(x, Decimal(1)) == x @@ -186,7 +188,7 @@ def _num_mul(x: decimal, y: decimal) -> decimal: # division failure modes(!) -def test_div_overflow(get_contract, assert_tx_failed): +def test_div_overflow(get_contract, tx_failed): code = """ @external def foo(x: decimal, y: decimal) -> decimal: @@ -198,32 +200,39 @@ def foo(x: decimal, y: decimal) -> decimal: x = SizeLimits.MIN_AST_DECIMAL y = -DECIMAL_EPSILON - assert_tx_failed(lambda: c.foo(x, y)) - assert_tx_failed(lambda: c.foo(x, Decimal(0))) - assert_tx_failed(lambda: c.foo(y, Decimal(0))) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + c.foo(x, Decimal(0)) + with tx_failed(): + c.foo(y, Decimal(0)) y = Decimal(1) - DECIMAL_EPSILON # 0.999999999 - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) y = Decimal(-1) - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) assert c.foo(x, Decimal(1)) == x assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) x = SizeLimits.MAX_AST_DECIMAL - assert_tx_failed(lambda: c.foo(x, DECIMAL_EPSILON)) + with tx_failed(): + c.foo(x, DECIMAL_EPSILON) y = Decimal(1) - DECIMAL_EPSILON - assert_tx_failed(lambda: c.foo(x, y)) + with tx_failed(): + c.foo(x, y) assert c.foo(x, Decimal(1)) == x assert c.foo(x, 1 + DECIMAL_EPSILON) == quantize(x / (1 + DECIMAL_EPSILON)) -def test_decimal_min_max_literals(assert_tx_failed, get_contract_with_gas_estimation): +def test_decimal_min_max_literals(tx_failed, get_contract_with_gas_estimation): code = """ @external def maximum(): diff --git a/tests/functional/codegen/types/numbers/test_exponents.py b/tests/functional/codegen/types/numbers/test_exponents.py index 5726e4c1ca..e958436efb 100644 --- a/tests/functional/codegen/types/numbers/test_exponents.py +++ b/tests/functional/codegen/types/numbers/test_exponents.py @@ -7,7 +7,7 @@ @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 255)) -def test_exp_uint256(get_contract, assert_tx_failed, power): +def test_exp_uint256(get_contract, tx_failed, power): code = f""" @external def foo(a: uint256) -> uint256: @@ -20,12 +20,13 @@ def foo(a: uint256) -> uint256: c = get_contract(code) c.foo(max_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) + with tx_failed(): + c.foo(max_base + 1) @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 127)) -def test_exp_int128(get_contract, assert_tx_failed, power): +def test_exp_int128(get_contract, tx_failed, power): code = f""" @external def foo(a: int128) -> int128: @@ -44,13 +45,15 @@ def foo(a: int128) -> int128: c.foo(max_base) c.foo(min_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) - assert_tx_failed(lambda: c.foo(min_base - 1)) + with tx_failed(): + c.foo(max_base + 1) + with tx_failed(): + c.foo(min_base - 1) @pytest.mark.fuzzing @pytest.mark.parametrize("power", range(2, 15)) -def test_exp_int16(get_contract, assert_tx_failed, power): +def test_exp_int16(get_contract, tx_failed, power): code = f""" @external def foo(a: int16) -> int16: @@ -69,8 +72,10 @@ def foo(a: int16) -> int16: c.foo(max_base) c.foo(min_base) - assert_tx_failed(lambda: c.foo(max_base + 1)) - assert_tx_failed(lambda: c.foo(min_base - 1)) + with tx_failed(): + c.foo(max_base + 1) + with tx_failed(): + c.foo(min_base - 1) @pytest.mark.fuzzing @@ -93,7 +98,7 @@ def foo(a: int16) -> int16: # 256 bits @example(a=2**256 - 1) @settings(max_examples=200) -def test_max_exp(get_contract, assert_tx_failed, a): +def test_max_exp(get_contract, tx_failed, a): code = f""" @external def foo(b: uint256) -> uint256: @@ -108,7 +113,8 @@ def foo(b: uint256) -> uint256: assert a ** (max_power + 1) >= 2**256 c.foo(max_power) - assert_tx_failed(lambda: c.foo(max_power + 1)) + with tx_failed(): + c.foo(max_power + 1) @pytest.mark.fuzzing @@ -128,7 +134,7 @@ def foo(b: uint256) -> uint256: # 128 bits @example(a=2**127 - 1) @settings(max_examples=200) -def test_max_exp_int128(get_contract, assert_tx_failed, a): +def test_max_exp_int128(get_contract, tx_failed, a): code = f""" @external def foo(b: int128) -> int128: @@ -143,4 +149,5 @@ def foo(b: int128) -> int128: assert not -(2**127) <= a ** (max_power + 1) < 2**127 c.foo(max_power) - assert_tx_failed(lambda: c.foo(max_power + 1)) + with tx_failed(): + c.foo(max_power + 1) diff --git a/tests/functional/codegen/types/numbers/test_modulo.py b/tests/functional/codegen/types/numbers/test_modulo.py index 018a406baa..465426cd1d 100644 --- a/tests/functional/codegen/types/numbers/test_modulo.py +++ b/tests/functional/codegen/types/numbers/test_modulo.py @@ -31,14 +31,15 @@ def num_modulo_decimal() -> decimal: assert c.num_modulo_decimal() == Decimal(".5") -def test_modulo_with_input_of_zero(assert_tx_failed, get_contract_with_gas_estimation): +def test_modulo_with_input_of_zero(tx_failed, get_contract_with_gas_estimation): code = """ @external def foo(a: decimal, b: decimal) -> decimal: return a % b """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.foo(Decimal("1"), Decimal("0"))) + with tx_failed(): + c.foo(Decimal("1"), Decimal("0")) def test_literals_vs_evm(get_contract): diff --git a/tests/functional/codegen/types/numbers/test_signed_ints.py b/tests/functional/codegen/types/numbers/test_signed_ints.py index 3e44beb826..52de5b649f 100644 --- a/tests/functional/codegen/types/numbers/test_signed_ints.py +++ b/tests/functional/codegen/types/numbers/test_signed_ints.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("typ", types) -def test_exponent_base_zero(get_contract, assert_tx_failed, typ): +def test_exponent_base_zero(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -25,12 +25,14 @@ def foo(x: {typ}) -> {typ}: assert c.foo(1) == 0 assert c.foo(hi) == 0 - assert_tx_failed(lambda: c.foo(-1)) - assert_tx_failed(lambda: c.foo(lo)) # note: lo < 0 + with tx_failed(): + c.foo(-1) + with tx_failed(): + c.foo(lo) # note: lo < 0 @pytest.mark.parametrize("typ", types) -def test_exponent_base_one(get_contract, assert_tx_failed, typ): +def test_exponent_base_one(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -43,8 +45,10 @@ def foo(x: {typ}) -> {typ}: assert c.foo(1) == 1 assert c.foo(hi) == 1 - assert_tx_failed(lambda: c.foo(-1)) - assert_tx_failed(lambda: c.foo(lo)) + with tx_failed(): + c.foo(-1) + with tx_failed(): + c.foo(lo) def test_exponent_base_minus_one(get_contract): @@ -63,7 +67,7 @@ def foo(x: int256) -> int256: # TODO: make this test pass @pytest.mark.parametrize("base", (0, 1)) -def test_exponent_negative_power(get_contract, assert_tx_failed, base): +def test_exponent_negative_power(get_contract, tx_failed, base): # #2985 code = f""" @external @@ -73,7 +77,8 @@ def bar() -> int16: """ c = get_contract(code) # known bug: 2985 - assert_tx_failed(lambda: c.bar()) + with tx_failed(): + c.bar() def test_exponent_min_int16(get_contract): @@ -103,7 +108,7 @@ def foo() -> int256: @pytest.mark.parametrize("typ", types) -def test_exponent(get_contract, assert_tx_failed, typ): +def test_exponent(get_contract, tx_failed, typ): code = f""" @external def foo(x: {typ}) -> {typ}: @@ -116,7 +121,8 @@ def foo(x: {typ}) -> {typ}: test_cases = [0, 1, 3, 4, 126, 127, -1, lo, hi] for x in test_cases: if x * 2 >= typ.bits or x < 0: # out of bounds - assert_tx_failed(lambda p=x: c.foo(p)) + with tx_failed(): + c.foo(x) else: assert c.foo(x) == 4**x @@ -145,7 +151,7 @@ def negative_four() -> {typ}: @pytest.mark.parametrize("typ", types) -def test_num_bound(assert_tx_failed, get_contract_with_gas_estimation, typ): +def test_num_bound(tx_failed, get_contract_with_gas_estimation, typ): lo, hi = typ.ast_bounds num_bound_code = f""" @@ -180,16 +186,22 @@ def _num_min() -> {typ}: assert c._num_sub(lo, 0) == lo assert c._num_add(hi - 1, 1) == hi assert c._num_sub(lo + 1, 1) == lo - assert_tx_failed(lambda: c._num_add(hi, 1)) - assert_tx_failed(lambda: c._num_sub(lo, 1)) - assert_tx_failed(lambda: c._num_add(hi - 1, 2)) - assert_tx_failed(lambda: c._num_sub(lo + 1, 2)) + with tx_failed(): + c._num_add(hi, 1) + with tx_failed(): + c._num_sub(lo, 1) + with tx_failed(): + c._num_add(hi - 1, 2) + with tx_failed(): + c._num_sub(lo + 1, 2) assert c._num_max() == hi assert c._num_min() == lo - assert_tx_failed(lambda: c._num_add3(hi, 1, -1)) + with tx_failed(): + c._num_add3(hi, 1, -1) assert c._num_add3(hi, -1, 1) == hi - 1 + 1 - assert_tx_failed(lambda: c._num_add3(lo, -1, 1)) + with tx_failed(): + c._num_add3(lo, -1, 1) assert c._num_add3(lo, 1, -1) == lo + 1 - 1 @@ -219,7 +231,7 @@ def num_sub() -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, assert_tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): # both variables code_1 = f""" @external @@ -304,14 +316,19 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) + with tx_failed(): + c.foo(x, y) assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + with tx_failed(): + get_contract(code_3).foo(y) assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) else: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_tx_failed(lambda p=x, code=code_2: get_contract(code).foo(p)) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + get_contract(code_2).foo(x) + with tx_failed(): + get_contract(code_3).foo(y) assert_compile_failed( lambda code=code_4: get_contract(code), (InvalidType, OverflowException) ) @@ -372,7 +389,7 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_negation(get_contract, assert_tx_failed, typ): +def test_negation(get_contract, tx_failed, typ): code = f""" @external def foo(a: {typ}) -> {typ}: @@ -390,7 +407,8 @@ def foo(a: {typ}) -> {typ}: assert c.foo(2) == -2 assert c.foo(-2) == 2 - assert_tx_failed(lambda: c.foo(lo)) + with tx_failed(): + c.foo(lo) @pytest.mark.parametrize("typ", types) diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 6c8d114f29..8982065b5d 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -85,7 +85,7 @@ def foo(x: {typ}) -> {typ}: @pytest.mark.parametrize("op", sorted(ARITHMETIC_OPS.keys())) @pytest.mark.parametrize("typ", types) @pytest.mark.fuzzing -def test_arithmetic_thorough(get_contract, assert_tx_failed, assert_compile_failed, op, typ): +def test_arithmetic_thorough(get_contract, tx_failed, assert_compile_failed, op, typ): # both variables code_1 = f""" @external @@ -148,17 +148,23 @@ def foo() -> {typ}: assert get_contract(code_3).foo(y) == expected assert get_contract(code_4).foo() == expected elif div_by_zero: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_compile_failed(lambda code=code_2: get_contract(code), ZeroDivisionException) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) - assert_compile_failed(lambda code=code_4: get_contract(code), ZeroDivisionException) + with tx_failed(): + c.foo(x, y) + with pytest.raises(ZeroDivisionException): + get_contract(code_2) + with tx_failed(): + get_contract(code_3).foo(y) + with pytest.raises(ZeroDivisionException): + get_contract(code_4) else: - assert_tx_failed(lambda p=(x, y): c.foo(*p)) - assert_tx_failed(lambda code=code_2, p=x: get_contract(code).foo(p)) - assert_tx_failed(lambda p=y, code=code_3: get_contract(code).foo(p)) - assert_compile_failed( - lambda code=code_4: get_contract(code), (InvalidType, OverflowException) - ) + with tx_failed(): + c.foo(x, y) + with tx_failed(): + get_contract(code_2).foo(x) + with tx_failed(): + get_contract(code_3).foo(y) + with pytest.raises((InvalidType, OverflowException)): + get_contract(code_4) COMPARISON_OPS = { diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 01ec75d5c1..1ee9b8d835 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -3,7 +3,7 @@ from vyper.exceptions import InvalidType, TypeMismatch -def test_test_bytes(get_contract_with_gas_estimation, assert_tx_failed): +def test_test_bytes(get_contract_with_gas_estimation, tx_failed): test_bytes = """ @external def foo(x: Bytes[100]) -> Bytes[100]: @@ -21,7 +21,8 @@ def foo(x: Bytes[100]) -> Bytes[100]: print("Passed max-length bytes test") # test for greater than 100 bytes, should raise exception - assert_tx_failed(lambda: c.foo(b"\x35" * 101)) + with tx_failed(): + c.foo(b"\x35" * 101) print("Passed input-too-long test") diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index d793a56d6e..4ef6874ae9 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -759,27 +759,30 @@ def test_multi4_2() -> DynArray[DynArray[DynArray[DynArray[uint256, 2], 2], 2], assert c.test_multi4_2() == nest4 -def test_uint256_accessor(get_contract_with_gas_estimation, assert_tx_failed): +def test_uint256_accessor(get_contract_with_gas_estimation, tx_failed): code = """ @external def bounds_check_uint256(xs: DynArray[uint256, 3], ix: uint256) -> uint256: return xs[ix] """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.bounds_check_uint256([], 0)) + with tx_failed(): + c.bounds_check_uint256([], 0) assert c.bounds_check_uint256([1], 0) == 1 - assert_tx_failed(lambda: c.bounds_check_uint256([1], 1)) + with tx_failed(): + c.bounds_check_uint256([1], 1) assert c.bounds_check_uint256([1, 2, 3], 0) == 1 assert c.bounds_check_uint256([1, 2, 3], 2) == 3 - assert_tx_failed(lambda: c.bounds_check_uint256([1, 2, 3], 3)) + with tx_failed(): + c.bounds_check_uint256([1, 2, 3], 3) # TODO do bounds checks for nested darrays @pytest.mark.parametrize("list_", ([], [11], [11, 12], [11, 12, 13])) -def test_dynarray_len(get_contract_with_gas_estimation, assert_tx_failed, list_): +def test_dynarray_len(get_contract_with_gas_estimation, tx_failed, list_): code = """ @external def darray_len(xs: DynArray[uint256, 3]) -> uint256: @@ -790,7 +793,7 @@ def darray_len(xs: DynArray[uint256, 3]) -> uint256: assert c.darray_len(list_) == len(list_) -def test_dynarray_too_large(get_contract_with_gas_estimation, assert_tx_failed): +def test_dynarray_too_large(get_contract_with_gas_estimation, tx_failed): code = """ @external def darray_len(xs: DynArray[uint256, 3]) -> uint256: @@ -798,10 +801,11 @@ def darray_len(xs: DynArray[uint256, 3]) -> uint256: """ c = get_contract_with_gas_estimation(code) - assert_tx_failed(lambda: c.darray_len([1, 2, 3, 4])) + with tx_failed(): + c.darray_len([1, 2, 3, 4]) -def test_int128_accessor(get_contract_with_gas_estimation, assert_tx_failed): +def test_int128_accessor(get_contract_with_gas_estimation, tx_failed): code = """ @external def bounds_check_int128(ix: int128) -> uint256: @@ -811,8 +815,10 @@ def bounds_check_int128(ix: int128) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check_int128(0) == 1 assert c.bounds_check_int128(2) == 3 - assert_tx_failed(lambda: c.bounds_check_int128(3)) - assert_tx_failed(lambda: c.bounds_check_int128(-1)) + with tx_failed(): + c.bounds_check_int128(3) + with tx_failed(): + c.bounds_check_int128(-1) def test_index_exception(get_contract_with_gas_estimation, assert_compile_failed): @@ -1164,12 +1170,13 @@ def test_invalid_append_pop(get_contract, assert_compile_failed, code, exception @pytest.mark.parametrize("code,check_result", append_pop_tests) # TODO change this to fuzz random data @pytest.mark.parametrize("test_data", [[1, 2, 3, 4, 5][:i] for i in range(6)]) -def test_append_pop(get_contract, assert_tx_failed, code, check_result, test_data): +def test_append_pop(get_contract, tx_failed, code, check_result, test_data): c = get_contract(code) expected_result = check_result(test_data) if expected_result is None: # None is sentinel to indicate txn should revert - assert_tx_failed(lambda: c.foo(test_data)) + with tx_failed(): + c.foo(test_data) else: assert c.foo(test_data) == expected_result @@ -1234,7 +1241,7 @@ def foo(x: {typ}) -> {typ}: ["uint256[3]", "DynArray[uint256,3]", "DynArray[uint8, 4]", "Foo", "DynArray[Foobar, 3]"], ) # TODO change this to fuzz random data -def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check_result, subtype): +def test_append_pop_complex(get_contract, tx_failed, code_template, check_result, subtype): code = code_template.format(typ=subtype) test_data = [1, 2, 3] if subtype == "Foo": @@ -1260,7 +1267,8 @@ def test_append_pop_complex(get_contract, assert_tx_failed, code_template, check expected_result = check_result(test_data) if expected_result is None: # None is sentinel to indicate txn should revert - assert_tx_failed(lambda: c.foo(test_data)) + with tx_failed(): + c.foo(test_data) else: assert c.foo(test_data) == expected_result @@ -1330,7 +1338,7 @@ def bar(_baz: DynArray[Foo, 3]) -> String[96]: assert c.bar(c_input) == "Hello world!!!!" -def test_list_of_structs_lists_with_nested_lists(get_contract, assert_tx_failed): +def test_list_of_structs_lists_with_nested_lists(get_contract, tx_failed): code = """ struct Bar: a: DynArray[uint8[2], 2] @@ -1351,7 +1359,8 @@ def foo(x: uint8) -> uint8: """ c = get_contract(code) assert c.foo(17) == 98 - assert_tx_failed(lambda: c.foo(241)) + with tx_failed(): + c.foo(241) def test_list_of_nested_struct_arrays(get_contract): @@ -1622,7 +1631,7 @@ def bar() -> uint256: assert c.bar() == 58 -def test_constant_list(get_contract, assert_tx_failed): +def test_constant_list(get_contract, tx_failed): some_good_primes = [5.0, 11.0, 17.0, 29.0, 37.0, 41.0] code = f""" MY_LIST: constant(DynArray[decimal, 6]) = {some_good_primes} @@ -1634,7 +1643,8 @@ def ix(i: uint256) -> decimal: for i, p in enumerate(some_good_primes): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(some_good_primes) + 1)) + with tx_failed(): + c.ix(len(some_good_primes) + 1) def test_public_dynarray(get_contract): @@ -1831,7 +1841,8 @@ def should_revert() -> DynArray[String[65], 2]: @pytest.mark.parametrize("code", dynarray_length_no_clobber_cases) -def test_dynarray_length_no_clobber(get_contract, assert_tx_failed, code): +def test_dynarray_length_no_clobber(get_contract, tx_failed, code): # check that length is not clobbered before dynarray data copy happens c = get_contract(code) - assert_tx_failed(lambda: c.should_revert()) + with tx_failed(): + c.should_revert() diff --git a/tests/functional/codegen/types/test_flag.py b/tests/functional/codegen/types/test_flag.py index 03c22134ed..5da6d57558 100644 --- a/tests/functional/codegen/types/test_flag.py +++ b/tests/functional/codegen/types/test_flag.py @@ -74,7 +74,7 @@ def is_not_boss(a: Roles) -> bool: assert c.is_not_boss(2**4) is False -def test_bitwise(get_contract, assert_tx_failed): +def test_bitwise(get_contract, tx_failed): code = """ flag Roles: USER @@ -134,18 +134,25 @@ def binv_arg(a: Roles) -> Roles: assert c.binv_arg(0b00000) == 0b11111 # LHS is out of bound - assert_tx_failed(lambda: c.bor_arg(32, 3)) - assert_tx_failed(lambda: c.band_arg(32, 3)) - assert_tx_failed(lambda: c.bxor_arg(32, 3)) - assert_tx_failed(lambda: c.binv_arg(32)) + with tx_failed(): + c.bor_arg(32, 3) + with tx_failed(): + c.band_arg(32, 3) + with tx_failed(): + c.bxor_arg(32, 3) + with tx_failed(): + c.binv_arg(32) # RHS - assert_tx_failed(lambda: c.bor_arg(3, 32)) - assert_tx_failed(lambda: c.band_arg(3, 32)) - assert_tx_failed(lambda: c.bxor_arg(3, 32)) + with tx_failed(): + c.bor_arg(3, 32) + with tx_failed(): + c.band_arg(3, 32) + with tx_failed(): + c.bxor_arg(3, 32) -def test_augassign_storage(get_contract, w3, assert_tx_failed): +def test_augassign_storage(get_contract, w3, tx_failed): code = """ flag Roles: ADMIN @@ -190,7 +197,8 @@ def checkMinter(minter: address): assert c.roles(minter_address) == 0b10 # admin is not a minter - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) c.addMinter(admin_address, transact={}) @@ -201,7 +209,8 @@ def checkMinter(minter: address): # revoke minter c.revokeMinter(admin_address, transact={}) assert c.roles(admin_address) == 0b01 - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) # flip minter c.flipMinter(admin_address, transact={}) @@ -211,7 +220,8 @@ def checkMinter(minter: address): # flip minter c.flipMinter(admin_address, transact={}) assert c.roles(admin_address) == 0b01 - assert_tx_failed(lambda: c.checkMinter(admin_address)) + with tx_failed(): + c.checkMinter(admin_address) def test_in_flag(get_contract_with_gas_estimation): diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index 832b679e5e..657c4ba0b8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -353,7 +353,7 @@ def test_multi4() -> uint256[2][2][2][2]: @pytest.mark.parametrize("type_", ["uint8", "uint256"]) -def test_unsigned_accessors(get_contract_with_gas_estimation, assert_tx_failed, type_): +def test_unsigned_accessors(get_contract_with_gas_estimation, tx_failed, type_): code = f""" @external def bounds_check(ix: {type_}) -> uint256: @@ -363,11 +363,12 @@ def bounds_check(ix: {type_}) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check(0) == 1 assert c.bounds_check(2) == 3 - assert_tx_failed(lambda: c.bounds_check(3)) + with tx_failed(): + c.bounds_check(3) @pytest.mark.parametrize("type_", ["int128", "int256"]) -def test_signed_accessors(get_contract_with_gas_estimation, assert_tx_failed, type_): +def test_signed_accessors(get_contract_with_gas_estimation, tx_failed, type_): code = f""" @external def bounds_check(ix: {type_}) -> uint256: @@ -377,8 +378,10 @@ def bounds_check(ix: {type_}) -> uint256: c = get_contract_with_gas_estimation(code) assert c.bounds_check(0) == 1 assert c.bounds_check(2) == 3 - assert_tx_failed(lambda: c.bounds_check(3)) - assert_tx_failed(lambda: c.bounds_check(-1)) + with tx_failed(): + c.bounds_check(3) + with tx_failed(): + c.bounds_check(-1) def test_list_check_heterogeneous_types(get_contract_with_gas_estimation, assert_compile_failed): @@ -662,7 +665,7 @@ def foo(x: Bar[2][2][2]) -> uint256: ("bool", [True, False, True, False, True, False]), ], ) -def test_constant_list(get_contract, assert_tx_failed, type, value): +def test_constant_list(get_contract, tx_failed, type, value): code = f""" MY_LIST: constant({type}[{len(value)}]) = {value} @external @@ -673,7 +676,8 @@ def ix(i: uint256) -> {type}: for i, p in enumerate(value): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(value) + 1)) + with tx_failed(): + c.ix(len(value) + 1) def test_nested_constant_list_accessor(get_contract): @@ -728,7 +732,7 @@ def foo(i: uint256) -> {return_type}: assert_compile_failed(lambda: get_contract(code), TypeMismatch) -def test_constant_list_address(get_contract, assert_tx_failed): +def test_constant_list_address(get_contract, tx_failed): some_good_address = [ "0x0000000000000000000000000000000000012345", "0x0000000000000000000000000000000000023456", @@ -754,10 +758,11 @@ def ix(i: uint256) -> address: for i, p in enumerate(some_good_address): assert c.ix(i) == p # assert oob - assert_tx_failed(lambda: c.ix(len(some_good_address) + 1)) + with tx_failed(): + c.ix(len(some_good_address) + 1) -def test_list_index_complex_expr(get_contract, assert_tx_failed): +def test_list_index_complex_expr(get_contract, tx_failed): # test subscripts where the index is not a literal code = """ @external @@ -771,7 +776,8 @@ def foo(xs: uint256[257], i: uint8) -> uint256: assert c.foo(xs, ix) == xs[ix + 1] # safemath should fail for uint8: 255 + 1. - assert_tx_failed(lambda: c.foo(xs, 255)) + with tx_failed(): + c.foo(xs, 255) @pytest.mark.parametrize( @@ -793,7 +799,7 @@ def foo(xs: uint256[257], i: uint8) -> uint256: ("bool", [[True, False], [True, False], [True, False]]), ], ) -def test_constant_nested_list(get_contract, assert_tx_failed, type, value): +def test_constant_nested_list(get_contract, tx_failed, type, value): code = f""" MY_LIST: constant({type}[{len(value[0])}][{len(value)}]) = {value} @external @@ -805,7 +811,8 @@ def ix(i: uint256, j: uint256) -> {type}: for j, q in enumerate(p): assert c.ix(i, j) == q # assert oob - assert_tx_failed(lambda: c.ix(len(value) + 1, len(value[0]) + 1)) + with tx_failed(): + c.ix(len(value) + 1, len(value[0]) + 1) @pytest.mark.parametrize("storage_type,return_type", itertools.permutations(integer_types, 2)) diff --git a/tests/functional/codegen/types/test_string.py b/tests/functional/codegen/types/test_string.py index 7f1fa71329..9d50f8df38 100644 --- a/tests/functional/codegen/types/test_string.py +++ b/tests/functional/codegen/types/test_string.py @@ -61,7 +61,7 @@ def get(k: String[34]) -> int128: assert c.get("a" * 34) == 6789 -def test_string_slice(get_contract_with_gas_estimation, assert_tx_failed): +def test_string_slice(get_contract_with_gas_estimation, tx_failed): test_slice4 = """ @external def foo(inp: String[10], start: uint256, _len: uint256) -> String[10]: @@ -76,10 +76,14 @@ def foo(inp: String[10], start: uint256, _len: uint256) -> String[10]: assert c.foo("badminton", 1, 0) == "" assert c.foo("badminton", 9, 0) == "" - assert_tx_failed(lambda: c.foo("badminton", 0, 10)) - assert_tx_failed(lambda: c.foo("badminton", 1, 9)) - assert_tx_failed(lambda: c.foo("badminton", 9, 1)) - assert_tx_failed(lambda: c.foo("badminton", 10, 0)) + with tx_failed(): + c.foo("badminton", 0, 10) + with tx_failed(): + c.foo("badminton", 1, 9) + with tx_failed(): + c.foo("badminton", 9, 1) + with tx_failed(): + c.foo("badminton", 10, 0) def test_private_string(get_contract_with_gas_estimation): diff --git a/tests/functional/examples/auctions/test_blind_auction.py b/tests/functional/examples/auctions/test_blind_auction.py index d814ab0cad..dcd4e0bf8b 100644 --- a/tests/functional/examples/auctions/test_blind_auction.py +++ b/tests/functional/examples/auctions/test_blind_auction.py @@ -33,15 +33,15 @@ def test_initial_state(w3, tester, auction_contract): assert auction_contract.highestBidder() is None -def test_late_bid(w3, auction_contract, assert_tx_failed): +def test_late_bid(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # Move time forward past bidding end w3.testing.mine(BIDDING_TIME + TEST_INCREMENT) # Try to bid after bidding has ended - assert_tx_failed( - lambda: auction_contract.bid( + with tx_failed(): + auction_contract.bid( w3.keccak( b"".join( [ @@ -53,10 +53,9 @@ def test_late_bid(w3, auction_contract, assert_tx_failed): ), transact={"value": 200, "from": k1}, ) - ) -def test_too_many_bids(w3, auction_contract, assert_tx_failed): +def test_too_many_bids(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # First 128 bids should be able to be placed successfully @@ -75,8 +74,8 @@ def test_too_many_bids(w3, auction_contract, assert_tx_failed): ) # 129th bid should fail - assert_tx_failed( - lambda: auction_contract.bid( + with tx_failed(): + auction_contract.bid( w3.keccak( b"".join( [ @@ -88,10 +87,9 @@ def test_too_many_bids(w3, auction_contract, assert_tx_failed): ), transact={"value": 128, "from": k1}, ) - ) -def test_early_reval(w3, auction_contract, assert_tx_failed): +def test_early_reval(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # k1 places 1 real bid @@ -119,11 +117,10 @@ def test_early_reval(w3, auction_contract, assert_tx_failed): _values[0] = 100 _fakes[0] = False _secrets[0] = (8675309).to_bytes(32, byteorder="big") - assert_tx_failed( - lambda: auction_contract.reveal( + with tx_failed(): + auction_contract.reveal( _numBids, _values, _fakes, _secrets, transact={"value": 0, "from": k1} ) - ) # Check highest bidder is still empty assert auction_contract.highestBidder() is None @@ -131,7 +128,7 @@ def test_early_reval(w3, auction_contract, assert_tx_failed): assert auction_contract.highestBid() == 0 -def test_late_reveal(w3, auction_contract, assert_tx_failed): +def test_late_reveal(w3, auction_contract, tx_failed): k1 = w3.eth.accounts[1] # k1 places 1 real bid @@ -159,11 +156,10 @@ def test_late_reveal(w3, auction_contract, assert_tx_failed): _values[0] = 100 _fakes[0] = False _secrets[0] = (8675309).to_bytes(32, byteorder="big") - assert_tx_failed( - lambda: auction_contract.reveal( + with tx_failed(): + auction_contract.reveal( _numBids, _values, _fakes, _secrets, transact={"value": 0, "from": k1} ) - ) # Check highest bidder is still empty assert auction_contract.highestBidder() is None @@ -171,14 +167,15 @@ def test_late_reveal(w3, auction_contract, assert_tx_failed): assert auction_contract.highestBid() == 0 -def test_early_end(w3, auction_contract, assert_tx_failed): +def test_early_end(w3, auction_contract, tx_failed): k0 = w3.eth.accounts[0] # Should not be able to end auction before reveal time has ended - assert_tx_failed(lambda: auction_contract.auctionEnd(transact={"value": 0, "from": k0})) + with tx_failed(): + auction_contract.auctionEnd(transact={"value": 0, "from": k0}) -def test_double_end(w3, auction_contract, assert_tx_failed): +def test_double_end(w3, auction_contract, tx_failed): k0 = w3.eth.accounts[0] # Move time forward past bidding and reveal end @@ -188,7 +185,8 @@ def test_double_end(w3, auction_contract, assert_tx_failed): auction_contract.auctionEnd(transact={"value": 0, "from": k0}) # Should not be able to end auction twice - assert_tx_failed(lambda: auction_contract.auctionEnd(transact={"value": 0, "from": k0})) + with tx_failed(): + auction_contract.auctionEnd(transact={"value": 0, "from": k0}) def test_blind_auction(w3, auction_contract): diff --git a/tests/functional/examples/auctions/test_simple_open_auction.py b/tests/functional/examples/auctions/test_simple_open_auction.py index cf0bb8cc20..c80b44d976 100644 --- a/tests/functional/examples/auctions/test_simple_open_auction.py +++ b/tests/functional/examples/auctions/test_simple_open_auction.py @@ -33,17 +33,19 @@ def test_initial_state(w3, tester, auction_contract, auction_start): assert auction_contract.auctionEnd() >= tester.get_block_by_number("latest")["timestamp"] -def test_bid(w3, tester, auction_contract, assert_tx_failed): +def test_bid(w3, tester, auction_contract, tx_failed): k1, k2, k3, k4, k5 = w3.eth.accounts[:5] # Bidder cannot bid 0 - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 0, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 0, "from": k1}) # Bidder can bid auction_contract.bid(transact={"value": 1, "from": k1}) # Check that highest bidder and highest bid have changed accordingly assert auction_contract.highestBidder() == k1 assert auction_contract.highestBid() == 1 # Bidder bid cannot equal current highest bid - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 1, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 1, "from": k1}) # Higher bid can replace current highest bid auction_contract.bid(transact={"value": 2, "from": k2}) # Check that highest bidder and highest bid have changed accordingly @@ -72,10 +74,11 @@ def test_bid(w3, tester, auction_contract, assert_tx_failed): assert auction_contract.pendingReturns(k1) == 0 -def test_end_auction(w3, tester, auction_contract, assert_tx_failed): +def test_end_auction(w3, tester, auction_contract, tx_failed): k1, k2, k3, k4, k5 = w3.eth.accounts[:5] # Fails if auction end time has not been reached - assert_tx_failed(lambda: auction_contract.endAuction()) + with tx_failed(): + auction_contract.endAuction() auction_contract.bid(transact={"value": 1 * 10**10, "from": k2}) # Move block timestamp foreward to reach auction end time # tester.time_travel(tester.get_block_by_number('latest')['timestamp'] + EXPIRY) @@ -86,6 +89,8 @@ def test_end_auction(w3, tester, auction_contract, assert_tx_failed): # Beneficiary receives the highest bid assert balance_after_end == balance_before_end + 1 * 10**10 # Bidder cannot bid after auction end time has been reached - assert_tx_failed(lambda: auction_contract.bid(transact={"value": 10, "from": k1})) + with tx_failed(): + auction_contract.bid(transact={"value": 10, "from": k1}) # Auction cannot be ended twice - assert_tx_failed(lambda: auction_contract.endAuction()) + with tx_failed(): + auction_contract.endAuction() diff --git a/tests/functional/examples/company/test_company.py b/tests/functional/examples/company/test_company.py index 71141b8bb5..5933a14e86 100644 --- a/tests/functional/examples/company/test_company.py +++ b/tests/functional/examples/company/test_company.py @@ -9,7 +9,7 @@ def c(w3, get_contract): return contract -def test_overbuy(w3, c, assert_tx_failed): +def test_overbuy(w3, c, tx_failed): # If all the stock has been bought, no one can buy more a1, a2 = w3.eth.accounts[1:3] test_shares = int(c.totalShares() / 2) @@ -19,15 +19,19 @@ def test_overbuy(w3, c, assert_tx_failed): assert c.stockAvailable() == 0 assert c.getHolding(a1) == (test_shares * 2) one_stock = c.price() - assert_tx_failed(lambda: c.buyStock(transact={"from": a1, "value": one_stock})) - assert_tx_failed(lambda: c.buyStock(transact={"from": a2, "value": one_stock})) + with tx_failed(): + c.buyStock(transact={"from": a1, "value": one_stock}) + with tx_failed(): + c.buyStock(transact={"from": a2, "value": one_stock}) -def test_sell_without_stock(w3, c, assert_tx_failed): +def test_sell_without_stock(w3, c, tx_failed): a1, a2 = w3.eth.accounts[1:3] # If you don't have any stock, you can't sell - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a2})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) + with tx_failed(): + c.sellStock(1, transact={"from": a2}) # But if you do, you can! test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) @@ -35,48 +39,57 @@ def test_sell_without_stock(w3, c, assert_tx_failed): assert c.getHolding(a1) == test_shares c.sellStock(test_shares, transact={"from": a1}) # But only until you run out - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) -def test_oversell(w3, c, assert_tx_failed): +def test_oversell(w3, c, tx_failed): a0, a1, a2 = w3.eth.accounts[:3] # You can't sell more than you own test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) - assert_tx_failed(lambda: c.sellStock(test_shares + 1, transact={"from": a1})) + with tx_failed(): + c.sellStock(test_shares + 1, transact={"from": a1}) -def test_transfer(w3, c, assert_tx_failed): +def test_transfer(w3, c, tx_failed): # If you don't have any stock, you can't transfer a1, a2 = w3.eth.accounts[1:3] - assert_tx_failed(lambda: c.transferStock(a2, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.transferStock(a1, 1, transact={"from": a2})) + with tx_failed(): + c.transferStock(a2, 1, transact={"from": a1}) + with tx_failed(): + c.transferStock(a1, 1, transact={"from": a2}) # If you transfer, you don't have the stock anymore test_shares = int(c.totalShares()) test_value = int(test_shares * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) assert c.getHolding(a1) == test_shares c.transferStock(a2, test_shares, transact={"from": a1}) - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) # But the other person does c.sellStock(test_shares, transact={"from": a2}) -def test_paybill(w3, c, assert_tx_failed): +def test_paybill(w3, c, tx_failed): a0, a1, a2, a3 = w3.eth.accounts[:4] # Only the company can authorize payments - assert_tx_failed(lambda: c.payBill(a2, 1, transact={"from": a1})) + with tx_failed(): + c.payBill(a2, 1, transact={"from": a1}) # A company can only pay someone if it has the money - assert_tx_failed(lambda: c.payBill(a2, 1, transact={"from": a0})) + with tx_failed(): + c.payBill(a2, 1, transact={"from": a0}) # If it has the money, it can pay someone test_value = int(c.totalShares() * c.price()) c.buyStock(transact={"from": a1, "value": test_value}) c.payBill(a2, test_value, transact={"from": a0}) # Until it runs out of money - assert_tx_failed(lambda: c.payBill(a3, 1, transact={"from": a0})) + with tx_failed(): + c.payBill(a3, 1, transact={"from": a0}) # Then no stockholders can sell their stock either - assert_tx_failed(lambda: c.sellStock(1, transact={"from": a1})) + with tx_failed(): + c.sellStock(1, transact={"from": a1}) def test_valuation(w3, c): diff --git a/tests/functional/examples/crowdfund/test_crowdfund_example.py b/tests/functional/examples/crowdfund/test_crowdfund_example.py index 9a08d9241c..e75a88bf48 100644 --- a/tests/functional/examples/crowdfund/test_crowdfund_example.py +++ b/tests/functional/examples/crowdfund/test_crowdfund_example.py @@ -27,7 +27,7 @@ def test_crowdfund_example(c, w3): assert post_bal - pre_bal == 54 -def test_crowdfund_example2(c, w3, assert_tx_failed): +def test_crowdfund_example2(c, w3, tx_failed): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] c.participate(transact={"value": 1, "from": a3}) c.participate(transact={"value": 2, "from": a4}) @@ -39,9 +39,11 @@ def test_crowdfund_example2(c, w3, assert_tx_failed): # assert c.expired() # assert not c.reached() pre_bals = [w3.eth.get_balance(x) for x in [a3, a4, a5, a6]] - assert_tx_failed(lambda: c.refund(transact={"from": a0})) + with tx_failed(): + c.refund(transact={"from": a0}) c.refund(transact={"from": a3}) - assert_tx_failed(lambda: c.refund(transact={"from": a3})) + with tx_failed(): + c.refund(transact={"from": a3}) c.refund(transact={"from": a4}) c.refund(transact={"from": a5}) c.refund(transact={"from": a6}) diff --git a/tests/functional/examples/market_maker/test_on_chain_market_maker.py b/tests/functional/examples/market_maker/test_on_chain_market_maker.py index db9700da3b..235a0ea66f 100644 --- a/tests/functional/examples/market_maker/test_on_chain_market_maker.py +++ b/tests/functional/examples/market_maker/test_on_chain_market_maker.py @@ -31,25 +31,21 @@ def test_initial_state(market_maker): assert market_maker.owner() is None -def test_initiate(w3, market_maker, erc20, assert_tx_failed): +def test_initiate(w3, market_maker, erc20, tx_failed): a0 = w3.eth.accounts[0] - erc20.approve(market_maker.address, w3.to_wei(2, "ether"), transact={}) - market_maker.initiate( - erc20.address, w3.to_wei(1, "ether"), transact={"value": w3.to_wei(2, "ether")} - ) - assert market_maker.totalEthQty() == w3.to_wei(2, "ether") - assert market_maker.totalTokenQty() == w3.to_wei(1, "ether") + ether, ethers = w3.to_wei(1, "ether"), w3.to_wei(2, "ether") + erc20.approve(market_maker.address, ethers, transact={}) + market_maker.initiate(erc20.address, ether, transact={"value": ethers}) + assert market_maker.totalEthQty() == ethers + assert market_maker.totalTokenQty() == ether assert market_maker.invariant() == 2 * 10**36 assert market_maker.owner() == a0 assert erc20.name() == TOKEN_NAME assert erc20.decimals() == TOKEN_DECIMALS # Initiate cannot be called twice - assert_tx_failed( - lambda: market_maker.initiate( - erc20.address, w3.to_wei(1, "ether"), transact={"value": w3.to_wei(2, "ether")} - ) - ) # noqa: E501 + with tx_failed(): + market_maker.initiate(erc20.address, ether, transact={"value": ethers}) def test_eth_to_tokens(w3, market_maker, erc20): @@ -95,7 +91,7 @@ def test_tokens_to_eth(w3, market_maker, erc20): assert market_maker.totalEthQty() == w3.to_wei(1, "ether") -def test_owner_withdraw(w3, market_maker, erc20, assert_tx_failed): +def test_owner_withdraw(w3, market_maker, erc20, tx_failed): a0, a1 = w3.eth.accounts[:2] a0_balance_before = w3.eth.get_balance(a0) # Approve 2 eth transfers. @@ -110,7 +106,8 @@ def test_owner_withdraw(w3, market_maker, erc20, assert_tx_failed): assert erc20.balanceOf(a0) == TOKEN_TOTAL_SUPPLY - w3.to_wei(1, "ether") # Only owner can call ownerWithdraw - assert_tx_failed(lambda: market_maker.ownerWithdraw(transact={"from": a1})) + with tx_failed(): + market_maker.ownerWithdraw(transact={"from": a1}) market_maker.ownerWithdraw(transact={}) assert w3.eth.get_balance(a0) == a0_balance_before # Eth balance restored. assert erc20.balanceOf(a0) == TOKEN_TOTAL_SUPPLY # Tokens returned to a0. diff --git a/tests/functional/examples/name_registry/test_name_registry.py b/tests/functional/examples/name_registry/test_name_registry.py index 26f5844484..a2e92a7c52 100644 --- a/tests/functional/examples/name_registry/test_name_registry.py +++ b/tests/functional/examples/name_registry/test_name_registry.py @@ -1,8 +1,9 @@ -def test_name_registry(w3, get_contract, assert_tx_failed): +def test_name_registry(w3, get_contract, tx_failed): a0, a1 = w3.eth.accounts[:2] with open("examples/name_registry/name_registry.vy") as f: code = f.read() c = get_contract(code) c.register(b"jacques", a0, transact={}) assert c.lookup(b"jacques") == a0 - assert_tx_failed(lambda: c.register(b"jacques", a1)) + with tx_failed(): + c.register(b"jacques", a1) diff --git a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py index 9a806ed885..2cc5dd8d4a 100644 --- a/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py +++ b/tests/functional/examples/safe_remote_purchase/test_safe_remote_purchase.py @@ -31,9 +31,10 @@ def get_balance(): return get_balance -def test_initial_state(w3, assert_tx_failed, get_contract, get_balance, contract_code): +def test_initial_state(w3, tx_failed, get_contract, get_balance, contract_code): # Inital deposit has to be divisible by two - assert_tx_failed(lambda: get_contract(contract_code, value=13)) + with tx_failed(): + get_contract(contract_code, value=13) # Seller puts item up for sale a0_pre_bal, a1_pre_bal = get_balance() c = get_contract(contract_code, value_in_eth=2) @@ -47,30 +48,34 @@ def test_initial_state(w3, assert_tx_failed, get_contract, get_balance, contract assert get_balance() == ((a0_pre_bal - w3.to_wei(2, "ether")), a1_pre_bal) -def test_abort(w3, assert_tx_failed, get_balance, get_contract, contract_code): +def test_abort(w3, tx_failed, get_balance, get_contract, contract_code): a0, a1, a2 = w3.eth.accounts[:3] a0_pre_bal, a1_pre_bal = get_balance() c = get_contract(contract_code, value=w3.to_wei(2, "ether")) assert c.value() == w3.to_wei(1, "ether") # Only sender can trigger refund - assert_tx_failed(lambda: c.abort(transact={"from": a2})) + with tx_failed(): + c.abort(transact={"from": a2}) # Refund works correctly c.abort(transact={"from": a0}) assert get_balance() == (a0_pre_bal, a1_pre_bal) # Purchase in process, no refund possible c = get_contract(contract_code, value=2) c.purchase(transact={"value": 2, "from": a1}) - assert_tx_failed(lambda: c.abort(transact={"from": a0})) + with tx_failed(): + c.abort(transact={"from": a0}) -def test_purchase(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_purchase(w3, get_contract, tx_failed, get_balance, contract_code): a0, a1, a2, a3 = w3.eth.accounts[:4] init_bal_a0, init_bal_a1 = get_balance() c = get_contract(contract_code, value=2) # Purchase for too low/high price - assert_tx_failed(lambda: c.purchase(transact={"value": 1, "from": a1})) - assert_tx_failed(lambda: c.purchase(transact={"value": 3, "from": a1})) + with tx_failed(): + c.purchase(transact={"value": 1, "from": a1}) + with tx_failed(): + c.purchase(transact={"value": 3, "from": a1}) # Purchase for the correct price c.purchase(transact={"value": 2, "from": a1}) # Check if buyer is set correctly @@ -80,26 +85,29 @@ def test_purchase(w3, get_contract, assert_tx_failed, get_balance, contract_code # Check balances, both deposits should have been deducted assert get_balance() == (init_bal_a0 - 2, init_bal_a1 - 2) # Allow nobody else to purchase - assert_tx_failed(lambda: c.purchase(transact={"value": 2, "from": a3})) + with tx_failed(): + c.purchase(transact={"value": 2, "from": a3}) -def test_received(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_received(w3, get_contract, tx_failed, get_balance, contract_code): a0, a1 = w3.eth.accounts[:2] init_bal_a0, init_bal_a1 = get_balance() c = get_contract(contract_code, value=2) # Can only be called after purchase - assert_tx_failed(lambda: c.received(transact={"from": a1})) + with tx_failed(): + c.received(transact={"from": a1}) # Purchase completed c.purchase(transact={"value": 2, "from": a1}) # Check that e.g. sender cannot trigger received - assert_tx_failed(lambda: c.received(transact={"from": a0})) + with tx_failed(): + c.received(transact={"from": a0}) # Check if buyer can call receive c.received(transact={"from": a1}) # Final check if everything worked. 1 value has been transferred assert get_balance() == (init_bal_a0 + 1, init_bal_a1 - 1) -def test_received_reentrancy(w3, get_contract, assert_tx_failed, get_balance, contract_code): +def test_received_reentrancy(w3, get_contract, tx_failed, get_balance, contract_code): buyer_contract_code = """ interface PurchaseContract: diff --git a/tests/functional/examples/storage/test_advanced_storage.py b/tests/functional/examples/storage/test_advanced_storage.py index 13ffce4f82..313d1a7e5c 100644 --- a/tests/functional/examples/storage/test_advanced_storage.py +++ b/tests/functional/examples/storage/test_advanced_storage.py @@ -18,32 +18,30 @@ def test_initial_state(adv_storage_contract): assert adv_storage_contract.storedData() == INITIAL_VALUE -def test_failed_transactions(w3, adv_storage_contract, assert_tx_failed): +def test_failed_transactions(w3, adv_storage_contract, tx_failed): k1 = w3.eth.accounts[1] # Try to set the storage to a negative amount - assert_tx_failed(lambda: adv_storage_contract.set(-10, transact={"from": k1})) + with tx_failed(): + adv_storage_contract.set(-10, transact={"from": k1}) # Lock the contract by storing more than 100. Then try to change the value adv_storage_contract.set(150, transact={"from": k1}) - assert_tx_failed(lambda: adv_storage_contract.set(10, transact={"from": k1})) + with tx_failed(): + adv_storage_contract.set(10, transact={"from": k1}) # Reset the contract and try to change the value adv_storage_contract.reset(transact={"from": k1}) adv_storage_contract.set(10, transact={"from": k1}) assert adv_storage_contract.storedData() == 10 - # Assert a different exception (ValidationError for non matching argument type) - assert_tx_failed( - lambda: adv_storage_contract.set("foo", transact={"from": k1}), ValidationError - ) + # Assert a different exception (ValidationError for non-matching argument type) + with tx_failed(ValidationError): + adv_storage_contract.set("foo", transact={"from": k1}) # Assert a different exception that contains specific text - assert_tx_failed( - lambda: adv_storage_contract.set(1, 2, transact={"from": k1}), - ValidationError, - "invocation failed due to improper number of arguments", - ) + with tx_failed(ValidationError, "invocation failed due to improper number of arguments"): + adv_storage_contract.set(1, 2, transact={"from": k1}) def test_events(w3, adv_storage_contract, get_logs): diff --git a/tests/functional/examples/tokens/test_erc1155.py b/tests/functional/examples/tokens/test_erc1155.py index abebd024b6..5dc314c037 100644 --- a/tests/functional/examples/tokens/test_erc1155.py +++ b/tests/functional/examples/tokens/test_erc1155.py @@ -29,7 +29,7 @@ @pytest.fixture -def erc1155(get_contract, w3, assert_tx_failed): +def erc1155(get_contract, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] with open("examples/tokens/ERC1155ownable.vy") as f: code = f.read() @@ -41,18 +41,20 @@ def erc1155(get_contract, w3, assert_tx_failed): assert c.balanceOf(a1, 1) == 1 assert c.balanceOf(a1, 2) == 1 assert c.balanceOf(a1, 3) == 1 - assert_tx_failed( - lambda: c.mintBatch(ZERO_ADDRESS, mintBatch, minBatchSetOf10, transact={"from": owner}) - ) - assert_tx_failed(lambda: c.mintBatch(a1, [1, 2, 3], [1, 1], transact={"from": owner})) + with tx_failed(): + c.mintBatch(ZERO_ADDRESS, mintBatch, minBatchSetOf10, transact={"from": owner}) + with tx_failed(): + c.mintBatch(a1, [1, 2, 3], [1, 1], transact={"from": owner}) c.mint(a1, 21, 1, transact={"from": owner}) c.mint(a1, 22, 1, transact={"from": owner}) c.mint(a1, 23, 1, transact={"from": owner}) c.mint(a1, 24, 1, transact={"from": owner}) - assert_tx_failed(lambda: c.mint(a1, 24, 1, transact={"from": a3})) - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 24, 1, transact={"from": owner})) + with tx_failed(): + c.mint(a1, 24, 1, transact={"from": a3}) + with tx_failed(): + c.mint(ZERO_ADDRESS, 24, 1, transact={"from": owner}) assert c.balanceOf(a1, 21) == 1 assert c.balanceOf(a1, 22) == 1 @@ -80,69 +82,76 @@ def test_initial_state(erc1155): assert erc1155.supportsInterface(ERC1155_INTERFACE_ID_METADATA) -def test_pause(erc1155, w3, assert_tx_failed): +def test_pause(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check the pause status, pause, check, unpause, check, with owner and non-owner w3.eth.accounts # this test will check all the function that should not work when paused. assert not erc1155.paused() # try to pause the contract from a non owner account - assert_tx_failed(lambda: erc1155.pause(transact={"from": a1})) + with tx_failed(): + erc1155.pause(transact={"from": a1}) # now pause the contract and check status erc1155.pause(transact={"from": owner}) assert erc1155.paused() # try pausing a paused contract - assert_tx_failed(lambda: erc1155.pause()) + with tx_failed(): + erc1155.pause() # try functions that should not work when paused - assert_tx_failed(lambda: erc1155.setURI(NEW_CONTRACT_URI)) + with tx_failed(): + erc1155.setURI(NEW_CONTRACT_URI) # test burn and burnbatch - assert_tx_failed(lambda: erc1155.burn(21, 1)) - assert_tx_failed(lambda: erc1155.burnBatch([21, 22], [1, 1])) + with tx_failed(): + erc1155.burn(21, 1) + with tx_failed(): + erc1155.burnBatch([21, 22], [1, 1]) # check mint and mintbatch - assert_tx_failed(lambda: erc1155.mint(a1, 21, 1, transact={"from": owner})) - assert_tx_failed( - lambda: erc1155.mintBatch(a1, mintBatch, minBatchSetOf10, transact={"from": owner}) - ) + with tx_failed(): + erc1155.mint(a1, 21, 1, transact={"from": owner}) + with tx_failed(): + erc1155.mintBatch(a1, mintBatch, minBatchSetOf10, transact={"from": owner}) # check safetransferfrom and safebatchtransferfrom - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # check ownership functions - assert_tx_failed(lambda: erc1155.transferOwnership(a1)) - assert_tx_failed(lambda: erc1155.renounceOwnership()) + with tx_failed(): + erc1155.transferOwnership(a1) + with tx_failed(): + erc1155.renounceOwnership() # check approval functions - assert_tx_failed(lambda: erc1155.setApprovalForAll(owner, a5, True)) + with tx_failed(): + erc1155.setApprovalForAll(owner, a5, True) # try and unpause as non-owner - assert_tx_failed(lambda: erc1155.unpause(transact={"from": a1})) + with tx_failed(): + erc1155.unpause(transact={"from": a1}) erc1155.unpause(transact={"from": owner}) assert not erc1155.paused() # try un pausing an unpaused contract - assert_tx_failed(lambda: erc1155.unpause()) + with tx_failed(): + erc1155.unpause() -def test_contractURI(erc1155, w3, assert_tx_failed): +def test_contractURI(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # change contract URI and restore. assert erc1155.contractURI() == CONTRACT_METADATA_URI - assert_tx_failed( - lambda: erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": a1}) - ) + with tx_failed(): + erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": a1}) erc1155.setContractURI(NEW_CONTRACT_METADATA_URI, transact={"from": owner}) assert erc1155.contractURI() == NEW_CONTRACT_METADATA_URI assert erc1155.contractURI() != CONTRACT_METADATA_URI @@ -150,10 +159,11 @@ def test_contractURI(erc1155, w3, assert_tx_failed): assert erc1155.contractURI() != NEW_CONTRACT_METADATA_URI assert erc1155.contractURI() == CONTRACT_METADATA_URI - assert_tx_failed(lambda: erc1155.setContractURI(CONTRACT_METADATA_URI)) + with tx_failed(): + erc1155.setContractURI(CONTRACT_METADATA_URI) -def test_URI(erc1155, w3, assert_tx_failed): +def test_URI(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # change contract URI and restore. assert erc1155.uri(0) == CONTRACT_URI @@ -164,7 +174,8 @@ def test_URI(erc1155, w3, assert_tx_failed): assert erc1155.uri(0) != NEW_CONTRACT_URI assert erc1155.uri(0) == CONTRACT_URI - assert_tx_failed(lambda: erc1155.setURI(CONTRACT_URI)) + with tx_failed(): + erc1155.setURI(CONTRACT_URI) # set contract to dynamic URI erc1155.toggleDynUri(True, transact={"from": owner}) @@ -172,49 +183,41 @@ def test_URI(erc1155, w3, assert_tx_failed): assert erc1155.uri(0) == CONTRACT_DYNURI + str(0) + ".json" -def test_safeTransferFrom_balanceOf_single(erc1155, w3, assert_tx_failed): +def test_safeTransferFrom_balanceOf_single(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] assert erc1155.balanceOf(a1, 24) == 1 # transfer by non-owner - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a2}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a2}) # transfer to zero address - assert_tx_failed( - lambda: erc1155.safeTransferFrom( - a1, ZERO_ADDRESS, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1} - ) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, ZERO_ADDRESS, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer to self - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a1, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a1, 24, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer more than owned - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 24, 500, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 24, 500, DUMMY_BYTES32_DATA, transact={"from": a1}) # transfer item not owned / not existing - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 500, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 500, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) assert erc1155.balanceOf(a2, 21) == 1 # try to transfer item again - assert_tx_failed( - lambda: erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) - ) + with tx_failed(): + erc1155.safeTransferFrom(a1, a2, 21, 1, DUMMY_BYTES32_DATA, transact={"from": a1}) assert erc1155.balanceOf(a1, 21) == 0 # TODO: mint 20 NFTs [1:20] and check the balance for each -def test_mintBatch_balanceOf(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_mintBatch_balanceOf(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # Use the mint three fixture to mint the tokens. # this test checks the balances of this test @@ -222,7 +225,7 @@ def test_mintBatch_balanceOf(erc1155, w3, assert_tx_failed): # test_mint_batch assert erc1155.balanceOf(a1, i) == 1 -def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check a1 balances for NFTs 21-24 @@ -231,67 +234,58 @@ def test_safeBatchTransferFrom_balanceOf_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOf(a1, 23) == 1 assert erc1155.balanceOf(a1, 23) == 1 - # try to transfer item from non item owner account - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + # try to transfer item from non-item owner account + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a2} ) - ) # try to transfer item to zero address - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, ZERO_ADDRESS, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer item to self - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a1, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer more items than we own - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 125, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # mismatched item and amounts - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) # try to transfer nonexisting item - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 500], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) assert erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) # try to transfer again, our balances are zero now, should fail - assert_tx_failed( - lambda: erc1155.safeBatchTransferFrom( + with tx_failed(): + erc1155.safeBatchTransferFrom( a1, a2, [21, 22, 23], [1, 1, 1], DUMMY_BYTES32_DATA, transact={"from": a1} ) - ) - assert_tx_failed( - lambda: erc1155.balanceOfBatch([a2, a2, a2], [21, 22], transact={"from": owner}) - == [1, 1, 1] - ) + with tx_failed(): + erc1155.balanceOfBatch([a2, a2, a2], [21, 22], transact={"from": owner}) assert erc1155.balanceOfBatch([a2, a2, a2], [21, 22, 23]) == [1, 1, 1] assert erc1155.balanceOf(a1, 21) == 0 -def test_mint_one_burn_one(erc1155, w3, assert_tx_failed): +def test_mint_one_burn_one(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # check the balance from an owner and non-owner account @@ -301,20 +295,23 @@ def test_mint_one_burn_one(erc1155, w3, assert_tx_failed): assert erc1155.balanceOf(owner, 25) == 1 # try and burn an item we don't control - assert_tx_failed(lambda: erc1155.burn(25, 1, transact={"from": a3})) + with tx_failed(): + erc1155.burn(25, 1, transact={"from": a3}) # burn an item that contains something we don't own - assert_tx_failed(lambda: erc1155.burn(595, 1, transact={"from": a1})) + with tx_failed(): + erc1155.burn(595, 1, transact={"from": a1}) # burn ah item passing a higher amount than we own - assert_tx_failed(lambda: erc1155.burn(25, 500, transact={"from": a1})) + with tx_failed(): + erc1155.burn(25, 500, transact={"from": a1}) erc1155.burn(25, 1, transact={"from": owner}) assert erc1155.balanceOf(owner, 25) == 0 -def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): +def test_mint_batch_burn_batch(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # mint NFTs 11-20 @@ -322,16 +319,20 @@ def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOfBatch([a3, a3, a3], [11, 12, 13]) == [1, 1, 1] # try and burn a batch we don't control - assert_tx_failed(lambda: erc1155.burnBatch([11, 12], [1, 1])) + with tx_failed(): + erc1155.burnBatch([11, 12], [1, 1]) # ids and amounts array length not matching - assert_tx_failed(lambda: erc1155.burnBatch([1, 2, 3], [1, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([1, 2, 3], [1, 1], transact={"from": a1}) # burn a batch that contains something we don't own - assert_tx_failed(lambda: erc1155.burnBatch([2, 3, 595], [1, 1, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([2, 3, 595], [1, 1, 1], transact={"from": a1}) # burn a batch passing a higher amount than we own - assert_tx_failed(lambda: erc1155.burnBatch([1, 2, 3], [1, 500, 1], transact={"from": a1})) + with tx_failed(): + erc1155.burnBatch([1, 2, 3], [1, 500, 1], transact={"from": a1}) # burn existing erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3}) @@ -339,18 +340,21 @@ def test_mint_batch_burn_batch(erc1155, w3, assert_tx_failed): assert erc1155.balanceOfBatch([a3, a3, a3], [11, 12, 13]) == [0, 0, 1] # burn again, should revert - assert_tx_failed(lambda: erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3})) + with tx_failed(): + erc1155.burnBatch([11, 12], [1, 1], transact={"from": a3}) - assert lambda: erc1155.balanceOfBatch([a3, a3, a3], [1, 2, 3]) == [0, 0, 1] + assert erc1155.balanceOfBatch([a3, a3, a3], [1, 2, 3]) == [0, 0, 0] -def test_approval_functions(erc1155, w3, assert_tx_failed): # test_mint_batch +def test_approval_functions(erc1155, w3, tx_failed): # test_mint_batch owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] # self-approval by the owner - assert_tx_failed(lambda: erc1155.setApprovalForAll(a5, a5, True, transact={"from": a5})) + with tx_failed(): + erc1155.setApprovalForAll(a5, a5, True, transact={"from": a5}) # let's approve and operator for somebody else's account - assert_tx_failed(lambda: erc1155.setApprovalForAll(owner, a5, True, transact={"from": a3})) + with tx_failed(): + erc1155.setApprovalForAll(owner, a5, True, transact={"from": a3}) # set approval correctly erc1155.setApprovalForAll(owner, a5, True) @@ -362,7 +366,7 @@ def test_approval_functions(erc1155, w3, assert_tx_failed): # test_mint_batch erc1155.setApprovalForAll(owner, a5, False) -def test_max_batch_size_violation(erc1155, w3, assert_tx_failed): +def test_max_batch_size_violation(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] TOTAL_BAD_BATCH = 200 ids = [] @@ -371,27 +375,29 @@ def test_max_batch_size_violation(erc1155, w3, assert_tx_failed): ids.append(i) amounts.append(1) - assert_tx_failed(lambda: erc1155.mintBatch(a1, ids, amounts, transact={"from": owner})) + with tx_failed(): + erc1155.mintBatch(a1, ids, amounts, transact={"from": owner}) # Transferring back and forth -def test_ownership_functions(erc1155, w3, assert_tx_failed, tester): +def test_ownership_functions(erc1155, w3, tx_failed, tester): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] print(owner, a1, a2) print("___owner___", erc1155.owner()) # change owner from account 0 to account 1 and back assert erc1155.owner() == owner - assert_tx_failed(lambda: erc1155.transferOwnership(a1, transact={"from": a2})) + with tx_failed(): + erc1155.transferOwnership(a1, transact={"from": a2}) # try to transfer ownership to current owner - assert_tx_failed(lambda: erc1155.transferOwnership(owner)) + with tx_failed(): + erc1155.transferOwnership(owner) # try to transfer ownership to ZERO ADDRESS - assert_tx_failed( - lambda: erc1155.transferOwnership("0x0000000000000000000000000000000000000000") - ) + with tx_failed(): + erc1155.transferOwnership("0x0000000000000000000000000000000000000000") # Transfer ownership to account 1 erc1155.transferOwnership(a1, transact={"from": owner}) @@ -399,11 +405,12 @@ def test_ownership_functions(erc1155, w3, assert_tx_failed, tester): assert erc1155.owner() == a1 -def test_renounce_ownership(erc1155, w3, assert_tx_failed): +def test_renounce_ownership(erc1155, w3, tx_failed): owner, a1, a2, a3, a4, a5 = w3.eth.accounts[0:6] assert erc1155.owner() == owner # try to transfer ownership from non-owner account - assert_tx_failed(lambda: erc1155.renounceOwnership(transact={"from": a2})) + with tx_failed(): + erc1155.renounceOwnership(transact={"from": a2}) erc1155.renounceOwnership(transact={"from": owner}) diff --git a/tests/functional/examples/tokens/test_erc20.py b/tests/functional/examples/tokens/test_erc20.py index cba7769bae..ce507f75f8 100644 --- a/tests/functional/examples/tokens/test_erc20.py +++ b/tests/functional/examples/tokens/test_erc20.py @@ -61,7 +61,7 @@ def test_initial_state(c, w3): assert c.allowance(a2, a3) == 0 -def test_mint_and_burn(c, w3, assert_tx_failed): +def test_mint_and_burn(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] # Test scenario were mints 2 to a1, burns twice (check balance consistency) @@ -70,23 +70,30 @@ def test_mint_and_burn(c, w3, assert_tx_failed): assert c.balanceOf(a1) == 2 c.burn(2, transact={"from": a1}) assert c.balanceOf(a1) == 0 - assert_tx_failed(lambda: c.burn(2, transact={"from": a1})) + with tx_failed(): + c.burn(2, transact={"from": a1}) assert c.balanceOf(a1) == 0 # Test scenario were mintes 0 to a2, burns (check balance consistency, false burn) c.mint(a2, 0, transact={"from": minter}) assert c.balanceOf(a2) == 0 - assert_tx_failed(lambda: c.burn(2, transact={"from": a2})) + with tx_failed(): + c.burn(2, transact={"from": a2}) # Check that a1 cannot burn after depleting their balance - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) # Check that a1, a2 cannot mint - assert_tx_failed(lambda: c.mint(a1, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": a2})) + with tx_failed(): + c.mint(a1, 1, transact={"from": a1}) + with tx_failed(): + c.mint(a2, 1, transact={"from": a2}) # Check that mint to ZERO_ADDRESS failed - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, 1, transact={"from": minter})) + with tx_failed(): + c.mint(ZERO_ADDRESS, 1, transact={"from": a1}) + with tx_failed(): + c.mint(ZERO_ADDRESS, 1, transact={"from": minter}) -def test_totalSupply(c, w3, assert_tx_failed): +def test_totalSupply(c, w3, tx_failed): # Test total supply initially, after mint, between two burns, and after failed burn minter, a1 = w3.eth.accounts[0:2] assert c.totalSupply() == 0 @@ -96,40 +103,49 @@ def test_totalSupply(c, w3, assert_tx_failed): assert c.totalSupply() == 1 c.burn(1, transact={"from": a1}) assert c.totalSupply() == 0 - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) assert c.totalSupply() == 0 # Test that 0-valued mint can't affect supply c.mint(a1, 0, transact={"from": minter}) assert c.totalSupply() == 0 -def test_transfer(c, w3, assert_tx_failed): +def test_transfer(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 2, transact={"from": minter}) c.burn(1, transact={"from": a1}) c.transfer(a2, 1, transact={"from": a1}) - assert_tx_failed(lambda: c.burn(1, transact={"from": a1})) + with tx_failed(): + c.burn(1, transact={"from": a1}) c.burn(1, transact={"from": a2}) - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) # Ensure transfer fails with insufficient balance - assert_tx_failed(lambda: c.transfer(a1, 1, transact={"from": a2})) + with tx_failed(): + c.transfer(a1, 1, transact={"from": a2}) # Ensure 0-transfer always succeeds c.transfer(a1, 0, transact={"from": a2}) -def test_maxInts(c, w3, assert_tx_failed): +def test_maxInts(c, w3, tx_failed): minter, a1, a2 = w3.eth.accounts[0:3] c.mint(a1, MAX_UINT256, transact={"from": minter}) assert c.balanceOf(a1) == MAX_UINT256 - assert_tx_failed(lambda: c.mint(a1, 1, transact={"from": a1})) - assert_tx_failed(lambda: c.mint(a1, MAX_UINT256, transact={"from": a1})) + with tx_failed(): + c.mint(a1, 1, transact={"from": a1}) + with tx_failed(): + c.mint(a1, MAX_UINT256, transact={"from": a1}) # Check that totalSupply cannot overflow, even when mint to other account - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": minter})) + with tx_failed(): + c.mint(a2, 1, transact={"from": minter}) # Check that corresponding mint is allowed after burn c.burn(1, transact={"from": a1}) c.mint(a2, 1, transact={"from": minter}) - assert_tx_failed(lambda: c.mint(a2, 1, transact={"from": minter})) + with tx_failed(): + c.mint(a2, 1, transact={"from": minter}) c.transfer(a1, 1, transact={"from": a2}) # Assert that after obtaining max number of tokens, a1 can transfer those but no more assert c.balanceOf(a1) == MAX_UINT256 @@ -150,21 +166,24 @@ def test_maxInts(c, w3, assert_tx_failed): assert c.balanceOf(a1) == 0 -def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): +def test_transferFrom_and_Allowance(c, w3, tx_failed): minter, a1, a2, a3 = w3.eth.accounts[0:4] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 1, transact={"from": minter}) c.mint(a2, 1, transact={"from": minter}) c.burn(1, transact={"from": a1}) # This should fail; no allowance or balance (0 always succeeds) - assert_tx_failed(lambda: c.transferFrom(a1, a3, 1, transact={"from": a2})) + with tx_failed(): + c.transferFrom(a1, a3, 1, transact={"from": a2}) c.transferFrom(a1, a3, 0, transact={"from": a2}) # Correct call to approval should update allowance (but not for reverse pair) c.approve(a2, 1, transact={"from": a1}) assert c.allowance(a1, a2) == 1 assert c.allowance(a2, a1) == 0 # transferFrom should succeed when allowed, fail with wrong sender - assert_tx_failed(lambda: c.transferFrom(a1, a3, 1, transact={"from": a3})) + with tx_failed(): + c.transferFrom(a1, a3, 1, transact={"from": a3}) assert c.balanceOf(a2) == 1 c.approve(a1, 1, transact={"from": a2}) c.transferFrom(a2, a3, 1, transact={"from": a1}) @@ -173,7 +192,8 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): # transferFrom with no funds should fail despite approval c.approve(a1, 1, transact={"from": a2}) assert c.allowance(a2, a1) == 1 - assert_tx_failed(lambda: c.transferFrom(a2, a3, 1, transact={"from": a1})) + with tx_failed(): + c.transferFrom(a2, a3, 1, transact={"from": a1}) # 0-approve should not change balance or allow transferFrom to change balance c.mint(a2, 1, transact={"from": minter}) assert c.allowance(a2, a1) == 1 @@ -181,7 +201,8 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 0 c.approve(a1, 0, transact={"from": a2}) assert c.allowance(a2, a1) == 0 - assert_tx_failed(lambda: c.transferFrom(a2, a3, 1, transact={"from": a1})) + with tx_failed(): + c.transferFrom(a2, a3, 1, transact={"from": a1}) # Test that if non-zero approval exists, 0-approval is NOT required to proceed # a non-conformant implementation is described in countermeasures at # https://docs.google.com/document/d/1YLPtQxZu1UAvO9cZ1O2RPXBbT0mooh4DYKjA_jp-RLM/edit#heading=h.m9fhqynw2xvt @@ -198,21 +219,24 @@ def test_transferFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 5 -def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): +def test_burnFrom_and_Allowance(c, w3, tx_failed): minter, a1, a2, a3 = w3.eth.accounts[0:4] - assert_tx_failed(lambda: c.burn(1, transact={"from": a2})) + with tx_failed(): + c.burn(1, transact={"from": a2}) c.mint(a1, 1, transact={"from": minter}) c.mint(a2, 1, transact={"from": minter}) c.burn(1, transact={"from": a1}) # This should fail; no allowance or balance (0 always succeeds) - assert_tx_failed(lambda: c.burnFrom(a1, 1, transact={"from": a2})) + with tx_failed(): + c.burnFrom(a1, 1, transact={"from": a2}) c.burnFrom(a1, 0, transact={"from": a2}) # Correct call to approval should update allowance (but not for reverse pair) c.approve(a2, 1, transact={"from": a1}) assert c.allowance(a1, a2) == 1 assert c.allowance(a2, a1) == 0 # transferFrom should succeed when allowed, fail with wrong sender - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a3})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a3}) assert c.balanceOf(a2) == 1 c.approve(a1, 1, transact={"from": a2}) c.burnFrom(a2, 1, transact={"from": a1}) @@ -221,7 +245,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): # transferFrom with no funds should fail despite approval c.approve(a1, 1, transact={"from": a2}) assert c.allowance(a2, a1) == 1 - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a1})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a1}) # 0-approve should not change balance or allow transferFrom to change balance c.mint(a2, 1, transact={"from": minter}) assert c.allowance(a2, a1) == 1 @@ -229,7 +254,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): assert c.allowance(a2, a1) == 0 c.approve(a1, 0, transact={"from": a2}) assert c.allowance(a2, a1) == 0 - assert_tx_failed(lambda: c.burnFrom(a2, 1, transact={"from": a1})) + with tx_failed(): + c.burnFrom(a2, 1, transact={"from": a1}) # Test that if non-zero approval exists, 0-approval is NOT required to proceed # a non-conformant implementation is described in countermeasures at # https://docs.google.com/document/d/1YLPtQxZu1UAvO9cZ1O2RPXBbT0mooh4DYKjA_jp-RLM/edit#heading=h.m9fhqynw2xvt @@ -245,7 +271,8 @@ def test_burnFrom_and_Allowance(c, w3, assert_tx_failed): c.approve(a1, 5, transact={"from": a2}) assert c.allowance(a2, a1) == 5 # Check that burnFrom to ZERO_ADDRESS failed - assert_tx_failed(lambda: c.burnFrom(ZERO_ADDRESS, 0, transact={"from": a1})) + with tx_failed(): + c.burnFrom(ZERO_ADDRESS, 0, transact={"from": a1}) def test_raw_logs(c, w3, get_log_args): @@ -307,33 +334,36 @@ def test_raw_logs(c, w3, get_log_args): assert args.value == 0 -def test_bad_transfer(c_bad, w3, assert_tx_failed): +def test_bad_transfer(c_bad, w3, tx_failed): # Ensure transfer fails if it would otherwise overflow balance when totalSupply is corrupted minter, a1, a2 = w3.eth.accounts[0:3] c_bad.mint(a1, MAX_UINT256, transact={"from": minter}) c_bad.mint(a2, 1, transact={"from": minter}) - assert_tx_failed(lambda: c_bad.transfer(a1, 1, transact={"from": a2})) + with tx_failed(): + c_bad.transfer(a1, 1, transact={"from": a2}) c_bad.transfer(a2, MAX_UINT256 - 1, transact={"from": a1}) assert c_bad.balanceOf(a1) == 1 assert c_bad.balanceOf(a2) == MAX_UINT256 -def test_bad_burn(c_bad, w3, assert_tx_failed): +def test_bad_burn(c_bad, w3, tx_failed): # Ensure burn fails if it would otherwise underflow balance when totalSupply is corrupted minter, a1 = w3.eth.accounts[0:2] assert c_bad.balanceOf(a1) == 0 c_bad.mint(a1, 2, transact={"from": minter}) assert c_bad.balanceOf(a1) == 2 - assert_tx_failed(lambda: c_bad.burn(3, transact={"from": a1})) + with tx_failed(): + c_bad.burn(3, transact={"from": a1}) -def test_bad_transferFrom(c_bad, w3, assert_tx_failed): +def test_bad_transferFrom(c_bad, w3, tx_failed): # Ensure transferFrom fails if it would otherwise overflow balance when totalSupply is corrupted minter, a1, a2 = w3.eth.accounts[0:3] c_bad.mint(a1, MAX_UINT256, transact={"from": minter}) c_bad.mint(a2, 1, transact={"from": minter}) c_bad.approve(a1, 1, transact={"from": a2}) - assert_tx_failed(lambda: c_bad.transferFrom(a2, a1, 1, transact={"from": a1})) + with tx_failed(): + c_bad.transferFrom(a2, a1, 1, transact={"from": a1}) c_bad.approve(a2, MAX_UINT256 - 1, transact={"from": a1}) assert c_bad.allowance(a1, a2) == MAX_UINT256 - 1 c_bad.transferFrom(a1, a2, MAX_UINT256 - 1, transact={"from": a2}) diff --git a/tests/functional/examples/tokens/test_erc721.py b/tests/functional/examples/tokens/test_erc721.py index ab3c6368c5..c881149baa 100644 --- a/tests/functional/examples/tokens/test_erc721.py +++ b/tests/functional/examples/tokens/test_erc721.py @@ -40,16 +40,18 @@ def test_erc165(w3, c): assert c.supportsInterface(ERC721_SIG) -def test_balanceOf(c, w3, assert_tx_failed): +def test_balanceOf(c, w3, tx_failed): someone = w3.eth.accounts[1] assert c.balanceOf(someone) == 3 - assert_tx_failed(lambda: c.balanceOf(ZERO_ADDRESS)) + with tx_failed(): + c.balanceOf(ZERO_ADDRESS) -def test_ownerOf(c, w3, assert_tx_failed): +def test_ownerOf(c, w3, tx_failed): someone = w3.eth.accounts[1] assert c.ownerOf(SOMEONE_TOKEN_IDS[0]) == someone - assert_tx_failed(lambda: c.ownerOf(INVALID_TOKEN_ID)) + with tx_failed(): + c.ownerOf(INVALID_TOKEN_ID) def test_getApproved(c, w3): @@ -72,32 +74,24 @@ def test_isApprovedForAll(c, w3): assert c.isApprovedForAll(someone, operator) == 1 -def test_transferFrom_by_owner(c, w3, assert_tx_failed, get_logs): +def test_transferFrom_by_owner(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # transfer from zero address - assert_tx_failed( - lambda: c.transferFrom( - ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.transferFrom(ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer to zero address - assert_tx_failed( - lambda: c.transferFrom( - someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.transferFrom(someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer token without ownership - assert_tx_failed( - lambda: c.transferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.transferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # transfer invalid token - assert_tx_failed( - lambda: c.transferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.transferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) # transfer by owner tx_hash = c.transferFrom(someone, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) @@ -152,32 +146,24 @@ def test_transferFrom_by_operator(c, w3, get_logs): assert c.balanceOf(operator) == 2 -def test_safeTransferFrom_by_owner(c, w3, assert_tx_failed, get_logs): +def test_safeTransferFrom_by_owner(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # transfer from zero address - assert_tx_failed( - lambda: c.safeTransferFrom( - ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.safeTransferFrom(ZERO_ADDRESS, operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer to zero address - assert_tx_failed( - lambda: c.safeTransferFrom( - someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) + with tx_failed(): + c.safeTransferFrom(someone, ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # transfer token without ownership - assert_tx_failed( - lambda: c.safeTransferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.safeTransferFrom(someone, operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # transfer invalid token - assert_tx_failed( - lambda: c.safeTransferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) - ) + with tx_failed(): + c.safeTransferFrom(someone, operator, INVALID_TOKEN_ID, transact={"from": someone}) # transfer by owner tx_hash = c.safeTransferFrom( @@ -238,15 +224,12 @@ def test_safeTransferFrom_by_operator(c, w3, get_logs): assert c.balanceOf(operator) == 2 -def test_safeTransferFrom_to_contract(c, w3, assert_tx_failed, get_logs, get_contract): +def test_safeTransferFrom_to_contract(c, w3, tx_failed, get_logs, get_contract): someone = w3.eth.accounts[1] # Can't transfer to a contract that doesn't implement the receiver code - assert_tx_failed( - lambda: c.safeTransferFrom( - someone, c.address, SOMEONE_TOKEN_IDS[0], transact={"from": someone} - ) - ) # noqa: E501 + with tx_failed(): + c.safeTransferFrom(someone, c.address, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # Only to an address that implements that function receiver = get_contract( @@ -277,17 +260,20 @@ def onERC721Received( assert c.balanceOf(receiver.address) == 1 -def test_approve(c, w3, assert_tx_failed, get_logs): +def test_approve(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # approve myself - assert_tx_failed(lambda: c.approve(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone})) + with tx_failed(): + c.approve(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # approve token without ownership - assert_tx_failed(lambda: c.approve(operator, OPERATOR_TOKEN_ID, transact={"from": someone})) + with tx_failed(): + c.approve(operator, OPERATOR_TOKEN_ID, transact={"from": someone}) # approve invalid token - assert_tx_failed(lambda: c.approve(operator, INVALID_TOKEN_ID, transact={"from": someone})) + with tx_failed(): + c.approve(operator, INVALID_TOKEN_ID, transact={"from": someone}) tx_hash = c.approve(operator, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) logs = get_logs(tx_hash, c, "Approval") @@ -299,12 +285,13 @@ def test_approve(c, w3, assert_tx_failed, get_logs): assert args.tokenId == SOMEONE_TOKEN_IDS[0] -def test_setApprovalForAll(c, w3, assert_tx_failed, get_logs): +def test_setApprovalForAll(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] approved = True # setApprovalForAll myself - assert_tx_failed(lambda: c.setApprovalForAll(someone, approved, transact={"from": someone})) + with tx_failed(): + c.setApprovalForAll(someone, approved, transact={"from": someone}) tx_hash = c.setApprovalForAll(operator, approved, transact={"from": someone}) logs = get_logs(tx_hash, c, "ApprovalForAll") @@ -316,14 +303,16 @@ def test_setApprovalForAll(c, w3, assert_tx_failed, get_logs): assert args.approved == approved -def test_mint(c, w3, assert_tx_failed, get_logs): +def test_mint(c, w3, tx_failed, get_logs): minter, someone = w3.eth.accounts[:2] # mint by non-minter - assert_tx_failed(lambda: c.mint(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone})) + with tx_failed(): + c.mint(someone, SOMEONE_TOKEN_IDS[0], transact={"from": someone}) # mint to zero address - assert_tx_failed(lambda: c.mint(ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": minter})) + with tx_failed(): + c.mint(ZERO_ADDRESS, SOMEONE_TOKEN_IDS[0], transact={"from": minter}) # mint by minter tx_hash = c.mint(someone, NEW_TOKEN_ID, transact={"from": minter}) @@ -338,11 +327,12 @@ def test_mint(c, w3, assert_tx_failed, get_logs): assert c.balanceOf(someone) == 4 -def test_burn(c, w3, assert_tx_failed, get_logs): +def test_burn(c, w3, tx_failed, get_logs): someone, operator = w3.eth.accounts[1:3] # burn token without ownership - assert_tx_failed(lambda: c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": operator})) + with tx_failed(): + c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": operator}) # burn token by owner tx_hash = c.burn(SOMEONE_TOKEN_IDS[0], transact={"from": someone}) @@ -353,5 +343,6 @@ def test_burn(c, w3, assert_tx_failed, get_logs): assert args.sender == someone assert args.receiver == ZERO_ADDRESS assert args.tokenId == SOMEONE_TOKEN_IDS[0] - assert_tx_failed(lambda: c.ownerOf(SOMEONE_TOKEN_IDS[0])) + with tx_failed(): + c.ownerOf(SOMEONE_TOKEN_IDS[0]) assert c.balanceOf(someone) == 2 diff --git a/tests/functional/examples/voting/test_ballot.py b/tests/functional/examples/voting/test_ballot.py index 4207fe6e4e..9c3a09fc83 100644 --- a/tests/functional/examples/voting/test_ballot.py +++ b/tests/functional/examples/voting/test_ballot.py @@ -33,7 +33,7 @@ def test_initial_state(w3, c): assert c.voters(z0)[0] == 0 # Voter.weight -def test_give_the_right_to_vote(w3, c, assert_tx_failed): +def test_give_the_right_to_vote(w3, c, tx_failed): a0, a1, a2, a3, a4, a5 = w3.eth.accounts[:6] c.giveRightToVote(a1, transact={}) # Check voter given right has weight of 1 @@ -56,7 +56,8 @@ def test_give_the_right_to_vote(w3, c, assert_tx_failed): # Check voter_acount is now 6 assert c.voterCount() == 6 # Check chairperson cannot give the right to vote twice to the same voter - assert_tx_failed(lambda: c.giveRightToVote(a5, transact={})) + with tx_failed(): + c.giveRightToVote(a5, transact={}) # Check voters weight didn't change assert c.voters(a5)[0] == 1 # Voter.weight @@ -127,7 +128,7 @@ def test_forward_weight(w3, c): assert c.voters(a9)[0] == 10 # Voter.weight -def test_block_short_cycle(w3, c, assert_tx_failed): +def test_block_short_cycle(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = w3.eth.accounts[:10] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -141,7 +142,8 @@ def test_block_short_cycle(w3, c, assert_tx_failed): c.delegate(a3, transact={"from": a2}) c.delegate(a4, transact={"from": a3}) # would create a length 5 cycle: - assert_tx_failed(lambda: c.delegate(a0, transact={"from": a4})) + with tx_failed(): + c.delegate(a0, transact={"from": a4}) c.delegate(a5, transact={"from": a4}) # can't detect length 6 cycle, so this works: @@ -150,7 +152,7 @@ def test_block_short_cycle(w3, c, assert_tx_failed): # but this is something the frontend should prevent for user friendliness -def test_delegate(w3, c, assert_tx_failed): +def test_delegate(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -167,9 +169,11 @@ def test_delegate(w3, c, assert_tx_failed): # Delegate's weight is 2 assert c.voters(a0)[0] == 2 # Voter.weight # Voter cannot delegate twice - assert_tx_failed(lambda: c.delegate(a2, transact={"from": a1})) + with tx_failed(): + c.delegate(a2, transact={"from": a1}) # Voter cannot delegate to themselves - assert_tx_failed(lambda: c.delegate(a2, transact={"from": a2})) + with tx_failed(): + c.delegate(a2, transact={"from": a2}) # Voter CAN delegate to someone who hasn't been granted right to vote # Exercise: prevent that c.delegate(a6, transact={"from": a2}) @@ -180,7 +184,7 @@ def test_delegate(w3, c, assert_tx_failed): assert c.voters(a0)[0] == 3 # Voter.weight -def test_vote(w3, c, assert_tx_failed): +def test_vote(w3, c, tx_failed): a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = w3.eth.accounts[:10] c.giveRightToVote(a0, transact={}) c.giveRightToVote(a1, transact={}) @@ -197,9 +201,11 @@ def test_vote(w3, c, assert_tx_failed): # Vote count changes based on voters weight assert c.proposals(0)[1] == 3 # Proposal.voteCount # Voter cannot vote twice - assert_tx_failed(lambda: c.vote(0)) + with tx_failed(): + c.vote(0) # Voter cannot vote if they've delegated - assert_tx_failed(lambda: c.vote(0, transact={"from": a1})) + with tx_failed(): + c.vote(0, transact={"from": a1}) # Several voters can vote c.vote(1, transact={"from": a4}) c.vote(1, transact={"from": a2}) @@ -207,7 +213,8 @@ def test_vote(w3, c, assert_tx_failed): c.vote(1, transact={"from": a6}) assert c.proposals(1)[1] == 4 # Proposal.voteCount # Can't vote on a non-proposal - assert_tx_failed(lambda: c.vote(2, transact={"from": a7})) + with tx_failed(): + c.vote(2, transact={"from": a7}) def test_winning_proposal(w3, c): diff --git a/tests/functional/examples/wallet/test_wallet.py b/tests/functional/examples/wallet/test_wallet.py index 71f1e5f331..b9db5acee3 100644 --- a/tests/functional/examples/wallet/test_wallet.py +++ b/tests/functional/examples/wallet/test_wallet.py @@ -29,7 +29,7 @@ def _sign(seq, to, value, data, key): return _sign -def test_approve(w3, c, tester, assert_tx_failed, sign): +def test_approve(w3, c, tester, tx_failed, sign): a0, a1, a2, a3, a4, a5, a6 = w3.eth.accounts[:7] k0, k1, k2, k3, k4, k5, k6, k7 = tester.backend.account_keys[:8] @@ -45,24 +45,20 @@ def pack_and_sign(seq, *args): c.approve(0, "0x" + to.hex(), value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if only 2 signatures are given sigs = pack_and_sign(1, k1, 0, k3, 0, 0) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if an invalid signature is given sigs = pack_and_sign(1, k1, 0, k7, 0, k5) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if transaction number is incorrect (the first argument should be 1) sigs = pack_and_sign(0, k1, 0, k3, 0, k5) - assert_tx_failed( - lambda: c.approve(0, to_address, value, data, sigs, transact={"value": value, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(0, to_address, value, data, sigs, transact={"value": value, "from": a1}) # Approve fails if not enough value is sent sigs = pack_and_sign(1, k1, 0, k3, 0, k5) - assert_tx_failed( - lambda: c.approve(1, to_address, value, data, sigs, transact={"value": 0, "from": a1}) - ) # noqa: E501 + with tx_failed(): + c.approve(1, to_address, value, data, sigs, transact={"value": 0, "from": a1}) sigs = pack_and_sign(1, k1, 0, k3, 0, k5) # this call should succeed diff --git a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py b/tests/unit/ast/nodes/test_evaluate_binop_decimal.py index 5c9956caba..44b82e321d 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_decimal.py +++ b/tests/unit/ast/nodes/test_evaluate_binop_decimal.py @@ -20,7 +20,7 @@ @example(left=Decimal("0.9999999999"), right=Decimal("0.9999999999")) @example(left=Decimal("0.0000000001"), right=Decimal("0.0000000001")) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_decimal(get_contract, assert_tx_failed, op, left, right): +def test_binop_decimal(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: decimal, b: decimal) -> decimal: @@ -39,7 +39,8 @@ def foo(a: decimal, b: decimal) -> decimal: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) def test_binop_pow(): @@ -57,7 +58,7 @@ def test_binop_pow(): values=st.lists(st_decimals, min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), ) -def test_nested(get_contract, assert_tx_failed, values, ops): +def test_nested(get_contract, tx_failed, values, ops): variables = "abcdefghij" input_value = ",".join(f"{i}: decimal" for i in variables[: len(values)]) return_value = " ".join(f"{a} {b}" for a, b in zip(variables[: len(values)], ops)) @@ -83,4 +84,5 @@ def foo({input_value}) -> decimal: if is_valid: assert contract.foo(*values) == expected else: - assert_tx_failed(lambda: contract.foo(*values)) + with tx_failed(): + contract.foo(*values) diff --git a/tests/unit/ast/nodes/test_evaluate_binop_int.py b/tests/unit/ast/nodes/test_evaluate_binop_int.py index 80c9381c0f..405d557f7d 100644 --- a/tests/unit/ast/nodes/test_evaluate_binop_int.py +++ b/tests/unit/ast/nodes/test_evaluate_binop_int.py @@ -16,7 +16,7 @@ @example(left=-1, right=1) @example(left=-1, right=-1) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_int128(get_contract, assert_tx_failed, op, left, right): +def test_binop_int128(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: int128, b: int128) -> int128: @@ -35,7 +35,8 @@ def foo(a: int128, b: int128) -> int128: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) st_uint64 = st.integers(min_value=0, max_value=2**64) @@ -45,7 +46,7 @@ def foo(a: int128, b: int128) -> int128: @settings(max_examples=50) @given(left=st_uint64, right=st_uint64) @pytest.mark.parametrize("op", "+-*/%") -def test_binop_uint256(get_contract, assert_tx_failed, op, left, right): +def test_binop_uint256(get_contract, tx_failed, op, left, right): source = f""" @external def foo(a: uint256, b: uint256) -> uint256: @@ -64,7 +65,8 @@ def foo(a: uint256, b: uint256) -> uint256: if is_valid: assert contract.foo(left, right) == new_node.value else: - assert_tx_failed(lambda: contract.foo(left, right)) + with tx_failed(): + contract.foo(left, right) @pytest.mark.xfail(reason="need to implement safe exponentiation logic") @@ -94,7 +96,7 @@ def foo(a: uint256, b: uint256) -> uint256: values=st.lists(st.integers(min_value=-256, max_value=256), min_size=2, max_size=10), ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11), ) -def test_binop_nested(get_contract, assert_tx_failed, values, ops): +def test_binop_nested(get_contract, tx_failed, values, ops): variables = "abcdefghij" input_value = ",".join(f"{i}: int128" for i in variables[: len(values)]) return_value = " ".join(f"{a} {b}" for a, b in zip(variables[: len(values)], ops)) @@ -122,4 +124,5 @@ def foo({input_value}) -> int128: if is_valid: assert contract.foo(*values) == expected else: - assert_tx_failed(lambda: contract.foo(*values)) + with tx_failed(): + contract.foo(*values) From 7489e34581bf013d2ba04cf28d847a443a255a34 Mon Sep 17 00:00:00 2001 From: Daniel Schiavini Date: Sun, 24 Dec 2023 17:10:45 +0100 Subject: [PATCH 19/27] feat: allow `range(x, y, bound=N)` (#3679) - allow range where both start and end arguments are variables, so long as a bound is supplied - ban range expressions of the form `range(x, x + N)` since the new form is cleaner and supersedes it. - also do a bit of refactoring of the codegen for range --------- Co-authored-by: Charles Cooper --- docs/control-structures.rst | 8 +- .../features/iteration/test_for_in_list.py | 19 +- .../features/iteration/test_for_range.py | 116 ++++++++++- .../codegen/integration/test_crowdfund.py | 4 +- .../test_invalid_literal_exception.py | 7 - tests/functional/syntax/test_for_range.py | 197 +++++++++++++++++- vyper/codegen/ir_node.py | 8 +- vyper/codegen/stmt.py | 67 +++--- vyper/exceptions.py | 2 +- vyper/semantics/analysis/local.py | 109 ++++------ 10 files changed, 390 insertions(+), 147 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index 873135709a..2f890bcb2f 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -287,9 +287,11 @@ Another use of range can be with ``START`` and ``STOP`` bounds. Here, ``START`` and ``STOP`` are literal integers, with ``STOP`` being a greater value than ``START``. ``i`` begins as ``START`` and increments by one until it is equal to ``STOP``. +Finally, it is possible to use ``range`` with runtime `start` and `stop` values as long as a constant `bound` value is provided. +In this case, Vyper checks at runtime that `end - start <= bound`. +``N`` must be a compile-time constant. + .. code-block:: python - for i in range(a, a + N): + for i in range(start, end, bound=N): ... - -``a`` is a variable with an integer type and ``N`` is a literal integer greater than zero. ``i`` begins as ``a`` and increments by one until it is equal to ``a + N``. If ``a + N`` would overflow, execution will revert. diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index fb01cc98eb..bc1a12ae9e 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -1,3 +1,4 @@ +import re from decimal import Decimal import pytest @@ -700,13 +701,16 @@ def foo(): """, StateAccessViolation, ), - """ + ( + """ @external def foo(): a: int128 = 6 for i in range(a,a-3): pass """, + StateAccessViolation, + ), # invalid argument length ( """ @@ -789,10 +793,13 @@ def test_for() -> int128: ), ] +BAD_CODE = [code if isinstance(code, tuple) else (code, StructureException) for code in BAD_CODE] +for_code_regex = re.compile(r"for .+ in (.*):") +bad_code_names = [ + f"{i} {for_code_regex.search(code).group(1)}" for i, (code, _) in enumerate(BAD_CODE) +] + -@pytest.mark.parametrize("code", BAD_CODE) -def test_bad_code(assert_compile_failed, get_contract, code): - err = StructureException - if not isinstance(code, str): - code, err = code +@pytest.mark.parametrize("code,err", BAD_CODE, ids=bad_code_names) +def test_bad_code(assert_compile_failed, get_contract, code, err): assert_compile_failed(lambda: get_contract(code), err) diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index 96b83ae691..e946447285 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -32,6 +32,102 @@ def repeat(n: uint256) -> uint256: c.repeat(7) +def test_range_bound_constant_end(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(n, 7, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 5): + assert c.repeat(n) == sum(i + 1 for i in range(n, 7)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(8) + # check assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0) + + +def test_range_bound_two_args(get_contract, tx_failed): + code = """ +@external +def repeat(n: uint256) -> uint256: + x: uint256 = 0 + for i in range(1, n, bound=6): + x += i + 1 + return x + """ + c = get_contract(code) + for n in range(1, 8): + assert c.repeat(n) == sum(i + 1 for i in range(1, n)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(0) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(8) + + +def test_range_bound_two_runtime_args(get_contract, tx_failed): + code = """ +@external +def repeat(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x += i + return x + """ + c = get_contract(code) + for n in range(0, 7): + assert c.repeat(0, n) == sum(range(0, n)) + assert c.repeat(n, n * 2) == sum(range(n, n * 2)) + + # check assertion for `start <= end` + with tx_failed(): + c.repeat(1, 0) + with tx_failed(): + c.repeat(7, 0) + with tx_failed(): + c.repeat(8, 7) + + # check codegen inserts assertion for `start + bound <= end` + with tx_failed(): + c.repeat(0, 7) + with tx_failed(): + c.repeat(14, 21) + + +def test_range_overflow(get_contract, tx_failed): + code = """ +@external +def get_last(start: uint256, end: uint256) -> uint256: + x: uint256 = 0 + for i in range(start, end, bound=6): + x = i + return x + """ + c = get_contract(code) + UINT_MAX = 2**256 - 1 + assert c.get_last(UINT_MAX, UINT_MAX) == 0 # initial value of x + + for n in range(1, 6): + assert c.get_last(UINT_MAX - n, UINT_MAX) == UINT_MAX - 1 + + # check for `start + bound <= end`, overflow cases + for n in range(1, 7): + with tx_failed(): + c.get_last(UINT_MAX - n, 0) + with tx_failed(): + c.get_last(UINT_MAX, UINT_MAX - n) + + def test_digit_reverser(get_contract_with_gas_estimation): digit_reverser = """ @external @@ -89,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101): + for i in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -146,26 +242,28 @@ def foo(a: {typ}) -> {typ}: assert c.foo(100) == 31337 -# test that we can get to the upper range of an integer @pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"]) def test_for_range_edge(get_contract, typ): + """ + Check that we can get to the upper range of an integer. + Note that to avoid overflow in the bounds check for range(), + we need to calculate i+1 inside the loop. + """ code = f""" @external def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x, x + 1): - if i == max_value({typ}): + for i in range(x - 1, x, bound=1): + if i + 1 == max_value({typ}): found = True - assert found found = False x = max_value({typ}) - 1 - for i in range(x, x + 2): - if i == max_value({typ}): + for i in range(x - 1, x + 1, bound=2): + if i + 1 == max_value({typ}): found = True - assert found """ c = get_contract(code) @@ -178,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x+2): + for i in range(x, x + 2, bound=2): pass """ c = get_contract(code) diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 2083e62610..671d424d60 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30): + for i in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py index 1f4f112252..a0cf10ad02 100644 --- a/tests/functional/syntax/exceptions/test_invalid_literal_exception.py +++ b/tests/functional/syntax/exceptions/test_invalid_literal_exception.py @@ -18,13 +18,6 @@ def foo(): """, """ @external -def foo(x: int128): - y: int128 = 7 - for i in range(x, x + y): - pass - """, - """ -@external def foo(): x: String[100] = "these bytes are nо gооd because the o's are from the Russian alphabet" """, diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index e6f35c1d2d..7c7f9c476d 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -1,7 +1,9 @@ +import re + import pytest from vyper import compiler -from vyper.exceptions import StructureException +from vyper.exceptions import ArgumentException, StateAccessViolation, StructureException fail_list = [ ( @@ -12,33 +14,191 @@ def foo(): pass """, StructureException, + "Invalid syntax for loop iterator", + "a[1]", + ), + ( + """ +@external +def foo(): + x: uint256 = 100 + for _ in range(10, bound=x): + pass + """, + StateAccessViolation, + "Bound must be a literal", + "x", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=5): + pass + """, + StructureException, + "Please remove the `bound=` kwarg when using range with constants", + "5", + ), + ( + """ +@external +def foo(): + for _ in range(10, 20, bound=0): + pass + """, + StructureException, + "Bound must be at least 1", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x,x+1,bound=2,extra=3): + pass + """, + ArgumentException, + "Invalid keyword argument 'extra'", + "extra=3", ), ( """ @external def bar(): - for i in range(1,2,bound=2): + for i in range(0): pass """, StructureException, + "End must be greater than start", + "0", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(0, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(0, n * 10): + pass + return n + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n * 10", ), ( """ @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2): + for i in range(0, x + 1): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x + 1", + ), + ( + """ +@external +def bar(): + for i in range(2, 1): pass """, StructureException, + "End must be greater than start", + "1", + ), + ( + """ +@external +def bar(): + x:uint256 = 1 + for i in range(x, x): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, x + 10): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", + ), + ( + """ +@external +def repeat(n: uint256) -> uint256: + for i in range(n, 6): + pass + return x + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "n", + ), + ( + """ +@external +def foo(x: int128): + y: int128 = 7 + for i in range(x, x + y): + pass + """, + StateAccessViolation, + "Value must be a literal integer, unless a bound is specified", + "x", ), ] +for_code_regex = re.compile(r"for .+ in (.*):") +fail_test_names = [ + ( + f"{i:02d}: {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + f" raises {type(err).__name__}" + ) + for i, (code, err, msg, src) in enumerate(fail_list) +] -@pytest.mark.parametrize("bad_code", fail_list) -def test_range_fail(bad_code): - with pytest.raises(bad_code[1]): - compiler.compile_code(bad_code[0]) + +@pytest.mark.parametrize("bad_code,error_type,message,source_code", fail_list, ids=fail_test_names) +def test_range_fail(bad_code, error_type, message, source_code): + with pytest.raises(error_type) as exc_info: + compiler.compile_code(bad_code) + assert message == exc_info.value.message + assert source_code == exc_info.value.args[1].node_source_code valid_list = [ @@ -58,7 +218,21 @@ def foo(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i in range(1, x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(x, bound=4): + pass + """, + """ +@external +def foo(): + x: int128 = 5 + for i in range(0, x, bound=4): pass """, """ @@ -72,7 +246,12 @@ def kick_foos(): """, ] +valid_test_names = [ + f"{i} {for_code_regex.search(code).group(1)}" # type: ignore[union-attr] + for i, code in enumerate(valid_list) +] + -@pytest.mark.parametrize("good_code", valid_list) +@pytest.mark.parametrize("good_code", valid_list, ids=valid_test_names) def test_range_success(good_code): assert compiler.compile_code(good_code) is not None diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py index ce26066968..45d93f3067 100644 --- a/vyper/codegen/ir_node.py +++ b/vyper/codegen/ir_node.py @@ -444,11 +444,15 @@ def unique_symbols(self): return ret @property - def is_literal(self): + def is_literal(self) -> bool: return isinstance(self.value, int) or self.value == "multi" + def int_value(self) -> int: + assert isinstance(self.value, int) + return self.value + @property - def is_pointer(self): + def is_pointer(self) -> bool: # not used yet but should help refactor/clarify downstream code # eventually return self.location is not None diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 601597771c..18e5c3d494 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -225,15 +225,6 @@ def parse_Raise(self): else: return IRnode.from_list(["revert", 0, 0], error_msg="user raise") - def _check_valid_range_constant(self, arg_ast_node): - with self.context.range_scope(): - arg_expr = Expr.parse_value_expr(arg_ast_node, self.context) - return arg_expr - - def _get_range_const_value(self, arg_ast_node): - arg_expr = self._check_valid_range_constant(arg_ast_node) - return arg_expr.value - def parse_For(self): with self.context.block_scope(): if self.stmt.get("iter.func.id") == "range": @@ -249,41 +240,37 @@ def _parse_For_range(self): iter_typ = INT256_T # Get arg0 - arg0 = self.stmt.iter.args[0] - num_of_args = len(self.stmt.iter.args) - - kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) - for s in self.stmt.iter.keywords or [] - } - - # Type 1 for, e.g. for i in range(10): ... - if num_of_args == 1: - n = Expr.parse_value_expr(arg0, self.context) - start = IRnode.from_list(0, typ=iter_typ) - rounds = n - rounds_bound = kwargs.get("bound", rounds) - - # Type 2 for, e.g. for i in range(100, 110): ... - elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal: - arg0_val = self._get_range_const_value(arg0) - arg1_val = self._get_range_const_value(self.stmt.iter.args[1]) - start = IRnode.from_list(arg0_val, typ=iter_typ) - rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ) - rounds_bound = rounds + for_iter: vy_ast.Call = self.stmt.iter + args_len = len(for_iter.args) + if args_len == 1: + arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + elif args_len == 2: + arg0, arg1 = for_iter.args + else: # pragma: nocover + raise TypeCheckFailure("unreachable: bad # of arguments to range()") - # Type 3 for, e.g. for i in range(x, x + 10): ... - else: - arg1 = self.stmt.iter.args[1] - rounds = self._get_range_const_value(arg1.right) + with self.context.range_scope(): start = Expr.parse_value_expr(arg0, self.context) - _, hi = start.typ.int_bounds - start = clamp("le", start, hi + 1 - rounds) + end = Expr.parse_value_expr(arg1, self.context) + kwargs = { + s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + } + + if "bound" in kwargs: + with end.cache_when_complex("end") as (b1, end): + # note: the check for rounds<=rounds_bound happens in asm + # generation for `repeat`. + clamped_start = clamp("le", start, end) + rounds = b1.resolve(IRnode.from_list(["sub", end, clamped_start])) + rounds_bound = kwargs.pop("bound").int_value() + else: + rounds = end.int_value() - start.int_value() rounds_bound = rounds - bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value - if bound < 1: - return + assert len(kwargs) == 0 # sanity check stray keywords + + if rounds_bound < 1: # pragma: nocover + raise TypeCheckFailure("unreachable: unchecked 0 bound") varname = self.stmt.target.id i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 0c549ec10f..0d4ebdd40b 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -41,7 +41,7 @@ def __init__(self, message="Error Message not found.", *items): Error message to display with the exception. *items : VyperNode | Tuple[str, VyperNode], optional Vyper ast node(s), or tuple of (description, node) indicating where - the exception occured. Source annotations are generated in the order + the exception occurred. Source annotations are generated in the order the nodes are given. A single tuple of (lineno, col_offset) is also understood to support diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 8f6103a217..bdf6680024 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -7,7 +7,6 @@ ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidLiteral, InvalidOperation, InvalidType, IteratorException, @@ -356,71 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - range_ = node.iter - validate_call_args(range_, (1, 2), kwargs=["bound"]) - - args = range_.args - kwargs = {s.arg: s.value for s in range_.keywords or []} - if len(args) == 1: - # range(CONSTANT) - n = args[0] - bound = kwargs.pop("bound", None) - validate_expected_type(n, IntegerT.any()) - - if bound is None: - if not isinstance(n, vy_ast.Num): - raise StateAccessViolation("Value must be a literal", n) - if n.value <= 0: - raise StructureException("For loop must have at least 1 iteration", args[0]) - type_list = get_possible_types_from_node(n) - - else: - if not isinstance(bound, vy_ast.Num): - raise StateAccessViolation("bound must be a literal", bound) - if bound.value <= 0: - raise StructureException("bound must be at least 1", args[0]) - type_list = get_common_types(n, bound) - - else: - if range_.keywords: - raise StructureException( - "Keyword arguments are not supported for `range(N, M)` and" - "`range(x, x + N)` expressions", - range_.keywords[0], - ) - - validate_expected_type(args[0], IntegerT.any()) - type_list = get_common_types(*args) - if not isinstance(args[0], vy_ast.Constant): - # range(x, x + CONSTANT) - if not isinstance(args[1], vy_ast.BinOp) or not isinstance( - args[1].op, vy_ast.Add - ): - raise StructureException( - "Second element must be the first element plus a literal value", args[0] - ) - if not vy_ast.compare_nodes(args[0], args[1].left): - raise StructureException( - "First and second variable must be the same", args[1].left - ) - if not isinstance(args[1].right, vy_ast.Int): - raise InvalidLiteral("Literal must be an integer", args[1].right) - if args[1].right.value < 1: - raise StructureException( - f"For loop has invalid number of iterations ({args[1].right.value})," - " the value must be greater than zero", - args[1].right, - ) - else: - # range(CONSTANT, CONSTANT) - if not isinstance(args[1], vy_ast.Int): - raise InvalidType("Value must be a literal integer", args[1]) - validate_expected_type(args[1], IntegerT.any()) - if args[0].value >= args[1].value: - raise StructureException("Second value must be > first value", args[1]) - - if not type_list: - raise TypeMismatch("Iterator values are of different types", node.iter) + type_list = _analyse_range_call(node.iter) else: # iteration over a variable or literal list @@ -491,8 +426,8 @@ def visit_For(self, node): try: with NodeMetadata.enter_typechecker_speculation(): - for n in node.body: - self.visit(n) + for stmt in node.body: + self.visit(stmt) except (TypeMismatch, InvalidOperation) as exc: for_loop_exceptions.append(exc) else: @@ -809,3 +744,41 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.body, typ) validate_expected_type(node.orelse, typ) self.visit(node.orelse, typ) + + +def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: + """ + Check that the arguments to a range() call are valid. + :param node: call to range() + :return: None + """ + validate_call_args(node, (1, 2), kwargs=["bound"]) + kwargs = {s.arg: s.value for s in node.keywords or []} + start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args + + all_args = (start, end, *kwargs.values()) + for arg1 in all_args: + validate_expected_type(arg1, IntegerT.any()) + + type_list = get_common_types(*all_args) + if not type_list: + raise TypeMismatch("Iterator values are of different types", node) + + if "bound" in kwargs: + bound = kwargs["bound"] + if not isinstance(bound, vy_ast.Num): + raise StateAccessViolation("Bound must be a literal", bound) + if bound.value <= 0: + raise StructureException("Bound must be at least 1", bound) + if isinstance(start, vy_ast.Num) and isinstance(end, vy_ast.Num): + error = "Please remove the `bound=` kwarg when using range with constants" + raise StructureException(error, bound) + else: + for arg in (start, end): + if not isinstance(arg, vy_ast.Num): + error = "Value must be a literal integer, unless a bound is specified" + raise StateAccessViolation(error, arg) + if end.value <= start.value: + raise StructureException("End must be greater than start", end) + + return type_list From 1040f3e7b27d037b68343746b9cfeb97034d9423 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 25 Dec 2023 09:21:38 -0500 Subject: [PATCH 20/27] feat: improve panics in IR generation (#3708) * feat: improve panics in IR generation this QOL commit improves on 91659266c55a by passing through the `__traceback__` field when the exception is modified (instead of using `__cause__` - cf. PEP-3134 regarding the difference) and improves error messages when an IRnode is not returned properly. using `__traceback__` generally results in a better experience because the immediate cause of the exception is displayed when running `vyper -v` instead of needing to scroll up through the exception chain (if the exception chain is reproduced correctly at all in the first place). --------- Co-authored-by: Harry Kalogirou --- vyper/codegen/expr.py | 30 ++++++++++++++---------------- vyper/codegen/stmt.py | 14 ++++++-------- vyper/exceptions.py | 15 +++++++++------ 3 files changed, 29 insertions(+), 30 deletions(-) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index c46f8cec1b..51807d80ba 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -74,19 +74,17 @@ def __init__(self, node, context): self.context = context if isinstance(node, IRnode): - # TODO this seems bad + # this is a kludge for parse_AugAssign to pass in IRnodes + # directly. + # TODO fixme! self.ir_node = node return - fn = getattr(self, f"parse_{type(node).__name__}", None) - if fn is None: - raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}", node) - - with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" + with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + fn = getattr(self, fn_name) self.ir_node = fn() - - if self.ir_node is None: - raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.\n", node) + assert isinstance(self.ir_node, IRnode), self.ir_node self.ir_node.annotation = self.expr.get("node_source_code") self.ir_node.source_pos = getpos(self.expr) @@ -375,9 +373,9 @@ def parse_Subscript(self): index = self.expr.slice.value.n # note: this check should also happen in get_element_ptr if not 0 <= index < len(sub.typ.member_types): - return + raise TypeCheckFailure("unreachable") else: - return + raise TypeCheckFailure("unreachable") ir_node = get_element_ptr(sub, index) ir_node.mutable = sub.mutable @@ -412,13 +410,13 @@ def parse_BinOp(self): new_typ = left.typ if new_typ.bits != 256: # TODO implement me. ["and", 2**bits - 1, shl(right, left)] - return + raise TypeCheckFailure("unreachable") return IRnode.from_list(shl(right, left), typ=new_typ) if isinstance(self.expr.op, vy_ast.RShift): new_typ = left.typ if new_typ.bits != 256: # TODO implement me. promote_signed_int(op(right, left), bits) - return + raise TypeCheckFailure("unreachable") op = shr if not left.typ.is_signed else sar return IRnode.from_list(op(right, left), typ=new_typ) @@ -461,7 +459,7 @@ def build_in_comparator(self): elif isinstance(self.expr.op, vy_ast.NotIn): found, not_found = 0, 1 else: # pragma: no cover - return + raise TypeCheckFailure("unreachable") i = IRnode.from_list(self.context.fresh_varname("in_ix"), typ=UINT256_T) @@ -523,7 +521,7 @@ def parse_Compare(self): right = Expr.parse_value_expr(self.expr.right, self.context) if right.value is None: - return + raise TypeCheckFailure("unreachable") if isinstance(self.expr.op, (vy_ast.In, vy_ast.NotIn)): if is_array_like(right.typ): @@ -575,7 +573,7 @@ def parse_Compare(self): elif left.typ._is_prim_word and right.typ._is_prim_word: if op not in ("eq", "ne"): - return + raise TypeCheckFailure("unreachable") else: # kludge to block behavior in #2638 # TODO actually implement equality for complex types diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index 18e5c3d494..bc29a79734 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -40,16 +40,14 @@ class Stmt: def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: self.stmt = node self.context = context - fn = getattr(self, f"parse_{type(node).__name__}", None) - if fn is None: - raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}") - with context.internal_memory_scope(): - with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" + with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + fn = getattr(self, fn_name) + with context.internal_memory_scope(): self.ir_node = fn() - if self.ir_node is None: - raise TypeCheckFailure("Statement node did not produce IR") + assert isinstance(self.ir_node, IRnode), self.ir_node self.ir_node.annotation = self.stmt.get("node_source_code") self.ir_node.source_pos = getpos(self.stmt) @@ -347,7 +345,7 @@ def parse_AugAssign(self): # because of this check, we do not need to check for # make_setter references lhs<->rhs as in parse_Assign - # single word load/stores are atomic. - return + raise TypeCheckFailure("unreachable") with target.cache_when_complex("_loc") as (b, target): rhs = Expr.parse_value_expr( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 0d4ebdd40b..f216069eab 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -359,14 +359,17 @@ class InvalidABIType(VyperInternalException): @contextlib.contextmanager -def tag_exceptions( - node, fallback_exception_type=CompilerPanic, fallback_message="unhandled exception" -): +def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): try: yield except _BaseVyperException as e: if not e.annotations and not e.lineno: - raise e.with_annotation(node) - raise e + tb = e.__traceback__ + raise e.with_annotation(node).with_traceback(tb) + raise e from None except Exception as e: - raise fallback_exception_type(fallback_message, node) from e + tb = e.__traceback__ + fallback_message = "unhandled exception" + if note: + fallback_message += f", {note}" + raise fallback_exception_type(fallback_message, node).with_traceback(tb) From 977851a8a809d4c14f150517fb84baa20c81e910 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Wed, 27 Dec 2023 09:49:08 -0500 Subject: [PATCH 21/27] add special visibility for the __init__ function --- vyper/semantics/analysis/base.py | 13 ++++++++++--- vyper/semantics/types/function.py | 29 ++++++++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 4def7e6d2e..8ac630b12c 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -68,9 +68,16 @@ def __ge__(self, other: object) -> bool: class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + CONSTRUCTOR = enum.auto() + + @classmethod + def is_valid_value(cls, value: str) -> bool: + # make CONSTRUCTOR visibility not available to the user + # (although as a design note - maybe `@constructor` should + # indeed be available) + return super().is_valid_value(value) and value != "constructor" class StateMutability(_StringEnum): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 969cfc4ac3..55f2365dc4 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -291,6 +291,8 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "function body in an interface can only be ...!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -334,13 +336,15 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): + raise FunctionDeclarationException( + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility is not None: raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@internal` or `@external`", funcdef ) + function_visibility = FunctionVisibility.CONSTRUCTOR if return_type is not None: raise FunctionDeclarationException( "Constructor may not have a return type", funcdef.returns @@ -352,6 +356,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -475,6 +482,10 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_constructor(self) -> bool: + return self.visibility == FunctionVisibility.CONSTRUCTOR + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -483,10 +494,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -620,7 +627,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -675,7 +682,7 @@ def _parse_decorators( else: raise StructureException("Bad decorator syntax", decorator) - if function_visibility is None: + if function_visibility is None and funcdef.name != "__init__": raise FunctionDeclarationException( f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", funcdef ) From 8af611bb7b4ac07d9834ef950015e761679961d2 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Thu, 28 Dec 2023 12:49:58 -0500 Subject: [PATCH 22/27] remove unused MemoryOffset, CalldataOffset classes --- vyper/semantics/analysis/base.py | 48 ++++++-------------------------- 1 file changed, 8 insertions(+), 40 deletions(-) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 8ac630b12c..c2f79318d8 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -104,56 +104,24 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # specifying a state mutability modifier at all. Do the same here. +@dataclass class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset - - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" + offset: int class StorageSlot(DataPosition): - _location = DataLocation.STORAGE - def __init__(self, position): - self.position = position + @property + def _location(self): + return DataLocation.STORAGE - def __repr__(self): - return f"" class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" + @property + def _location(self): + return DataLocation.CODE # base class for things that are the "result" of analysis class AnalysisResult: From c241e91fe5f8b4e9433c6afa2b6207e50a13d307 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 29 Dec 2023 11:17:37 -0500 Subject: [PATCH 23/27] wip - allow init functions to be called from init func --- vyper/codegen/context.py | 2 +- vyper/codegen/core.py | 8 +- vyper/codegen/expr.py | 6 +- vyper/codegen/function_definitions/common.py | 12 +- vyper/codegen/module.py | 18 ++- vyper/codegen/stmt.py | 6 +- vyper/exceptions.py | 2 +- vyper/semantics/analysis/base.py | 76 +++++++---- vyper/semantics/analysis/data_positions.py | 135 +------------------ vyper/semantics/analysis/module.py | 2 +- vyper/semantics/data_locations.py | 16 ++- vyper/semantics/types/base.py | 15 +++ vyper/semantics/types/function.py | 2 +- vyper/semantics/types/module.py | 3 + 14 files changed, 113 insertions(+), 190 deletions(-) diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 334d988dd4..0ee3a0baef 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -108,7 +108,7 @@ def self_ptr(self, location): if location == STORAGE: return IRnode.from_list("self_ptr_storage", typ=module_t, location=location) if location == IMMUTABLES: - return IRnode.from_list("self_ptr_code", typ=module_t, location=location) + return IRnode.from_list("self_ptr_immutables", typ=module_t, location=location) def is_constant(self): return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index f705e9deed..10968d6a40 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -471,13 +471,9 @@ def _get_element_ptr_module(parent, key): ofst = 0 # offset from parent start - assert parent.location == STORAGE, parent.location + assert parent.location in (STORAGE, IMMUTABLES, DATA), parent.location - for i in range(index): - ofst += module_t.variables[attrs[i]].typ.storage_slots_required - - # calculated the same way both ways - assert ofst == module_t.variables[key].position.position + ofst = module_t.offset_of(key) return IRnode.from_list( add_ofst(parent, ofst), diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 51807d80ba..31e89762b9 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -80,8 +80,8 @@ def __init__(self, node, context): self.ir_node = node return - fn_name = f"parse_{type(node).__name__}" - with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" fn = getattr(self, fn_name) self.ir_node = fn() assert isinstance(self.ir_node, IRnode), self.ir_node @@ -705,7 +705,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index f8283c8539..1d2248d9aa 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -65,7 +65,13 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context + suffix = "_deploy" if is_ctor_context else "_runtime" return self.ir_identifier + suffix @@ -140,7 +146,7 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal: + if func_t.is_internal or func_t.is_constructor: ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: @@ -163,7 +169,7 @@ def generate_ir_for_function( else: assert frame_info == func_t._ir_info.frame_info - if not func_t.is_internal: + if func_t.is_external: # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 8ce8a262f9..d397203986 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -431,7 +431,8 @@ def generate_ir_for_module(compilation_target: ModuleT) -> tuple[IRnode, IRnode] reachable = _globally_reachable_functions(compilation_target.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + + init_function = next((f for f in compilation_target.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -477,18 +478,15 @@ def generate_ir_for_module(compilation_target: ModuleT) -> tuple[IRnode, IRnode] deploy_code: List[Any] = ["seq"] immutables_len = compilation_target.immutable_bytes_required - if init_function: + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, compilation_target, is_ctor_context=True) + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + + func_ir = _ir_for_internal_function(fn_ast, compilation_target, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..34d44f64b8 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -41,8 +41,8 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: self.stmt = node self.context = context - fn_name = f"parse_{type(node).__name__}" - with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" fn = getattr(self, fn_name) with context.internal_memory_scope(): self.ir_node = fn() @@ -145,7 +145,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..6afa829d24 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -369,7 +369,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index c2f79318d8..0a619c49bc 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, ClassVar from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle @@ -46,14 +46,14 @@ def values(cls) -> List[str]: # Comparison operations def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") + raise CompilerPanic("bad comparison") return self is other # Python normally does __ne__(other) ==> not self.__eq__(other) def __lt__(self, other: object) -> bool: if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") + raise CompilerPanic("bad comparison") options = self.__class__.options() return options.index(self) < options.index(other) # type: ignore @@ -104,25 +104,6 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability": # specifying a state mutability modifier at all. Do the same here. -@dataclass -class DataPosition: - offset: int - - -class StorageSlot(DataPosition): - - @property - def _location(self): - return DataLocation.STORAGE - - - -class CodeOffset(DataPosition): - - @property - def _location(self): - return DataLocation.CODE - # base class for things that are the "result" of analysis class AnalysisResult: pass @@ -168,15 +149,56 @@ def __hash__(self): def __post_init__(self): self._reads = [] self._writes = [] - self.position = None # the location provided by the allocator + self._position = None # the location provided by the allocator - def set_position(self, position: DataPosition) -> None: - if self.position is not None: - raise CompilerPanic("Position was already assigned") + def set_position_in(self, position: DataPosition) -> None: + assert self.position is None if self.location != position._location: raise CompilerPanic(f"Incompatible locations: {self.location}, {position._location}") - self.position = position + self._position = position + + def get_position(self) -> int: + return self._position + + def get_size_in(self, location) -> int: + """ + Get the amount of space this variable occupies in a given location + """ + if location == self.location: + return self.typ.size_in_location(location) + return 0 + +class ModuleInfo(VarInfo): + """ + A special VarInfo for modules + """ + def __post_init__(self): + super.__post_init__() + assert isinstance(self.typ, ModuleT) + + self.code_offset = None + self.storage_offset = None + + def set_code_offset(ofst): + assert self.code_offset is None + self.code_offset = ofst + + def set_storage_offset(ofst): + assert self.storage_offset is None + self.storage_offset = ofst + + def get_position(self): + raise CompilerPanic("use get_offset_in for ModuleInfo!") + + def get_offset_in(self, location): + if location == DataLocation.STORAGE: + return self.storage_offset + if location == DataLocation.CODE: + return self.code_offset + raise CompilerPanic("unreachable") # pragma: nocover + def get_size_in(self, location): + return self.typ.size_in_location(location) @dataclass class ExprInfo: diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 291d9b7af2..97fe7e58bd 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -8,9 +8,7 @@ from vyper.utils import ceil32 -def set_data_positions( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout = None -) -> StorageLayout: +def allocate_variables(vyper_module: vy_ast.Module) -> StorageLayout: """ Parse the annotated Vyper AST, determine data positions for all variables, and annotate the AST nodes with the position data. @@ -20,123 +18,12 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + code_offsets = _set_code_offsets(vyper_module) + storage_slots = _set_storage_slots(vyper_module) return {"storage_layout": storage_slots, "code_layout": code_offsets} -class StorageAllocator: - """ - Keep track of which storage slots have been used. If there is a collision of - storage slots, this will raise an error and fail to compile - """ - - def __init__(self): - self.occupied_slots: Dict[int, str] = {} - - def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: - """ - Reserves `n_slots` storage slots, starting at slot `first_slot` - This will raise an error if a storage slot has already been allocated. - It is responsibility of calling function to ensure first_slot is an int - """ - list_to_check = [x + first_slot for x in range(n_slots)] - self._reserve_slots(list_to_check, var_name) - - def _reserve_slots(self, slots: List[int], var_name: str) -> None: - for slot in slots: - self._reserve_slot(slot, var_name) - - def _reserve_slot(self, slot: int, var_name: str) -> None: - if slot < 0 or slot >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, out of bounds: {slot}" - ) - if slot in self.occupied_slots: - collided_var = self.occupied_slots[slot] - raise StorageLayoutException( - f"Storage collision! Tried to assign '{var_name}' to slot {slot} but it has " - f"already been reserved by '{collided_var}'" - ) - self.occupied_slots[slot] = var_name - - -def set_storage_slots_with_overrides( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout -) -> StorageLayout: - """ - Parse module-level Vyper AST to calculate the layout of storage variables. - Returns the layout as a dict of variable name -> variable info - """ - - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() - - # Search through function definitions to find non-reentrant functions - for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] - - # Ignore functions without non-reentrant - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - - # re-entrant key was already identified - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) - continue - - # Expect to find this variable within the storage layout override - if variable_name in storage_layout_overrides: - reentrant_slot = storage_layout_overrides[variable_name]["slot"] - # Ensure that this slot has not been used, and prevents other storage variables - # from using the same slot - reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) - - ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {variable_name}. " - "Have you used the correct storage layout file?", - node, - ) - - # Iterate through variables - for varinfo in vyper_module._metadata["type"].variables.values(): - if varinfo.is_immutable: - continue - - assert isinstance((node := varinfo.decl_node), vy_ast.VariableDecl) - - # Expect to find this variable within the storage layout overrides - if node.target.id in storage_layout_overrides: - var_slot = storage_layout_overrides[node.target.id]["slot"] - storage_length = varinfo.typ.storage_slots_required - # Ensure that all required storage slots are reserved, and prevents other variables - # from using these slots - reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) - - # TODO: FIXME! - ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {node.target.id}. " - "Have you used the correct storage layout file?", - node, - ) - - return ret - class SimpleStorageAllocator: def __init__(self, starting_slot: int = 0): @@ -153,7 +40,7 @@ def allocate_slot(self, n, var_name): return ret -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: +def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info @@ -218,24 +105,16 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: return ret -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass - - -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass - - def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: ret = {} offset = 0 for varinfo in vyper_module._metadata["type"].variables.values(): - if not varinfo.is_immutable: - continue - type_ = varinfo.typ + if not varinfo.is_immutable and not isinstance(type_, ModuleT): + continue + len_ = ceil32(type_.immutable_bytes_required) varinfo.set_position(CodeOffset(offset)) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index a9bd3a3c6c..2d2447b8b5 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -206,7 +206,7 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and (call_t.is_internal or call_t.is_constructor): fn_t.called_functions.add(call_t) for func in function_defs: diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 07e8435686..294c03077b 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -2,9 +2,13 @@ class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - # TRANSIENT = 5 + # TODO: rename me to something like VarLocation, or StorageRegion + """ + Possible locations for variables in vyper + """ + UNSET = enum.auto() # like constants and stack variables + MEMORY = enum.auto() # local variables + STORAGE = enum.auto() # storage variables + CALLDATA = enum.auto() # arguments to external functions + IMMUTABLE = enum.auto() # immutable variables + TRANSIENT = enum.auto() # transient storage variables diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 8a2c24c259..4043be5292 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -123,6 +123,21 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + # return the size in bytes or slots that this type + # needs to allocate in the provided location + def size_in_location(self, location): + if location in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in {location}!") + + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.CODE: + return self.immutable_bytes_required + if location == DataLocation.STORAGE: + return self.storage_slots_required + + raise CompilerPanic("invalid location: {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: if DataLocation.MEMORY in self._invalid_locations: diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 55f2365dc4..d9fa4d6e3e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -127,7 +127,7 @@ def touches_location(self, location): @property def touched_locations(self): - # return the DataLocations of touched module variables + # return the DataLocation of touched module variables ret = [] possible_locations = [DataLocation.STORAGE, DataLocation.CODE] for location in possible_locations: diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index 318a6393f6..fbb54806d0 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -310,6 +310,9 @@ def __repr__(self): else: return f"module {self._id} (loaded from '{self._module.resolved_path}')" + def offset_of(self, attr: str, location: DataLocation): + pass + def get_type_member(self, attr: str, node: vy_ast.VyperNode) -> VyperType: return self._helper.get_type_member(attr, node) From 3de7bb2f4f281073e0ae6868871edbb88cc4a1d1 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 29 Dec 2023 13:00:27 -0500 Subject: [PATCH 24/27] rename DataLocations.CODE to IMMUTABLES and wip on ModuleVarInfo --- vyper/codegen/core.py | 3 ++- vyper/semantics/analysis/base.py | 15 ++++++++++++--- vyper/semantics/analysis/module.py | 2 +- vyper/semantics/types/base.py | 4 ++-- vyper/semantics/types/function.py | 2 +- vyper/semantics/types/module.py | 2 +- vyper/semantics/types/subscriptable.py | 2 +- 7 files changed, 20 insertions(+), 10 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 10968d6a40..d77d1c7042 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -71,7 +71,8 @@ def data_location_to_addr_space(s: DataLocation): return STORAGE if s == DataLocation.MEMORY: return MEMORY - if s == DataLocation.CODE: + if s == DataLocation.IMMUTABLES: + # note: this is confusing in ctor context! return IMMUTABLES raise CompilerPanic("unreachable") # pragma: nocover diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 0a619c49bc..7fe37346e8 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -135,7 +135,7 @@ class VarInfo: """ typ: VyperType - location: DataLocation = DataLocation.UNSET + _location: DataLocation = DataLocation.UNSET is_constant: bool = False is_public: bool = False is_immutable: bool = False @@ -146,6 +146,10 @@ class VarInfo: def __hash__(self): return hash(id(self)) + @property + def location(self): + return self._location + def __post_init__(self): self._reads = [] self._writes = [] @@ -168,7 +172,7 @@ def get_size_in(self, location) -> int: return self.typ.size_in_location(location) return 0 -class ModuleInfo(VarInfo): +class ModuleVarInfo(VarInfo): """ A special VarInfo for modules """ @@ -179,6 +183,11 @@ def __post_init__(self): self.code_offset = None self.storage_offset = None + @property + def location(self): + # location does not make sense for module vars + raise CompilerPanic("unreachable") + def set_code_offset(ofst): assert self.code_offset is None self.code_offset = ofst @@ -193,7 +202,7 @@ def get_position(self): def get_offset_in(self, location): if location == DataLocation.STORAGE: return self.storage_offset - if location == DataLocation.CODE: + if location == DataLocation.IMMUTABLES: return self.code_offset raise CompilerPanic("unreachable") # pragma: nocover diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 2d2447b8b5..d405830f0b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -264,7 +264,7 @@ def visit_VariableDecl(self, node): raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) data_loc = ( - DataLocation.CODE + DataLocation.IMMUTABLES if node.is_immutable else DataLocation.UNSET if node.is_constant diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 4043be5292..75134245f7 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -131,7 +131,7 @@ def size_in_location(self, location): if location == DataLocation.MEMORY: return self.memory_bytes_required - if location == DataLocation.CODE: + if location == DataLocation.IMMUTABLES: return self.immutable_bytes_required if location == DataLocation.STORAGE: return self.storage_slots_required @@ -168,7 +168,7 @@ def immutable_bytes_required(self) -> int: in the immutables section """ # sanity check the type can actually be instantiated as an immutable - if DataLocation.CODE in self._invalid_locations: + if DataLocation.IMMUTABLES in self._invalid_locations: raise CompilerPanic(f"{self} cannot be an immutable!") return self._size_in_bytes diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index d9fa4d6e3e..1d5d822e7e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -129,7 +129,7 @@ def touches_location(self, location): def touched_locations(self): # return the DataLocation of touched module variables ret = [] - possible_locations = [DataLocation.STORAGE, DataLocation.CODE] + possible_locations = (DataLocation.STORAGE, DataLocation.IMMUTABLES) for location in possible_locations: if self.touches_location(location): ret.append(location) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index fbb54806d0..a1d985e271 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -260,7 +260,7 @@ class ModuleT(VyperType): _invalid_locations = ( DataLocation.UNSET, DataLocation.CALLDATA, - DataLocation.CODE, + DataLocation.IMMUTABLES, DataLocation.MEMORY, ) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index a6ec335b5e..9a272be309 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -49,7 +49,7 @@ class HashMapT(_SubscriptableT): _invalid_locations = ( DataLocation.UNSET, DataLocation.CALLDATA, - DataLocation.CODE, + DataLocation.IMMUTABLES, DataLocation.MEMORY, ) From ee91a5227eb08dcec8938853c153f214121e2bb7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 29 Dec 2023 18:22:50 -0500 Subject: [PATCH 25/27] refactor set_data_positions and rework VarInfo positions API - split VarInfo positions API from get_position/set_position to get_storage_position/set_storage_position and ditto for immutables - only one of these is valid for regular VarInfos, but modules get allocated in both storage and immutables. - rewrite set_code_offsets to share allocator code - rename set_data_positions to allocate_module_variables --- vyper/codegen/core.py | 3 +- vyper/semantics/__init__.py | 1 - vyper/semantics/analysis/base.py | 50 +++++++++++++--- vyper/semantics/analysis/data_positions.py | 68 +++++++++++++--------- vyper/semantics/analysis/local.py | 2 +- vyper/semantics/analysis/module.py | 33 +++++------ vyper/semantics/data_locations.py | 12 ++-- 7 files changed, 106 insertions(+), 63 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index d77d1c7042..5a91bc6d53 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -466,8 +466,7 @@ def _get_element_ptr_module(parent, key): assert isinstance(key, str) typ = module_t.variables[key].typ - attrs = list(module_t.variables.keys()) - index = attrs.index(key) + # attrs = list(module_t.variables.keys()) annotation = key ofst = 0 # offset from parent start diff --git a/vyper/semantics/__init__.py b/vyper/semantics/__init__.py index bb40c266a4..48a51d3917 100644 --- a/vyper/semantics/__init__.py +++ b/vyper/semantics/__init__.py @@ -1,2 +1 @@ from .analysis import validate_semantics -from .analysis.data_positions import set_data_positions diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 7fe37346e8..201a1ef303 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union, ClassVar +from typing import TYPE_CHECKING, Dict, List, Optional, Union from vyper import ast as vy_ast from vyper.compiler.input_bundle import InputBundle @@ -122,6 +122,27 @@ def __eq__(self, other): return self is other +@dataclass +class DataPosition: + offset: int + + @property + def location(self): + raise CompilerPanic("unreachable!") + + +class StorageSlot(DataPosition): + @property + def location(self): + return DataLocation.STORAGE + + +class CodeOffset(DataPosition): + @property + def location(self): + return DataLocation.IMMUTABLES + + @dataclass class VarInfo: """ @@ -155,14 +176,26 @@ def __post_init__(self): self._writes = [] self._position = None # the location provided by the allocator - def set_position_in(self, position: DataPosition) -> None: - assert self.position is None - if self.location != position._location: + def _set_position_in(self, position: DataPosition) -> None: + assert self._position is None + if self.location != position.location: raise CompilerPanic(f"Incompatible locations: {self.location}, {position._location}") self._position = position + def set_storage_position(self, position: DataPosition): + assert self.location == DataLocation.STORAGE + self._set_position_in(position) + + def set_immutables_position(self, position: DataPosition): + assert self.location == DataLocation.IMMUTABLES + self._set_position_in(position) + def get_position(self) -> int: - return self._position + return self._position.offset + + def get_offset_in(self, location): + assert location == self.location + return self._position.offset def get_size_in(self, location) -> int: """ @@ -172,10 +205,12 @@ def get_size_in(self, location) -> int: return self.typ.size_in_location(location) return 0 + class ModuleVarInfo(VarInfo): """ A special VarInfo for modules """ + def __post_init__(self): super.__post_init__() assert isinstance(self.typ, ModuleT) @@ -188,11 +223,11 @@ def location(self): # location does not make sense for module vars raise CompilerPanic("unreachable") - def set_code_offset(ofst): + def set_immutables_position(self, ofst): assert self.code_offset is None self.code_offset = ofst - def set_storage_offset(ofst): + def set_storage_position(self, ofst): assert self.storage_offset is None self.storage_offset = ofst @@ -209,6 +244,7 @@ def get_offset_in(self, location): def get_size_in(self, location): return self.typ.size_in_location(location) + @dataclass class ExprInfo: """ diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 97fe7e58bd..975667856c 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,11 +1,7 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List - from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import CodeOffset, ModuleVarInfo, StorageSlot from vyper.typing import StorageLayout -from vyper.utils import ceil32 def allocate_variables(vyper_module: vy_ast.Module) -> StorageLayout: @@ -24,14 +20,15 @@ def allocate_variables(vyper_module: vy_ast.Module) -> StorageLayout: return {"storage_layout": storage_slots, "code_layout": code_offsets} +class SimpleAllocator: + _max_slots = None -class SimpleStorageAllocator: def __init__(self, starting_slot: int = 0): self._slot = starting_slot - def allocate_slot(self, n, var_name): + def allocate(self, n, var_name=""): ret = self._slot - if self._slot + n >= 2**256: + if self._slot + n >= self._max_slots: raise StorageLayoutException( f"Invalid storage slot for var {var_name}, tried to allocate" f" slots {self._slot} through {self._slot + n}" @@ -40,6 +37,14 @@ def allocate_slot(self, n, var_name): return ret +class SimpleStorageAllocator(SimpleAllocator): + _max_slots = 2**256 + + +class SimpleImmutablesAllocator(SimpleAllocator): + _max_slots = 0x6000 # eip-170 + + def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. @@ -49,7 +54,7 @@ def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: # note storage is word-addressable, not byte-addressable allocator = SimpleStorageAllocator() - ret: Dict[str, Dict] = {} + ret: dict[str, dict] = {} for funcdef in vyper_module.get_children(vy_ast.FunctionDef): type_ = funcdef._metadata["func_type"] @@ -69,7 +74,7 @@ def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, keyname) + slot = allocator.allocate(1, keyname) type_.set_reentrancy_key_position(StorageSlot(slot)) @@ -84,7 +89,8 @@ def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: type_ = varinfo.typ - assert isinstance((vardecl := varinfo.decl_node), vy_ast.VariableDecl) + vardecl = varinfo.decl_node + assert isinstance(vardecl, vy_ast.VariableDecl) varname = vardecl.target.id @@ -93,9 +99,9 @@ def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: # for HashMaps because downstream code might use the slot # ID as a salt. n_slots = type_.storage_slots_required - slot = allocator.allocate_slot(n_slots, varname) + slot = allocator.allocate(n_slots, varname) - varinfo.set_position(StorageSlot(slot)) + varinfo.set_storage_position(StorageSlot(slot)) assert varname not in ret # this could have better typing but leave it untyped until @@ -105,34 +111,38 @@ def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: return ret -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: +def _set_code_offsets(vyper_module: vy_ast.Module) -> dict[str, dict]: ret = {} - offset = 0 + allocator = SimpleImmutablesAllocator() for varinfo in vyper_module._metadata["type"].variables.values(): type_ = varinfo.typ - if not varinfo.is_immutable and not isinstance(type_, ModuleT): + if not varinfo.is_immutable and not isinstance(varinfo, ModuleVarInfo): continue - len_ = ceil32(type_.immutable_bytes_required) + len_ = type_.immutable_bytes_required + + # sanity check. there are ways to construct varinfo with no + # decl_node but they shouldn't make it to here + vardecl = varinfo.decl_node + assert isinstance(vardecl, vy_ast.VariableDecl) + varname = vardecl.target.id + + if len_ % 32 != 0: + # sanity check length is a multiple of 32, it's an invariant + # that is used a lot in downstream code. + raise CompilerPanic("bad invariant") - varinfo.set_position(CodeOffset(offset)) + offset = allocator.allocate(len_, varname) + varinfo.set_immutables_position(CodeOffset(offset)) # this could have better typing but leave it untyped until # we understand the use case better output_dict = {"type": str(type_), "offset": offset, "length": len_} # put it into the storage layout - - # sanity check. there are ways to construct varinfo with no - # decl_node but they shouldn't make it here - assert isinstance(varinfo.decl_node, vy_ast.VariableDecl) - name = varinfo.decl_node.target.id - - assert name not in ret - ret[name] = output_dict - - offset += len_ + assert varname not in ret + ret[varname] = output_dict return ret diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index bdf6680024..1cded2d4fb 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -196,7 +196,7 @@ def analyze(self): ) for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, is_immutable=is_immutable + arg.typ, _location=location, is_immutable=is_immutable ) for node in self.fn_node.body: diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index d405830f0b..7d5c9023e1 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -20,9 +20,9 @@ VariableDeclarationException, VyperException, ) -from vyper.semantics.analysis.base import ImportInfo, VarInfo +from vyper.semantics.analysis.base import ImportInfo, ModuleVarInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase -from vyper.semantics.analysis.data_positions import set_data_positions +from vyper.semantics.analysis.data_positions import allocate_variables from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import validate_functions from vyper.semantics.analysis.utils import ( @@ -39,16 +39,8 @@ # TODO: rename to `analyze_vyper` -def validate_semantics( - module_ast, input_bundle, storage_layout_overrides=None, is_interface=False -) -> ModuleT: - return validate_semantics_r( - module_ast, - input_bundle, - ImportGraph(), - is_interface, - storage_layout_overrides=storage_layout_overrides, - ) +def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: + return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) def validate_semantics_r( @@ -56,7 +48,6 @@ def validate_semantics_r( input_bundle: InputBundle, import_graph: ImportGraph, is_interface: bool, - storage_layout_overrides: Any = None, ) -> ModuleT: """ Analyze a Vyper module AST node, add all module-level objects to the @@ -76,7 +67,7 @@ def validate_semantics_r( if not is_interface: validate_functions(module_ast) - layout = set_data_positions(module_ast, storage_layout_overrides) + layout = allocate_variables(module_ast) module_ast._metadata["variables_layout"] = layout return ret @@ -206,7 +197,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and (call_t.is_internal or call_t.is_constructor): + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -279,15 +272,21 @@ def visit_VariableDecl(self, node): if node.is_transient and not version_check(begin="cancun"): raise StructureException("`transient` is not available pre-cancun", node.annotation) - var_info = VarInfo( + if isinstance(type_, ModuleT): + var_type = ModuleVarInfo + else: + var_type = VarInfo + + var_info = var_type( type_, decl_node=node, - location=data_loc, is_constant=node.is_constant, is_public=node.is_public, is_immutable=node.is_immutable, is_transient=node.is_transient, + _location=data_loc, ) + node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index 294c03077b..39a3fa346a 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -6,9 +6,9 @@ class DataLocation(enum.Enum): """ Possible locations for variables in vyper """ - UNSET = enum.auto() # like constants and stack variables - MEMORY = enum.auto() # local variables - STORAGE = enum.auto() # storage variables - CALLDATA = enum.auto() # arguments to external functions - IMMUTABLE = enum.auto() # immutable variables - TRANSIENT = enum.auto() # transient storage variables + UNSET = enum.auto() # like constants and stack variables + MEMORY = enum.auto() # local variables + STORAGE = enum.auto() # storage variables + CALLDATA = enum.auto() # arguments to external functions + IMMUTABLES = enum.auto() # immutable variables + TRANSIENT = enum.auto() # transient storage variables From 5e08300a36f57675be6f9225b0ea6f66a1f7b42b Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 29 Dec 2023 18:40:44 -0500 Subject: [PATCH 26/27] thread new offset through codegen --- vyper/codegen/core.py | 28 +++++++++++++----- vyper/codegen/function_definitions/common.py | 11 ++++--- vyper/semantics/analysis/base.py | 31 ++++++++++++-------- vyper/semantics/analysis/data_positions.py | 2 +- vyper/semantics/types/function.py | 2 +- 5 files changed, 48 insertions(+), 26 deletions(-) diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index 5a91bc6d53..f106f9228f 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -4,7 +4,15 @@ from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch from vyper.semantics.data_locations import DataLocation @@ -78,6 +86,15 @@ def data_location_to_addr_space(s: DataLocation): raise CompilerPanic("unreachable") # pragma: nocover +def addr_space_to_data_location(s: AddrSpace): + if s == STORAGE: + return DataLocation.STORAGE + if s in (IMMUTABLES, DATA): + return DataLocation.IMMUTABLES + + raise CompilerPanic("unreachable") # pragma: nocover + + def get_type_for_exact_size(n_bytes): """Create a type which will take up exactly n_bytes. Used for allocating internal buffers. @@ -465,19 +482,16 @@ def _get_element_ptr_module(parent, key): assert isinstance(module_t, ModuleT) assert isinstance(key, str) - typ = module_t.variables[key].typ - # attrs = list(module_t.variables.keys()) + varinfo = module_t.variables[key] annotation = key - ofst = 0 # offset from parent start - assert parent.location in (STORAGE, IMMUTABLES, DATA), parent.location - ofst = module_t.offset_of(key) + ofst = varinfo.get_offset_in(addr_space_to_data_location(parent.location)) return IRnode.from_list( add_ofst(parent, ofst), - typ=typ, + typ=varinfo.typ, location=parent.location, encoding=parent.encoding, annotation=annotation, diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 1d2248d9aa..4b3e665bf5 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -117,6 +117,7 @@ def generate_ir_for_function( - Function body """ func_t = code._metadata["func_type"] + module_t = code._parent._metadata["type"] # type: ignore # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -146,7 +147,9 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal or func_t.is_constructor: + is_internal_init = func_t.is_constructor and compilation_target != module_t + + if func_t.is_internal or is_internal_init: ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: @@ -169,7 +172,9 @@ def generate_ir_for_function( else: assert frame_info == func_t._ir_info.frame_info - if func_t.is_external: + if func_t.is_internal or is_internal_init: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since @@ -177,7 +182,5 @@ def generate_ir_for_function( ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore - else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index 201a1ef303..b2285af2ef 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -179,7 +179,7 @@ def __post_init__(self): def _set_position_in(self, position: DataPosition) -> None: assert self._position is None if self.location != position.location: - raise CompilerPanic(f"Incompatible locations: {self.location}, {position._location}") + raise CompilerPanic(f"Incompatible locations: {self.location}, {position.location}") self._position = position def set_storage_position(self, position: DataPosition): @@ -212,33 +212,38 @@ class ModuleVarInfo(VarInfo): """ def __post_init__(self): - super.__post_init__() + super().__post_init__() + # hmm + from vyper.semantics.types.module import ModuleT + assert isinstance(self.typ, ModuleT) - self.code_offset = None - self.storage_offset = None + self._immutables_offset = None + self._storage_offset = None @property def location(self): - # location does not make sense for module vars - raise CompilerPanic("unreachable") + # location does not make sense for module vars but make the API work + return DataLocation.STORAGE def set_immutables_position(self, ofst): - assert self.code_offset is None - self.code_offset = ofst + assert self._immutables_offset is None + assert ofst.location == DataLocation.IMMUTABLES + self._immutables_offset = ofst def set_storage_position(self, ofst): - assert self.storage_offset is None - self.storage_offset = ofst + assert self._storage_offset is None + assert ofst.location == DataLocation.STORAGE + self._storage_offset = ofst def get_position(self): - raise CompilerPanic("use get_offset_in for ModuleInfo!") + raise CompilerPanic("use get_offset_in for ModuleVarInfo!") def get_offset_in(self, location): if location == DataLocation.STORAGE: - return self.storage_offset + return self._storage_offset.offset if location == DataLocation.IMMUTABLES: - return self.code_offset + return self._immutables_offset.offset raise CompilerPanic("unreachable") # pragma: nocover def get_size_in(self, location): diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 975667856c..5acc209321 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -21,7 +21,7 @@ def allocate_variables(vyper_module: vy_ast.Module) -> StorageLayout: class SimpleAllocator: - _max_slots = None + _max_slots: int = None # type: ignore def __init__(self, starting_slot: int = 0): self._slot = starting_slot diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 1d5d822e7e..87968bfcdc 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -376,7 +376,7 @@ def set_reentrancy_key_position(self, position: StorageSlot) -> None: if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: + if position.location != DataLocation.STORAGE: raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position From 1e393fa4f97788ff0bfd95be554dabf465ad82f7 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Fri, 29 Dec 2023 18:50:39 -0500 Subject: [PATCH 27/27] mark storage layout override tests as xfail --- tests/unit/cli/storage_layout/test_storage_layout_overrides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..29fbb384e3 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -4,6 +4,7 @@ from vyper.exceptions import StorageLayoutException +@pytest.mark.xfail(reason="storage layout overrides disabled") def test_storage_layout_overrides(): code = """ a: uint256