diff --git a/tests/functional/codegen/types/test_node_types.py b/tests/functional/codegen/types/test_node_types.py index 8a2b1681d7..abc23068ee 100644 --- a/tests/functional/codegen/types/test_node_types.py +++ b/tests/functional/codegen/types/test_node_types.py @@ -12,6 +12,7 @@ ) # TODO: this module should be merged in with other tests/functional/semantics/types/ tests. +# and moved to tests/unit/! def test_bytearray_node_type(): @@ -51,17 +52,17 @@ def test_canonicalize_type(): def test_type_storage_sizes(): - assert IntegerT(True, 128).storage_size_in_words == 1 - assert BytesT(12).storage_size_in_words == 2 - assert BytesT(33).storage_size_in_words == 3 - assert SArrayT(IntegerT(True, 128), 10).storage_size_in_words == 10 + assert IntegerT(True, 128).storage_slots_required == 1 + assert BytesT(12).storage_slots_required == 2 + assert BytesT(33).storage_slots_required == 3 + assert SArrayT(IntegerT(True, 128), 10).storage_slots_required == 10 tuple_ = TupleT([IntegerT(True, 128), DecimalT()]) - assert tuple_.storage_size_in_words == 2 + assert tuple_.storage_slots_required == 2 struct_ = StructT("Foo", {"a": IntegerT(True, 128), "b": DecimalT()}) - assert struct_.storage_size_in_words == 2 + assert struct_.storage_slots_required == 2 # Don't allow unknown types. with raises(AttributeError): - _ = int.storage_size_in_words + _ = int.storage_slots_required diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..29fbb384e3 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -4,6 +4,7 @@ from vyper.exceptions import StorageLayoutException +@pytest.mark.xfail(reason="storage layout overrides disabled") def test_storage_layout_overrides(): code = """ a: uint256 diff --git a/tests/unit/semantics/types/test_size_in_bytes.py b/tests/unit/semantics/types/test_size_in_bytes.py index 69250fdfdf..244c52b3e0 100644 --- a/tests/unit/semantics/types/test_size_in_bytes.py +++ b/tests/unit/semantics/types/test_size_in_bytes.py @@ -11,7 +11,7 @@ def test_base_types(build_node, type_str): node = build_node(type_str) type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == 32 + assert type_definition._size_in_bytes == 32 @pytest.mark.parametrize("type_str", BYTESTRING_TYPES) @@ -20,7 +20,7 @@ def test_array_value_types(build_node, type_str, length, size): node = build_node(f"{type_str}[{length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == size + assert type_definition._size_in_bytes == size @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -29,7 +29,7 @@ def test_dynamic_array_lengths(build_node, type_str, length): node = build_node(f"DynArray[{type_str}, {length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == 32 + length * 32 + assert type_definition._size_in_bytes == 32 + length * 32 @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -38,7 +38,7 @@ def test_base_types_as_arrays(build_node, type_str, length): node = build_node(f"{type_str}[{length}]") type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == length * 32 + assert type_definition._size_in_bytes == length * 32 @pytest.mark.parametrize("type_str", BASE_TYPES) @@ -49,4 +49,4 @@ def test_base_types_as_multidimensional_arrays(build_node, type_str, first, seco type_definition = type_from_annotation(node) - assert type_definition.size_in_bytes == first * second * 32 + assert type_definition._size_in_bytes == first * second * 32 diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index dea30faabc..0ee3a0baef 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -3,10 +3,11 @@ from dataclasses import dataclass from typing import Any, Optional -from vyper.codegen.ir_node import Encoding -from vyper.evm.address_space import MEMORY, AddrSpace +import vyper.ast as vy_ast +from vyper.codegen.ir_node import Encoding, IRnode +from vyper.evm.address_space import IMMUTABLES, MEMORY, STORAGE, AddrSpace from vyper.exceptions import CompilerPanic, StateAccessViolation -from vyper.semantics.types import VyperType +from vyper.semantics.types import ModuleT, VyperType class Constancy(enum.Enum): @@ -48,7 +49,7 @@ def __repr__(self): class Context: def __init__( self, - module_ctx, + compilation_target, memory_allocator, vars_=None, forvars=None, @@ -59,9 +60,6 @@ def __init__( # In-memory variables, in the form (name, memory location, type) self.vars = vars_ or {} - # Global variables, in the form (name, storage location, type) - self.globals = module_ctx.variables - # Variables defined in for loops, e.g. for i in range(6): ... self.forvars = forvars or {} @@ -75,8 +73,8 @@ def __init__( # Whether we are currently parsing a range expression self.in_range_expr = False - # store module context - self.module_ctx = module_ctx + # the type information for the current compilation target + self.compilation_target: ModuleT = compilation_target # full function type self.func_t = func_t @@ -94,6 +92,24 @@ def __init__( # either the constructor, or called from the constructor self.is_ctor_context = is_ctor_context + def self_ptr(self, location): + func_module = self.func_t.ast_def._parent + assert isinstance(func_module, vy_ast.Module) + + module_t = func_module._metadata["type"] + module_is_compilation_target = module_t == self.compilation_target + + if module_is_compilation_target: + # return 0 for the special case where compilation target is self + return IRnode.from_list(0, typ=module_t, location=location) + + # otherwise, the function compilation context takes a `self_ptr` + # argument in the calling convention + if location == STORAGE: + return IRnode.from_list("self_ptr_storage", typ=module_t, location=location) + if location == IMMUTABLES: + return IRnode.from_list("self_ptr_immutables", typ=module_t, location=location) + def is_constant(self): return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py index c16de3c55a..f106f9228f 100644 --- a/vyper/codegen/core.py +++ b/vyper/codegen/core.py @@ -4,9 +4,18 @@ from vyper import ast as vy_ast from vyper.codegen.ir_node import Encoding, IRnode from vyper.compiler.settings import OptimizationLevel -from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT +from vyper.evm.address_space import ( + CALLDATA, + DATA, + IMMUTABLES, + MEMORY, + STORAGE, + TRANSIENT, + AddrSpace, +) from vyper.evm.opcodes import version_check from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch +from vyper.semantics.data_locations import DataLocation from vyper.semantics.types import ( AddressT, BoolT, @@ -17,6 +26,7 @@ HashMapT, IntegerT, InterfaceT, + ModuleT, StructT, TupleT, _BytestringT, @@ -64,6 +74,27 @@ def is_array_like(typ): return ret +def data_location_to_addr_space(s: DataLocation): + if s == DataLocation.STORAGE: + return STORAGE + if s == DataLocation.MEMORY: + return MEMORY + if s == DataLocation.IMMUTABLES: + # note: this is confusing in ctor context! + return IMMUTABLES + + raise CompilerPanic("unreachable") # pragma: nocover + + +def addr_space_to_data_location(s: AddrSpace): + if s == STORAGE: + return DataLocation.STORAGE + if s in (IMMUTABLES, DATA): + return DataLocation.IMMUTABLES + + raise CompilerPanic("unreachable") # pragma: nocover + + def get_type_for_exact_size(n_bytes): """Create a type which will take up exactly n_bytes. Used for allocating internal buffers. @@ -442,6 +473,31 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp=True): ) +# get a variable out of a module +def _get_element_ptr_module(parent, key): + # note that this implementation is substantially similar to + # the StructT pathway through get_element_ptr_tuplelike and + # has potential to be refactored. + module_t = parent.typ + assert isinstance(module_t, ModuleT) + + assert isinstance(key, str) + varinfo = module_t.variables[key] + annotation = key + + assert parent.location in (STORAGE, IMMUTABLES, DATA), parent.location + + ofst = varinfo.get_offset_in(addr_space_to_data_location(parent.location)) + + return IRnode.from_list( + add_ofst(parent, ofst), + typ=varinfo.typ, + location=parent.location, + encoding=parent.encoding, + annotation=annotation, + ) + + # TODO simplify this code, especially the ABI decoding def _get_element_ptr_tuplelike(parent, key): typ = parent.typ @@ -485,7 +541,7 @@ def _get_element_ptr_tuplelike(parent, key): if parent.location.word_addressable: for i in range(index): - ofst += typ.member_types[attrs[i]].storage_size_in_words + ofst += typ.member_types[attrs[i]].storage_slots_required elif parent.location.byte_addressable: for i in range(index): ofst += typ.member_types[attrs[i]].memory_bytes_required @@ -552,7 +608,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check): return _getelemptr_abi_helper(parent, subtype, ofst) if parent.location.word_addressable: - element_size = subtype.storage_size_in_words + element_size = subtype.storage_slots_required elif parent.location.byte_addressable: element_size = subtype.memory_bytes_required else: @@ -590,6 +646,9 @@ def get_element_ptr(parent, key, array_bounds_check=True): if is_tuple_like(typ): ret = _get_element_ptr_tuplelike(parent, key) + elif isinstance(typ, ModuleT): + ret = _get_element_ptr_module(parent, key) + elif isinstance(typ, HashMapT): ret = _get_element_ptr_mapping(parent, key) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 577660b883..dbef5ae474 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -36,6 +36,7 @@ VyperException, tag_exceptions, ) +from vyper.semantics.analysis.base import VarInfo from vyper.semantics.types import ( AddressT, BoolT, @@ -77,14 +78,18 @@ def __init__(self, node, context): return assert isinstance(node, vy_ast.VyperNode) + + # keep the original ast node for exception-handling purposes + og_node = node + if node.has_folded_value: node = node.get_folded_value() self.expr = node self.context = context - fn_name = f"parse_{type(node).__name__}" - with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + with tag_exceptions(og_node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" fn = getattr(self, fn_name) self.ir_node = fn() assert isinstance(self.ir_node, IRnode), self.ir_node @@ -171,7 +176,8 @@ def parse_NameConstant(self): # Variable names def parse_Name(self): if self.expr.id == "self": - return IRnode.from_list(["address"], typ=AddressT()) + # TODO: have `self` return a module type + return IRnode.from_list(["self"], typ=AddressT()) elif self.expr.id in self.context.vars: var = self.context.vars[self.expr.id] ret = IRnode.from_list( @@ -185,8 +191,7 @@ def parse_Name(self): ret._referenced_variables = {var} return ret - # TODO: use self.expr._expr_info - elif self.expr.id in self.context.globals: + elif (varinfo := self.expr._metadata.get("variable_access")) is not None: varinfo = self.context.globals[self.expr.id] if varinfo.is_constant: @@ -197,8 +202,6 @@ def parse_Name(self): assert varinfo.is_immutable, "not an immutable!" - ofst = varinfo.position.offset - if self.context.is_ctor_context: mutable = True location = IMMUTABLES @@ -206,10 +209,16 @@ def parse_Name(self): mutable = False location = DATA - ret = IRnode.from_list( - ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable - ) + module_ptr = self.context.self_ptr(location) + + ret = get_element_ptr(module_ptr, self.expr.id) + + assert ret.typ == varinfo.typ + assert ret.location == location + + ret.mutable = mutable ret._referenced_variables = {varinfo} + return ret # x.y or x[5] @@ -228,7 +237,7 @@ def parse_Attribute(self): return IRnode.from_list(value, typ=typ) # x.balance: balance of address x - if self.expr.attr == "balance": + elif self.expr.attr == "balance": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): if ( @@ -240,6 +249,7 @@ def parse_Attribute(self): else: seq = ["balance", addr] return IRnode.from_list(seq, typ=UINT256_T) + # x.codesize: codesize of address x elif self.expr.attr == "codesize" or self.expr.attr == "is_contract": addr = Expr.parse_value_expr(self.expr.value, self.context) @@ -251,14 +261,17 @@ def parse_Attribute(self): eval_code = ["extcodesize", addr] output_type = UINT256_T else: + assert self.expr.attr == "is_contract" eval_code = ["gt", ["extcodesize", addr], 0] output_type = BoolT() return IRnode.from_list(eval_code, typ=output_type) + # x.codehash: keccak of address x elif self.expr.attr == "codehash": addr = Expr.parse_value_expr(self.expr.value, self.context) if addr.typ == AddressT(): return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T) + # x.code: codecopy/extcodecopy of address x elif self.expr.attr == "code": addr = Expr.parse_value_expr(self.expr.value, self.context) @@ -267,24 +280,12 @@ def parse_Attribute(self): if addr.value == "address": # for `self.code` return IRnode.from_list(["~selfcode"], typ=BytesT(0)) return IRnode.from_list(["~extcode", addr], typ=BytesT(0)) - # self.x: global attribute - elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self": - varinfo = self.context.globals[self.expr.attr] - location = TRANSIENT if varinfo.is_transient else STORAGE - - ret = IRnode.from_list( - varinfo.position.position, - typ=varinfo.typ, - location=location, - annotation="self." + self.expr.attr, - ) - ret._referenced_variables = {varinfo} - - return ret # Reserved keywords elif ( - isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES + # TODO: use type information here + isinstance(self.expr.value, vy_ast.Name) + and self.expr.value.id in ENVIRONMENT_VARIABLES ): key = f"{self.expr.value.id}.{self.expr.attr}" if key == "msg.sender": @@ -336,17 +337,34 @@ def parse_Attribute(self): "chain.id is unavailable prior to istanbul ruleset", self.expr ) return IRnode.from_list(["chainid"], typ=UINT256_T) - # Other variables - else: - sub = Expr(self.expr.value, self.context).ir_node - # contract type - if isinstance(sub.typ, InterfaceT): - # MyInterface.address - assert self.expr.attr == "address" - sub.typ = typ - return sub - if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: - return get_element_ptr(sub, self.expr.attr) + + # self.x: global storage variable or immutable + if (varinfo := self.expr._metadata.get("variable_access")) is not None: + assert isinstance(varinfo, VarInfo) + + # TODO: handle immutables + location = TRANSIENT if varinfo.is_transient else STORAGE + + module_ptr = Expr(self.expr.value, self.context).ir_node + if module_ptr.value == "self": + module_ptr = self.context.self_ptr(location) + + ret = get_element_ptr(module_ptr, self.expr.attr) + # TODO: take referenced variables info from analysis + ret._referenced_variables = {varinfo} + return ret + + # if we have gotten here, it's an instance of an interface or struct + sub = Expr(self.expr.value, self.context).ir_node + + if isinstance(sub.typ, InterfaceT): + # MyInterface.address + assert self.expr.attr == "address" + sub.typ = typ + return sub + + if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types: + return get_element_ptr(sub, self.expr.attr) def parse_Subscript(self): sub = Expr(self.expr.value, self.context).ir_node @@ -703,7 +721,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=True) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.expr, self.context) else: return external_call.ir_for_external_call(self.expr, self.context) diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 454ba9c8cd..4b3e665bf5 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -65,7 +65,13 @@ def external_function_base_entry_label(self) -> str: return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: - assert self.func_t.is_internal, "uh oh, should be internal" + f = self.func_t + assert f.is_internal or f.is_constructor, "uh oh, should be internal" + + if f.is_constructor: + # sanity check - imported init functions only callable from main init + assert is_ctor_context + suffix = "_deploy" if is_ctor_context else "_runtime" return self.ir_identifier + suffix @@ -101,7 +107,7 @@ class InternalFuncIR(FuncIR): # TODO: should split this into external and internal ir generation? def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False + code: vy_ast.FunctionDef, compilation_target: ModuleT, is_ctor_context: bool = False ) -> FuncIR: """ Parse a function and produce IR code for the function, includes: @@ -111,6 +117,7 @@ def generate_ir_for_function( - Function body """ func_t = code._metadata["func_type"] + module_t = code._parent._metadata["type"] # type: ignore # generate _FuncIRInfo func_t._ir_info = _FuncIRInfo(func_t) @@ -133,14 +140,16 @@ def generate_ir_for_function( context = Context( vars_=None, - module_ctx=module_ctx, + compilation_target=compilation_target, memory_allocator=memory_allocator, constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant, func_t=func_t, is_ctor_context=is_ctor_context, ) - if func_t.is_internal: + is_internal_init = func_t.is_constructor and compilation_target != module_t + + if func_t.is_internal or is_internal_init: ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore else: @@ -163,7 +172,9 @@ def generate_ir_for_function( else: assert frame_info == func_t._ir_info.frame_info - if not func_t.is_internal: + if func_t.is_internal or is_internal_init: + ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + else: # adjust gas estimate to include cost of mem expansion # frame_size of external function includes all private functions called # (note: internal functions do not need to adjust gas estimate since @@ -171,7 +182,5 @@ def generate_ir_for_function( ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore - else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore return ret diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..043b01f4ed 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -50,6 +50,11 @@ def generate_ir_for_internal_function( cleanup_label = func_t._ir_info.exit_sequence_label stack_args = ["var_list"] + + for location in func_t.touched_locations: + location_name = location.name.lower() + stack_args.append(f"self_ptr_{location_name}") + if func_t.return_type: stack_args += ["return_buffer"] stack_args += ["return_pc"] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index 98395a6a0c..d397203986 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -15,6 +15,8 @@ def _topsort(functions): # single pass to get a global topological sort of functions (so that each # function comes after each of its callees). + # note this function can return functions from other modules if + # they are reachable! ret = OrderedSet() for func_ast in functions: fn_t = func_ast._metadata["func_type"] @@ -104,12 +106,12 @@ def _ir_for_internal_function(func_ast, *args, **kwargs): return generate_ir_for_function(func_ast, *args, **kwargs).func_ir -def _generate_external_entry_points(external_functions, module_ctx): +def _generate_external_entry_points(external_functions, compilation_target): entry_points = {} # map from ABI sigs to ir code sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_function(code, compilation_target) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -131,13 +133,13 @@ def _generate_external_entry_points(external_functions, module_ctx): # into a bucket (of about 8-10 items), and then uses perfect hash # to select the final function. # costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead) -def _selector_section_dense(external_functions, module_ctx): +def _selector_section_dense(external_functions, compilation_target): function_irs = [] if len(external_functions) == 0: return IRnode.from_list(["seq"]) - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) # generate the label so the jumptable works for abi_sig, entry_point in entry_points.items(): @@ -282,13 +284,13 @@ def _selector_section_dense(external_functions, module_ctx): # a bucket, and then descends into linear search from there. # costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5) # function and 24 bytes of code (+ ~23 bytes of global overhead) -def _selector_section_sparse(external_functions, module_ctx): +def _selector_section_sparse(external_functions, compilation_target): ret = ["seq"] if len(external_functions) == 0: return ret - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys()) @@ -387,14 +389,14 @@ def _selector_section_sparse(external_functions, module_ctx): # O(n) linear search for the method id # mainly keep this in for backends which cannot handle the indirect jump # in selector_section_dense and selector_section_sparse -def _selector_section_linear(external_functions, module_ctx): +def _selector_section_linear(external_functions, compilation_target): ret = ["seq"] if len(external_functions) == 0: return ret ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]]) - entry_points, sig_of = _generate_external_entry_points(external_functions, module_ctx) + entry_points, sig_of = _generate_external_entry_points(external_functions, compilation_target) dispatcher = ["seq"] @@ -423,13 +425,14 @@ def _selector_section_linear(external_functions, module_ctx): # take a ModuleT, and generate the runtime and deploy IR -def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: +def generate_ir_for_module(compilation_target: ModuleT) -> tuple[IRnode, IRnode]: # order functions so that each function comes after all of its callees - function_defs = _topsort(module_ctx.function_defs) - reachable = _globally_reachable_functions(module_ctx.function_defs) + function_defs = _topsort(compilation_target.function_defs) + reachable = _globally_reachable_functions(compilation_target.function_defs) runtime_functions = [f for f in function_defs if not _is_constructor(f)] - init_function = next((f for f in function_defs if _is_constructor(f)), None) + + init_function = next((f for f in compilation_target.function_defs if _is_constructor(f)), None) internal_functions = [f for f in runtime_functions if _is_internal(f)] @@ -444,7 +447,7 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: for func_ast in internal_functions: # compile it so that _ir_info is populated (whether or not it makes # it into the final IR artifact) - func_ir = _ir_for_internal_function(func_ast, module_ctx, False) + func_ir = _ir_for_internal_function(func_ast, compilation_target, False) # only include it in the IR if it is reachable from an external # function. @@ -452,16 +455,16 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: internal_functions_ir.append(IRnode.from_list(func_ir)) if core._opt_none(): - selector_section = _selector_section_linear(external_functions, module_ctx) + selector_section = _selector_section_linear(external_functions, compilation_target) # dense vs sparse global overhead is amortized after about 4 methods. # (--debug will force dense selector table anyway if _opt_codesize is selected.) elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()): - selector_section = _selector_section_dense(external_functions, module_ctx) + selector_section = _selector_section_dense(external_functions, compilation_target) else: - selector_section = _selector_section_sparse(external_functions, module_ctx) + selector_section = _selector_section_sparse(external_functions, compilation_target) if default_function: - fallback_ir = _ir_for_fallback_or_ctor(default_function, module_ctx) + fallback_ir = _ir_for_fallback_or_ctor(default_function, compilation_target) else: fallback_ir = IRnode.from_list( ["revert", 0, 0], annotation="Default function", error_msg="fallback function" @@ -474,25 +477,24 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: runtime.extend(internal_functions_ir) deploy_code: List[Any] = ["seq"] - immutables_len = module_ctx.immutable_section_bytes - if init_function: + immutables_len = compilation_target.immutable_bytes_required + if init_function is not None: # cleanly rerun codegen for internal functions with `is_ctor_ctx=True` init_func_t = init_function._metadata["func_type"] ctor_internal_func_irs = [] - internal_functions = [f for f in runtime_functions if _is_internal(f)] - for f in internal_functions: - func_t = f._metadata["func_type"] - if func_t not in init_func_t.reachable_internal_functions: - # unreachable code, delete it - continue - - func_ir = _ir_for_internal_function(f, module_ctx, is_ctor_context=True) + reachable_from_ctor = init_func_t.reachable_internal_functions + for func_t in reachable_from_ctor: + fn_ast = func_t.ast_def + + func_ir = _ir_for_internal_function(fn_ast, compilation_target, is_ctor_context=True) ctor_internal_func_irs.append(func_ir) # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor( + init_function, compilation_target, is_ctor_context=True + ) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py index f53e4a81b4..6eb60ded6c 100644 --- a/vyper/codegen/self_call.py +++ b/vyper/codegen/self_call.py @@ -1,4 +1,11 @@ -from vyper.codegen.core import _freshname, eval_once_check, make_setter +from vyper import ast as vy_ast +from vyper.codegen.core import ( + _freshname, + data_location_to_addr_space, + eval_once_check, + get_element_ptr, + make_setter, +) from vyper.codegen.ir_node import IRnode from vyper.evm.address_space import MEMORY from vyper.exceptions import StateAccessViolation @@ -20,7 +27,41 @@ def _align_kwargs(func_t, args_ir): return [i.default_value for i in unprovided_kwargs] -def ir_for_self_call(stmt_expr, context): +def _get_self_ptr_for_location(node: vy_ast.Attribute, context, location): + # resolve something like self.x.y.z to a pointer + if isinstance(node.value, vy_ast.Name): + # base case - we should always end up at self. sanity check this! + assert node.value.id == "self" + ptr = context.self_ptr(location) + else: + assert isinstance(node.value, vy_ast.Attribute) # mypy hint + # recurse + ptr = _get_self_ptr_for_location(node.value, context, location) + + return get_element_ptr(ptr, node.attr) + + +def _calculate_self_ptr_requirements(call_expr, func_t, context): + ret = [] + + module_t = func_t.ast_def._parent._metadata["type"] + if module_t == context.compilation_target: + # we don't need to pass a pointer + return ret + + func_expr = call_expr.func + assert isinstance(func_expr, vy_ast.Attribute) + pointer_expr = func_expr.value + for location in func_t.touched_locations: + codegen_location = data_location_to_addr_space(location) + + # self.foo.bar.baz() => pointer_expr == `self.foo.bar` + ret.append(_get_self_ptr_for_location(pointer_expr, context, codegen_location)) + + return ret + + +def ir_for_self_call(call_expr, context): from vyper.codegen.expr import Expr # TODO rethink this circular import # ** Internal Call ** @@ -30,10 +71,10 @@ def ir_for_self_call(stmt_expr, context): # - push jumpdest (callback ptr) and return buffer location # - jump to label # - (private function will fill return buffer and jump back) - method_name = stmt_expr.func.attr - func_t = stmt_expr.func._metadata["type"] + method_name = call_expr.func.attr + func_t = call_expr.func._metadata["type"] - pos_args_ir = [Expr(x, context).ir_node for x in stmt_expr.args] + pos_args_ir = [Expr(x, context).ir_node for x in call_expr.args] default_vals = _align_kwargs(func_t, pos_args_ir) default_vals_ir = [Expr(x, context).ir_node for x in default_vals] @@ -49,7 +90,7 @@ def ir_for_self_call(stmt_expr, context): raise StateAccessViolation( f"May not call state modifying function " f"'{method_name}' within {context.pp_constancy()}.", - stmt_expr, + call_expr, ) # note: internal_function_label asserts `func_t.is_internal`. @@ -91,14 +132,20 @@ def ir_for_self_call(stmt_expr, context): copy_args = make_setter(args_dst, args_as_tuple) goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)] + + # if needed, pass pointers to the callee + for self_ptr in _calculate_self_ptr_requirements(call_expr, func_t, context): + goto_op.append(self_ptr) + # pass return buffer to subroutine if return_buffer is not None: - goto_op += [return_buffer] + goto_op.append(return_buffer) + # pass return label to subroutine goto_op.append(["symbol", return_label]) call_sequence = ["seq"] - call_sequence.append(eval_once_check(_freshname(stmt_expr.node_source_code))) + call_sequence.append(eval_once_check(_freshname(call_expr.node_source_code))) call_sequence.extend([copy_args, goto_op, ["label", return_label, ["var_list"], "pass"]]) if return_buffer is not None: # push return buffer location to stack @@ -108,7 +155,7 @@ def ir_for_self_call(stmt_expr, context): call_sequence, typ=func_t.return_type, location=MEMORY, - annotation=stmt_expr.get("node_source_code"), + annotation=call_expr.get("node_source_code"), add_gas_estimate=func_t._ir_info.gas_estimate, ) o.is_self_call = True diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..34d44f64b8 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -41,8 +41,8 @@ def __init__(self, node: vy_ast.VyperNode, context: Context) -> None: self.stmt = node self.context = context - fn_name = f"parse_{type(node).__name__}" - with tag_exceptions(node, fallback_exception_type=CodegenPanic, note=fn_name): + with tag_exceptions(node, fallback_exception_type=CodegenPanic): + fn_name = f"parse_{type(node).__name__}" fn = getattr(self, fn_name) with context.internal_memory_scope(): self.ir_node = fn() @@ -145,7 +145,7 @@ def parse_Call(self): return pop_dyn_array(darray, return_popped_item=False) if isinstance(func_type, ContractFunctionT): - if func_type.is_internal: + if func_type.is_internal or func_type.is_constructor: return self_call.ir_for_self_call(self.stmt, self.context) else: return external_call.ir_for_external_call(self.stmt, self.context) diff --git a/vyper/compiler/input_bundle.py b/vyper/compiler/input_bundle.py index 27170f0a56..8ec82bd918 100644 --- a/vyper/compiler/input_bundle.py +++ b/vyper/compiler/input_bundle.py @@ -70,6 +70,11 @@ def __init__(self, search_paths): # share the same lifetime as this input bundle. self._cache = lambda: None + # very strict equality! we don't want to accidentally compare + # two input bundles as being equal when they aren't the same. + def __eq__(self, other): + return self is other + def _normalize_path(self, path): raise NotImplementedError(f"not implemented! {self.__class__}._normalize_path()") diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py index 850adcfea3..d78baf71ab 100644 --- a/vyper/compiler/phases.py +++ b/vyper/compiler/phases.py @@ -12,7 +12,7 @@ from vyper.compiler.settings import OptimizationLevel, Settings from vyper.exceptions import StructureException from vyper.ir import compile_ir, optimizer -from vyper.semantics import set_data_positions, validate_semantics +from vyper.semantics import validate_semantics from vyper.semantics.types.function import ContractFunctionT from vyper.semantics.types.module import ModuleT from vyper.typing import StorageLayout @@ -154,18 +154,17 @@ def vyper_module(self): @cached_property def _annotated_module(self): return generate_annotated_ast( - self.vyper_module, self.input_bundle, self.storage_layout_override + self.vyper_module, self.input_bundle ) @property def annotated_vyper_module(self) -> vy_ast.Module: - module, storage_layout = self._annotated_module - return module + return self._annotated_module @property def storage_layout(self) -> StorageLayout: - module, storage_layout = self._annotated_module - return storage_layout + module = self.vyper_module_folded + return module._metadata["variables_layout"] @property def global_ctx(self) -> ModuleT: @@ -246,7 +245,7 @@ def generate_annotated_ast( vyper_module: vy_ast.Module, input_bundle: InputBundle, storage_layout_overrides: StorageLayout = None, -) -> tuple[vy_ast.Module, StorageLayout]: +) -> vy_ast.Module: """ Validates and annotates the Vyper AST. @@ -259,17 +258,13 @@ def generate_annotated_ast( ------- vy_ast.Module Annotated Vyper AST - StorageLayout - Layout of variables in storage """ vyper_module = copy.deepcopy(vyper_module) with input_bundle.search_path(Path(vyper_module.resolved_path).parent): - # note: validate_semantics does type inference on the AST - validate_semantics(vyper_module, input_bundle) + # note: validate_semantics does type checking on the AST + validate_semantics(vyper_module, input_bundle, storage_layout_overrides) - symbol_tables = set_data_positions(vyper_module, storage_layout_overrides) - - return vyper_module, symbol_tables + return vyper_module def generate_ir_nodes( diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..6afa829d24 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -369,7 +369,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py index 8ce8c887f1..7e94d57b55 100644 --- a/vyper/ir/compile_ir.py +++ b/vyper/ir/compile_ir.py @@ -112,6 +112,9 @@ def calc_mem_ofst_size(ctor_mem_size): def _rewrite_return_sequences(ir_node, label_params=None): args = ir_node.args + # special values which should be popped at the end of function execution + POPPABLE_PARAMS = ("return_buffer", "self_ptr_storage", "self_ptr_immutables") + if ir_node.value == "return": if args[0].value == "ret_ofst" and args[1].value == "ret_len": ir_node.args[0].value = "pass" @@ -126,8 +129,10 @@ def _rewrite_return_sequences(ir_node, label_params=None): ir_node.value = "seq" _t = ["seq"] - if "return_buffer" in label_params: - _t.append(["pop", "pass"]) + + for s in POPPABLE_PARAMS: + if s in label_params: + _t.append(["pop", "pass"]) dest = args[0].value # works for both internal and external exit_to diff --git a/vyper/semantics/__init__.py b/vyper/semantics/__init__.py index bb40c266a4..48a51d3917 100644 --- a/vyper/semantics/__init__.py +++ b/vyper/semantics/__init__.py @@ -1,2 +1 @@ from .analysis import validate_semantics -from .analysis.data_positions import set_data_positions diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py index bb6d9ad9f7..7d44fc8321 100644 --- a/vyper/semantics/analysis/base.py +++ b/vyper/semantics/analysis/base.py @@ -46,14 +46,14 @@ def values(cls) -> List[str]: # Comparison operations def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") + raise CompilerPanic("bad comparison") return self is other # Python normally does __ne__(other) ==> not self.__eq__(other) def __lt__(self, other: object) -> bool: if not isinstance(other, self.__class__): - raise CompilerPanic("Can only compare like types.") + raise CompilerPanic("bad comparison") options = self.__class__.options() return options.index(self) < options.index(other) # type: ignore @@ -68,9 +68,16 @@ def __ge__(self, other: object) -> bool: class FunctionVisibility(_StringEnum): - # TODO: these can just be enum.auto() right? - EXTERNAL = _StringEnum.auto() - INTERNAL = _StringEnum.auto() + EXTERNAL = enum.auto() + INTERNAL = enum.auto() + CONSTRUCTOR = enum.auto() + + @classmethod + def is_valid_value(cls, value: str) -> bool: + # make CONSTRUCTOR visibility not available to the user + # (although as a design note - maybe `@constructor` should + # indeed be available) + return super().is_valid_value(value) and value != "constructor" class StateMutability(_StringEnum): @@ -117,57 +124,25 @@ class Modifiability(enum.IntEnum): # compile-time / always constant CONSTANT = enum.auto() - +@dataclass class DataPosition: - _location: DataLocation - - -class CalldataOffset(DataPosition): - __slots__ = ("dynamic_offset", "static_offset") - _location = DataLocation.CALLDATA - - def __init__(self, static_offset, dynamic_offset=None): - self.static_offset = static_offset - self.dynamic_offset = dynamic_offset + offset: int - def __repr__(self): - if self.dynamic_offset is not None: - return f"" - else: - return f"" - - -class MemoryOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.MEMORY - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" + @property + def location(self): + raise CompilerPanic("unreachable!") class StorageSlot(DataPosition): - __slots__ = ("position",) - _location = DataLocation.STORAGE - - def __init__(self, position): - self.position = position - - def __repr__(self): - return f"" + @property + def location(self): + return DataLocation.STORAGE class CodeOffset(DataPosition): - __slots__ = ("offset",) - _location = DataLocation.CODE - - def __init__(self, offset): - self.offset = offset - - def __repr__(self): - return f"" + @property + def location(self): + return DataLocation.IMMUTABLES # base class for things that are the "result" of analysis @@ -175,29 +150,18 @@ class AnalysisResult: pass -@dataclass -class ModuleInfo(AnalysisResult): - module_t: "ModuleT" - - @property - def module_node(self): - return self.module_t._module - - # duck type, conform to interface of VarInfo and ExprInfo - @property - def typ(self): - return self.module_t - - @dataclass class ImportInfo(AnalysisResult): - typ: Union[ModuleInfo, "InterfaceT"] + typ: Union["ModuleT", "InterfaceT"] alias: str # the name in the namespace qualified_module_name: str # for error messages # source_id: int input_bundle: InputBundle node: vy_ast.VyperNode + def __eq__(self, other): + return self is other + @dataclass class VarInfo: @@ -212,7 +176,7 @@ class VarInfo: """ typ: VyperType - location: DataLocation = DataLocation.UNSET + _location: DataLocation = DataLocation.UNSET modifiability: Modifiability = Modifiability.MODIFIABLE is_public: bool = False decl_node: Optional[vy_ast.VyperNode] = None @@ -220,10 +184,22 @@ class VarInfo: def __hash__(self): return hash(id(self)) + @property + def location(self): + return self._location + def __post_init__(self): - self._modification_count = 0 + self._reads = [] + self._writes = [] + self._position = None # the location provided by the allocator - def set_position(self, position: DataPosition) -> None: + def _set_position_in(self, position: DataPosition) -> None: + assert self._position is None + if self.location != position.location: + raise CompilerPanic(f"Incompatible locations: {self.location}, {position.location}") + self._position = position + + def _DEAD_set_position(self, position: DataPosition) -> None: if hasattr(self, "position"): raise CompilerPanic("Position was already assigned") if self.location != position._location: @@ -237,13 +213,81 @@ def set_position(self, position: DataPosition) -> None: raise CompilerPanic("Incompatible locations") self.position = position + + def set_storage_position(self, position: DataPosition): + assert self.location == DataLocation.STORAGE + self._set_position_in(position) + + def set_immutables_position(self, position: DataPosition): + assert self.location == DataLocation.IMMUTABLES + self._set_position_in(position) + + def get_position(self) -> int: + return self._position.offset + + def get_offset_in(self, location): + assert location == self.location + return self._position.offset + + def get_size_in(self, location) -> int: + """ + Get the amount of space this variable occupies in a given location + """ + if location == self.location: + return self.typ.size_in_location(location) + return 0 + + +class ModuleVarInfo(VarInfo): + """ + A special VarInfo for modules + """ + + def __post_init__(self): + super().__post_init__() + # hmm + from vyper.semantics.types.module import ModuleT + + assert isinstance(self.typ, ModuleT) + + self._immutables_offset = None + self._storage_offset = None + + @property + def location(self): + # location does not make sense for module vars but make the API work + return DataLocation.STORAGE + + def set_immutables_position(self, ofst): + assert self._immutables_offset is None + assert ofst.location == DataLocation.IMMUTABLES + self._immutables_offset = ofst + + def set_storage_position(self, ofst): + assert self._storage_offset is None + assert ofst.location == DataLocation.STORAGE + self._storage_offset = ofst + + def get_position(self): + raise CompilerPanic("use get_offset_in for ModuleVarInfo!") + + def get_offset_in(self, location): + if location == DataLocation.STORAGE: + return self._storage_offset.offset + if location == DataLocation.IMMUTABLES: + return self._immutables_offset.offset + raise CompilerPanic("unreachable") # pragma: nocover + + def get_size_in(self, location): + return self.typ.size_in_location(location) + @property def is_transient(self): return self.location == DataLocation.TRANSIENT @property def is_immutable(self): - return self.location == DataLocation.CODE + return self.location == DataLocation.IMMUTABLES @property def is_constant(self): @@ -259,30 +303,26 @@ class ExprInfo: """ typ: VyperType - var_info: Optional[VarInfo] = None location: DataLocation = DataLocation.UNSET + _var_info: Optional[VarInfo] = None modifiability: Modifiability = Modifiability.MODIFIABLE def __post_init__(self): should_match = ("typ", "location", "modifiability") - if self.var_info is not None: + if self._var_info is not None: for attr in should_match: - if getattr(self.var_info, attr) != getattr(self, attr): + if getattr(self._var_info, attr) != getattr(self, attr): raise CompilerPanic("Bad analysis: non-matching {attr}: {self}") @classmethod def from_varinfo(cls, var_info: VarInfo) -> "ExprInfo": return cls( var_info.typ, - var_info=var_info, - location=var_info.location, modifiability=var_info.modifiability, + location=var_info.location, + _var_info=var_info, ) - @classmethod - def from_moduleinfo(cls, module_info: ModuleInfo) -> "ExprInfo": - return cls(module_info.module_t) - def copy_with_type(self, typ: VyperType) -> "ExprInfo": """ Return a copy of the ExprInfo but with the type set to something else @@ -291,6 +331,7 @@ def copy_with_type(self, typ: VyperType) -> "ExprInfo": fields = {k: getattr(self, k) for k in to_copy} return self.__class__(typ=typ, **fields) + # TODO: move to analysis/local.py def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutability) -> None: """ Validate an attempt to modify this value. @@ -312,18 +353,24 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if self.location == DataLocation.CALLDATA: raise ImmutableViolation("Cannot write to calldata", node) + func_node = node.get_ancestor(vy_ast.FunctionDef) + + assert self._var_info is not None # mypy hint + assert isinstance(func_node, vy_ast.FunctionDef) # mypy hint + + func_t = func_node._metadata["func_type"] + if self.modifiability == Modifiability.RUNTIME_CONSTANT: - if self.location == DataLocation.CODE: - if node.get_ancestor(vy_ast.FunctionDef).get("name") != "__init__": + # special handling for immutable variables in the ctor + if self.location == DataLocation.IMMUTABLES: + if func_node.name != "__init__": raise ImmutableViolation("Immutable value cannot be written to", node) - # special handling for immutable variables in the ctor - # TODO: we probably want to remove this restriction. - if self.var_info._modification_count: # type: ignore + # we may consider removing this restriction. + if len(self._var_info._writes) > 0: raise ImmutableViolation( "Immutable value cannot be modified after assignment", node ) - self.var_info._modification_count += 1 # type: ignore else: raise ImmutableViolation("Environment variable cannot be written to", node) @@ -332,3 +379,8 @@ def validate_modification(self, node: vy_ast.VyperNode, mutability: StateMutabil if isinstance(node, vy_ast.AugAssign): self.typ.validate_numeric_op(node) + + # tag it in the metadata + node._metadata["variable_write"] = self._var_info + self._var_info._writes.append(node) + func_t._variable_writes.append(node) diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 88679a4b09..5acc209321 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -1,16 +1,10 @@ -# TODO this module doesn't really belong in "validation" -from typing import Dict, List - from vyper import ast as vy_ast -from vyper.exceptions import StorageLayoutException -from vyper.semantics.analysis.base import CodeOffset, StorageSlot +from vyper.exceptions import CompilerPanic, StorageLayoutException +from vyper.semantics.analysis.base import CodeOffset, ModuleVarInfo, StorageSlot from vyper.typing import StorageLayout -from vyper.utils import ceil32 -def set_data_positions( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout = None -) -> StorageLayout: +def allocate_variables(vyper_module: vy_ast.Module) -> StorageLayout: """ Parse the annotated Vyper AST, determine data positions for all variables, and annotate the AST nodes with the position data. @@ -20,131 +14,21 @@ def set_data_positions( vyper_module : vy_ast.Module Top-level Vyper AST node that has already been annotated with type data. """ - code_offsets = set_code_offsets(vyper_module) - storage_slots = ( - set_storage_slots_with_overrides(vyper_module, storage_layout_overrides) - if storage_layout_overrides is not None - else set_storage_slots(vyper_module) - ) + code_offsets = _set_code_offsets(vyper_module) + storage_slots = _set_storage_slots(vyper_module) return {"storage_layout": storage_slots, "code_layout": code_offsets} -class StorageAllocator: - """ - Keep track of which storage slots have been used. If there is a collision of - storage slots, this will raise an error and fail to compile - """ - - def __init__(self): - self.occupied_slots: Dict[int, str] = {} - - def reserve_slot_range(self, first_slot: int, n_slots: int, var_name: str) -> None: - """ - Reserves `n_slots` storage slots, starting at slot `first_slot` - This will raise an error if a storage slot has already been allocated. - It is responsibility of calling function to ensure first_slot is an int - """ - list_to_check = [x + first_slot for x in range(n_slots)] - self._reserve_slots(list_to_check, var_name) - - def _reserve_slots(self, slots: List[int], var_name: str) -> None: - for slot in slots: - self._reserve_slot(slot, var_name) - - def _reserve_slot(self, slot: int, var_name: str) -> None: - if slot < 0 or slot >= 2**256: - raise StorageLayoutException( - f"Invalid storage slot for var {var_name}, out of bounds: {slot}" - ) - if slot in self.occupied_slots: - collided_var = self.occupied_slots[slot] - raise StorageLayoutException( - f"Storage collision! Tried to assign '{var_name}' to slot {slot} but it has " - f"already been reserved by '{collided_var}'" - ) - self.occupied_slots[slot] = var_name - - -def set_storage_slots_with_overrides( - vyper_module: vy_ast.Module, storage_layout_overrides: StorageLayout -) -> StorageLayout: - """ - Parse module-level Vyper AST to calculate the layout of storage variables. - Returns the layout as a dict of variable name -> variable info - """ - - ret: Dict[str, Dict] = {} - reserved_slots = StorageAllocator() - - # Search through function definitions to find non-reentrant functions - for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] - - # Ignore functions without non-reentrant - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - - # re-entrant key was already identified - if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(StorageSlot(_slot)) - continue - - # Expect to find this variable within the storage layout override - if variable_name in storage_layout_overrides: - reentrant_slot = storage_layout_overrides[variable_name]["slot"] - # Ensure that this slot has not been used, and prevents other storage variables - # from using the same slot - reserved_slots.reserve_slot_range(reentrant_slot, 1, variable_name) - - type_.set_reentrancy_key_position(StorageSlot(reentrant_slot)) - - ret[variable_name] = {"type": "nonreentrant lock", "slot": reentrant_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {variable_name}. " - "Have you used the correct storage layout file?", - node, - ) - - # Iterate through variables - for node in vyper_module.get_children(vy_ast.VariableDecl): - # Ignore immutable parameters - if node.get("annotation.func.id") == "immutable": - continue - - varinfo = node.target._metadata["varinfo"] - - # Expect to find this variable within the storage layout overrides - if node.target.id in storage_layout_overrides: - var_slot = storage_layout_overrides[node.target.id]["slot"] - storage_length = varinfo.typ.storage_size_in_words - # Ensure that all required storage slots are reserved, and prevents other variables - # from using these slots - reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id) - varinfo.set_position(StorageSlot(var_slot)) - - ret[node.target.id] = {"type": str(varinfo.typ), "slot": var_slot} - else: - raise StorageLayoutException( - f"Could not find storage_slot for {node.target.id}. " - "Have you used the correct storage layout file?", - node, - ) - - return ret - +class SimpleAllocator: + _max_slots: int = None # type: ignore -class SimpleStorageAllocator: def __init__(self, starting_slot: int = 0): self._slot = starting_slot - def allocate_slot(self, n, var_name): + def allocate(self, n, var_name=""): ret = self._slot - if self._slot + n >= 2**256: + if self._slot + n >= self._max_slots: raise StorageLayoutException( f"Invalid storage slot for var {var_name}, tried to allocate" f" slots {self._slot} through {self._slot + n}" @@ -153,7 +37,15 @@ def allocate_slot(self, n, var_name): return ret -def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: +class SimpleStorageAllocator(SimpleAllocator): + _max_slots = 2**256 + + +class SimpleImmutablesAllocator(SimpleAllocator): + _max_slots = 0x6000 # eip-170 + + +def _set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: """ Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info @@ -162,81 +54,95 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout: # note storage is word-addressable, not byte-addressable allocator = SimpleStorageAllocator() - ret: Dict[str, Dict] = {} + ret: dict[str, dict] = {} - for node in vyper_module.get_children(vy_ast.FunctionDef): - type_ = node._metadata["func_type"] + for funcdef in vyper_module.get_children(vy_ast.FunctionDef): + type_ = funcdef._metadata["func_type"] if type_.nonreentrant is None: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + keyname = f"nonreentrant.{type_.nonreentrant}" # a nonreentrant key can appear many times in a module but it # only takes one slot. after the first time we see it, do not # increment the storage slot. - if variable_name in ret: - _slot = ret[variable_name]["slot"] + if keyname in ret: + _slot = ret[keyname]["slot"] type_.set_reentrancy_key_position(StorageSlot(_slot)) continue # TODO use one byte - or bit - per reentrancy key # requires either an extra SLOAD or caching the value of the # location in memory at entrance - slot = allocator.allocate_slot(1, variable_name) + slot = allocator.allocate(1, keyname) type_.set_reentrancy_key_position(StorageSlot(slot)) # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[variable_name] = {"type": "nonreentrant lock", "slot": slot} + ret[keyname] = {"type": "nonreentrant lock", "slot": slot} - for node in vyper_module.get_children(vy_ast.VariableDecl): + for varinfo in vyper_module._metadata["type"].variables.values(): # skip non-storage variables - if node.is_constant or node.is_immutable: + if varinfo.is_constant or varinfo.is_immutable: continue - varinfo = node.target._metadata["varinfo"] type_ = varinfo.typ + vardecl = varinfo.decl_node + assert isinstance(vardecl, vy_ast.VariableDecl) + + varname = vardecl.target.id + # CMC 2021-07-23 note that HashMaps get assigned a slot here. # I'm not sure if it's safe to avoid allocating that slot # for HashMaps because downstream code might use the slot # ID as a salt. - n_slots = type_.storage_size_in_words - slot = allocator.allocate_slot(n_slots, node.target.id) + n_slots = type_.storage_slots_required + slot = allocator.allocate(n_slots, varname) - varinfo.set_position(StorageSlot(slot)) + varinfo.set_storage_position(StorageSlot(slot)) + assert varname not in ret # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "slot": slot} + ret[varname] = {"type": str(type_), "slot": slot} return ret -def set_calldata_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass +def _set_code_offsets(vyper_module: vy_ast.Module) -> dict[str, dict]: + ret = {} + allocator = SimpleImmutablesAllocator() + for varinfo in vyper_module._metadata["type"].variables.values(): + type_ = varinfo.typ -def set_memory_offsets(fn_node: vy_ast.FunctionDef) -> None: - pass + if not varinfo.is_immutable and not isinstance(varinfo, ModuleVarInfo): + continue + len_ = type_.immutable_bytes_required -def set_code_offsets(vyper_module: vy_ast.Module) -> Dict: - ret = {} - offset = 0 + # sanity check. there are ways to construct varinfo with no + # decl_node but they shouldn't make it to here + vardecl = varinfo.decl_node + assert isinstance(vardecl, vy_ast.VariableDecl) + varname = vardecl.target.id - for node in vyper_module.get_children(vy_ast.VariableDecl, filters={"is_immutable": True}): - varinfo = node.target._metadata["varinfo"] - type_ = varinfo.typ - varinfo.set_position(CodeOffset(offset)) + if len_ % 32 != 0: + # sanity check length is a multiple of 32, it's an invariant + # that is used a lot in downstream code. + raise CompilerPanic("bad invariant") - len_ = ceil32(type_.size_in_bytes) + offset = allocator.allocate(len_, varname) + varinfo.set_immutables_position(CodeOffset(offset)) # this could have better typing but leave it untyped until # we understand the use case better - ret[node.target.id] = {"type": str(type_), "offset": offset, "length": len_} + output_dict = {"type": str(type_), "offset": offset, "length": len_} - offset += len_ + # put it into the storage layout + assert varname not in ret + ret[varname] = output_dict return ret diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91fb2c21f0..df302c9645 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -171,6 +171,7 @@ def _validate_pure_access(node: vy_ast.Attribute, typ: VyperType) -> None: def _validate_self_reference(node: vy_ast.Name) -> None: # CMC 2023-10-19 this detector seems sus, things like `a.b(self)` could slip through + # TODO: this is now wrong, we can have things like `self.module.foo` if node.id == "self" and not isinstance(node.get_ancestor(), vy_ast.Attribute): raise StateAccessViolation("not allowed to query self in pure functions", node) @@ -197,7 +198,7 @@ def analyze(self): for arg in self.func.arguments: self.namespace[arg.name] = VarInfo( - arg.typ, location=location, modifiability=modifiability + arg.typ, _location=location, modifiability=modifiability ) for node in self.fn_node.body: @@ -264,17 +265,17 @@ def _assign_helper(self, node): if isinstance(node.value, vy_ast.Tuple): raise StructureException("Right-hand side of assignment cannot be a tuple", node.value) - target = get_expr_info(node.target) - if isinstance(target.typ, HashMapT): + target_info = get_expr_info(node.target) + if isinstance(target_info.typ, HashMapT): raise StructureException( "Left-hand side of assignment cannot be a HashMap without a key", node ) - validate_expected_type(node.value, target.typ) - target.validate_modification(node, self.func.mutability) + validate_expected_type(node.value, target_info.typ) + target_info.validate_modification(node, self.func.mutability) - self.expr_visitor.visit(node.value, target.typ) - self.expr_visitor.visit(node.target, target.typ) + self.expr_visitor.visit(node.value, target_info.typ) + self.expr_visitor.visit(node.target, target_info.typ) def visit_Assign(self, node): self._assign_helper(node) @@ -549,6 +550,13 @@ def visit(self, node, typ): # annotate node._metadata["type"] = typ + # tag variable accesses + info = get_expr_info(node) + if (var_info := info._var_info) is not None: + node._metadata["variable_access"] = var_info + var_info._reads.append(node) + self.func._variable_reads.append(node) + # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8e435f870f..232a771c4d 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -23,6 +23,7 @@ ) from vyper.semantics.analysis.base import ImportInfo, Modifiability, ModuleInfo, VarInfo from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.analysis.data_positions import allocate_variables from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions from vyper.semantics.analysis.pre_typecheck import pre_typecheck @@ -39,6 +40,7 @@ from vyper.semantics.types.utils import type_from_annotation +# TODO: rename to `analyze_vyper` def validate_semantics(module_ast, input_bundle, is_interface=False) -> ModuleT: return validate_semantics_r(module_ast, input_bundle, ImportGraph(), is_interface) @@ -71,6 +73,9 @@ def validate_semantics_r( if not is_interface: validate_functions(module_ast) + layout = allocate_variables(module_ast) + module_ast._metadata["variables_layout"] = layout + return ret @@ -198,7 +203,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and call_t.is_internal: + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: @@ -256,7 +263,7 @@ def visit_VariableDecl(self, node): raise SyntaxException(message, node.node_source_code, node.lineno, node.col_offset) data_loc = ( - DataLocation.CODE + DataLocation.IMMUTABLES if node.is_immutable else DataLocation.UNSET if node.is_constant @@ -278,16 +285,23 @@ def visit_VariableDecl(self, node): if node.is_transient and not version_check(begin="cancun"): raise StructureException("`transient` is not available pre-cancun", node.annotation) - var_info = VarInfo( + if isinstance(type_, ModuleT): + var_type = ModuleVarInfo + else: + var_type = VarInfo + + var_info = var_type( type_, decl_node=node, - location=data_loc, + _location=data_loc, modifiability=modifiability, is_public=node.is_public, ) + node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace node._metadata["type"] = type_ + # TODO: maybe this code can be removed def _finalize(): # add the variable name to `self` namespace if the variable is either # 1. a public constant or immutable; or @@ -405,7 +419,7 @@ def _add_import( ) self.namespace[alias] = module_info - # load an InterfaceT or ModuleInfo from an import. + # load an InterfaceT or ModuleT from an import. # raises FileNotFoundError def _load_import(self, node: vy_ast.VyperNode, level: int, module_str: str, alias: str) -> Any: # the directory this (currently being analyzed) module is in @@ -446,7 +460,7 @@ def _load_import_helper( is_interface=False, ) - return ModuleInfo(module_t) + return module_t except FileNotFoundError as e: # escape `e` from the block scope, it can make things diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ba1b02b8d6..68c16243c2 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleVarInfo, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -71,9 +71,6 @@ def get_expr_info(self, node: vy_ast.VyperNode) -> ExprInfo: if isinstance(info, VarInfo): return ExprInfo.from_varinfo(info) - if isinstance(info, ModuleInfo): - return ExprInfo.from_moduleinfo(info) - raise CompilerPanic("unreachable!", node) if isinstance(node, vy_ast.Attribute): diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py index cecea35a60..39a3fa346a 100644 --- a/vyper/semantics/data_locations.py +++ b/vyper/semantics/data_locations.py @@ -2,9 +2,13 @@ class DataLocation(enum.Enum): - UNSET = 0 - MEMORY = 1 - STORAGE = 2 - CALLDATA = 3 - CODE = 4 - TRANSIENT = 5 + # TODO: rename me to something like VarLocation, or StorageRegion + """ + Possible locations for variables in vyper + """ + UNSET = enum.auto() # like constants and stack variables + MEMORY = enum.auto() # local variables + STORAGE = enum.auto() # storage variables + CALLDATA = enum.auto() # arguments to external functions + IMMUTABLES = enum.auto() # immutable variables + TRANSIENT = enum.auto() # transient storage variables diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py index 880857ccb8..a8586a62dd 100644 --- a/vyper/semantics/types/__init__.py +++ b/vyper/semantics/types/__init__.py @@ -2,7 +2,7 @@ from .base import TYPE_T, KwargSettings, VyperType, is_type_t from .bytestrings import BytesT, StringT, _BytestringT from .function import MemberFunctionT -from .module import InterfaceT +from .module import InterfaceT, ModuleT from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT from .subscriptable import DArrayT, HashMapT, SArrayT, TupleT from .user import EventT, FlagT, StructT diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 429ba807e1..10a27fb2e8 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -13,6 +13,7 @@ UnknownAttribute, ) from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions +from vyper.semantics.data_locations import DataLocation # Some fake type with an overridden `compare_type` which accepts any RHS @@ -68,7 +69,11 @@ class VyperType: _supports_external_calls: bool = False _attribute_in_annotation: bool = False - size_in_bytes = 32 # default; override for larger types + # _size_in_bytes is an internal property that is used + # to calculate sizes required in various locations. it can + # be used by subclasses, but does not have to be. it should + # *not* by used by external consumers of VyperType! + _size_in_bytes = 32 # default; override for larger types def __init__(self, members: Optional[Dict] = None) -> None: self.members: Dict = {} @@ -118,24 +123,56 @@ def abi_type(self) -> ABIType: """ raise CompilerPanic("Method must be implemented by the inherited class") + # return the size in bytes or slots that this type + # needs to allocate in the provided location + def size_in_location(self, location): + if location in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in {location}!") + + if location == DataLocation.MEMORY: + return self.memory_bytes_required + if location == DataLocation.IMMUTABLES: + return self.immutable_bytes_required + if location == DataLocation.STORAGE: + return self.storage_slots_required + + raise CompilerPanic("invalid location: {location}") # pragma: nocover + @property def memory_bytes_required(self) -> int: + if DataLocation.MEMORY in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in memory!") # alias for API compatibility with codegen - return self.size_in_bytes + return self._size_in_bytes @property - def storage_size_in_words(self) -> int: + def storage_slots_required(self) -> int: # consider renaming if other word-addressable address spaces are # added to EVM or exist in other arches """ Returns the number of words required to allocate in storage for this type """ - r = self.memory_bytes_required + if DataLocation.STORAGE in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be instantiated in storage!") + + r = self._size_in_bytes if r % 32 != 0: raise CompilerPanic("Memory bytes must be multiple of 32") return r // 32 + @property + def immutable_bytes_required(self) -> int: + """ + Returns the number of bytes required when instantiating this type + in the immutables section + """ + # sanity check the type can actually be instantiated as an immutable + if DataLocation.IMMUTABLES in self._invalid_locations: + raise CompilerPanic(f"{self} cannot be an immutable!") + + return self._size_in_bytes + @property def canonical_abi_type(self) -> str: """ diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index e3c381ac69..dee7a8e9be 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -67,7 +67,7 @@ def validate_literal(self, node: vy_ast.Constant) -> None: raise CompilerPanic("unreachable") @property - def size_in_bytes(self): + def _size_in_bytes(self): # the first slot (32 bytes) stores the actual length, and then we reserve # enough additional slots to store the data if it uses the max available length # because this data type is single-bytes, we make it so it takes the max 32 byte diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 7c77560e49..46989602e2 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -112,10 +112,34 @@ def __init__( # recursively reachable from this function self.reachable_internal_functions: OrderedSet[ContractFunctionT] = OrderedSet() + # list of variables read in this function + self._variable_reads: list[vy_ast.VyperNode] = [] + # list of variables written in this function + self._variable_writes: list[vy_ast.VyperNode] = [] + # to be populated during codegen self._ir_info: Any = None self._function_id: Optional[int] = None + def touches_location(self, location): + for r in self._variable_reads: + if r._metadata["variable_access"].location == location: + return True + for w in self._variable_writes: + if w._metadata["variable_write"].location == location: + return True + return False + + @property + def touched_locations(self): + # return the DataLocation of touched module variables + ret = [] + possible_locations = (DataLocation.STORAGE, DataLocation.IMMUTABLES) + for location in possible_locations: + if self.touches_location(location): + ret.append(location) + return ret + @cached_property def call_site_kwargs(self): # special kwargs that are allowed in call site @@ -272,6 +296,8 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "function body in an interface can only be ...!", funcdef ) + assert function_visibility is not None # mypy hint + return cls( funcdef.name, positional_args, @@ -315,13 +341,15 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ) if funcdef.name == "__init__": - if ( - state_mutability in (StateMutability.PURE, StateMutability.VIEW) - or function_visibility == FunctionVisibility.INTERNAL - ): + if state_mutability in (StateMutability.PURE, StateMutability.VIEW): + raise FunctionDeclarationException( + "Constructor cannot be marked as `@pure` or `@view`", funcdef + ) + if function_visibility is not None: raise FunctionDeclarationException( - "Constructor cannot be marked as `@pure`, `@view` or `@internal`", funcdef + "Constructor cannot be marked as `@internal` or `@external`", funcdef ) + function_visibility = FunctionVisibility.CONSTRUCTOR if return_type is not None: raise FunctionDeclarationException( "Constructor may not have a return type", funcdef.returns @@ -333,6 +361,9 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": "Constructor may not use default arguments", funcdef.args.defaults[0] ) + # sanity check + assert function_visibility is not None + return cls( funcdef.name, positional_args, @@ -350,7 +381,7 @@ def set_reentrancy_key_position(self, position: StorageSlot) -> None: if self.nonreentrant is None: raise CompilerPanic(f"No reentrant key {self}") # sanity check even though implied by the type - if position._location != DataLocation.STORAGE: + if position.location != DataLocation.STORAGE: raise CompilerPanic("Non-storage reentrant key") self.reentrancy_key_position = position @@ -456,6 +487,10 @@ def is_external(self) -> bool: def is_internal(self) -> bool: return self.visibility == FunctionVisibility.INTERNAL + @property + def is_constructor(self) -> bool: + return self.visibility == FunctionVisibility.CONSTRUCTOR + @property def is_mutable(self) -> bool: return self.mutability > StateMutability.VIEW @@ -464,10 +499,6 @@ def is_mutable(self) -> bool: def is_payable(self) -> bool: return self.mutability == StateMutability.PAYABLE - @property - def is_constructor(self) -> bool: - return self.name == "__init__" - @property def is_fallback(self) -> bool: return self.name == "__default__" @@ -601,7 +632,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[FunctionVisibility, StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: function_visibility = None state_mutability = None nonreentrant_key = None @@ -656,7 +687,7 @@ def _parse_decorators( else: raise StructureException("Bad decorator syntax", decorator) - if function_visibility is None: + if function_visibility is None and funcdef.name != "__init__": raise FunctionDeclarationException( f"Visibility must be set to one of: {', '.join(FunctionVisibility.values())}", funcdef ) diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..a1d985e271 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -7,6 +7,7 @@ from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids +from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType from vyper.semantics.types.function import ContractFunctionT @@ -253,6 +254,16 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT": # Datatype to store all module information. class ModuleT(VyperType): + _attribute_in_annotation = True + + # disallow everything but storage + _invalid_locations = ( + DataLocation.UNSET, + DataLocation.CALLDATA, + DataLocation.IMMUTABLES, + DataLocation.MEMORY, + ) + def __init__(self, module: vy_ast.Module, name: Optional[str] = None): super().__init__() @@ -292,8 +303,18 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - def get_type_member(self, key: str, node: vy_ast.VyperNode) -> "VyperType": - return self._helper.get_member(key, node) + def __repr__(self): + resolved_path = self._module.resolved_path + if self._id == resolved_path: + return f"module {self._id}" + else: + return f"module {self._id} (loaded from '{self._module.resolved_path}')" + + def offset_of(self, attr: str, location: DataLocation): + pass + + def get_type_member(self, attr: str, node: vy_ast.VyperNode) -> VyperType: + return self._helper.get_type_member(attr, node) # this is a property, because the function set changes after AST expansion @property @@ -322,13 +343,21 @@ def variables(self): # `x: uint256` is a private storage variable named x return {s.target.id: s.target._metadata["varinfo"] for s in self.variable_decls} + @cached_property + def storage_slots_required(self): + return sum(v.typ.storage_slots_required for v in self.variables.values()) + @cached_property def immutables(self): return [t for t in self.variables.values() if t.is_immutable] @cached_property - def immutable_section_bytes(self): - return sum([imm.typ.memory_bytes_required for imm in self.immutables]) + def immutable_bytes_required(self): + # note: super().immutable_bytes_required checks that + # `DataLocations.CODE not in self._invalid_locations`; this is ok because + # ModuleT is a bit of a hybrid - it can't be declared as an immutable, but + # it can have immutable members. + return sum(imm.typ.immutable_bytes_required for imm in self.immutables) @cached_property def interface(self): diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 55ffc23b2f..01a88cdf8c 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -49,7 +49,7 @@ class HashMapT(_SubscriptableT): _invalid_locations = ( DataLocation.UNSET, DataLocation.CALLDATA, - DataLocation.CODE, + DataLocation.IMMUTABLES, DataLocation.MEMORY, ) @@ -174,10 +174,9 @@ def to_abi_arg(self, name: str = "") -> Dict[str, Any]: ret["type"] += f"[{self.length}]" return _set_first_key(ret, "name", name) - # TODO rename to `memory_bytes_required` @property - def size_in_bytes(self): - return self.value_type.size_in_bytes * self.length + def _size_in_bytes(self): + return self.value_type._size_in_bytes * self.length @property def subtype(self): @@ -257,11 +256,10 @@ def to_abi_arg(self, name: str = "") -> Dict[str, Any]: ret["type"] += "[]" return _set_first_key(ret, "name", name) - # TODO rename me to memory_bytes_required @property - def size_in_bytes(self): + def _size_in_bytes(self): # one length word + size of the array items - return 32 + self.value_type.size_in_bytes * self.length + return 32 + self.value_type._size_in_bytes * self.length def compare_type(self, other): # TODO allow static array to be assigned to dyn array? @@ -357,8 +355,8 @@ def to_abi_arg(self, name: str = "") -> dict: return {"name": name, "type": "tuple", "components": components} @property - def size_in_bytes(self): - return sum(i.size_in_bytes for i in self.member_types) + def _size_in_bytes(self): + return sum(i._size_in_bytes for i in self.member_types) def validate_index_type(self, node): if not isinstance(node, vy_ast.Int): diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a4e782349d..45fe8bbcc6 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -355,11 +355,12 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": return cls(struct_name, members, ast_def=base_node) def __repr__(self): - return f"{self._id} declaration object" + arg_types = ",".join(repr(t) for t in self.members.values()) + return f"struct {self._id}({arg_types})" @property - def size_in_bytes(self): - return sum(i.size_in_bytes for i in self.member_types.values()) + def _size_in_bytes(self): + return sum(i._size_in_bytes for i in self.member_types.values()) @property def abi_type(self) -> ABIType: diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py index eb96375404..428557d322 100644 --- a/vyper/semantics/types/utils.py +++ b/vyper/semantics/types/utils.py @@ -123,18 +123,14 @@ def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType: raise InvalidType(err_msg, node) try: - module_or_interface = namespace[node.value.id] # type: ignore + module_or_interface = namespace[node.value.id] except UndeclaredDefinition: raise InvalidType(err_msg, node) from None - interface = module_or_interface - if hasattr(module_or_interface, "module_t"): # i.e., it's a ModuleInfo - interface = module_or_interface.module_t.interface - - if not interface._attribute_in_annotation: + if not module_or_interface._attribute_in_annotation: raise InvalidType(err_msg, node) - type_t = interface.get_type_member(node.attr, node) + type_t = module_or_interface.get_type_member(node.attr, node) assert isinstance(type_t, TYPE_T) # sanity check return type_t.typedef