From b67d361e3b437ffc8808c9d87cf23953c8d5b3a3 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Sun, 14 Jan 2024 14:53:15 -0500 Subject: [PATCH] get __init__() working both as entry point and as internal function refactor: - refactor generate_ir_for_function into generate_ir_for_external_function and generate_ir_for_internal_function - move get_nonreentrant_lock to function-definitions/common.py --- vyper/codegen/context.py | 6 +- .../codegen/function_definitions/__init__.py | 5 +- vyper/codegen/function_definitions/common.py | 114 ++++++++---------- .../function_definitions/external_function.py | 45 +++++-- .../function_definitions/internal_function.py | 29 +++-- vyper/codegen/function_definitions/utils.py | 31 ----- vyper/codegen/module.py | 13 +- vyper/semantics/analysis/module.py | 5 +- 8 files changed, 131 insertions(+), 117 deletions(-) delete mode 100644 vyper/codegen/function_definitions/utils.py diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py index 316ebecc8f..d404f8d8b5 100644 --- a/vyper/codegen/context.py +++ b/vyper/codegen/context.py @@ -44,7 +44,7 @@ def __repr__(self): return f"VariableRecord({ret})" -# Contains arguments, variables, etc +# compilation context for a function class Context: def __init__( self, @@ -83,6 +83,10 @@ def __init__( # Not intended to be accessed directly self.memory_allocator = memory_allocator + # save the starting memory location so we can find out (later) + # how much memory this function uses. + self.starting_memory = memory_allocator.next_mem + # Incremented values, used for internal IDs self._internal_var_iter = 0 self._scope_id_iter = 0 diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py index 94617bef35..6bc1254fed 100644 --- a/vyper/codegen/function_definitions/__init__.py +++ b/vyper/codegen/function_definitions/__init__.py @@ -1 +1,4 @@ -from .common import FuncIR, generate_ir_for_function # noqa +from .external_function import generate_ir_for_external_function +from .internal_function import generate_ir_for_internal_function + +__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function] diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py index 09f969305f..d017ba7b81 100644 --- a/vyper/codegen/function_definitions/common.py +++ b/vyper/codegen/function_definitions/common.py @@ -2,17 +2,14 @@ from functools import cached_property from typing import Optional -import vyper.ast as vy_ast from vyper.codegen.context import Constancy, Context -from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function -from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function from vyper.codegen.ir_node import IRnode from vyper.codegen.memory_allocator import MemoryAllocator -from vyper.exceptions import CompilerPanic +from vyper.evm.opcodes import version_check from vyper.semantics.types import VyperType -from vyper.semantics.types.function import ContractFunctionT +from vyper.semantics.types.function import ContractFunctionT, StateMutability from vyper.semantics.types.module import ModuleT -from vyper.utils import MemoryPositions, calc_mem_gas +from vyper.utils import MemoryPositions @dataclass @@ -53,14 +50,16 @@ def ir_identifier(self) -> str: return f"{self.visibility} {function_id} {name}({argz})" def set_frame_info(self, frame_info: FrameInfo) -> None: + # XXX: when can this happen? if self.frame_info is not None: - raise CompilerPanic(f"frame_info already set for {self.func_t}!") - self.frame_info = frame_info + assert frame_info == self.frame_info + else: + self.frame_info = frame_info @property # common entry point for external function with kwargs def external_function_base_entry_label(self) -> str: - assert self.func_t.is_external, "uh oh, should be external" + assert not self.func_t.is_internal, "uh oh, should be external" return self.ir_identifier + "_common" def internal_function_label(self, is_ctor_context: bool = False) -> str: @@ -75,10 +74,6 @@ def internal_function_label(self, is_ctor_context: bool = False) -> str: return self.ir_identifier + suffix -class FuncIR: - pass - - @dataclass class EntryPointInfo: func_t: ContractFunctionT @@ -86,7 +81,7 @@ class EntryPointInfo: ir_node: IRnode # the ir for this entry point def __post_init__(self): - # ABI v2 property guaranteed by the spec. + # sanity check ABI v2 properties guaranteed by the spec. # https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501 # > Note that for any X, len(enc(X)) is a multiple of 32. assert self.min_calldatasize >= 4 @@ -94,34 +89,28 @@ def __post_init__(self): @dataclass -class ExternalFuncIR(FuncIR): +class ExternalFuncIR: entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points common_ir: IRnode # the "common" code for the function @dataclass -class InternalFuncIR(FuncIR): +class InternalFuncIR: func_ir: IRnode # the code for the function -# TODO: should split this into external and internal ir generation? -def generate_ir_for_function( - code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False -) -> FuncIR: - """ - Parse a function and produce IR code for the function, includes: - - Signature method if statement - - Argument handling - - Clamping and copying of arguments - - Function body - """ - func_t = code._metadata["func_type"] - - # generate _FuncIRInfo +def init_ir_info(func_t: ContractFunctionT): + # initialize IRInfo on the function func_t._ir_info = _FuncIRInfo(func_t) - callees = func_t.called_functions +def initialize_context( + func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False +): + init_ir_info(func_t) + + # calculate starting frame + callees = func_t.called_functions # we start our function frame from the largest callee frame max_callee_frame_size = 0 for c_func_t in callees: @@ -132,7 +121,7 @@ def generate_ir_for_function( memory_allocator = MemoryAllocator(allocate_start) - context = Context( + return Context( vars_=None, module_ctx=module_ctx, memory_allocator=memory_allocator, @@ -141,38 +130,41 @@ def generate_ir_for_function( is_ctor_context=is_ctor_context, ) - if func_t.is_internal or func_t.is_constructor: - ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context)) - func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore - else: - kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context) - entry_points = { - k: EntryPointInfo(func_t, mincalldatasize, ir_node) - for k, (mincalldatasize, ir_node) in kwarg_handlers.items() - } - ret = ExternalFuncIR(entry_points, common) - # note: this ignores the cost of traversing selector table - func_t._ir_info.gas_estimate = ret.common_ir.gas +def tag_frame_info(func_t, context): frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY + frame_start = context.starting_memory - frame_info = FrameInfo(allocate_start, frame_size, context.vars) + frame_info = FrameInfo(frame_start, frame_size, context.vars) + func_t._ir_info.set_frame_info(frame_info) - # XXX: when can this happen? - if func_t._ir_info.frame_info is None: - func_t._ir_info.set_frame_info(frame_info) - else: - assert frame_info == func_t._ir_info.frame_info - - if func_t.is_external: - # adjust gas estimate to include cost of mem expansion - # frame_size of external function includes all private functions called - # (note: internal functions do not need to adjust gas estimate since - mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore - ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore - ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore - ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + return frame_info + + +def get_nonreentrant_lock(func_t): + if not func_t.nonreentrant: + return ["pass"], ["pass"] + + nkey = func_t.reentrancy_key_position.position + + LOAD, STORE = "sload", "sstore" + if version_check(begin="cancun"): + LOAD, STORE = "tload", "tstore" + + if version_check(begin="berlin"): + # any nonzero values would work here (see pricing as of net gas + # metering); these values are chosen so that downgrading to the + # 0,1 scheme (if it is somehow necessary) is safe. + final_value, temp_value = 3, 2 else: - ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore + final_value, temp_value = 0, 1 - return ret + check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] + + if func_t.mutability == StateMutability.VIEW: + return [check_notset], [["seq"]] + + else: + pre = ["seq", check_notset, [STORE, nkey, temp_value]] + post = [STORE, nkey, final_value] + return [pre], [post] diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py index 65276469e7..5539c78e96 100644 --- a/vyper/codegen/function_definitions/external_function.py +++ b/vyper/codegen/function_definitions/external_function.py @@ -2,12 +2,19 @@ from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp from vyper.codegen.expr import Expr -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + EntryPointInfo, + ExternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import Encoding, IRnode from vyper.codegen.stmt import parse_body from vyper.evm.address_space import CALLDATA, DATA, MEMORY from vyper.semantics.types import TupleT from vyper.semantics.types.function import ContractFunctionT +from vyper.utils import calc_mem_gas # register function args with the local calling context. @@ -126,34 +133,52 @@ def handler_for(calldata_kwargs, default_kwargs): default_kwargs = keyword_args[i:] sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) sig, calldata_min_size, ir_node = handler_for(keyword_args, []) - ret[sig] = calldata_min_size, ir_node + assert sig not in ret + ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node) return ret -def generate_ir_for_external_function(code, func_t, context): +def _adjust_gas_estimate(func_t, common_ir): + # adjust gas estimate to include cost of mem expansion + # frame_size of external function includes all private functions called + # (note: internal functions do not need to adjust gas estimate since + frame_info = func_t._ir_info.frame_info + + mem_expansion_cost = calc_mem_gas(frame_info.mem_used) + common_ir.add_gas_estimate += mem_expansion_cost + func_t._ir_info.gas_estimate = common_ir.gas + + # pass metadata through for venom pipeline: + common_ir.passthrough_metadata["func_t"] = func_t + common_ir.passthrough_metadata["frame_info"] = frame_info + + +def generate_ir_for_external_function(code, compilation_target): # TODO type hints: # def generate_ir_for_external_function( # code: vy_ast.FunctionDef, - # func_t: ContractFunctionT, - # context: Context, + # compilation_target: ModuleT, # ) -> IRnode: """ Return the IR for an external function. Returns IR for the body of the function, handle kwargs and exit the function. Also returns metadata required for `module.py` to construct the selector table. """ + func_t = code._metadata["func_type"] + context = initialize_context(func_t, compilation_target, func_t.is_constructor) nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) # generate handlers for base args and register the variable records handle_base_args = _register_function_args(func_t, context) # generate handlers for kwargs and register the variable records - kwarg_handlers = _generate_kwarg_handlers(func_t, context) + entry_points = _generate_kwarg_handlers(func_t, context) body = ["seq"] # once optional args have been handled, @@ -185,4 +210,8 @@ def generate_ir_for_external_function(code, func_t, context): # besides any kwarg handling func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code)) - return kwarg_handlers, func_common_ir + tag_frame_info(func_t, context) + + _adjust_gas_estimate(func_t, func_common_ir) + + return ExternalFuncIR(entry_points, func_common_ir) diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py index cf01dbdab4..3bb540980c 100644 --- a/vyper/codegen/function_definitions/internal_function.py +++ b/vyper/codegen/function_definitions/internal_function.py @@ -1,23 +1,25 @@ from vyper import ast as vy_ast -from vyper.codegen.context import Context -from vyper.codegen.function_definitions.utils import get_nonreentrant_lock +from vyper.codegen.function_definitions.common import ( + InternalFuncIR, + get_nonreentrant_lock, + initialize_context, + tag_frame_info, +) from vyper.codegen.ir_node import IRnode from vyper.codegen.stmt import parse_body -from vyper.semantics.types.function import ContractFunctionT def generate_ir_for_internal_function( - code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context + code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool ) -> IRnode: """ Parse a internal function (FuncDef), and produce full function body. :param func_t: the ContractFunctionT :param code: ast of function - :param context: current calling context + :param compilation_target: current calling context :return: function body in IR """ - # The calling convention is: # Caller fills in argument buffer # Caller provides return address, return buffer on the stack @@ -37,13 +39,16 @@ def generate_ir_for_internal_function( # situation like the following is easy to bork: # x: T[2] = [self.generate_T(), self.generate_T()] - # Get nonreentrant lock + func_t = code._metadata["func_type"] + + context = initialize_context(func_t, module_ctx, is_ctor_context) for arg in func_t.arguments: # allocate a variable for every arg, setting mutability # to True to allow internal function arguments to be mutable context.new_variable(arg.name, arg.typ, is_mutable=True) + # Get nonreentrant lock nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t) function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context) @@ -69,5 +74,13 @@ def generate_ir_for_internal_function( ] ir_node = IRnode.from_list(["seq", body, cleanup_routine]) + + # tag gas estimate and frame info + func_t._ir_info.gas_estimate = ir_node.gas + frame_info = tag_frame_info(func_t, context) + + # pass metadata through for venom pipeline: + ir_node.passthrough_metadata["frame_info"] = frame_info ir_node.passthrough_metadata["func_t"] = func_t - return ir_node + + return InternalFuncIR(ir_node) diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py deleted file mode 100644 index f524ec6e88..0000000000 --- a/vyper/codegen/function_definitions/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -from vyper.evm.opcodes import version_check -from vyper.semantics.types.function import StateMutability - - -def get_nonreentrant_lock(func_type): - if not func_type.nonreentrant: - return ["pass"], ["pass"] - - nkey = func_type.reentrancy_key_position.position - - LOAD, STORE = "sload", "sstore" - if version_check(begin="cancun"): - LOAD, STORE = "tload", "tstore" - - if version_check(begin="berlin"): - # any nonzero values would work here (see pricing as of net gas - # metering); these values are chosen so that downgrading to the - # 0,1 scheme (if it is somehow necessary) is safe. - final_value, temp_value = 3, 2 - else: - final_value, temp_value = 0, 1 - - check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]] - - if func_type.mutability == StateMutability.VIEW: - return [check_notset], [["seq"]] - - else: - pre = ["seq", check_notset, [STORE, nkey, temp_value]] - post = [STORE, nkey, final_value] - return [pre], [post] diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py index ad0ab34f0f..fef4f23949 100644 --- a/vyper/codegen/module.py +++ b/vyper/codegen/module.py @@ -4,7 +4,10 @@ from vyper.codegen import core, jumptable_utils from vyper.codegen.core import shr -from vyper.codegen.function_definitions import generate_ir_for_function +from vyper.codegen.function_definitions import ( + generate_ir_for_external_function, + generate_ir_for_internal_function, +) from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import _is_debug_mode from vyper.exceptions import CompilerPanic @@ -89,7 +92,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): callvalue_check = ["assert", ["iszero", "callvalue"]] ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check")) - func_ir = generate_ir_for_function(func_ast, *args, **kwargs) + func_ir = generate_ir_for_external_function(func_ast, *args, **kwargs) assert len(func_ir.entry_points) == 1 # add a goto to make the function entry look like other functions @@ -101,7 +104,7 @@ def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs): def _ir_for_internal_function(func_ast, *args, **kwargs): - return generate_ir_for_function(func_ast, *args, **kwargs).func_ir + return generate_ir_for_internal_function(func_ast, *args, **kwargs).func_ir def _generate_external_entry_points(external_functions, module_ctx): @@ -109,7 +112,7 @@ def _generate_external_entry_points(external_functions, module_ctx): sig_of = {} # reverse map from method ids to abi sig for code in external_functions: - func_ir = generate_ir_for_function(code, module_ctx) + func_ir = generate_ir_for_external_function(code, module_ctx) for abi_sig, entry_point in func_ir.entry_points.items(): method_id = method_id_int(abi_sig) assert abi_sig not in entry_points @@ -490,7 +493,7 @@ def generate_ir_for_module(module_ctx: ModuleT) -> tuple[IRnode, IRnode]: # generate init_func_ir after callees to ensure they have analyzed # memory usage. # TODO might be cleaner to separate this into an _init_ir helper func - init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx, is_ctor_context=True) + init_func_ir = _ir_for_fallback_or_ctor(init_function, module_ctx) # pass the amount of memory allocated for the init function # so that deployment does not clobber while preparing immutables diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8bf9b6eae6..1e47d4766b 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -197,8 +197,9 @@ def analyze_call_graph(self): # we just want to be able to construct the call graph. continue - if isinstance(call_t, ContractFunctionT) and (call_t.is_internal or call_t.is_constructor): - + if isinstance(call_t, ContractFunctionT) and ( + call_t.is_internal or call_t.is_constructor + ): fn_t.called_functions.add(call_t) for func in function_defs: