Skip to content

Commit

Permalink
feat: require type annotations for loop variables (#3596)
Browse files Browse the repository at this point in the history
this commit changes the vyper language to require type annotations for
loop variables. that is, before, the following was allowed:
```vyper
for i in [1, 2, 3]:
    pass
```

now, `i` is required to have a type annotation:
```vyper
for i: uint256 in [1, 2, 3]:
    pass
```

this makes the annotation of loop variables consistent with the rest of
vyper (it was previously a special case, that loop variables did not
need to be annotated).

the approach taken in this commit is to add a pre-parsing step which
lifts out the type annotation into a separate data structure, and then
splices it back in during the post-processing steps in
`vyper/ast/parse.py`.

this commit also simplifies a lot of analysis regarding for loops.
notably, the possible types for the loop variable no longer needs to be
iterated over, we can just propagate the type provided by the user. for
this reason we also no longer need to use the typechecker speculation
machinery for inferring the type of the loop variable. however, the
NodeMetadata code is not removed because it might come in handy at a
later date.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Jan 7, 2024
1 parent 0c82d0b commit ddfce52
Show file tree
Hide file tree
Showing 40 changed files with 432 additions and 330 deletions.
2 changes: 1 addition & 1 deletion examples/auctions/blind_auction.vy
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets:

# Calculate refund for sender
refund: uint256 = 0
for i in range(MAX_BIDS):
for i: int128 in range(MAX_BIDS):
# Note that loop may break sooner than 128 iterations if i >= _numBids
if (i >= _numBids):
break
Expand Down
8 changes: 4 additions & 4 deletions examples/tokens/ERC1155ownable.vy
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25
assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch"
batchBalances: DynArray[uint256, BATCH_SIZE] = []
j: uint256 = 0
for i in ids:
for i: uint256 in ids:
batchBalances.append(self.balanceOf[accounts[j]][i])
j += 1
return batchBalances
Expand Down Expand Up @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy
assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch"
operator: address = msg.sender

for i in range(BATCH_SIZE):
for i: uint256 in range(BATCH_SIZE):
if i >= len(ids):
break
self.balanceOf[receiver][ids[i]] += amounts[i]
Expand Down Expand Up @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT
assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch"
operator: address = msg.sender

for i in range(BATCH_SIZE):
for i: uint256 in range(BATCH_SIZE):
if i >= len(ids):
break
self.balanceOf[msg.sender][ids[i]] -= amounts[i]
Expand Down Expand Up @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint
assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID"
assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch"
operator: address = msg.sender
for i in range(BATCH_SIZE):
for i: uint256 in range(BATCH_SIZE):
if i >= len(ids):
break
id: uint256 = ids[i]
Expand Down
6 changes: 3 additions & 3 deletions examples/voting/ballot.vy
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool:
def __init__(_proposalNames: bytes32[2]):
self.chairperson = msg.sender
self.voterCount = 0
for i in range(2):
for i: int128 in range(2):
self.proposals[i] = Proposal({
name: _proposalNames[i],
voteCount: 0
Expand Down Expand Up @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address):
assert self.voters[delegate_with_weight_to_forward].weight > 0

target: address = self.voters[delegate_with_weight_to_forward].delegate
for i in range(4):
for i: int128 in range(4):
if self._delegated(target):
target = self.voters[target].delegate
# The following effectively detects cycles of length <= 5,
Expand Down Expand Up @@ -157,7 +157,7 @@ def vote(proposal: int128):
def _winningProposal() -> int128:
winning_vote_count: int128 = 0
winning_proposal: int128 = 0
for i in range(2):
for i: int128 in range(2):
if self.proposals[i].voteCount > winning_vote_count:
winning_vote_count = self.proposals[i].voteCount
winning_proposal = i
Expand Down
4 changes: 2 additions & 2 deletions examples/wallet/wallet.vy
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ seq: public(int128)

@external
def __init__(_owners: address[5], _threshold: int128):
for i in range(5):
for i: uint256 in range(5):
if _owners[i] != empty(address):
self.owners[i] = _owners[i]
self.threshold = _threshold
Expand Down Expand Up @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda
assert self.seq == _seq
# # Iterates through all the owners and verifies that there signatures,
# # given as the sigdata argument are correct
for i in range(5):
for i: uint256 in range(5):
if sigdata[i][0] != 0:
# If an invalid signature is given for an owner then the contract throws
assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i]
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/builtins/codegen/test_empty.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view
@internal
def write_junk_to_memory():
xs: int128[1024] = empty(int128[1024])
for i in range(1024):
for i: uint256 in range(1024):
xs[i] = -(i + 1)
@internal
def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool:
Expand Down Expand Up @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation):
@internal
def write_junk_to_memory():
xs: int128[1024] = empty(int128[1024])
for i in range(1024):
for i: uint256 in range(1024):
xs[i] = -(i + 1)
@external
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_mulmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation):
@external
def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256:
o: uint256 = 1
for i in range(256):
for i: uint256 in range(256):
o = uint256_mulmod(o, o, modulus)
if exponent & shift(1, 255 - i) != 0:
o = uint256_mulmod(o, base, modulus)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation):
@external
def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]:
inp: Bytes[50] = inp1
for i in range(1, 11):
for i: uint256 in range(1, 11):
inp = slice(inp, 1, 30 - i * 2)
return inp
"""
Expand Down
12 changes: 6 additions & 6 deletions tests/functional/codegen/features/iteration/test_break.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation):
def foo(n: decimal) -> int128:
c: decimal = n * 1.0
output: int128 = 0
for i in range(400):
for i: int128 in range(400):
c = c / 1.2589
if c < 1.0:
output = i
Expand All @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation):
def foo(n: decimal) -> int128:
c: decimal = n * 1.0
output: int128 = 0
for i in range(40):
for i: int128 in range(40):
if c < 10.0:
output = i * 10
break
c = c / 10.0
for i in range(10):
for i: int128 in range(10):
c = c / 1.2589
if c < 1.0:
output = output + i
Expand All @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation):
def foo(n: int128) -> int128:
c: decimal = convert(n, decimal)
output: int128 = 0
for i in range(40):
for i: int128 in range(40):
if c < 10.0:
output = i * 10
break
c /= 10.0
for i in range(10):
for i: int128 in range(10):
c /= 1.2589
if c < 1.0:
output = output + i
Expand Down Expand Up @@ -108,7 +108,7 @@ def foo():
"""
@external
def foo():
for i in [1, 2, 3]:
for i: uint256 in [1, 2, 3]:
b: uint256 = i
if True:
break
Expand Down
10 changes: 5 additions & 5 deletions tests/functional/codegen/features/iteration/test_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation):
code = """
@external
def foo() -> bool:
for i in range(2):
for i: uint256 in range(2):
continue
return False
return True
Expand All @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation):
@external
def foo() -> int128:
x: int128 = 0
for i in range(3):
for i: int128 in range(3):
x += 1
continue
x -= 1
Expand All @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation):
@external
def foo() -> int128:
x: int128 = 0
for i in range(3):
for i: int128 in range(3):
x += i
continue
return x
Expand All @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation):
@external
def foo() -> int128:
x: int128 = 0
for i in range(6):
for i: int128 in range(6):
if i % 2 == 0:
continue
x += 1
Expand Down Expand Up @@ -83,7 +83,7 @@ def foo():
"""
@external
def foo():
for i in [1, 2, 3]:
for i: uint256 in [1, 2, 3]:
b: uint256 = i
if True:
continue
Expand Down
Loading

0 comments on commit ddfce52

Please sign in to comment.