diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py index 9609e76ba0..180c0266bf 100644 --- a/tests/parser/test_selector_table.py +++ b/tests/parser/test_selector_table.py @@ -478,69 +478,72 @@ def test_dense_jumptable_bucket_size(n_methods, seed): assert n_buckets / n < 0.4 or n < 10 +@st.composite +def generate_methods(draw, max_calldata_bytes): + max_default_args = draw(st.integers(min_value=0, max_value=4)) + default_fn_mutability = draw(st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"])) + + return ( + max_default_args, + default_fn_mutability, + draw( + st.lists( + st.tuples( + # function id: + st.integers(min_value=0), + # mutability: + st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), + # n calldata words: + st.integers(min_value=0, max_value=max_calldata_bytes // 32), + # n bytes to strip from calldata + st.integers(min_value=1, max_value=4), + # n default args + st.integers(min_value=0, max_value=max_default_args), + ), + unique_by=lambda x: x[0], + min_size=1, + max_size=100, + ) + ), + ) + + @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) # dense selector table packing boundaries at 256 and 65336 @pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336]) -@settings(max_examples=5, deadline=None) -@given( - max_default_args=st.integers(min_value=0, max_value=4), - default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]), -) @pytest.mark.fuzzing def test_selector_table_fuzz( - max_calldata_bytes, - max_default_args, - opt_level, - default_fn_mutability, - w3, - get_contract, - assert_tx_failed, - get_logs, + max_calldata_bytes, opt_level, w3, get_contract, assert_tx_failed, get_logs ): - def abi_sig(seed, calldata_words, i, n_default_args): - args = [] if not calldata_words else [f"uint256[{calldata_words}]"] - args.extend(["uint256"] * n_default_args) - argstr = ",".join(args) - return f"foo{seed + i}({argstr})" + def abi_sig(func_id, calldata_words, n_default_args): + params = [] if not calldata_words else [f"uint256[{calldata_words}]"] + params.extend(["uint256"] * n_default_args) + paramstr = ",".join(params) + return f"foo{func_id}({paramstr})" - def generate_func_def(seed, mutability, calldata_words, i, n_default_args): + def generate_func_def(func_id, mutability, calldata_words, n_default_args): arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"] for j in range(n_default_args): arglist.append(f"x{j}: uint256 = 0") args = ", ".join(arglist) - _log_return = f"log _Return({i})" if mutability == "@payable" else "" + _log_return = f"log _Return({func_id})" if mutability == "@payable" else "" return f""" @external {mutability} -def foo{seed + i}({args}) -> uint256: +def foo{func_id}({args}) -> uint256: {_log_return} - return {i} + return {func_id} """ - @given( - methods=st.lists( - st.tuples( - # seed: - st.integers(min_value=0, max_value=2**64 - 1), - # mutability: - st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]), - # n calldata words: - st.integers(min_value=0, max_value=max_calldata_bytes // 32), - # n bytes to strip from calldata - st.integers(min_value=1, max_value=4), - # n default args - st.integers(min_value=0, max_value=max_default_args), - ), - min_size=1, - max_size=100, - ) - ) - @settings(max_examples=25) - def _test(methods): + @given(_input=generate_methods(max_calldata_bytes)) + @settings(max_examples=125, deadline=None) + def _test(_input): + max_default_args, default_fn_mutability, methods = _input + func_defs = "\n".join( - generate_func_def(seed, mutability, calldata_words, i, n_default_args) - for i, (seed, mutability, calldata_words, _, n_default_args) in enumerate(methods) + generate_func_def(func_id, mutability, calldata_words, n_default_args) + for (func_id, mutability, calldata_words, _, n_default_args) in (methods) ) if default_fn_mutability == "": @@ -574,10 +577,8 @@ def __default__(): c = get_contract(code, override_opt_level=opt_level) - for i, (seed, mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate( - methods - ): - funcname = f"foo{seed + i}" + for func_id, mutability, n_calldata_words, n_strip_bytes, n_default_args in methods: + funcname = f"foo{func_id}" func = getattr(c, funcname) for j in range(n_default_args + 1): @@ -585,9 +586,9 @@ def __default__(): args.extend([1] * j) # check the function returns as expected - assert func(*args) == i + assert func(*args) == func_id - method_id = utils.method_id(abi_sig(seed, n_calldata_words, i, j)) + method_id = utils.method_id(abi_sig(func_id, n_calldata_words, j)) argsdata = b"\x00" * (n_calldata_words * 32 + j * 32) @@ -595,7 +596,7 @@ def __default__(): if mutability == "@payable": tx = func(*args, transact={"value": 1}) (event,) = get_logs(tx, c, "_Return") - assert event.args.val == i + assert event.args.val == func_id else: hexstr = (method_id + argsdata).hex() txdata = {"to": c.address, "data": hexstr, "value": 1}