Skip to content

Commit

Permalink
fix PyTypes handling of inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Nov 7, 2024
1 parent f6502b9 commit 0ec9fce
Show file tree
Hide file tree
Showing 65 changed files with 672 additions and 590 deletions.
40 changes: 17 additions & 23 deletions src/puyapy/awst_build/arc4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _allowed_oca(name: object) -> OnCompletionAction | None:


def _is_arc4_struct(typ: pytypes.PyType) -> typing.TypeGuard[pytypes.StructType]:
if pytypes.ARC4StructBaseType not in typ.mro:
if not (pytypes.ARC4StructBaseType < typ):
return False
if not isinstance(typ, pytypes.StructType):
raise InternalError(
Expand Down Expand Up @@ -357,14 +357,6 @@ def pytype_to_arc4_pytype(
match pytype:
case pytypes.BoolType:
return pytypes.ARC4BoolType
case pytypes.UInt64Type:
return pytypes.ARC4UIntN_Aliases[64]
case pytypes.BigUIntType:
return pytypes.ARC4UIntN_Aliases[512]
case pytypes.BytesType:
return pytypes.ARC4DynamicBytesType
case pytypes.StringType:
return pytypes.ARC4StringType
case pytypes.NamedTupleType():
return pytypes.StructType(
base=pytypes.ARC4StructBaseType,
Expand All @@ -379,21 +371,23 @@ def pytype_to_arc4_pytype(
return pytypes.GenericARC4TupleType.parameterise(
[pytype_to_arc4_pytype(t, on_error) for t in pytype.items], pytype.source_location
)

case (
pytypes.NoneType
| pytypes.ApplicationType
| pytypes.AssetType
| pytypes.AccountType
| pytypes.GroupTransactionBaseType
):
return pytype
case maybe_gtxn if maybe_gtxn in pytypes.GroupTransactionTypes.values():
case pytypes.NoneType | pytypes.GroupTransactionType():
return pytype
case pytypes.PyType(wtype=wtypes.ARC4Type()):
return pytype
case unsupported:
return on_error(unsupported)

if pytypes.UInt64Type <= pytype:
return pytypes.ARC4UIntN_Aliases[64]
elif pytypes.BigUIntType <= pytype:
return pytypes.ARC4UIntN_Aliases[512]
elif pytypes.BytesType <= pytype:
return pytypes.ARC4DynamicBytesType
elif pytypes.StringType <= pytype:
return pytypes.ARC4StringType
elif pytype.is_type_or_subtype(
pytypes.ApplicationType, pytypes.AssetType, pytypes.AccountType
) or isinstance(pytype.wtype, wtypes.ARC4Type):
return pytype
else:
return on_error(pytype)


_UINT_REGEX = re.compile(r"^uint(?P<n>[0-9]+)$")
Expand Down
13 changes: 7 additions & 6 deletions src/puyapy/awst_build/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def add_program_method(
name: str,
body: Sequence[awst_nodes.Statement],
*,
return_type: pytypes.PyType = pytypes.BoolType,
return_type: pytypes.RuntimeType = pytypes.BoolType,
) -> None:
result.symbols[name] = pytypes.FuncType(
name=".".join((_ARC4_CONTRACT_BASE_CREF, name)),
Expand Down Expand Up @@ -670,7 +670,7 @@ def _build_symbols_and_state(
if pytyp and not isinstance(pytyp, pytypes.FuncType):
definition = None
if isinstance(pytyp, pytypes.StorageProxyType):
wtypes.validate_persistable(pytyp.content.wtype, node_loc)
wtypes.validate_persistable(pytyp.content_wtype, node_loc)
match pytyp.generic:
case pytypes.GenericLocalStateType:
kind = AppStorageKind.account_local
Expand All @@ -683,13 +683,14 @@ def _build_symbols_and_state(
case _:
raise InternalError(f"unhandled StorageProxyType: {pytyp}", node_loc)
elif isinstance(pytyp, pytypes.StorageMapProxyType):
wtypes.validate_persistable(pytyp.key.wtype, node_loc)
wtypes.validate_persistable(pytyp.content.wtype, node_loc)
wtypes.validate_persistable(pytyp.key_wtype, node_loc)
wtypes.validate_persistable(pytyp.content_wtype, node_loc)
if pytyp.generic != pytypes.GenericBoxMapType:
raise InternalError(f"unhandled StorageMapProxyType: {pytyp}", node_loc)
kind = AppStorageKind.box
else: # global state, direct
wtypes.validate_persistable(pytyp.wtype, node_loc)
wtype = pytyp.checked_wtype(node_loc)
wtypes.validate_persistable(wtype, node_loc)
key = awst_nodes.BytesConstant(
value=name.encode("utf8"),
encoding=BytesEncoding.utf8,
Expand All @@ -701,7 +702,7 @@ def _build_symbols_and_state(
source_location=node_loc,
member_name=name,
kind=kind,
storage_wtype=pytyp.wtype,
storage_wtype=wtype,
key_wtype=None,
key=key,
description=None,
Expand Down
4 changes: 3 additions & 1 deletion src/puyapy/awst_build/eb/_bytes_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def call(
args, pytypes.BytesType, location, resolve_literal=True
)
result_expr = ReinterpretCast(
expr=arg.resolve(), wtype=self.result_type.wtype, source_location=location
expr=arg.resolve(),
wtype=self.result_type.checked_wtype(location),
source_location=location,
)
return builder_for_instance(self.result_type, result_expr)

Expand Down
14 changes: 7 additions & 7 deletions src/puyapy/awst_build/eb/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from puyapy.awst_build import pytypes
from puyapy.awst_build.eb._utils import dummy_value
from puyapy.awst_build.eb.interface import InstanceBuilder, NodeBuilder
from puyapy.awst_build.utils import is_type_or_subtype, maybe_resolve_literal
from puyapy.awst_build.utils import maybe_resolve_literal

_T = typing.TypeVar("_T")
_TBuilder = typing.TypeVar("_TBuilder", bound=NodeBuilder)
Expand All @@ -35,7 +35,7 @@ def at_most_one_arg_of_type(
first, *rest = args
if rest:
logger.error(f"expected at most 1 argument, got {len(args)}", location=location)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of_any=valid_types):
if isinstance(first, InstanceBuilder) and first.pytype.is_type_or_subtype(*valid_types):
return first
return not_this_type(first, default=default_none)

Expand Down Expand Up @@ -111,7 +111,7 @@ def exactly_one_arg(

def exactly_one_arg_of_type(
args: Sequence[NodeBuilder],
pytype: pytypes.PyType,
expected: pytypes.PyType,
location: SourceLocation,
*,
default: Callable[[str, SourceLocation], _T],
Expand All @@ -126,8 +126,8 @@ def exactly_one_arg_of_type(
if rest:
logger.error(f"expected 1 argument, got {len(args)}", location=location)
if resolve_literal:
first = maybe_resolve_literal(first, pytype)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of=pytype):
first = maybe_resolve_literal(first, expected)
if isinstance(first, InstanceBuilder) and expected <= first.pytype:
return first
return not_this_type(first, default=default)

Expand Down Expand Up @@ -183,8 +183,8 @@ def argument_of_type(
if resolve_literal:
builder = maybe_resolve_literal(builder, target_type)

if isinstance(builder, InstanceBuilder) and is_type_or_subtype(
builder.pytype, of_any=(target_type, *additional_types)
if isinstance(builder, InstanceBuilder) and builder.pytype.is_type_or_subtype(
target_type, *additional_types
):
return builder
return not_this_type(builder, default=default)
Expand Down
2 changes: 1 addition & 1 deletion src/puyapy/awst_build/eb/_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, value: ConstantValue, source_location: SourceLocation):
self._value = value
match value:
case bool():
typ = pytypes.BoolType
typ: pytypes.PyType = pytypes.BoolType
case int():
typ = pytypes.IntLiteralType
case str():
Expand Down
11 changes: 6 additions & 5 deletions src/puyapy/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def dummy_value(pytype: pytypes.PyType, location: SourceLocation) -> InstanceBui
from puyapy.awst_build.eb._literals import LiteralBuilderImpl

return LiteralBuilderImpl(pytype.python_type(), location)
expr = VarExpression(name="", wtype=pytype.wtype, source_location=location)
expr = VarExpression(name="", wtype=pytype.checked_wtype(location), source_location=location)
return builder_for_instance(pytype, expr)


Expand Down Expand Up @@ -79,14 +79,15 @@ def constant_bool_and_error(

def compare_bytes(
*,
lhs: InstanceBuilder,
self: InstanceBuilder,
op: BuilderComparisonOp,
rhs: InstanceBuilder,
other: InstanceBuilder,
source_location: SourceLocation,
) -> InstanceBuilder:
if rhs.pytype != lhs.pytype:
# defer to most derived type if not equal
if not (other.pytype <= self.pytype):
return NotImplemented
return _compare_expr_bytes_unchecked(lhs.resolve(), op, rhs.resolve(), source_location)
return _compare_expr_bytes_unchecked(self.resolve(), op, other.resolve(), source_location)


def compare_expr_bytes(
Expand Down
19 changes: 8 additions & 11 deletions src/puyapy/awst_build/eb/arc4/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@
resolve_negative_literal_index,
)
from puyapy.awst_build.eb.factories import builder_for_instance
from puyapy.awst_build.eb.interface import (
BuilderComparisonOp,
InstanceBuilder,
NodeBuilder,
)
from puyapy.awst_build.eb.interface import BuilderComparisonOp, InstanceBuilder, NodeBuilder

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -65,7 +61,7 @@ def abi_expr_from_log(
) -> Expression:
tmp_value = value.single_eval().resolve()
arc4_value = intrinsic_factory.extract(
tmp_value, start=4, loc=location, result_type=typ.wtype
tmp_value, start=4, loc=location, result_type=typ.checked_wtype(location)
)
arc4_prefix = intrinsic_factory.extract(tmp_value, start=0, length=4, loc=location)
arc4_prefix_is_valid = compare_expr_bytes(
Expand Down Expand Up @@ -120,15 +116,16 @@ def call(
def arc4_bool_bytes(
builder: InstanceBuilder, false_bytes: bytes, location: SourceLocation, *, negate: bool
) -> InstanceBuilder:
lhs = builder.resolve()
false_value = BytesConstant(
value=false_bytes,
encoding=BytesEncoding.base16,
wtype=builder.pytype.wtype,
wtype=lhs.wtype,
source_location=location,
)
return compare_expr_bytes(
op=BuilderComparisonOp.eq if negate else BuilderComparisonOp.ne,
lhs=builder.resolve(),
lhs=lhs,
rhs=false_value,
source_location=location,
)
Expand All @@ -137,7 +134,7 @@ def arc4_bool_bytes(
class _ARC4ArrayExpressionBuilder(BytesBackedInstanceExpressionBuilder[pytypes.ArrayType], ABC):
@typing.override
def iterate(self) -> Expression:
if not self.pytype.items.wtype.immutable:
if not self.pytype.items_wtype.immutable:
# this case is an error raised during AWST validation
# adding a front end specific message here to compliment the error message
# raise across all front ends
Expand All @@ -158,7 +155,7 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui
result_expr = IndexExpression(
base=self.resolve(),
index=index.resolve(),
wtype=self.pytype.items.wtype,
wtype=self.pytype.items_wtype,
source_location=location,
)
return builder_for_instance(self.pytype.items, result_expr)
Expand All @@ -180,7 +177,7 @@ def member_access(self, name: str, location: SourceLocation) -> NodeBuilder:
def compare(
self, other: InstanceBuilder, op: BuilderComparisonOp, location: SourceLocation
) -> InstanceBuilder:
return compare_bytes(lhs=self, op=op, rhs=other, source_location=location)
return compare_bytes(self=self, op=op, other=other, source_location=location)

@typing.override
@typing.final
Expand Down
38 changes: 10 additions & 28 deletions src/puyapy/awst_build/eb/arc4/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from puyapy.awst_build.eb.factories import builder_for_type
from puyapy.awst_build.eb.interface import (
InstanceBuilder,
LiteralBuilder,
NodeBuilder,
StaticSizedCollectionBuilder,
)
Expand Down Expand Up @@ -83,25 +82,15 @@ def convert_args(


def _gtxn_to_itxn(pytype: pytypes.PyType) -> pytypes.PyType:
if (
isinstance(pytype, pytypes.TransactionRelatedType)
and pytypes.GroupTransactionBaseType in pytype.mro
):
txn_type = pytype.transaction_type
return pytypes.InnerTransactionFieldsetTypes[txn_type]
if isinstance(pytype, pytypes.GroupTransactionType):
return pytypes.InnerTransactionFieldsetTypes[pytype.transaction_type]
return pytype


def get_arc4_signature(
method: NodeBuilder, native_args: Sequence[NodeBuilder], loc: SourceLocation
) -> tuple[str, ARC4Signature]:
method = expect.argument_of_type(method, pytypes.StrLiteralType, default=expect.default_raise)
match method:
case LiteralBuilder(value=str(method_sig)):
pass
case _:
raise CodeError("method selector must be a simple str literal", method.source_location)

method_sig = expect.simple_string_literal(method, default=expect.default_raise)
method_name, maybe_args, maybe_returns = _split_signature(method_sig, method.source_location)
if maybe_args is None:
arg_types = [
Expand Down Expand Up @@ -145,16 +134,14 @@ def on_error(invalid_pytype: pytypes.PyType) -> typing.Never:


def _inner_transaction_type_matches(instance: pytypes.PyType, target: pytypes.PyType) -> bool:
from puya.awst.wtypes import WInnerTransactionFields

if not isinstance(instance.wtype, WInnerTransactionFields):
if not isinstance(instance, pytypes.InnerTransactionFieldsetType):
return False
if not isinstance(target.wtype, WInnerTransactionFields):
if not isinstance(target, pytypes.InnerTransactionFieldsetType):
return False
return (
instance.wtype == target.wtype
or instance.wtype.transaction_type is None
or target.wtype.transaction_type is None
instance.transaction_type == target.transaction_type
or instance.transaction_type is None
or target.transaction_type is None
)


Expand All @@ -165,7 +152,7 @@ def _implicit_arc4_conversion(

instance = expect.instance_builder(operand, default=expect.default_dummy_value(target_type))
instance = _maybe_resolve_arc4_literal(instance, target_type)
if instance.pytype == target_type:
if target_type <= instance.pytype:
return instance
target_wtype = target_type.wtype
if isinstance(target_type, pytypes.TransactionRelatedType):
Expand All @@ -189,7 +176,7 @@ def _implicit_arc4_conversion(
location=instance.source_location,
)
return dummy_value(target_type, instance.source_location)
if not target_wtype.can_encode_type(instance.pytype.wtype):
if not target_wtype.can_encode_type(instance.pytype.checked_wtype(instance.source_location)):
logger.error(
f"cannot encode {instance.pytype} to {target_type}", location=instance.source_location
)
Expand Down Expand Up @@ -278,8 +265,3 @@ def _split_signature(
if not name or not _VALID_NAME_PATTERN.match(name):
logger.error(f"invalid signature: {name=}", location=location)
return name, args, returns


def no_literal_items(array_type: pytypes.ArrayType, location: SourceLocation) -> None:
if isinstance(array_type.items, pytypes.LiteralOnlyType):
raise CodeError("arrays of literals are not supported", location)
Loading

0 comments on commit 0ec9fce

Please sign in to comment.