Skip to content

Commit

Permalink
feat[lang]: introduce floordiv operator (#2937)
Browse files Browse the repository at this point in the history
introduce floordiv operator to increase type safety of numeric
operations (and to look more like python). floordiv is banned for
decimals; "regular" div is banned for integers. we could maybe loosen
this restriction in the future, e.g. int1 / int2 -> decimal, but for
now just segregate into decimal and integer division operations.

---------

Co-authored-by: tserg <[email protected]>
  • Loading branch information
charles-cooper and tserg authored Feb 21, 2024
1 parent bc57775 commit 1ca243b
Show file tree
Hide file tree
Showing 21 changed files with 173 additions and 49 deletions.
6 changes: 3 additions & 3 deletions examples/market_maker/on_chain_market_maker.vy
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def initiate(token_addr: address, token_quantity: uint256):
@external
@payable
def ethToTokens():
fee: uint256 = msg.value / 500
fee: uint256 = msg.value // 500
eth_in_purchase: uint256 = msg.value - fee
new_total_eth: uint256 = self.totalEthQty + eth_in_purchase
new_total_tokens: uint256 = self.invariant / new_total_eth
new_total_tokens: uint256 = self.invariant // new_total_eth
self.token_address.transfer(msg.sender, self.totalTokenQty - new_total_tokens)
self.totalEthQty = new_total_eth
self.totalTokenQty = new_total_tokens
Expand All @@ -42,7 +42,7 @@ def ethToTokens():
def tokensToEth(sell_quantity: uint256):
self.token_address.transferFrom(msg.sender, self, sell_quantity)
new_total_tokens: uint256 = self.totalTokenQty + sell_quantity
new_total_eth: uint256 = self.invariant / new_total_tokens
new_total_eth: uint256 = self.invariant // new_total_tokens
eth_to_send: uint256 = self.totalEthQty - new_total_eth
send(msg.sender, eth_to_send)
self.totalEthQty = new_total_eth
Expand Down
2 changes: 1 addition & 1 deletion examples/safe_remote_purchase/safe_remote_purchase.vy
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ended: public(bool)
@payable
def __init__():
assert (msg.value % 2) == 0
self.value = msg.value / 2 # The seller initializes the contract by
self.value = msg.value // 2 # The seller initializes the contract by
# posting a safety deposit of 2*value of the item up for sale.
self.seller = msg.sender
self.unlocked = True
Expand Down
2 changes: 1 addition & 1 deletion examples/stock/company.vy
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def stockAvailable() -> uint256:
def buyStock():
# Note: full amount is given to company (no fractional shares),
# so be sure to send exact amount to buy shares
buy_order: uint256 = msg.value / self.price # rounds down
buy_order: uint256 = msg.value // self.price # rounds down

# Check that there are enough shares to buy.
assert self._stockAvailable() >= buy_order
Expand Down
4 changes: 2 additions & 2 deletions examples/tokens/ERC4626.vy
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _convertToAssets(shareAmount: uint256) -> uint256:

# NOTE: `shareAmount = 0` is extremely rare case, not optimizing for it
# NOTE: `totalAssets = 0` is extremely rare case, not optimizing for it
return shareAmount * self.asset.balanceOf(self) / totalSupply
return shareAmount * self.asset.balanceOf(self) // totalSupply


@view
Expand All @@ -132,7 +132,7 @@ def _convertToShares(assetAmount: uint256) -> uint256:
return assetAmount # 1:1 price

# NOTE: `assetAmount = 0` is extremely rare case, not optimizing for it
return assetAmount * totalSupply / totalAssets
return assetAmount * totalSupply // totalAssets


@view
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def reverse_digits(x: int128) -> int128:
z: int128 = x
for i: uint256 in range(6):
dig[i] = z % 10
z = z / 10
z = z // 10
o: int128 = 0
for i: uint256 in range(6):
o = o * 10 + dig[i]
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/codegen/features/test_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def foo3(y: uint256) -> uint256:
assert c.foo3(11) == 12


def test_invalid_uin256_assignment(assert_compile_failed, get_contract_with_gas_estimation):
def test_invalid_uint256_assignment(assert_compile_failed, get_contract_with_gas_estimation):
code = """
storx: uint256
Expand All @@ -210,14 +210,14 @@ def foo2() -> uint256:
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), TypeMismatch)


def test_invalid_uin256_assignment_calculate_literals(get_contract_with_gas_estimation):
def test_invalid_uint256_assignment_calculate_literals(get_contract_with_gas_estimation):
code = """
storx: uint256
@external
def foo2() -> uint256:
x: uint256 = 0
x = 3 * 4 / 2 + 1 - 2
x = 3 * 4 // 2 + 1 - 2
return x
"""
c = get_contract_with_gas_estimation(code)
Expand Down
14 changes: 13 additions & 1 deletion tests/functional/codegen/types/numbers/test_decimals.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_decimal_override():
)


@pytest.mark.parametrize("op", ["**", "&", "|", "^"])
@pytest.mark.parametrize("op", ["//", "**", "&", "|", "^"])
def test_invalid_ops(op):
code = f"""
@external
Expand Down Expand Up @@ -300,3 +300,15 @@ def foo():
"""
with pytest.raises(OverflowException):
compile_code(code)


def test_invalid_floordiv():
code = """
@external
def foo():
a: decimal = 5.0 // 9.0
"""
with pytest.raises(InvalidOperation) as e:
compile_code(code)

assert e.value._hint == "did you mean `5.0 / 9.0`?"
29 changes: 26 additions & 3 deletions tests/functional/codegen/types/numbers/test_signed_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def num_sub() -> {typ}:
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": evm_div,
"//": evm_div,
"%": evm_mod,
}

Expand Down Expand Up @@ -263,7 +263,7 @@ def foo() -> {typ}:
"""
lo, hi = typ.ast_bounds

fns = {"+": operator.add, "-": operator.sub, "*": operator.mul, "/": evm_div, "%": evm_mod}
fns = {"+": operator.add, "-": operator.sub, "*": operator.mul, "//": evm_div, "%": evm_mod}
fn = fns[op]

c = get_contract(code_1)
Expand Down Expand Up @@ -307,7 +307,7 @@ def foo() -> {typ}:
in_bounds = lo <= expected <= hi

# safediv and safemod disallow divisor == 0
div_by_zero = y == 0 and op in ("/", "%")
div_by_zero = y == 0 and op in ("//", "%")

ok = in_bounds and not div_by_zero

Expand Down Expand Up @@ -417,6 +417,17 @@ def foo(a: {typ}) -> {typ}:
c.foo(lo)


@pytest.mark.parametrize("typ", types)
@pytest.mark.parametrize("op", ["/"])
def test_invalid_ops(get_contract, assert_compile_failed, typ, op):
code = f"""
@external
def foo(x: {typ}, y: {typ}) -> {typ}:
return x {op} y
"""
assert_compile_failed(lambda: get_contract(code), InvalidOperation)


@pytest.mark.parametrize("typ", types)
@pytest.mark.parametrize("op", ["not"])
def test_invalid_unary_ops(typ, op):
Expand All @@ -437,3 +448,15 @@ def foo():
"""
with pytest.raises(TypeMismatch):
compile_code(code)


def test_invalid_div():
code = """
@external
def foo():
a: int256 = -5 / 9
"""
with pytest.raises(InvalidOperation) as e:
compile_code(code)

assert e.value._hint == "did you mean `-5 // 9`?"
29 changes: 26 additions & 3 deletions tests/functional/codegen/types/numbers/test_unsigned_ints.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def foo(x: {typ}) -> {typ}:
"+": operator.add,
"-": operator.sub,
"*": operator.mul,
"/": evm_div,
"//": evm_div,
"%": evm_mod,
}

Expand Down Expand Up @@ -140,7 +140,7 @@ def foo() -> {typ}:

in_bounds = lo <= expected <= hi
# safediv and safemod disallow divisor == 0
div_by_zero = y == 0 and op in ("/", "%")
div_by_zero = y == 0 and op in ("//", "%")

ok = in_bounds and not div_by_zero

Expand Down Expand Up @@ -236,6 +236,17 @@ def test() -> {typ}:
compile_code(code_template.format(typ=typ, val=val))


@pytest.mark.parametrize("typ", types)
@pytest.mark.parametrize("op", ["/"])
def test_invalid_ops(get_contract, assert_compile_failed, typ, op):
code = f"""
@external
def foo(x: {typ}, y: {typ}) -> {typ}:
return x {op} y
"""
assert_compile_failed(lambda: get_contract(code), InvalidOperation)


@pytest.mark.parametrize("typ", types)
@pytest.mark.parametrize("op", ["not", "-"])
def test_invalid_unary_ops(get_contract, assert_compile_failed, typ, op):
Expand All @@ -252,7 +263,19 @@ def test_binop_nested_intermediate_overflow():
code = """
@external
def foo():
a: uint256 = 2**255 * 2 / 10
a: uint256 = 2**255 * 2 // 10
"""
with pytest.raises(OverflowException):
compile_code(code)


def test_invalid_div():
code = """
@external
def foo():
a: uint256 = 5 / 9
"""
with pytest.raises(InvalidOperation) as e:
compile_code(code)

assert e.value._hint == "did you mean `5 // 9`?"
4 changes: 2 additions & 2 deletions tests/functional/syntax/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def foo() -> int128[2]:
def foo() -> decimal:
x: int128 = as_wei_value(5, "finney")
y: int128 = block.timestamp + 50
return x / y
return x // y
""",
(
"""
Expand Down Expand Up @@ -106,7 +106,7 @@ def add_record():
def foo() -> uint256:
x: uint256 = as_wei_value(5, "finney")
y: uint256 = block.timestamp + 50 - block.timestamp
return x / y
return x // y
""",
"""
@external
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/syntax/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test1():
@external
@view
def test():
for i: uint256 in range(CONST / 4):
for i: uint256 in range(CONST // 4):
pass
""",
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/syntax/test_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__():
@external
def foo() -> int128:
return self.x / self.y / self.z
return self.x // self.y // self.z
""",
# expansion of public user-defined struct
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/ast/nodes/test_fold_binop_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@example(left=1, right=-1)
@example(left=-1, right=1)
@example(left=-1, right=-1)
@pytest.mark.parametrize("op", "+-*/%")
@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"])
def test_binop_int128(get_contract, tx_failed, op, left, right):
source = f"""
@external
Expand Down Expand Up @@ -45,7 +45,7 @@ def foo(a: int128, b: int128) -> int128:
@pytest.mark.fuzzing
@settings(max_examples=50)
@given(left=st_uint64, right=st_uint64)
@pytest.mark.parametrize("op", "+-*/%")
@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"])
def test_binop_uint256(get_contract, tx_failed, op, left, right):
source = f"""
@external
Expand Down Expand Up @@ -94,7 +94,7 @@ def foo(a: uint256, b: uint256) -> uint256:
@settings(max_examples=50)
@given(
values=st.lists(st.integers(min_value=-256, max_value=256), min_size=2, max_size=10),
ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11),
ops=st.lists(st.sampled_from(["+", "-", "*", "//", "%"]), min_size=11, max_size=11),
)
def test_binop_nested(get_contract, tx_failed, values, ops):
variables = "abcdefghij"
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/semantics/analysis/test_potential_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,17 @@ def test_attribute_not_member_type(build_node, namespace):
get_possible_types_from_node(node)


@pytest.mark.parametrize("op", ["+", "-", "*", "//", "%"])
@pytest.mark.parametrize("left,right", INTEGER_LITERALS)
def test_binop_ints(build_node, namespace, op, left, right):
node = build_node(f"{left}{op}{right}")
with namespace.enter_scope():
get_possible_types_from_node(node)


@pytest.mark.parametrize("op", "+-*/%")
@pytest.mark.parametrize("left,right", INTEGER_LITERALS + DECIMAL_LITERALS)
def test_binop(build_node, namespace, op, left, right):
@pytest.mark.parametrize("left,right", DECIMAL_LITERALS)
def test_binop_decimal(build_node, namespace, op, left, right):
node = build_node(f"{left}{op}{right}")
with namespace.enter_scope():
get_possible_types_from_node(node)
Expand Down
2 changes: 2 additions & 0 deletions vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ assign: (variable_access | multiple_assign | "(" multiple_assign ")" ) "=" _expr
| "-" -> sub
| "*" -> mul
| "/" -> div
| "//" -> floordiv
| "%" -> mod
| "**" -> pow
| "<<" -> shl
Expand Down Expand Up @@ -274,6 +275,7 @@ _IN: "in"
?product: unary
| product "*" unary -> mul
| product "/" unary -> div
| product "//" unary -> floordiv
| product "%" unary -> mod
?unary: power
| "+" power -> uadd
Expand Down
44 changes: 28 additions & 16 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
VyperException,
ZeroDivisionException,
)
from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code
from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code, evm_div

NODE_BASE_ATTRIBUTES = (
"_children",
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def _op(self, left, right):

class Div(Operator):
__slots__ = ()
_description = "division"
_description = "decimal division"
_pretty = "/"

def _op(self, left, right):
Expand All @@ -1065,20 +1065,32 @@ def _op(self, left, right):
if not right:
raise ZeroDivisionException("Division by zero")

if isinstance(left, decimal.Decimal):
value = left / right
if value < 0:
# the EVM always truncates toward zero
value = -(-left / right)
# ensure that the result is truncated to MAX_DECIMAL_PLACES
return value.quantize(
decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN
)
else:
value = left // right
if value < 0:
return -(-left // right)
return value
if not isinstance(left, decimal.Decimal):
raise UnfoldableNode("Cannot use `/` on non-decimals (did you mean `//`?)")

value = left / right
if value < 0:
# the EVM always truncates toward zero
value = -(-left / right)
# ensure that the result is truncated to MAX_DECIMAL_PLACES
return value.quantize(decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN)


class FloorDiv(VyperNode):
__slots__ = ()
_description = "integer division"
_pretty = "//"

def _op(self, left, right):
# evaluate the operation using true division or floor division
assert type(left) is type(right)
if not right:
raise ZeroDivisionException("Division by zero")

if not isinstance(left, int):
raise UnfoldableNode("Cannot use `//` on non-integers (did you mean `/`?)")

return evm_div(left, right)


class Mod(Operator):
Expand Down
Loading

0 comments on commit 1ca243b

Please sign in to comment.