Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support _replace method on ARC4Struct #379

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
arc4_types/Arc4RefTypes 85 46 - | 32 27 -
arc4_types/Arc4StringTypes 349 35 - | 149 13 -
arc4_types/Arc4StructsFromAnotherModule 67 12 - | 49 6 -
arc4_types/Arc4StructsType 386 48 - | 258 16 -
arc4_types/Arc4StructsType 424 48 - | 284 16 -
arc4_types/Arc4TuplesType 938 8 - | 644 4 -
arc4_types/MutableParams2 318 193 48 | 185 92 23
arc_28/EventEmitter 172 124 102 | 92 58 48
Expand Down Expand Up @@ -138,4 +138,4 @@
unssa/UnSSA 420 266 - | 237 153 -
voting/VotingRoundApp 1584 1426 1415 | 725 624 625
with_reentrancy/WithReentrancy 245 214 - | 126 108 -
Total 70482 39117 36225 | 33270 18249 16976
Total 70520 39117 36225 | 33296 18249 16976
54 changes: 52 additions & 2 deletions src/puyapy/awst_build/eb/arc4/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from puya import log
from puya.awst import wtypes
from puya.awst.nodes import Expression, FieldExpression, NewStruct
from puya.awst.nodes import Copy, Expression, FieldExpression, NewStruct
from puya.parse import SourceLocation
from puyapy.awst_build import pytypes
from puyapy.awst_build.eb import _expect as expect
from puyapy.awst_build.eb._base import NotIterableInstanceExpressionBuilder
from puyapy.awst_build.eb._base import FunctionBuilder, NotIterableInstanceExpressionBuilder
from puyapy.awst_build.eb._bytes_backed import (
BytesBackedInstanceExpressionBuilder,
BytesBackedTypeBuilder,
Expand Down Expand Up @@ -78,6 +78,8 @@ def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
return builder_for_instance(field, result_expr)
case "copy":
return CopyBuilder(self.resolve(), location, self.pytype)
case "_replace":
return _Replace(self, self.pytype, location)
case _:
return super().member_access(name, location)

Expand All @@ -88,3 +90,51 @@ def compare(

def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> InstanceBuilder:
return constant_bool_and_error(value=True, location=location, negate=negate)


class _Replace(FunctionBuilder):
def __init__(
self,
instance: ARC4StructExpressionBuilder,
struct_type: pytypes.StructType,
location: SourceLocation,
):
super().__init__(location)
self.instance = instance
self.struct_type = struct_type

@typing.override
def call(
self,
args: Sequence[NodeBuilder],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
) -> InstanceBuilder:
pytype = self.struct_type
field_mapping, _ = get_arg_mapping(
optional_kw_only=list(pytype.fields),
args=args,
arg_names=arg_names,
call_location=location,
raise_on_missing=False,
)
base_expr = self.instance.single_eval().resolve()
values = dict[str, Expression]()
for field_name, field_pytype in pytype.fields.items():
new_value = field_mapping.get(field_name)
if new_value is not None:
item_builder = expect.argument_of_type_else_dummy(new_value, field_pytype)
item = item_builder.resolve()
else:
field_wtype = field_pytype.checked_wtype(location)
item = FieldExpression(base=base_expr, name=field_name, source_location=location)
if not field_wtype.immutable:
logger.error(
f"mutable field {field_name!r} requires explicit copy", location=location
)
# implicitly create a copy node so that there is only one error
item = Copy(value=item, source_location=location)
values[field_name] = item
new_tuple = NewStruct(values=values, wtype=pytype.wtype, source_location=location)
return ARC4StructExpressionBuilder(new_tuple, pytype)
2 changes: 1 addition & 1 deletion src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class StructType(RuntimeType):
converter=immutabledict, validator=[attrs.validators.min_len(1)]
)
frozen: bool
wtype: wtypes.WType
wtype: wtypes.ARC4Struct | wtypes.WStructType
source_location: SourceLocation | None
generic: None = None
desc: str | None = None
Expand Down
5 changes: 5 additions & 0 deletions stubs/algopy-stubs/arc4.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ class Struct(metaclass=_StructMeta):
def copy(self) -> typing.Self:
"""Create a copy of this struct"""

def _replace(self, **kwargs: typing.Any) -> typing.Self: # type: ignore[misc]
"""Return a new instance of the struct replacing specified fields with new values.

Note that any mutable fields must be explicitly copied to avoid aliasing."""

class ARC4Client(typing.Protocol):
"""Used to provide typed method signatures for ARC4 contracts"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"sources": [
"../structs.py"
],
"mappings": ";AAsC+B;;;;;;;;;;;;AAAX;;;;;;;;;;AACR;AADZ;AAAA;;;;;;;;;;;;AAGgB;;;AAER;AAcO;;AAAP",
"mappings": ";AAsC+B;;;;;;;;;;;;AAAX;;;;;;;;;;AACR;AADZ;AAAA;;;;;;;;;;;;AAGgB;;;AAER;AAkBO;;AAAP",
"op_pc_offset": 0,
"pc_events": {
"1": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"sources": [
"../structs.py"
],
"mappings": ";AA4De;;AAAP",
"mappings": ";AAgEe;;AAAP",
"op_pc_offset": 0,
"pc_events": {
"1": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
pushint 1 // 1
return
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,24 @@ main test_cases.arc4_types.structs.Arc4StructsTypeContract.approval_program:
let no_copy#0: bytes = immutable#0
let tmp%5#0: bool = (== no_copy#0 immutable#0)
(assert tmp%5#0)
let tmp%6#0: bytes = (extract3 immutable#0 0u 8u) // on error: Index access is out of bounds
let current_tail_offset%7#0: uint64 = 16u
let encoded_tuple_buffer%20#0: bytes = 0x
let encoded_tuple_buffer%21#0: bytes = (concat encoded_tuple_buffer%20#0 tmp%6#0)
let encoded_tuple_buffer%22#0: bytes = (concat encoded_tuple_buffer%21#0 0x000000000000007b)
let immutable2#0: bytes = encoded_tuple_buffer%22#0
let reinterpret_biguint%0#0: biguint = (extract3 immutable2#0 8u 8u) // on error: Index access is out of bounds
let reinterpret_biguint%1#0: biguint = 0x000000000000007b
let tmp%7#0: bool = (b== reinterpret_biguint%0#0 reinterpret_biguint%1#0)
(assert tmp%7#0)
let reinterpret_biguint%2#0: biguint = (extract3 immutable2#0 0u 8u) // on error: Index access is out of bounds
let reinterpret_biguint%3#0: biguint = (extract3 immutable#0 0u 8u) // on error: Index access is out of bounds
let tmp%8#0: bool = (b== reinterpret_biguint%2#0 reinterpret_biguint%3#0)
(assert tmp%8#0)
return 1u

subroutine test_cases.arc4_types.structs.add(v1: bytes, v2: bytes) -> <bytes, bytes, bytes>:
block@0: // L64
block@0: // L68
let v1%is_original#0: bool = 1u
let v1%out#0: bytes = v1#0
let v2%is_original#0: bool = 1u
Expand All @@ -122,15 +136,15 @@ subroutine test_cases.arc4_types.structs.add(v1: bytes, v2: bytes) -> <bytes, by
return encoded_tuple_buffer%2#0 v1#0 v2#0

subroutine test_cases.arc4_types.structs.add_decimal(x: bytes, y: bytes) -> bytes:
block@0: // L86
block@0: // L90
let tmp%0#0: uint64 = (btoi x#0)
let tmp%1#0: uint64 = (btoi y#0)
let tmp%2#0: uint64 = (+ tmp%0#0 tmp%1#0)
let tmp%3#0: bytes = (itob tmp%2#0)
return tmp%3#0

subroutine test_cases.arc4_types.structs.check(flags: bytes) -> bytes:
block@0: // L72
block@0: // L76
let flags%is_original#0: bool = 1u
let flags%out#0: bytes = flags#0
let is_true%0#0: uint64 = (getbit flags#0 0u)
Expand All @@ -154,7 +168,7 @@ subroutine test_cases.arc4_types.structs.check(flags: bytes) -> bytes:
return flags%out#0

subroutine test_cases.arc4_types.structs.nested_decode(vector_flags: bytes) -> bytes:
block@0: // L80
block@0: // L84
let vector_flags%is_original#0: bool = 1u
let vector_flags%out#0: bytes = vector_flags#0
let tmp%0#0: bytes = (extract3 vector_flags#0 0u 16u) // on error: Index access is out of bounds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ main test_cases.arc4_types.structs.Arc4StructsTypeContract.approval_program:
return 1u

subroutine test_cases.arc4_types.structs.add_decimal(x: bytes, y: bytes) -> bytes:
block@0: // L86
block@0: // L90
let tmp%0#0: uint64 = (btoi x#0)
let tmp%1#0: uint64 = (btoi y#0)
let tmp%2#0: uint64 = (+ tmp%0#0 tmp%1#0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ main_after_for@4:
// arc4_types/structs.py:44
// log(flags.bytes)
log (𝕗) val#2,loop_counter%0#0 |
// arc4_types/structs.py:58
// arc4_types/structs.py:62
// return True
int 1 (𝕗) val#2,loop_counter%0#0 | 1
return (𝕗) val#2,loop_counter%0#0 |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
main test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program:
block@0: // L60
block@0: // L64
return 1u
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
main test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program:
block@0: // L60
block@0: // L64
return 1u
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Op Stack (out)
// test_cases.arc4_types.structs.Arc4StructsTypeContract.clear_state_program() -> uint64:
main:
// arc4_types/structs.py:61
// arc4_types/structs.py:65
// return True
int 1 1
return
Expand Down
3 changes: 3 additions & 0 deletions test_cases/arc4_types/out/module.awst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ contract Arc4StructsTypeContract
immutable: test_cases.arc4_types.structs.FrozenAndImmutable = new test_cases.arc4_types.structs.FrozenAndImmutable(one=12_arc4u64, two=34_arc4u64)
no_copy: test_cases.arc4_types.structs.FrozenAndImmutable = immutable
assert(no_copy == immutable)
immutable2: test_cases.arc4_types.structs.FrozenAndImmutable = new test_cases.arc4_types.structs.FrozenAndImmutable(one=immutable.one, two=123_arc4u64)
assert(reinterpret_cast<biguint>(immutable2.two) == reinterpret_cast<biguint>(123_arc4u64))
assert(reinterpret_cast<biguint>(immutable2.one) == reinterpret_cast<biguint>(immutable.one))
return true
}

Expand Down
Loading
Loading