Skip to content

Commit

Permalink
feat[test]: add more coverage to abi_decode fuzzer tests (#4153)
Browse files Browse the repository at this point in the history
fuzz with `unwrap_tuple=False`
add fuzzing for structs

follow up to 69e5c05
  • Loading branch information
charles-cooper authored Jun 17, 2024
1 parent 69e5c05 commit 2d82a74
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 20 deletions.
124 changes: 105 additions & 19 deletions tests/functional/builtins/codegen/test_abi_decode_fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
IntegerT,
SArrayT,
StringT,
StructT,
TupleT,
VyperType,
_get_primitive_types,
_get_sequence_types,
)
from vyper.semantics.types.shortcuts import UINT256_T

from .abi_decode import DecodeError, spec_decode

Expand All @@ -39,7 +39,7 @@
continue
type_ctors.append(t)

complex_static_ctors = [SArrayT, TupleT]
complex_static_ctors = [SArrayT, TupleT, StructT]
complex_dynamic_ctors = [DArrayT]
leaf_ctors = [t for t in type_ctors if t not in _get_sequence_types().values()]
static_leaf_ctors = [t for t in leaf_ctors if t._is_prim_word]
Expand All @@ -50,10 +50,12 @@

@st.composite
# max type nesting
def vyper_type(draw, nesting=3, skip=None):
def vyper_type(draw, nesting=3, skip=None, source_fragments=None):
assert nesting >= 0

skip = skip or []
if source_fragments is None:
source_fragments = []

st_leaves = st.one_of(st.sampled_from(dynamic_leaf_ctors), st.sampled_from(static_leaf_ctors))
st_complex = st.one_of(
Expand All @@ -71,39 +73,52 @@ def vyper_type(draw, nesting=3, skip=None):
# note: maybe st.deferred is good here, we could define it with
# mutual recursion
def _go(skip=skip):
return draw(vyper_type(nesting=nesting - 1, skip=skip))
_, typ = draw(vyper_type(nesting=nesting - 1, skip=skip, source_fragments=source_fragments))
return typ

def finalize(typ):
return source_fragments, typ

if t in (BytesT, StringT):
# arbitrary max_value
bound = draw(st.integers(min_value=1, max_value=1024))
return t(bound)
return finalize(t(bound))

if t == SArrayT:
subtype = _go(skip=[TupleT, BytesT, StringT])
bound = draw(st.integers(min_value=1, max_value=6))
return t(subtype, bound)
return finalize(t(subtype, bound))
if t == DArrayT:
subtype = _go(skip=[TupleT])
bound = draw(st.integers(min_value=1, max_value=16))
return t(subtype, bound)
return finalize(t(subtype, bound))

if t == TupleT:
# zero-length tuples are not allowed in vyper
n = draw(st.integers(min_value=1, max_value=6))
subtypes = [_go() for _ in range(n)]
return TupleT(subtypes)
return finalize(TupleT(subtypes))

if t == StructT:
n = draw(st.integers(min_value=1, max_value=6))
subtypes = {f"x{i}": _go() for i in range(n)}
_id = len(source_fragments) # poor man's unique id
name = f"MyStruct{_id}"
typ = StructT(name, subtypes)
source_fragments.append(typ.def_source_str())
return finalize(StructT(name, subtypes))

if t in (BoolT, AddressT):
return t()
return finalize(t())

if t == IntegerT:
signed = draw(st.booleans())
bits = 8 * draw(st.integers(min_value=1, max_value=32))
return t(signed, bits)
return finalize(t(signed, bits))

if t == BytesM_T:
m = draw(st.integers(min_value=1, max_value=32))
return t(m)
return finalize(t(m))

raise RuntimeError("unreachable")

Expand All @@ -116,6 +131,9 @@ def _go(t):
if isinstance(typ, TupleT):
return tuple(_go(item_t) for item_t in typ.member_types)

if isinstance(typ, StructT):
return tuple(_go(item_t) for item_t in typ.tuple_members())

if isinstance(typ, SArrayT):
return [_go(typ.value_type) for _ in range(typ.length)]

Expand Down Expand Up @@ -294,6 +312,13 @@ def _finalize(): # little trick to save re-typing the arguments
num_dynamic_types = sum(s.num_dynamic_types for s in substats)
return _finalize()

if isinstance(typ, StructT):
substats = [_type_stats(t) for t in typ.tuple_members()]
nesting = 1 + max(s.nesting for s in substats)
breadth = max(len(typ.member_types), *[s.breadth for s in substats])
num_dynamic_types = sum(s.num_dynamic_types for s in substats)
return _finalize()

if isinstance(typ, DArrayT):
substat = _type_stats(typ.value_type)
nesting = 1 + substat.nesting
Expand Down Expand Up @@ -332,8 +357,8 @@ def payload_copier(get_contract_from_ir):
@pytest.mark.parametrize("_n", list(range(PARALLELISM)))
@hp.given(typ=vyper_type())
@hp.settings(max_examples=100, **_settings)
@hp.example(typ=DArrayT(DArrayT(UINT256_T, 2), 2))
def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier, env):
source_fragments, typ = typ
# import time
# t0 = time.time()
# print("ENTER", typ)
Expand All @@ -350,12 +375,13 @@ def test_abi_decode_fuzz(_n, typ, get_contract, tx_failed, payload_copier):
# by bytes length check at function entry
type_bound = wrapped_type.abi_type.size_bound()
buffer_bound = type_bound + MAX_MUTATIONS
type_str = repr(typ) # annotation in vyper code
# TODO: intrinsic decode from staticcall/extcall
# TODO: _abi_decode from other sources (staticcall/extcall?)
# TODO: dirty the buffer
# TODO: check unwrap_tuple=False

preamble = "\n\n".join(source_fragments)
type_str = str(typ) # annotation in vyper code

code = f"""
{preamble}
@external
def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
ret: {type_str} = abi_decode(xs, {type_str})
Expand All @@ -375,14 +401,20 @@ def run3(xs: Bytes[{buffer_bound}], copier: Foo) -> {type_str}:
assert len(xs) <= {type_bound}
return (extcall copier.bar(xs))
"""
try:
c = get_contract(code)
except EvmError as e:
if env.contract_size_limit_error in str(e):
hp.assume(False)
# print(code)
hp.note(code)
c = get_contract(code)

@hp.given(data=payload_from(wrapped_type))
@hp.settings(max_examples=100, **_settings)
def _fuzz(data):
hp.note(f"type: {typ}")
hp.note(f"abi_t: {wrapped_type.abi_type.selector_name()}")
hp.note(code)
hp.note(data.hex())

try:
Expand Down Expand Up @@ -414,3 +446,57 @@ def _fuzz(data):

# t1 = time.time()
# print(f"elapsed {t1 - t0}s")


@pytest.mark.parametrize("_n", list(range(PARALLELISM)))
@hp.given(typ=vyper_type())
@hp.settings(max_examples=100, **_settings)
def test_abi_decode_no_wrap_fuzz(_n, typ, get_contract, tx_failed, env):
source_fragments, typ = typ
# import time
# t0 = time.time()
# print("ENTER", typ)

stats = _type_stats(typ)
hp.target(stats.num_dynamic_types)

# add max_mutations bytes worth of padding so we don't just get caught
# by bytes length check at function entry
type_bound = typ.abi_type.size_bound()
buffer_bound = type_bound + MAX_MUTATIONS

type_str = str(typ) # annotation in vyper code
preamble = "\n\n".join(source_fragments)

code = f"""
{preamble}
@external
def run(xs: Bytes[{buffer_bound}]) -> {type_str}:
ret: {type_str} = abi_decode(xs, {type_str}, unwrap_tuple=False)
return ret
"""
try:
c = get_contract(code)
except EvmError as e:
if env.contract_size_limit_error in str(e):
hp.assume(False)

@hp.given(data=payload_from(typ))
@hp.settings(max_examples=100, **_settings)
def _fuzz(data):
hp.note(code)
hp.note(data.hex())
try:
expected = spec_decode(typ, data)
hp.note(f"expected {expected}")
assert expected == c.run(data)
except DecodeError:
hp.note("expect failure")
with tx_failed(EvmError):
c.run(data)

_fuzz()

# t1 = time.time()
# print(f"elapsed {t1 - t0}s")
11 changes: 10 additions & 1 deletion vyper/semantics/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,11 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT":

return cls(struct_name, members, ast_def=base_node)

def __str__(self):
return f"{self._id}"

def __repr__(self):
return f"{self._id} declaration object"
return f"{self._id} {self.members}"

def _try_fold(self, node):
if len(node.args) != 1:
Expand All @@ -384,6 +387,12 @@ def _try_fold(self, node):
# it can't be reduced, but this lets upstream code know it's constant
return node

def def_source_str(self):
ret = f"struct {self._id}:\n"
for k, v in self.member_types.items():
ret += f" {k}: {v}\n"
return ret

@property
def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types.values())
Expand Down

0 comments on commit 2d82a74

Please sign in to comment.