Skip to content

Commit

Permalink
refactor: move arc4_util functions into ir.builder.arc4
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Oct 24, 2024
1 parent 4b854a3 commit a256d75
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 94 deletions.
67 changes: 0 additions & 67 deletions src/puya/arc4_util.py

This file was deleted.

84 changes: 61 additions & 23 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from collections.abc import Sequence
from itertools import zip_longest

import attrs

from puya.arc4_util import (
determine_arc4_tuple_head_size,
get_arc4_fixed_bit_size,
is_arc4_dynamic_size,
is_arc4_static_size,
)
from puya.avm_type import AVMType
from puya.awst import (
nodes as awst_nodes,
Expand Down Expand Up @@ -38,7 +33,7 @@
from puya.ir.types_ import AVMBytesEncoding, IRType, get_wtype_arity
from puya.ir.utils import format_tuple_index
from puya.parse import SourceLocation, sequential_source_locations_merge
from puya.utils import bits_to_bytes
from puya.utils import bits_to_bytes, round_bits_to_nearest_bytes


@attrs.frozen(kw_only=True)
Expand Down Expand Up @@ -296,7 +291,7 @@ def arc4_array_index(
source_location=source_location,
)
else:
item_bit_size = get_arc4_fixed_bit_size(item_wtype)
item_bit_size = _get_arc4_fixed_bit_size(item_wtype)
# no _assert_index_in_bounds here as static items will error on read if past end of array
return _read_static_item_from_arc4_container(
data=array_head_and_tail,
Expand Down Expand Up @@ -499,7 +494,7 @@ def concat_values(
"concat_result",
)
if is_arc4_static_size(left_element_type):
element_size = get_arc4_fixed_bit_size(left_element_type)
element_size = _get_arc4_fixed_bit_size(left_element_type)
return _concat_dynamic_array_fixed_size(
context,
left=left_expr,
Expand Down Expand Up @@ -573,7 +568,7 @@ def pop_arc4_array(
elif is_arc4_dynamic_size(array_wtype.element_type):
method_name = "dynamic_array_pop_dynamic_element"
else:
fixed_size = get_arc4_fixed_bit_size(array_wtype.element_type)
fixed_size = _get_arc4_fixed_bit_size(array_wtype.element_type)
method_name = "dynamic_array_pop_fixed_size"
args.append(fixed_size // 8)

Expand Down Expand Up @@ -645,7 +640,7 @@ def _is_byte_length_header(wtype: wtypes.ARC4Type) -> bool:
return (
isinstance(wtype, wtypes.ARC4DynamicArray)
and is_arc4_static_size(wtype.element_type)
and get_arc4_fixed_bit_size(wtype.element_type) == 8
and _get_arc4_fixed_bit_size(wtype.element_type) == 8
)


Expand All @@ -657,7 +652,7 @@ def _maybe_get_inner_element_size(item_wtype: wtypes.ARC4Type) -> int | None:
pass
case _:
return None
return get_arc4_fixed_bit_size(inner_static_element_type) // 8
return _get_arc4_fixed_bit_size(inner_static_element_type) // 8


def _read_dynamic_item_using_length_from_arc4_container(
Expand Down Expand Up @@ -740,15 +735,15 @@ def _visit_arc4_tuple_encode(
tuple_items: Sequence[wtypes.ARC4Type],
expr_loc: SourceLocation,
) -> ValueProvider:
header_size = determine_arc4_tuple_head_size(tuple_items, round_end_result=True)
header_size = _determine_arc4_tuple_head_size(tuple_items, round_end_result=True)
factory = _OpFactory(context, expr_loc)
current_tail_offset = factory.assign(factory.constant(header_size // 8), "current_tail_offset")
encoded_tuple_buffer = factory.assign(factory.constant(b""), "encoded_tuple_buffer")

for index, (element, el_wtype) in enumerate(zip(elements, tuple_items, strict=True)):
if el_wtype == wtypes.arc4_bool_wtype:
# Pack boolean
before_header = determine_arc4_tuple_head_size(
before_header = _determine_arc4_tuple_head_size(
tuple_items[0:index], round_end_result=False
)
if before_header % 8 == 0:
Expand All @@ -763,7 +758,7 @@ def _visit_arc4_tuple_encode(
bit=is_true,
temp_desc="encoded_tuple_buffer",
)
elif not is_arc4_dynamic_size(el_wtype):
elif is_arc4_static_size(el_wtype):
# Append value
encoded_tuple_buffer = factory.concat(
encoded_tuple_buffer, element, "encoded_tuple_buffer"
Expand Down Expand Up @@ -817,7 +812,7 @@ def _arc4_replace_tuple_item(
base = context.visitor.visit_and_materialise_single(base_expr)
value = factory.assign(value, "assigned_value")
element_type = wtype.types[index_int]
header_up_to_item = determine_arc4_tuple_head_size(
header_up_to_item = _determine_arc4_tuple_head_size(
wtype.types[0:index_int],
round_end_result=element_type != wtypes.arc4_bool_wtype,
)
Expand Down Expand Up @@ -848,7 +843,7 @@ def _arc4_replace_tuple_item(
# This is the last dynamic type in the tuple
# No need to update headers - just replace the data
return factory.concat(data_up_to_item, value, "updated_data")
header_up_to_next_dynamic_item = determine_arc4_tuple_head_size(
header_up_to_next_dynamic_item = _determine_arc4_tuple_head_size(
types=wtype.types[0 : dynamic_indices_after_item[0]],
round_end_result=True,
)
Expand All @@ -873,7 +868,7 @@ def _arc4_replace_tuple_item(
item_length = factory.sub(next_item_offset, item_offset, "item_length")
new_value_length = factory.len(value, "new_value_length")
for dynamic_index in dynamic_indices_after_item:
header_up_to_dynamic_item = determine_arc4_tuple_head_size(
header_up_to_dynamic_item = _determine_arc4_tuple_head_size(
types=wtype.types[0:dynamic_index],
round_end_result=True,
)
Expand Down Expand Up @@ -903,7 +898,7 @@ def _read_nth_item_of_arc4_heterogeneous_container(
tuple_item_types = tuple_type.types

item_wtype = tuple_item_types[index]
head_up_to_item = determine_arc4_tuple_head_size(
head_up_to_item = _determine_arc4_tuple_head_size(
tuple_item_types[:index], round_end_result=False
)
if item_wtype == wtypes.arc4_bool_wtype:
Expand Down Expand Up @@ -933,7 +928,7 @@ def _read_nth_item_of_arc4_heterogeneous_container(
tuple_item_types[next_index:], start=next_index
):
if is_arc4_dynamic_size(tuple_item_type):
head_up_to_next_dynamic_item = determine_arc4_tuple_head_size(
head_up_to_next_dynamic_item = _determine_arc4_tuple_head_size(
tuple_item_types[:tuple_item_index], round_end_result=False
)
next_dynamic_head_offset = UInt64Constant(
Expand Down Expand Up @@ -994,7 +989,7 @@ def _read_static_item_from_arc4_container(
item_wtype: wtypes.ARC4Type,
source_location: SourceLocation,
) -> ValueProvider:
item_bit_size = get_arc4_fixed_bit_size(item_wtype)
item_bit_size = _get_arc4_fixed_bit_size(item_wtype)
item_length = UInt64Constant(value=item_bit_size // 8, source_location=source_location)
return Intrinsic(
op=AVMOp.extract3,
Expand Down Expand Up @@ -1173,7 +1168,7 @@ def updated_result(method_name: str, args: list[Value | int | bytes]) -> Registe
)
_assert_index_in_bounds(context, index, array_length, source_location)

element_size = get_arc4_fixed_bit_size(wtype.element_type)
element_size = _get_arc4_fixed_bit_size(wtype.element_type)
dynamic_offset = 0 if isinstance(wtype, wtypes.ARC4StaticArray) else 2
if element_size == 1:
dynamic_offset *= 8
Expand Down Expand Up @@ -1384,7 +1379,7 @@ def _get_arc4_array_tail(
array_head_and_tail: Value,
source_location: SourceLocation,
) -> Value:
if not is_arc4_dynamic_size(element_wtype):
if is_arc4_static_size(element_wtype):
# no header for static sized elements
return array_head_and_tail

Expand Down Expand Up @@ -1606,3 +1601,46 @@ def extract3(
source_location=self.source_location,
)
return result


def is_arc4_dynamic_size(wtype: wtypes.ARC4Type) -> bool:
match wtype:
case wtypes.ARC4DynamicArray():
return True
case wtypes.ARC4StaticArray(element_type=element_type):
return is_arc4_dynamic_size(element_type)
case wtypes.ARC4Tuple(types=types) | wtypes.ARC4Struct(types=types):
return any(map(is_arc4_dynamic_size, types))
return False


def is_arc4_static_size(wtype: wtypes.ARC4Type) -> bool:
return not is_arc4_dynamic_size(wtype)


def _get_arc4_fixed_bit_size(wtype: wtypes.ARC4Type) -> int:
if is_arc4_dynamic_size(wtype):
raise InternalError(f"Cannot get fixed bit size for a dynamic ABI type: {wtype}")
match wtype:
case wtypes.arc4_bool_wtype:
return 1
case wtypes.ARC4UIntN(n=n) | wtypes.ARC4UFixedNxM(n=n):
return n
case wtypes.ARC4StaticArray(element_type=element_type, array_size=array_size):
el_size = _get_arc4_fixed_bit_size(element_type)
return round_bits_to_nearest_bytes(array_size * el_size)
case wtypes.ARC4Tuple(types=types) | wtypes.ARC4Struct(types=types):
return _determine_arc4_tuple_head_size(types, round_end_result=True)
raise InternalError(f"Unexpected ABI wtype: {wtype}")


def _determine_arc4_tuple_head_size(
types: Sequence[wtypes.ARC4Type], *, round_end_result: bool
) -> int:
bit_size = 0
for t, next_t in zip_longest(types, types[1:]):
size = 16 if is_arc4_dynamic_size(t) else _get_arc4_fixed_bit_size(t)
bit_size += size
if t == wtypes.arc4_bool_wtype and next_t != t and (round_end_result or next_t):
bit_size = round_bits_to_nearest_bytes(bit_size)
return bit_size
6 changes: 2 additions & 4 deletions src/puya/ir/builder/assignment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing
from collections.abc import Sequence

from puya import arc4_util, log
from puya import log
from puya.avm_type import AVMType
from puya.awst import (
nodes as awst_nodes,
Expand Down Expand Up @@ -154,9 +154,7 @@ def _handle_assignment(
)
if scalar_type == AVMType.bytes:
serialized_value = mat_value
if not (
isinstance(wtype, wtypes.ARC4Type) and arc4_util.is_arc4_static_size(wtype)
):
if not (isinstance(wtype, wtypes.ARC4Type) and arc4.is_arc4_static_size(wtype)):
context.block_builder.add(
Intrinsic(
op=AVMOp.box_del, args=[key_value], source_location=assignment_location
Expand Down

0 comments on commit a256d75

Please sign in to comment.