Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for modules with variables #3707

Closed
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a59d575
wip - module variables
charles-cooper Dec 20, 2023
e353a15
add size_in_bytes to module
charles-cooper Dec 20, 2023
33dffaf
wip - storage allocator
charles-cooper Dec 15, 2023
654256b
remove ImportedVariable thing
charles-cooper Dec 20, 2023
bcc03d6
wip - add get_element_ptr for module, fix some logic in Expr.parse_At…
charles-cooper Dec 20, 2023
e9b867a
call set_data_positions recursively
charles-cooper Dec 21, 2023
037d5a6
add a sanity check
charles-cooper Dec 21, 2023
3512e3f
rename some size calculators and add immutable_bytes_required to Vype…
charles-cooper Dec 21, 2023
86c299a
add a comment
charles-cooper Dec 21, 2023
9e689b9
add Context.self_ptr helper
charles-cooper Dec 21, 2023
8becda2
add a note
charles-cooper Dec 21, 2023
a0d0bd1
Merge branch 'master' into feat/module_variables
charles-cooper Dec 21, 2023
9ac072b
improve Context.self_ptr
charles-cooper Dec 21, 2023
4e102bf
add function variable read/writes analysis
charles-cooper Dec 21, 2023
9f100d9
calculate pointer things
charles-cooper Dec 22, 2023
2d05699
quash mypy
charles-cooper Dec 22, 2023
1585bdc
wip - handle immutables
charles-cooper Dec 22, 2023
bf6e99c
feat: replace `enum` with `flag` keyword (#3697)
AlbertoCentonze Dec 23, 2023
1824321
refactor: make `assert_tx_failed` a contextmanager (#3706)
DanielSchiavini Dec 23, 2023
7489e34
feat: allow `range(x, y, bound=N)` (#3679)
DanielSchiavini Dec 24, 2023
1040f3e
feat: improve panics in IR generation (#3708)
charles-cooper Dec 25, 2023
977851a
add special visibility for the __init__ function
charles-cooper Dec 27, 2023
8af611b
remove unused MemoryOffset, CalldataOffset classes
charles-cooper Dec 28, 2023
c241e91
wip - allow init functions to be called from init func
charles-cooper Dec 29, 2023
3de7bb2
rename DataLocations.CODE to IMMUTABLES
charles-cooper Dec 29, 2023
ee91a52
refactor set_data_positions and rework VarInfo positions API
charles-cooper Dec 29, 2023
5e08300
thread new offset through codegen
charles-cooper Dec 29, 2023
1e393fa
mark storage layout override tests as xfail
charles-cooper Dec 29, 2023
6cf9ff0
Merge branch 'master' into feat/module_variables
charles-cooper Dec 30, 2023
7a82180
Merge branch 'master' into feat/module_variables
charles-cooper Jan 2, 2024
eff76e0
Merge branch 'master' into feat/module_variables
charles-cooper Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tests/functional/codegen/types/test_node_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)

# TODO: this module should be merged in with other tests/functional/semantics/types/ tests.
# and moved to tests/unit/!


def test_bytearray_node_type():
Expand Down Expand Up @@ -51,17 +52,17 @@ def test_canonicalize_type():


def test_type_storage_sizes():
assert IntegerT(True, 128).storage_size_in_words == 1
assert BytesT(12).storage_size_in_words == 2
assert BytesT(33).storage_size_in_words == 3
assert SArrayT(IntegerT(True, 128), 10).storage_size_in_words == 10
assert IntegerT(True, 128).storage_slots_required == 1
assert BytesT(12).storage_slots_required == 2
assert BytesT(33).storage_slots_required == 3
assert SArrayT(IntegerT(True, 128), 10).storage_slots_required == 10

tuple_ = TupleT([IntegerT(True, 128), DecimalT()])
assert tuple_.storage_size_in_words == 2
assert tuple_.storage_slots_required == 2

struct_ = StructT("Foo", {"a": IntegerT(True, 128), "b": DecimalT()})
assert struct_.storage_size_in_words == 2
assert struct_.storage_slots_required == 2

# Don't allow unknown types.
with raises(AttributeError):
_ = int.storage_size_in_words
_ = int.storage_slots_required
10 changes: 5 additions & 5 deletions tests/unit/semantics/types/test_size_in_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_base_types(build_node, type_str):
node = build_node(type_str)
type_definition = type_from_annotation(node)

assert type_definition.size_in_bytes == 32
assert type_definition._size_in_bytes == 32


@pytest.mark.parametrize("type_str", BYTESTRING_TYPES)
Expand All @@ -20,7 +20,7 @@ def test_array_value_types(build_node, type_str, length, size):
node = build_node(f"{type_str}[{length}]")
type_definition = type_from_annotation(node)

assert type_definition.size_in_bytes == size
assert type_definition._size_in_bytes == size


@pytest.mark.parametrize("type_str", BASE_TYPES)
Expand All @@ -29,7 +29,7 @@ def test_dynamic_array_lengths(build_node, type_str, length):
node = build_node(f"DynArray[{type_str}, {length}]")
type_definition = type_from_annotation(node)

assert type_definition.size_in_bytes == 32 + length * 32
assert type_definition._size_in_bytes == 32 + length * 32


@pytest.mark.parametrize("type_str", BASE_TYPES)
Expand All @@ -38,7 +38,7 @@ def test_base_types_as_arrays(build_node, type_str, length):
node = build_node(f"{type_str}[{length}]")
type_definition = type_from_annotation(node)

assert type_definition.size_in_bytes == length * 32
assert type_definition._size_in_bytes == length * 32


@pytest.mark.parametrize("type_str", BASE_TYPES)
Expand All @@ -49,4 +49,4 @@ def test_base_types_as_multidimensional_arrays(build_node, type_str, first, seco

type_definition = type_from_annotation(node)

assert type_definition.size_in_bytes == first * second * 32
assert type_definition._size_in_bytes == first * second * 32
34 changes: 25 additions & 9 deletions vyper/codegen/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from dataclasses import dataclass
from typing import Any, Optional

from vyper.codegen.ir_node import Encoding
from vyper.evm.address_space import MEMORY, AddrSpace
import vyper.ast as vy_ast
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import IMMUTABLES, MEMORY, STORAGE, AddrSpace
from vyper.exceptions import CompilerPanic, StateAccessViolation
from vyper.semantics.types import VyperType
from vyper.semantics.types import ModuleT, VyperType


class Constancy(enum.Enum):
Expand Down Expand Up @@ -48,7 +49,7 @@
class Context:
def __init__(
self,
module_ctx,
compilation_target,
memory_allocator,
vars_=None,
forvars=None,
Expand All @@ -59,9 +60,6 @@
# In-memory variables, in the form (name, memory location, type)
self.vars = vars_ or {}

# Global variables, in the form (name, storage location, type)
self.globals = module_ctx.variables

# Variables defined in for loops, e.g. for i in range(6): ...
self.forvars = forvars or {}

Expand All @@ -75,8 +73,8 @@
# Whether we are currently parsing a range expression
self.in_range_expr = False

# store module context
self.module_ctx = module_ctx
# the type information for the current compilation target
self.compilation_target: ModuleT = compilation_target

# full function type
self.func_t = func_t
Expand All @@ -94,6 +92,24 @@
# either the constructor, or called from the constructor
self.is_ctor_context = is_ctor_context

def self_ptr(self, location):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
func_module = self.func_t.ast_def._parent
assert isinstance(func_module, vy_ast.Module)

module_t = func_module._metadata["type"]
module_is_compilation_target = module_t == self.compilation_target

if module_is_compilation_target:
# return 0 for the special case where compilation target is self
return IRnode.from_list(0, typ=module_t, location=location)

# otherwise, the function compilation context takes a `self_ptr`
# argument in the calling convention
if location == STORAGE:
return IRnode.from_list("self_ptr_storage", typ=module_t, location=location)
if location == IMMUTABLES:
return IRnode.from_list("self_ptr_code", typ=module_t, location=location)

def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr

Expand Down
53 changes: 51 additions & 2 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types import (
AddressT,
BoolT,
Expand All @@ -17,6 +18,7 @@
HashMapT,
IntegerT,
InterfaceT,
ModuleT,
StructT,
TupleT,
_BytestringT,
Expand Down Expand Up @@ -64,6 +66,17 @@ def is_array_like(typ):
return ret


def data_location_to_addr_space(s: DataLocation):
if s == DataLocation.STORAGE:
return STORAGE
if s == DataLocation.MEMORY:
return MEMORY
if s == DataLocation.CODE:
return IMMUTABLES

raise CompilerPanic("unreachable") # pragma: nocover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if s == DataLocation.STORAGE:
return STORAGE
if s == DataLocation.MEMORY:
return MEMORY
if s == DataLocation.CODE:
return IMMUTABLES
raise CompilerPanic("unreachable") # pragma: nocover
return {
DataLocation.STORAGE: STORAGE,
DataLocation.MEMORY: MEMORY,
DataLocation.CODE: IMMUTABLES,
}[s]



def get_type_for_exact_size(n_bytes):
"""Create a type which will take up exactly n_bytes. Used for allocating internal buffers.

Expand Down Expand Up @@ -442,6 +455,39 @@ def _getelemptr_abi_helper(parent, member_t, ofst, clamp=True):
)


# get a variable out of a module
def _get_element_ptr_module(parent, key):
# note that this implementation is substantially similar to
# the StructT pathway through get_element_ptr_tuplelike and
# has potential to be refactored.
module_t = parent.typ
assert isinstance(module_t, ModuleT)

assert isinstance(key, str)
typ = module_t.variables[key].typ
attrs = list(module_t.variables.keys())
index = attrs.index(key)
annotation = key

ofst = 0 # offset from parent start

assert parent.location == STORAGE, parent.location

for i in range(index):
ofst += module_t.variables[attrs[i]].typ.storage_slots_required
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(index):
ofst += module_t.variables[attrs[i]].typ.storage_slots_required
offset_from_parent = sum(
module_t.variables[var].typ.storage_slots_required
for var in attrs[:index]
)


# calculated the same way both ways
assert ofst == module_t.variables[key].position.position

return IRnode.from_list(
add_ofst(parent, ofst),
typ=typ,
location=parent.location,
encoding=parent.encoding,
annotation=annotation,
)


# TODO simplify this code, especially the ABI decoding
def _get_element_ptr_tuplelike(parent, key):
typ = parent.typ
Expand Down Expand Up @@ -485,7 +531,7 @@ def _get_element_ptr_tuplelike(parent, key):

if parent.location.word_addressable:
for i in range(index):
ofst += typ.member_types[attrs[i]].storage_size_in_words
ofst += typ.member_types[attrs[i]].storage_slots_required
elif parent.location.byte_addressable:
for i in range(index):
ofst += typ.member_types[attrs[i]].memory_bytes_required
Expand Down Expand Up @@ -552,7 +598,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check):
return _getelemptr_abi_helper(parent, subtype, ofst)

if parent.location.word_addressable:
element_size = subtype.storage_size_in_words
element_size = subtype.storage_slots_required
elif parent.location.byte_addressable:
element_size = subtype.memory_bytes_required
else:
Expand Down Expand Up @@ -590,6 +636,9 @@ def get_element_ptr(parent, key, array_bounds_check=True):
if is_tuple_like(typ):
ret = _get_element_ptr_tuplelike(parent, key)

elif isinstance(typ, ModuleT):
ret = _get_element_ptr_module(parent, key)

elif isinstance(typ, HashMapT):
ret = _get_element_ptr_mapping(parent, key)

Expand Down
67 changes: 39 additions & 28 deletions vyper/codegen/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
VyperException,
tag_exceptions,
)
from vyper.semantics.analysis.base import VarInfo

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.base
begins an import cycle.
from vyper.semantics.types import (
AddressT,
BoolT,
Expand Down Expand Up @@ -169,7 +170,8 @@ def parse_NameConstant(self):
# Variable names
def parse_Name(self):
if self.expr.id == "self":
return IRnode.from_list(["address"], typ=AddressT())
# TODO: have `self` return a module type
return IRnode.from_list(["self"], typ=AddressT())
elif self.expr.id in self.context.vars:
var = self.context.vars[self.expr.id]
ret = IRnode.from_list(
Expand Down Expand Up @@ -219,7 +221,7 @@ def parse_Attribute(self):
return IRnode.from_list(value, typ=typ)

# x.balance: balance of address x
if self.expr.attr == "balance":
elif self.expr.attr == "balance":
addr = Expr.parse_value_expr(self.expr.value, self.context)
if addr.typ == AddressT():
if (
Expand All @@ -231,6 +233,7 @@ def parse_Attribute(self):
else:
seq = ["balance", addr]
return IRnode.from_list(seq, typ=UINT256_T)

# x.codesize: codesize of address x
elif self.expr.attr == "codesize" or self.expr.attr == "is_contract":
addr = Expr.parse_value_expr(self.expr.value, self.context)
Expand All @@ -242,14 +245,17 @@ def parse_Attribute(self):
eval_code = ["extcodesize", addr]
output_type = UINT256_T
else:
assert self.expr.attr == "is_contract"
eval_code = ["gt", ["extcodesize", addr], 0]
output_type = BoolT()
return IRnode.from_list(eval_code, typ=output_type)

# x.codehash: keccak of address x
elif self.expr.attr == "codehash":
addr = Expr.parse_value_expr(self.expr.value, self.context)
if addr.typ == AddressT():
return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T)

# x.code: codecopy/extcodecopy of address x
elif self.expr.attr == "code":
addr = Expr.parse_value_expr(self.expr.value, self.context)
Expand All @@ -258,24 +264,12 @@ def parse_Attribute(self):
if addr.value == "address": # for `self.code`
return IRnode.from_list(["~selfcode"], typ=BytesT(0))
return IRnode.from_list(["~extcode", addr], typ=BytesT(0))
# self.x: global attribute
elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self":
varinfo = self.context.globals[self.expr.attr]
location = TRANSIENT if varinfo.is_transient else STORAGE

ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
location=location,
annotation="self." + self.expr.attr,
)
ret._referenced_variables = {varinfo}

return ret

# Reserved keywords
elif (
isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES
# TODO: use type information here
isinstance(self.expr.value, vy_ast.Name)
and self.expr.value.id in ENVIRONMENT_VARIABLES
):
key = f"{self.expr.value.id}.{self.expr.attr}"
if key == "msg.sender":
Expand Down Expand Up @@ -327,17 +321,34 @@ def parse_Attribute(self):
"chain.id is unavailable prior to istanbul ruleset", self.expr
)
return IRnode.from_list(["chainid"], typ=UINT256_T)
# Other variables
else:
sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

# self.x: global storage variable or immutable
if (varinfo := self.expr._metadata.get("variable_access")) is not None:
assert isinstance(varinfo, VarInfo)

# TODO: handle immutables
location = TRANSIENT if varinfo.is_transient else STORAGE

module_ptr = Expr(self.expr.value, self.context).ir_node
if module_ptr.value == "self":
module_ptr = self.context.self_ptr(location)

ret = get_element_ptr(module_ptr, self.expr.attr)
# TODO: take referenced variables info from analysis
ret._referenced_variables = {varinfo}
return ret

# if we have gotten here, it's an instance of an interface or struct
sub = Expr(self.expr.value, self.context).ir_node

if isinstance(sub.typ, InterfaceT):
# MyInterface.address
assert self.expr.attr == "address"
sub.typ = typ
return sub

if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)

def parse_Subscript(self):
sub = Expr(self.expr.value, self.context).ir_node
Expand Down
4 changes: 2 additions & 2 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class InternalFuncIR(FuncIR):

# TODO: should split this into external and internal ir generation?
def generate_ir_for_function(
code: vy_ast.FunctionDef, module_ctx: ModuleT, is_ctor_context: bool = False
code: vy_ast.FunctionDef, compilation_target: ModuleT, is_ctor_context: bool = False
) -> FuncIR:
"""
Parse a function and produce IR code for the function, includes:
Expand Down Expand Up @@ -133,7 +133,7 @@ def generate_ir_for_function(

context = Context(
vars_=None,
module_ctx=module_ctx,
compilation_target=compilation_target,
memory_allocator=memory_allocator,
constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant,
func_t=func_t,
Expand Down
Loading
Loading