diff --git a/tests/builtins/folding/test_keccak_sha.py b/tests/builtins/folding/test_keccak_sha.py index 0ea293ea29..c6ac7067ed 100644 --- a/tests/builtins/folding/test_keccak_sha.py +++ b/tests/builtins/folding/test_keccak_sha.py @@ -47,13 +47,13 @@ def foo(a: Bytes[100]) -> bytes32: @pytest.mark.fuzzing -@given(value=st.binary(min_size=32, max_size=32)) +@given(value=st.binary(min_size=1, max_size=32)) @settings(max_examples=50) @pytest.mark.parametrize("fn_name", ["keccak256", "sha256"]) def test_hex(get_contract, value, fn_name): source = f""" @external -def foo(a: Bytes[100]) -> bytes32: +def foo(a: bytes{len(value)}) -> bytes32: return {fn_name}(a) """ contract = get_contract(source) diff --git a/tests/parser/functions/test_keccak256.py b/tests/parser/functions/test_keccak256.py index a0af16539e..6a36aa78ee 100644 --- a/tests/parser/functions/test_keccak256.py +++ b/tests/parser/functions/test_keccak256.py @@ -38,10 +38,13 @@ def foo(inp: Bytes[100]) -> bool: def test_hash_code3(get_contract_with_gas_estimation): hash_code3 = """ test: Bytes[100] +test1: Bytes[1] @external def set_test(inp: Bytes[100]): self.test = inp + if len(inp) > 0: + self.test1 = slice(inp, 0, 1) @external def tryy(inp: Bytes[100]) -> bool: @@ -60,6 +63,9 @@ def trymem(inp: Bytes[100]) -> bool: def try32(inp: bytes32) -> bool: return keccak256(inp) == keccak256(self.test) +@external +def try1(inp: bytes1) -> bool: + return keccak256(inp) == keccak256(self.test1) """ c = get_contract_with_gas_estimation(hash_code3) c.set_test(b"", transact={}) @@ -71,10 +77,12 @@ def try32(inp: bytes32) -> bool: assert c.tryy(b"") is False assert c.tryy(b"cow") is True assert c.tryy_str("cow") is True + assert c.try1(b"c") is True c.set_test(b"\x35" * 32, transact={}) assert c.tryy(b"\x35" * 32) is True assert c.trymem(b"\x35" * 32) is True assert c.try32(b"\x35" * 32) is True + assert c.try1(b"\x35") is True assert c.tryy(b"\x35" * 33) is False c.set_test(b"\x35" * 33, transact={}) assert c.tryy(b"\x35" * 32) is False diff --git a/tests/parser/functions/test_sha256.py b/tests/parser/functions/test_sha256.py index 442389d0fc..5a03075cc9 100644 --- a/tests/parser/functions/test_sha256.py +++ b/tests/parser/functions/test_sha256.py @@ -32,6 +32,19 @@ def bar() -> (bytes32 , bytes32): assert c.bar() == [h, h] +def test_sha256_bytes1(get_contract_with_gas_estimation): + code = """ +@external +def bar(a: bytes1) -> bytes32: + return sha256(a) + """ + + c = get_contract_with_gas_estimation(code) + + test_val = b"b" + assert c.bar(test_val) == hashlib.sha256(test_val).digest() + + def test_sha256_bytes32(get_contract_with_gas_estimation): code = """ @external diff --git a/vyper/ast/folding.py b/vyper/ast/folding.py index 1e6b88720e..c74dba8861 100644 --- a/vyper/ast/folding.py +++ b/vyper/ast/folding.py @@ -123,7 +123,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int: except UnfoldableNode: continue - # print(node._metadata["type"]) new_node._metadata["type"] = node._metadata["type"] changed_nodes += 1 diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 3d7bbb4c02..c870478f2e 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -597,7 +597,7 @@ def build_IR(self, expr, context): class Keccak256(BuiltinFunction): _id = "keccak256" # TODO allow any BytesM_T - _inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))] + _inputs = [("value", (BytesT.any(), BytesM_T.any(), StringT.any()))] _return_type = BYTES32_T def evaluate(self, node): @@ -612,7 +612,7 @@ def evaluate(self, node): arg_typ = self.infer_arg_types(node).pop() if isinstance(arg_typ, StringT): value = value.encode() - elif arg_typ == BYTES32_T: + elif isinstance(arg_typ, BytesM_T): length = len(value) // 2 - 1 value = int(value, 16).to_bytes(length, "big") @@ -648,7 +648,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len): class Sha256(BuiltinFunction): _id = "sha256" - _inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))] + _inputs = [("value", (BytesM_T.any(), BytesT.any(), StringT.any()))] _return_type = BYTES32_T def evaluate(self, node): @@ -663,7 +663,7 @@ def evaluate(self, node): arg_typ = self.infer_arg_types(node).pop() if isinstance(arg_typ, StringT): value = value.encode() - elif arg_typ == BYTES32_T: + elif isinstance(arg_typ, BytesM_T): length = len(value) // 2 - 1 value = int(value, 16).to_bytes(length, "big") @@ -679,15 +679,15 @@ def infer_arg_types(self, node, *args, **kwargs): @process_inputs def build_IR(self, expr, args, kwargs, context): sub = args[0] - # bytes32 input - if sub.typ == BYTES32_T: + # bytesM_T input + if isinstance(sub.typ, BytesM_T): return IRnode.from_list( [ "seq", ["mstore", MemoryPositions.FREE_VAR_SPACE, sub], _make_sha256_call( inp_start=MemoryPositions.FREE_VAR_SPACE, - inp_len=32, + inp_len=sub.typ.length, out_start=MemoryPositions.FREE_VAR_SPACE, out_len=32, ), diff --git a/vyper/codegen/keccak256_helper.py b/vyper/codegen/keccak256_helper.py index 9c5f5eb1d0..88ad4f7ef3 100644 --- a/vyper/codegen/keccak256_helper.py +++ b/vyper/codegen/keccak256_helper.py @@ -4,12 +4,13 @@ from vyper.codegen.ir_node import IRnode from vyper.exceptions import CompilerPanic from vyper.semantics.types.bytestrings import _BytestringT +from vyper.semantics.types.primitives import BytesM_T from vyper.semantics.types.shortcuts import BYTES32_T from vyper.utils import SHA3_BASE, SHA3_PER_WORD, MemoryPositions, bytes_to_int, keccak256 def _check_byteslike(typ): - if not isinstance(typ, _BytestringT) and typ != BYTES32_T: + if not isinstance(typ, (_BytestringT, BytesM_T)): # NOTE this may be checked at a higher level, but just be safe raise CompilerPanic("keccak256 only accepts bytes-like objects") @@ -26,14 +27,13 @@ def keccak256_helper(to_hash, context): if isinstance(to_hash, bytes): return IRnode.from_list(bytes_to_int(keccak256(to_hash)), typ=BYTES32_T) - # Can hash bytes32 objects - # TODO: Want to generalize to all bytes_M - if to_hash.typ == BYTES32_T: + # Can hash bytesM_T objects + if isinstance(to_hash.typ, BytesM_T): return IRnode.from_list( [ "seq", ["mstore", MemoryPositions.FREE_VAR_SPACE, to_hash], - ["sha3", MemoryPositions.FREE_VAR_SPACE, 32], + ["sha3", MemoryPositions.FREE_VAR_SPACE, to_hash.typ.length], ], typ=BYTES32_T, add_gas_estimate=_gas_bound(1),