Skip to content

Commit

Permalink
get __init__() working both as entry point and as internal function
Browse files Browse the repository at this point in the history
refactor:
- refactor generate_ir_for_function into generate_ir_for_external_function
  and generate_ir_for_internal_function
- move get_nonreentrant_lock to function-definitions/common.py
  • Loading branch information
charles-cooper committed Jan 15, 2024
1 parent d6d7de7 commit b67d361
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 117 deletions.
6 changes: 5 additions & 1 deletion vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __repr__(self):
return f"VariableRecord({ret})"


# Contains arguments, variables, etc
# compilation context for a function
class Context:
def __init__(
self,
Expand Down Expand Up @@ -83,6 +83,10 @@ def __init__(
# Not intended to be accessed directly
self.memory_allocator = memory_allocator

# save the starting memory location so we can find out (later)
# how much memory this function uses.
self.starting_memory = memory_allocator.next_mem

# Incremented values, used for internal IDs
self._internal_var_iter = 0
self._scope_id_iter = 0
Expand Down
5 changes: 4 additions & 1 deletion vyper/codegen/function_definitions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from .common import FuncIR, generate_ir_for_function # noqa
from .external_function import generate_ir_for_external_function
from .internal_function import generate_ir_for_internal_function

__all__ = [generate_ir_for_internal_function, generate_ir_for_external_function]
114 changes: 53 additions & 61 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
from functools import cached_property
from typing import Optional

import vyper.ast as vy_ast
from vyper.codegen.context import Constancy, Context
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.ir_node import IRnode
from vyper.codegen.memory_allocator import MemoryAllocator
from vyper.exceptions import CompilerPanic
from vyper.evm.opcodes import version_check
from vyper.semantics.types import VyperType
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.function import ContractFunctionT, StateMutability
from vyper.semantics.types.module import ModuleT
from vyper.utils import MemoryPositions, calc_mem_gas
from vyper.utils import MemoryPositions


@dataclass
Expand Down Expand Up @@ -53,14 +50,16 @@ def ir_identifier(self) -> str:
return f"{self.visibility} {function_id} {name}({argz})"

def set_frame_info(self, frame_info: FrameInfo) -> None:
# XXX: when can this happen?
if self.frame_info is not None:
raise CompilerPanic(f"frame_info already set for {self.func_t}!")
self.frame_info = frame_info
assert frame_info == self.frame_info
else:
self.frame_info = frame_info

@property
# common entry point for external function with kwargs
def external_function_base_entry_label(self) -> str:
assert self.func_t.is_external, "uh oh, should be external"
assert not self.func_t.is_internal, "uh oh, should be external"
return self.ir_identifier + "_common"

def internal_function_label(self, is_ctor_context: bool = False) -> str:
Expand All @@ -75,53 +74,43 @@ def internal_function_label(self, is_ctor_context: bool = False) -> str:
return self.ir_identifier + suffix


class FuncIR:
pass


@dataclass
class EntryPointInfo:
func_t: ContractFunctionT
min_calldatasize: int # the min calldata required for this entry point
ir_node: IRnode # the ir for this entry point

def __post_init__(self):
# ABI v2 property guaranteed by the spec.
# sanity check ABI v2 properties guaranteed by the spec.
# https://docs.soliditylang.org/en/v0.8.21/abi-spec.html#formal-specification-of-the-encoding states: # noqa: E501
# > Note that for any X, len(enc(X)) is a multiple of 32.
assert self.min_calldatasize >= 4
assert (self.min_calldatasize - 4) % 32 == 0


@dataclass
class ExternalFuncIR(FuncIR):
class ExternalFuncIR:
entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points
common_ir: IRnode # the "common" code for the function


@dataclass
class InternalFuncIR(FuncIR):
class InternalFuncIR:
func_ir: IRnode # the code for the function


# TODO: should split this into external and internal ir generation?
def generate_ir_for_function(
code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False
) -> FuncIR:
"""
Parse a function and produce IR code for the function, includes:
- Signature method if statement
- Argument handling
- Clamping and copying of arguments
- Function body
"""
func_t = code._metadata["func_type"]

# generate _FuncIRInfo
def init_ir_info(func_t: ContractFunctionT):
# initialize IRInfo on the function
func_t._ir_info = _FuncIRInfo(func_t)

callees = func_t.called_functions

def initialize_context(
func_t: ContractFunctionT, module_ctx: ModuleT, is_ctor_context: bool = False
):
init_ir_info(func_t)

# calculate starting frame
callees = func_t.called_functions
# we start our function frame from the largest callee frame
max_callee_frame_size = 0
for c_func_t in callees:
Expand All @@ -132,7 +121,7 @@ def generate_ir_for_function(

memory_allocator = MemoryAllocator(allocate_start)

context = Context(
return Context(
vars_=None,
module_ctx=module_ctx,
memory_allocator=memory_allocator,
Expand All @@ -141,38 +130,41 @@ def generate_ir_for_function(
is_ctor_context=is_ctor_context,
)

if func_t.is_internal or func_t.is_constructor:
ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context))
func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore
else:
kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context)
entry_points = {
k: EntryPointInfo(func_t, mincalldatasize, ir_node)
for k, (mincalldatasize, ir_node) in kwarg_handlers.items()
}
ret = ExternalFuncIR(entry_points, common)
# note: this ignores the cost of traversing selector table
func_t._ir_info.gas_estimate = ret.common_ir.gas

def tag_frame_info(func_t, context):
frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY
frame_start = context.starting_memory

frame_info = FrameInfo(allocate_start, frame_size, context.vars)
frame_info = FrameInfo(frame_start, frame_size, context.vars)
func_t._ir_info.set_frame_info(frame_info)

# XXX: when can this happen?
if func_t._ir_info.frame_info is None:
func_t._ir_info.set_frame_info(frame_info)
else:
assert frame_info == func_t._ir_info.frame_info

if func_t.is_external:
# adjust gas estimate to include cost of mem expansion
# frame_size of external function includes all private functions called
# (note: internal functions do not need to adjust gas estimate since
mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore
ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore
ret.common_ir.passthrough_metadata["func_t"] = func_t # type: ignore
ret.common_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore
return frame_info


def get_nonreentrant_lock(func_t):
if not func_t.nonreentrant:
return ["pass"], ["pass"]

nkey = func_t.reentrancy_key_position.position

LOAD, STORE = "sload", "sstore"
if version_check(begin="cancun"):
LOAD, STORE = "tload", "tstore"

if version_check(begin="berlin"):
# any nonzero values would work here (see pricing as of net gas
# metering); these values are chosen so that downgrading to the
# 0,1 scheme (if it is somehow necessary) is safe.
final_value, temp_value = 3, 2
else:
ret.func_ir.passthrough_metadata["frame_info"] = frame_info # type: ignore
final_value, temp_value = 0, 1

return ret
check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]]

if func_t.mutability == StateMutability.VIEW:
return [check_notset], [["seq"]]

else:
pre = ["seq", check_notset, [STORE, nkey, temp_value]]
post = [STORE, nkey, final_value]
return [pre], [post]
45 changes: 37 additions & 8 deletions vyper/codegen/function_definitions/external_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp
from vyper.codegen.expr import Expr
from vyper.codegen.function_definitions.utils import get_nonreentrant_lock
from vyper.codegen.function_definitions.common import (
EntryPointInfo,
ExternalFuncIR,
get_nonreentrant_lock,
initialize_context,
tag_frame_info,
)
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.codegen.stmt import parse_body
from vyper.evm.address_space import CALLDATA, DATA, MEMORY
from vyper.semantics.types import TupleT
from vyper.semantics.types.function import ContractFunctionT
from vyper.utils import calc_mem_gas


# register function args with the local calling context.
Expand Down Expand Up @@ -126,34 +133,52 @@ def handler_for(calldata_kwargs, default_kwargs):
default_kwargs = keyword_args[i:]

sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs)
ret[sig] = calldata_min_size, ir_node
assert sig not in ret
ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node)

sig, calldata_min_size, ir_node = handler_for(keyword_args, [])

ret[sig] = calldata_min_size, ir_node
assert sig not in ret
ret[sig] = EntryPointInfo(func_t, calldata_min_size, ir_node)

return ret


def generate_ir_for_external_function(code, func_t, context):
def _adjust_gas_estimate(func_t, common_ir):
# adjust gas estimate to include cost of mem expansion
# frame_size of external function includes all private functions called
# (note: internal functions do not need to adjust gas estimate since
frame_info = func_t._ir_info.frame_info

mem_expansion_cost = calc_mem_gas(frame_info.mem_used)
common_ir.add_gas_estimate += mem_expansion_cost
func_t._ir_info.gas_estimate = common_ir.gas

# pass metadata through for venom pipeline:
common_ir.passthrough_metadata["func_t"] = func_t
common_ir.passthrough_metadata["frame_info"] = frame_info


def generate_ir_for_external_function(code, compilation_target):
# TODO type hints:
# def generate_ir_for_external_function(
# code: vy_ast.FunctionDef,
# func_t: ContractFunctionT,
# context: Context,
# compilation_target: ModuleT,
# ) -> IRnode:
"""
Return the IR for an external function. Returns IR for the body
of the function, handle kwargs and exit the function. Also returns
metadata required for `module.py` to construct the selector table.
"""
func_t = code._metadata["func_type"]
context = initialize_context(func_t, compilation_target, func_t.is_constructor)
nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)

# generate handlers for base args and register the variable records
handle_base_args = _register_function_args(func_t, context)

# generate handlers for kwargs and register the variable records
kwarg_handlers = _generate_kwarg_handlers(func_t, context)
entry_points = _generate_kwarg_handlers(func_t, context)

body = ["seq"]
# once optional args have been handled,
Expand Down Expand Up @@ -185,4 +210,8 @@ def generate_ir_for_external_function(code, func_t, context):
# besides any kwarg handling
func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code))

return kwarg_handlers, func_common_ir
tag_frame_info(func_t, context)

_adjust_gas_estimate(func_t, func_common_ir)

return ExternalFuncIR(entry_points, func_common_ir)
29 changes: 21 additions & 8 deletions vyper/codegen/function_definitions/internal_function.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from vyper import ast as vy_ast
from vyper.codegen.context import Context
from vyper.codegen.function_definitions.utils import get_nonreentrant_lock
from vyper.codegen.function_definitions.common import (
InternalFuncIR,
get_nonreentrant_lock,
initialize_context,
tag_frame_info,
)
from vyper.codegen.ir_node import IRnode
from vyper.codegen.stmt import parse_body
from vyper.semantics.types.function import ContractFunctionT


def generate_ir_for_internal_function(
code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context
code: vy_ast.FunctionDef, module_ctx, is_ctor_context: bool
) -> IRnode:
"""
Parse a internal function (FuncDef), and produce full function body.
:param func_t: the ContractFunctionT
:param code: ast of function
:param context: current calling context
:param compilation_target: current calling context
:return: function body in IR
"""

# The calling convention is:
# Caller fills in argument buffer
# Caller provides return address, return buffer on the stack
Expand All @@ -37,13 +39,16 @@ def generate_ir_for_internal_function(
# situation like the following is easy to bork:
# x: T[2] = [self.generate_T(), self.generate_T()]

# Get nonreentrant lock
func_t = code._metadata["func_type"]

context = initialize_context(func_t, module_ctx, is_ctor_context)

for arg in func_t.arguments:
# allocate a variable for every arg, setting mutability
# to True to allow internal function arguments to be mutable
context.new_variable(arg.name, arg.typ, is_mutable=True)

# Get nonreentrant lock
nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)

function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context)
Expand All @@ -69,5 +74,13 @@ def generate_ir_for_internal_function(
]

ir_node = IRnode.from_list(["seq", body, cleanup_routine])

# tag gas estimate and frame info
func_t._ir_info.gas_estimate = ir_node.gas
frame_info = tag_frame_info(func_t, context)

# pass metadata through for venom pipeline:
ir_node.passthrough_metadata["frame_info"] = frame_info
ir_node.passthrough_metadata["func_t"] = func_t
return ir_node

return InternalFuncIR(ir_node)
31 changes: 0 additions & 31 deletions vyper/codegen/function_definitions/utils.py

This file was deleted.

Loading

0 comments on commit b67d361

Please sign in to comment.