Skip to content

Commit

Permalink
allow bytesm_t
Browse files Browse the repository at this point in the history
  • Loading branch information
tserg committed Oct 26, 2023
1 parent fd95436 commit 652199e
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 15 deletions.
4 changes: 2 additions & 2 deletions tests/builtins/folding/test_keccak_sha.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def foo(a: Bytes[100]) -> bytes32:


@pytest.mark.fuzzing
@given(value=st.binary(min_size=32, max_size=32))
@given(value=st.binary(min_size=1, max_size=32))
@settings(max_examples=50)
@pytest.mark.parametrize("fn_name", ["keccak256", "sha256"])
def test_hex(get_contract, value, fn_name):
source = f"""
@external
def foo(a: Bytes[100]) -> bytes32:
def foo(a: bytes{len(value)}) -> bytes32:
return {fn_name}(a)
"""
contract = get_contract(source)
Expand Down
8 changes: 8 additions & 0 deletions tests/parser/functions/test_keccak256.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ def foo(inp: Bytes[100]) -> bool:
def test_hash_code3(get_contract_with_gas_estimation):
hash_code3 = """
test: Bytes[100]
test1: Bytes[1]
@external
def set_test(inp: Bytes[100]):
self.test = inp
if len(inp) > 0:
self.test1 = slice(inp, 0, 1)
@external
def tryy(inp: Bytes[100]) -> bool:
Expand All @@ -60,6 +63,9 @@ def trymem(inp: Bytes[100]) -> bool:
def try32(inp: bytes32) -> bool:
return keccak256(inp) == keccak256(self.test)
@external
def try1(inp: bytes1) -> bool:
return keccak256(inp) == keccak256(self.test1)
"""
c = get_contract_with_gas_estimation(hash_code3)
c.set_test(b"", transact={})
Expand All @@ -71,10 +77,12 @@ def try32(inp: bytes32) -> bool:
assert c.tryy(b"") is False
assert c.tryy(b"cow") is True
assert c.tryy_str("cow") is True
assert c.try1(b"c") is True
c.set_test(b"\x35" * 32, transact={})
assert c.tryy(b"\x35" * 32) is True
assert c.trymem(b"\x35" * 32) is True
assert c.try32(b"\x35" * 32) is True
assert c.try1(b"\x35") is True
assert c.tryy(b"\x35" * 33) is False
c.set_test(b"\x35" * 33, transact={})
assert c.tryy(b"\x35" * 32) is False
Expand Down
13 changes: 13 additions & 0 deletions tests/parser/functions/test_sha256.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ def bar() -> (bytes32 , bytes32):
assert c.bar() == [h, h]


def test_sha256_bytes1(get_contract_with_gas_estimation):
code = """
@external
def bar(a: bytes1) -> bytes32:
return sha256(a)
"""

c = get_contract_with_gas_estimation(code)

test_val = b"b"
assert c.bar(test_val) == hashlib.sha256(test_val).digest()


def test_sha256_bytes32(get_contract_with_gas_estimation):
code = """
@external
Expand Down
1 change: 0 additions & 1 deletion vyper/ast/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def replace_builtin_functions(vyper_module: vy_ast.Module) -> int:
except UnfoldableNode:
continue

# print(node._metadata["type"])
new_node._metadata["type"] = node._metadata["type"]

changed_nodes += 1
Expand Down
14 changes: 7 additions & 7 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def build_IR(self, expr, context):
class Keccak256(BuiltinFunction):
_id = "keccak256"
# TODO allow any BytesM_T
_inputs = [("value", (BytesT.any(), BYTES32_T, StringT.any()))]
_inputs = [("value", (BytesT.any(), BytesM_T.any(), StringT.any()))]
_return_type = BYTES32_T

def evaluate(self, node):
Expand All @@ -612,7 +612,7 @@ def evaluate(self, node):
arg_typ = self.infer_arg_types(node).pop()
if isinstance(arg_typ, StringT):
value = value.encode()
elif arg_typ == BYTES32_T:
elif isinstance(arg_typ, BytesM_T):
length = len(value) // 2 - 1
value = int(value, 16).to_bytes(length, "big")

Expand Down Expand Up @@ -648,7 +648,7 @@ def _make_sha256_call(inp_start, inp_len, out_start, out_len):

class Sha256(BuiltinFunction):
_id = "sha256"
_inputs = [("value", (BYTES32_T, BytesT.any(), StringT.any()))]
_inputs = [("value", (BytesM_T.any(), BytesT.any(), StringT.any()))]
_return_type = BYTES32_T

def evaluate(self, node):
Expand All @@ -663,7 +663,7 @@ def evaluate(self, node):
arg_typ = self.infer_arg_types(node).pop()
if isinstance(arg_typ, StringT):
value = value.encode()
elif arg_typ == BYTES32_T:
elif isinstance(arg_typ, BytesM_T):
length = len(value) // 2 - 1
value = int(value, 16).to_bytes(length, "big")

Expand All @@ -679,15 +679,15 @@ def infer_arg_types(self, node, *args, **kwargs):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
sub = args[0]
# bytes32 input
if sub.typ == BYTES32_T:
# bytesM_T input
if isinstance(sub.typ, BytesM_T):
return IRnode.from_list(
[
"seq",
["mstore", MemoryPositions.FREE_VAR_SPACE, sub],
_make_sha256_call(
inp_start=MemoryPositions.FREE_VAR_SPACE,
inp_len=32,
inp_len=sub.typ.length,
out_start=MemoryPositions.FREE_VAR_SPACE,
out_len=32,
),
Expand Down
10 changes: 5 additions & 5 deletions vyper/codegen/keccak256_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic
from vyper.semantics.types.bytestrings import _BytestringT
from vyper.semantics.types.primitives import BytesM_T
from vyper.semantics.types.shortcuts import BYTES32_T
from vyper.utils import SHA3_BASE, SHA3_PER_WORD, MemoryPositions, bytes_to_int, keccak256


def _check_byteslike(typ):
if not isinstance(typ, _BytestringT) and typ != BYTES32_T:
if not isinstance(typ, (_BytestringT, BytesM_T)):
# NOTE this may be checked at a higher level, but just be safe
raise CompilerPanic("keccak256 only accepts bytes-like objects")

Expand All @@ -26,14 +27,13 @@ def keccak256_helper(to_hash, context):
if isinstance(to_hash, bytes):
return IRnode.from_list(bytes_to_int(keccak256(to_hash)), typ=BYTES32_T)

# Can hash bytes32 objects
# TODO: Want to generalize to all bytes_M
if to_hash.typ == BYTES32_T:
# Can hash bytesM_T objects
if isinstance(to_hash.typ, BytesM_T):
return IRnode.from_list(
[
"seq",
["mstore", MemoryPositions.FREE_VAR_SPACE, to_hash],
["sha3", MemoryPositions.FREE_VAR_SPACE, 32],
["sha3", MemoryPositions.FREE_VAR_SPACE, to_hash.typ.length],
],
typ=BYTES32_T,
add_gas_estimate=_gas_bound(1),
Expand Down

0 comments on commit 652199e

Please sign in to comment.