diff --git a/examples/auctions/blind_auction.vy b/examples/auctions/blind_auction.vy index 04f908f6d0..597aed57c7 100644 --- a/examples/auctions/blind_auction.vy +++ b/examples/auctions/blind_auction.vy @@ -107,7 +107,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets: # Calculate refund for sender refund: uint256 = 0 - for i in range(MAX_BIDS): + for i: int128 in range(MAX_BIDS): # Note that loop may break sooner than 128 iterations if i >= _numBids if (i >= _numBids): break diff --git a/examples/tokens/ERC1155ownable.vy b/examples/tokens/ERC1155ownable.vy index 30057582e8..e105a79133 100644 --- a/examples/tokens/ERC1155ownable.vy +++ b/examples/tokens/ERC1155ownable.vy @@ -205,7 +205,7 @@ def balanceOfBatch(accounts: DynArray[address, BATCH_SIZE], ids: DynArray[uint25 assert len(accounts) == len(ids), "ERC1155: accounts and ids length mismatch" batchBalances: DynArray[uint256, BATCH_SIZE] = [] j: uint256 = 0 - for i in ids: + for i: uint256 in ids: batchBalances.append(self.balanceOf[accounts[j]][i]) j += 1 return batchBalances @@ -243,7 +243,7 @@ def mintBatch(receiver: address, ids: DynArray[uint256, BATCH_SIZE], amounts: Dy assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[receiver][ids[i]] += amounts[i] @@ -277,7 +277,7 @@ def burnBatch(ids: DynArray[uint256, BATCH_SIZE], amounts: DynArray[uint256, BAT assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break self.balanceOf[msg.sender][ids[i]] -= amounts[i] @@ -333,7 +333,7 @@ def safeBatchTransferFrom(sender: address, receiver: address, ids: DynArray[uint assert sender == msg.sender or self.isApprovedForAll[sender][msg.sender], "Caller is neither owner nor approved operator for this ID" assert len(ids) == len(amounts), "ERC1155: ids and amounts length mismatch" operator: address = msg.sender - for i in range(BATCH_SIZE): + for i: uint256 in range(BATCH_SIZE): if i >= len(ids): break id: uint256 = ids[i] diff --git a/examples/voting/ballot.vy b/examples/voting/ballot.vy index 0b568784a9..107716accf 100644 --- a/examples/voting/ballot.vy +++ b/examples/voting/ballot.vy @@ -54,7 +54,7 @@ def directlyVoted(addr: address) -> bool: def __init__(_proposalNames: bytes32[2]): self.chairperson = msg.sender self.voterCount = 0 - for i in range(2): + for i: int128 in range(2): self.proposals[i] = Proposal({ name: _proposalNames[i], voteCount: 0 @@ -82,7 +82,7 @@ def _forwardWeight(delegate_with_weight_to_forward: address): assert self.voters[delegate_with_weight_to_forward].weight > 0 target: address = self.voters[delegate_with_weight_to_forward].delegate - for i in range(4): + for i: int128 in range(4): if self._delegated(target): target = self.voters[target].delegate # The following effectively detects cycles of length <= 5, @@ -157,7 +157,7 @@ def vote(proposal: int128): def _winningProposal() -> int128: winning_vote_count: int128 = 0 winning_proposal: int128 = 0 - for i in range(2): + for i: int128 in range(2): if self.proposals[i].voteCount > winning_vote_count: winning_vote_count = self.proposals[i].voteCount winning_proposal = i diff --git a/examples/wallet/wallet.vy b/examples/wallet/wallet.vy index e2515d9e62..231f538ecf 100644 --- a/examples/wallet/wallet.vy +++ b/examples/wallet/wallet.vy @@ -14,7 +14,7 @@ seq: public(int128) @external def __init__(_owners: address[5], _threshold: int128): - for i in range(5): + for i: uint256 in range(5): if _owners[i] != empty(address): self.owners[i] = _owners[i] self.threshold = _threshold @@ -47,7 +47,7 @@ def approve(_seq: int128, to: address, _value: uint256, data: Bytes[4096], sigda assert self.seq == _seq # # Iterates through all the owners and verifies that there signatures, # # given as the sigdata argument are correct - for i in range(5): + for i: uint256 in range(5): if sigdata[i][0] != 0: # If an invalid signature is given for an owner then the contract throws assert ecrecover(h2, sigdata[i][0], sigdata[i][1], sigdata[i][2]) == self.owners[i] diff --git a/tests/functional/builtins/codegen/test_empty.py b/tests/functional/builtins/codegen/test_empty.py index c3627785dc..896c845da2 100644 --- a/tests/functional/builtins/codegen/test_empty.py +++ b/tests/functional/builtins/codegen/test_empty.py @@ -423,7 +423,7 @@ def test_empty(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: view @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @internal def priv(xs: int128[111], ys: Bytes[1024], zs: Bytes[31]) -> bool: @@ -469,7 +469,7 @@ def test_return_empty(get_contract_with_gas_estimation): @internal def write_junk_to_memory(): xs: int128[1024] = empty(int128[1024]) - for i in range(1024): + for i: uint256 in range(1024): xs[i] = -(i + 1) @external diff --git a/tests/functional/builtins/codegen/test_mulmod.py b/tests/functional/builtins/codegen/test_mulmod.py index ba82ebd5b8..31de1d9f22 100644 --- a/tests/functional/builtins/codegen/test_mulmod.py +++ b/tests/functional/builtins/codegen/test_mulmod.py @@ -20,7 +20,7 @@ def test_uint256_mulmod_complex(get_contract_with_gas_estimation): @external def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256: o: uint256 = 1 - for i in range(256): + for i: uint256 in range(256): o = uint256_mulmod(o, o, modulus) if exponent & shift(1, 255 - i) != 0: o = uint256_mulmod(o, base, modulus) diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index a15a3eeb35..80936bbf82 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -17,7 +17,7 @@ def test_basic_slice(get_contract_with_gas_estimation): @external def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: inp: Bytes[50] = inp1 - for i in range(1, 11): + for i: uint256 in range(1, 11): inp = slice(inp, 1, 30 - i * 2) return inp """ diff --git a/tests/functional/codegen/features/iteration/test_break.py b/tests/functional/codegen/features/iteration/test_break.py index 8a08a11cc2..4abde9c617 100644 --- a/tests/functional/codegen/features/iteration/test_break.py +++ b/tests/functional/codegen/features/iteration/test_break.py @@ -11,7 +11,7 @@ def test_break_test(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(400): + for i: int128 in range(400): c = c / 1.2589 if c < 1.0: output = i @@ -35,12 +35,12 @@ def test_break_test_2(get_contract_with_gas_estimation): def foo(n: decimal) -> int128: c: decimal = n * 1.0 output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c = c / 10.0 - for i in range(10): + for i: int128 in range(10): c = c / 1.2589 if c < 1.0: output = output + i @@ -63,12 +63,12 @@ def test_break_test_3(get_contract_with_gas_estimation): def foo(n: int128) -> int128: c: decimal = convert(n, decimal) output: int128 = 0 - for i in range(40): + for i: int128 in range(40): if c < 10.0: output = i * 10 break c /= 10.0 - for i in range(10): + for i: int128 in range(10): c /= 1.2589 if c < 1.0: output = output + i @@ -108,7 +108,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: break diff --git a/tests/functional/codegen/features/iteration/test_continue.py b/tests/functional/codegen/features/iteration/test_continue.py index 5f4f82a2de..1b2fcab460 100644 --- a/tests/functional/codegen/features/iteration/test_continue.py +++ b/tests/functional/codegen/features/iteration/test_continue.py @@ -7,7 +7,7 @@ def test_continue1(get_contract_with_gas_estimation): code = """ @external def foo() -> bool: - for i in range(2): + for i: uint256 in range(2): continue return False return True @@ -21,7 +21,7 @@ def test_continue2(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += 1 continue x -= 1 @@ -36,7 +36,7 @@ def test_continue3(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(3): + for i: int128 in range(3): x += i continue return x @@ -50,7 +50,7 @@ def test_continue4(get_contract_with_gas_estimation): @external def foo() -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): if i % 2 == 0: continue x += 1 @@ -83,7 +83,7 @@ def foo(): """ @external def foo(): - for i in [1, 2, 3]: + for i: uint256 in [1, 2, 3]: b: uint256 = i if True: continue diff --git a/tests/functional/codegen/features/iteration/test_for_in_list.py b/tests/functional/codegen/features/iteration/test_for_in_list.py index bc1a12ae9e..5c7b5c6b1b 100644 --- a/tests/functional/codegen/features/iteration/test_for_in_list.py +++ b/tests/functional/codegen/features/iteration/test_for_in_list.py @@ -21,7 +21,7 @@ @external def data() -> int128: s: int128[5] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -33,7 +33,7 @@ def data() -> int128: @external def data() -> int128: s: DynArray[int128, 10] = [1, 2, 3, 4, 5] - for i in s: + for i: int128 in s: if i >= 3: return i return -1""", @@ -53,8 +53,8 @@ def data() -> int128: [S({x:3, y:4}), S({x:5, y:6}), S({x:7, y:8}), S({x:9, y:10})] ] ret: int128 = 0 - for ss in sss: - for s in ss: + for ss: DynArray[S, 10] in sss: + for s: S in ss: ret += s.x + s.y return ret""", sum(range(1, 11)), @@ -64,7 +64,7 @@ def data() -> int128: """ @external def data() -> int128: - for i in [3, 5, 7, 9]: + for i: int128 in [3, 5, 7, 9]: if i > 5: return i return -1""", @@ -76,7 +76,7 @@ def data() -> int128: @external def data() -> String[33]: xs: DynArray[String[33], 3] = ["hello", ",", "world"] - for x in xs: + for x: String[33] in xs: if x == ",": return x return "" @@ -88,7 +88,7 @@ def data() -> String[33]: """ @external def data() -> String[33]: - for x in ["hello", ",", "world"]: + for x: String[33] in ["hello", ",", "world"]: if x == ",": return x return "" @@ -100,7 +100,7 @@ def data() -> String[33]: """ @external def data() -> DynArray[String[33], 2]: - for x in [["hello", "world"], ["goodbye", "world!"]]: + for x: DynArray[String[33], 2] in [["hello", "world"], ["goodbye", "world!"]]: if x[1] == "world": return x return [] @@ -114,8 +114,8 @@ def data() -> DynArray[String[33], 2]: def data() -> int128: ret: int128 = 0 xss: int128[3][3] = [[1,2,3],[4,5,6],[7,8,9]] - for xs in xss: - for x in xs: + for xs: int128[3] in xss: + for x: int128 in xs: ret += x return ret""", sum(range(1, 10)), @@ -130,8 +130,8 @@ def data() -> int128: @external def data() -> int128: ret: int128 = 0 - for ss in [[S({x:1, y:2})]]: - for s in ss: + for ss: S[1] in [[S({x:1, y:2})]]: + for s: S in ss: ret += s.x + s.y return ret""", 1 + 2, @@ -147,7 +147,7 @@ def data() -> address: 0xDCEceAF3fc5C0a63d195d69b1A90011B7B19650D ] count: int128 = 0 - for i in addresses: + for i: address in addresses: count += 1 if count == 2: return i @@ -174,7 +174,7 @@ def set(): @external def data() -> int128: - for i in self.x: + for i: int128 in self.x: if i > 5: return i return -1 @@ -198,7 +198,7 @@ def set(xs: DynArray[int128, 4]): @external def data() -> int128: t: int128 = 0 - for i in self.x: + for i: int128 in self.x: t += i return t """ @@ -227,7 +227,7 @@ def ret(i: int128) -> address: @external def iterate_return_second() -> address: count: int128 = 0 - for i in self.addresses: + for i: address in self.addresses: count += 1 if count == 2: return i @@ -258,7 +258,7 @@ def ret(i: int128) -> decimal: @external def i_return(break_count: int128) -> decimal: count: int128 = 0 - for i in self.readings: + for i: decimal in self.readings: if count == break_count: return i count += 1 @@ -284,7 +284,7 @@ def func(amounts: uint256[3]) -> uint256: total: uint256 = as_wei_value(0, "wei") # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -303,7 +303,7 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: total: uint256 = 0 # calculate total - for amount in amounts: + for amount: uint256 in amounts: total += amount return total @@ -321,42 +321,42 @@ def func(amounts: DynArray[uint256, 3]) -> uint256: @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in range(4): + for i: int128 in range(4): p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in range(3): + for i: int128 in range(3): p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(x: int128): p: int128 = 0 - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i - for i in [1, 2, 3, 4]: + for i: int128 in [1, 2, 3, 4]: p += i """, """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass - for i in range(20): + for i: uint256 in range(20): pass """, # using index variable after loop """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass i: int128 = 100 # create new variable i i = 200 # look up the variable i and check whether it is in forvars @@ -372,25 +372,25 @@ def test_good_code(code, get_contract): RANGE_CONSTANT_CODE = [ ( """ -TREE_FIDDY: constant(int128) = 350 +TREE_FIDDY: constant(uint256) = 350 @external def a() -> uint256: x: uint256 = 0 - for i in range(TREE_FIDDY): + for i: uint256 in range(TREE_FIDDY): x += 1 return x""", 350, ), ( """ -ONE_HUNDRED: constant(int128) = 100 +ONE_HUNDRED: constant(uint256) = 100 @external def a() -> uint256: x: uint256 = 0 - for i in range(1, 1 + ONE_HUNDRED): + for i: uint256 in range(1, 1 + ONE_HUNDRED): x += 1 return x""", 100, @@ -401,9 +401,9 @@ def a() -> uint256: END: constant(int128) = 199 @external -def a() -> uint256: - x: uint256 = 0 - for i in range(START, END): +def a() -> int128: + x: int128 = 0 + for i: int128 in range(START, END): x += 1 return x""", 99, @@ -413,11 +413,23 @@ def a() -> uint256: @external def a() -> int128: x: int128 = 0 - for i in range(-5, -1): + for i: int128 in range(-5, -1): x += i return x""", -14, ), + ( + """ +@external +def a() -> uint256: + a: DynArray[DynArray[uint256, 2], 3] = [[0, 1], [2, 3], [4, 5]] + x: uint256 = 0 + for i: uint256 in a[2]: + x += i + return x + """, + 9, + ), ] @@ -436,7 +448,7 @@ def test_range_constant(get_contract, code, result): def data() -> int128: s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -451,7 +463,7 @@ def data() -> int128: def foo(): s: int128[6] = [1, 2, 3, 4, 5, 6] count: int128 = 0 - for i in s: + for i: int128 in s: s[count] += 1 """, ImmutableViolation, @@ -468,7 +480,7 @@ def set(): @external def data() -> int128: count: int128 = 0 - for i in self.s: + for i: int128 in self.s: self.s[count] = 1 # this should not be allowed. if i >= 3: return i @@ -493,7 +505,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff(i) i += 1 """, @@ -519,7 +531,7 @@ def doStuff(i: uint256) -> uint256: @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.bar.foo: + for item: uint256 in self.my_array2.bar.foo: self.doStuff(i) i += 1 """, @@ -545,7 +557,7 @@ def doStuff(): @internal def _helper(): i: uint256 = 0 - for item in self.my_array2.foo: + for item: uint256 in self.my_array2.foo: self.doStuff() i += 1 """, @@ -556,8 +568,8 @@ def _helper(): """ @external def foo(x: int128): - for i in range(4): - for i in range(5): + for i: int128 in range(4): + for i: int128 in range(5): pass """, NamespaceCollision, @@ -566,8 +578,8 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: - for i in [1,2]: + for i: int128 in [1,2]: + for i: int128 in [1,2]: pass """, NamespaceCollision, @@ -577,7 +589,7 @@ def foo(x: int128): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i = 2 """, ImmutableViolation, @@ -588,7 +600,7 @@ def foo(x: int128): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.pop() """, ImmutableViolation, @@ -599,7 +611,7 @@ def foo(): @external def foo(): xs: DynArray[uint256, 5] = [1,2,3] - for x in xs: + for x: uint256 in xs: xs.append(x) """, ImmutableViolation, @@ -610,7 +622,7 @@ def foo(): @external def foo(): xs: DynArray[DynArray[uint256, 5], 5] = [[1,2,3]] - for x in xs: + for x: DynArray[uint256, 5] in xs: x.pop() """, ImmutableViolation, @@ -629,7 +641,7 @@ def b(): @external def foo(): - for x in self.array: + for x: uint256 in self.array: self.a() """, ImmutableViolation, @@ -638,7 +650,7 @@ def foo(): """ @external def foo(x: int128): - for i in [1,2]: + for i: int128 in [1,2]: i += 2 """, ImmutableViolation, @@ -648,7 +660,7 @@ def foo(x: int128): """ @external def foo(): - for i in range(-3): + for i: int128 in range(-3): pass """, StructureException, @@ -656,13 +668,13 @@ def foo(): """ @external def foo(): - for i in range(0): + for i: uint256 in range(0): pass """, """ @external def foo(): - for i in []: + for i: uint256 in []: pass """, """ @@ -670,14 +682,14 @@ def foo(): @external def foo(): - for i in FOO: + for i: uint256 in FOO: pass """, ( """ @external def foo(): - for i in range(5,3): + for i: uint256 in range(5,3): pass """, StructureException, @@ -686,7 +698,7 @@ def foo(): """ @external def foo(): - for i in range(5,3,-1): + for i: int128 in range(5,3,-1): pass """, ArgumentException, @@ -696,7 +708,7 @@ def foo(): @external def foo(): a: uint256 = 2 - for i in range(a): + for i: uint256 in range(a): pass """, StateAccessViolation, @@ -706,7 +718,7 @@ def foo(): @external def foo(): a: int128 = 6 - for i in range(a,a-3): + for i: int128 in range(a,a-3): pass """, StateAccessViolation, @@ -716,7 +728,7 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, ArgumentException, @@ -725,7 +737,7 @@ def foo(): """ @external def foo(): - for i in range(0,1,2): + for i: uint256 in range(0,1,2): pass """, ArgumentException, @@ -735,7 +747,7 @@ def foo(): """ @external def foo(): - for i in b"asdf": + for i: Bytes[1] in b"asdf": pass """, InvalidType, @@ -744,7 +756,7 @@ def foo(): """ @external def foo(): - for i in 31337: + for i: uint256 in 31337: pass """, InvalidType, @@ -753,7 +765,7 @@ def foo(): """ @external def foo(): - for i in bar(): + for i: uint256 in bar(): pass """, IteratorException, @@ -762,7 +774,7 @@ def foo(): """ @external def foo(): - for i in self.bar(): + for i: uint256 in self.bar(): pass """, IteratorException, @@ -772,11 +784,11 @@ def foo(): @external def test_for() -> int128: a: int128 = 0 - for i in range(max_value(int128), max_value(int128)+2): + for i: int128 in range(max_value(int128), max_value(int128)+2): a = i return a """, - TypeMismatch, + InvalidType, ), ( """ @@ -784,7 +796,7 @@ def test_for() -> int128: def test_for() -> int128: a: int128 = 0 b: uint256 = 0 - for i in range(5): + for i: int128 in range(5): a = i b = i return a diff --git a/tests/functional/codegen/features/iteration/test_for_range.py b/tests/functional/codegen/features/iteration/test_for_range.py index e946447285..c661c46553 100644 --- a/tests/functional/codegen/features/iteration/test_for_range.py +++ b/tests/functional/codegen/features/iteration/test_for_range.py @@ -6,7 +6,7 @@ def test_basic_repeater(get_contract_with_gas_estimation): @external def repeat(z: int128) -> int128: x: int128 = 0 - for i in range(6): + for i: int128 in range(6): x = x + z return(x) """ @@ -19,7 +19,7 @@ def test_range_bound(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, bound=6): + for i: uint256 in range(n, bound=6): x += i + 1 return x """ @@ -37,7 +37,7 @@ def test_range_bound_constant_end(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(n, 7, bound=6): + for i: uint256 in range(n, 7, bound=6): x += i + 1 return x """ @@ -58,7 +58,7 @@ def test_range_bound_two_args(get_contract, tx_failed): @external def repeat(n: uint256) -> uint256: x: uint256 = 0 - for i in range(1, n, bound=6): + for i: uint256 in range(1, n, bound=6): x += i + 1 return x """ @@ -80,7 +80,7 @@ def test_range_bound_two_runtime_args(get_contract, tx_failed): @external def repeat(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x += i return x """ @@ -109,7 +109,7 @@ def test_range_overflow(get_contract, tx_failed): @external def get_last(start: uint256, end: uint256) -> uint256: x: uint256 = 0 - for i in range(start, end, bound=6): + for i: uint256 in range(start, end, bound=6): x = i return x """ @@ -134,11 +134,11 @@ def test_digit_reverser(get_contract_with_gas_estimation): def reverse_digits(x: int128) -> int128: dig: int128[6] = [0, 0, 0, 0, 0, 0] z: int128 = x - for i in range(6): + for i: uint256 in range(6): dig[i] = z % 10 z = z / 10 o: int128 = 0 - for i in range(6): + for i: uint256 in range(6): o = o * 10 + dig[i] return o @@ -153,9 +153,9 @@ def test_more_complex_repeater(get_contract_with_gas_estimation): @external def repeat() -> int128: out: int128 = 0 - for i in range(6): + for i: uint256 in range(6): out = out * 10 - for j in range(4): + for j: int128 in range(4): out = out + j return(out) """ @@ -170,7 +170,7 @@ def test_offset_repeater(get_contract_with_gas_estimation, typ): @external def sum() -> {typ}: out: {typ} = 0 - for i in range(80, 121): + for i: {typ} in range(80, 121): out = out + i return out """ @@ -185,7 +185,7 @@ def test_offset_repeater_2(get_contract_with_gas_estimation, typ): @external def sum(frm: {typ}, to: {typ}) -> {typ}: out: {typ} = 0 - for i in range(frm, frm + 101, bound=101): + for i: {typ} in range(frm, frm + 101, bound=101): if i == to: break out = out + i @@ -205,7 +205,7 @@ def _bar() -> bool: @external def foo() -> bool: - for i in range(3): + for i: uint256 in range(3): self._bar() return True """ @@ -219,8 +219,8 @@ def test_return_inside_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for j in range(10): + for i: {typ} in range(10): + for j: {typ} in range(10): if j > 5: if i > a: return i @@ -254,14 +254,14 @@ def test_for_range_edge(get_contract, typ): def test(): found: bool = False x: {typ} = max_value({typ}) - for i in range(x - 1, x, bound=1): + for i: {typ} in range(x - 1, x, bound=1): if i + 1 == max_value({typ}): found = True assert found found = False x = max_value({typ}) - 1 - for i in range(x - 1, x + 1, bound=2): + for i: {typ} in range(x - 1, x + 1, bound=2): if i + 1 == max_value({typ}): found = True assert found @@ -276,7 +276,7 @@ def test_for_range_oob_check(get_contract, tx_failed, typ): @external def test(): x: {typ} = max_value({typ}) - for i in range(x, x + 2, bound=2): + for i: {typ} in range(x, x + 2, bound=2): pass """ c = get_contract(code) @@ -289,8 +289,8 @@ def test_return_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -318,8 +318,8 @@ def test_return_void_nested_repeater(get_contract, typ, val): result: {typ} @internal def _final(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -347,8 +347,8 @@ def test_external_nested_repeater(get_contract, typ, val): code = f""" @external def foo(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: return i + x return 31337 @@ -368,8 +368,8 @@ def test_external_void_nested_repeater(get_contract, typ, val): result: public({typ}) @external def foo(a: {typ}): - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if i + x > a: self.result = i + x return @@ -388,8 +388,8 @@ def test_breaks_and_returns_inside_nested_repeater(get_contract, typ): code = f""" @internal def _final(a: {typ}) -> {typ}: - for i in range(10): - for x in range(10): + for i: {typ} in range(10): + for x: {typ} in range(10): if a < 2: break return 6 diff --git a/tests/functional/codegen/features/test_assert.py b/tests/functional/codegen/features/test_assert.py index af189e6dca..df379d3f16 100644 --- a/tests/functional/codegen/features/test_assert.py +++ b/tests/functional/codegen/features/test_assert.py @@ -159,7 +159,7 @@ def test_assert_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5 return True """ @@ -179,7 +179,7 @@ def test_assert_with_reason_in_for_loop(get_contract, tx_failed, memory_mocker): code = """ @external def test(x: uint256[3]) -> bool: - for i in range(3): + for i: uint256 in range(3): assert x[i] < 5, "because reasons" return True """ diff --git a/tests/functional/codegen/features/test_internal_call.py b/tests/functional/codegen/features/test_internal_call.py index f10d22ec99..422f53fdeb 100644 --- a/tests/functional/codegen/features/test_internal_call.py +++ b/tests/functional/codegen/features/test_internal_call.py @@ -152,7 +152,7 @@ def _increment(): @external def returnten() -> int128: - for i in range(10): + for i: uint256 in range(10): self._increment() return self.counter """ diff --git a/tests/functional/codegen/integration/test_crowdfund.py b/tests/functional/codegen/integration/test_crowdfund.py index 671d424d60..891ed5aebe 100644 --- a/tests/functional/codegen/integration/test_crowdfund.py +++ b/tests/functional/codegen/integration/test_crowdfund.py @@ -52,7 +52,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return @@ -147,7 +147,7 @@ def finalize(): @external def refund(): ind: int128 = self.refundIndex - for i in range(ind, ind + 30, bound=30): + for i: int128 in range(ind, ind + 30, bound=30): if i >= self.nextFunderIndex: self.refundIndex = self.nextFunderIndex return diff --git a/tests/functional/codegen/types/numbers/test_decimals.py b/tests/functional/codegen/types/numbers/test_decimals.py index fcf71f12f0..72171dd4b5 100644 --- a/tests/functional/codegen/types/numbers/test_decimals.py +++ b/tests/functional/codegen/types/numbers/test_decimals.py @@ -125,7 +125,7 @@ def test_harder_decimal_test(get_contract_with_gas_estimation): @external def phooey(inp: decimal) -> decimal: x: decimal = 10000.0 - for i in range(4): + for i: uint256 in range(4): x = x * inp return x diff --git a/tests/functional/codegen/types/test_bytes.py b/tests/functional/codegen/types/test_bytes.py index 1ee9b8d835..882629de65 100644 --- a/tests/functional/codegen/types/test_bytes.py +++ b/tests/functional/codegen/types/test_bytes.py @@ -268,7 +268,7 @@ def test_zero_padding_with_private(get_contract): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_bytes_zero_padding.py b/tests/functional/codegen/types/test_bytes_zero_padding.py index f9fcf37b25..6597facd1b 100644 --- a/tests/functional/codegen/types/test_bytes_zero_padding.py +++ b/tests/functional/codegen/types/test_bytes_zero_padding.py @@ -10,7 +10,7 @@ def little_endian_contract(get_contract_module): def to_little_endian_64(_value: uint256) -> Bytes[8]: y: uint256 = 0 x: uint256 = _value - for _ in range(8): + for _: uint256 in range(8): y = (y << 8) | (x & 255) x >>= 8 return slice(convert(y, bytes32), 24, 8) diff --git a/tests/functional/codegen/types/test_dynamic_array.py b/tests/functional/codegen/types/test_dynamic_array.py index 70a68e3206..e47eda6042 100644 --- a/tests/functional/codegen/types/test_dynamic_array.py +++ b/tests/functional/codegen/types/test_dynamic_array.py @@ -969,7 +969,7 @@ def foo() -> (uint256, uint256, uint256, uint256, uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -981,7 +981,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: some_var: uint256 @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.some_var = x # test that typechecker for append args works self.my_array.append(self.some_var) @@ -994,9 +994,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) - for x in xs: + for x: uint256 in xs: self.my_array.pop() return self.my_array """, @@ -1008,7 +1008,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array, self.my_array.pop() """, @@ -1020,7 +1020,7 @@ def foo(xs: DynArray[uint256, 5]) -> (DynArray[uint256, 5], uint256): my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array.pop(), self.my_array """, @@ -1033,7 +1033,7 @@ def foo(xs: DynArray[uint256, 5]) -> (uint256, DynArray[uint256, 5]): def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] i: uint256 = 0 - for x in xs: + for x: uint256 in xs: if i >= len(xs) - 1: break ys.append(x) @@ -1049,7 +1049,7 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: my_array: DynArray[uint256, 5] @external def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: - for x in xs: + for x: uint256 in xs: self.my_array.append(x) return self.my_array """, @@ -1061,9 +1061,9 @@ def foo(xs: DynArray[uint256, 6]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() return ys """, @@ -1075,9 +1075,9 @@ def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: @external def foo(xs: DynArray[uint256, 5]) -> DynArray[uint256, 5]: ys: DynArray[uint256, 5] = [] - for x in xs: + for x: uint256 in xs: ys.append(x) - for x in xs: + for x: uint256 in xs: ys.pop() ys.pop() # fail return ys @@ -1328,7 +1328,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: DynArray[Foo, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): e: Foobar = _baz[i].z f: uint256 = convert(e, uint256) sum += _baz[i].x * _baz[i].y + f @@ -1397,7 +1397,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: DynArray[Bar, 3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/codegen/types/test_lists.py b/tests/functional/codegen/types/test_lists.py index b5b9538c20..ee287064e8 100644 --- a/tests/functional/codegen/types/test_lists.py +++ b/tests/functional/codegen/types/test_lists.py @@ -566,7 +566,7 @@ def test_list_of_structs_arg(get_contract): @external def bar(_baz: Foo[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _baz[i].x * _baz[i].y return sum """ @@ -608,7 +608,7 @@ def test_list_of_nested_struct_arrays(get_contract): @external def bar(_bar: Bar[3]) -> uint256: sum: uint256 = 0 - for i in range(3): + for i: uint256 in range(3): sum += _bar[i].f[0].e.a[0] * _bar[i].f[1].e.a[1] return sum """ diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 7dd8c35929..652102c376 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -106,6 +106,6 @@ def has_no_docstrings(c): @hypothesis.settings(max_examples=500) def test_grammar_bruteforce(code): if utf8_encodable(code): - _, _, reformatted_code = pre_parse(code + "\n") + _, _, _, reformatted_code = pre_parse(code + "\n") tree = parse_to_ast(reformatted_code) assert isinstance(tree, Module) diff --git a/tests/functional/syntax/exceptions/test_argument_exception.py b/tests/functional/syntax/exceptions/test_argument_exception.py index 0b7ec21bdb..4240aec8d2 100644 --- a/tests/functional/syntax/exceptions/test_argument_exception.py +++ b/tests/functional/syntax/exceptions/test_argument_exception.py @@ -80,13 +80,13 @@ def foo(): """ @external def foo(): - for i in range(): + for i: uint256 in range(): pass """, """ @external def foo(): - for i in range(1, 2, 3, 4): + for i: uint256 in range(1, 2, 3, 4): pass """, ] diff --git a/tests/functional/syntax/exceptions/test_constancy_exception.py b/tests/functional/syntax/exceptions/test_constancy_exception.py index 4bd0b4fcb9..7adf9538c7 100644 --- a/tests/functional/syntax/exceptions/test_constancy_exception.py +++ b/tests/functional/syntax/exceptions/test_constancy_exception.py @@ -57,7 +57,7 @@ def foo() -> int128: return 5 @external def bar(): - for i in range(self.foo(), self.foo() + 1): + for i: int128 in range(self.foo(), self.foo() + 1): pass""", """ glob: int128 @@ -67,13 +67,13 @@ def foo() -> int128: return 5 @external def bar(): - for i in [1,2,3,4,self.foo()]: + for i: int128 in [1,2,3,4,self.foo()]: pass""", """ @external def foo(): x: int128 = 5 - for i in range(x): + for i: int128 in range(x): pass""", """ f:int128 diff --git a/tests/functional/syntax/test_blockscope.py b/tests/functional/syntax/test_blockscope.py index 942aa3fa68..466b5509ca 100644 --- a/tests/functional/syntax/test_blockscope.py +++ b/tests/functional/syntax/test_blockscope.py @@ -33,7 +33,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a = 1 """, @@ -41,7 +41,7 @@ def foo(choice: bool): @external def foo(choice: bool): - for i in range(4): + for i: int128 in range(4): a: int128 = 0 a += 1 """, diff --git a/tests/functional/syntax/test_constants.py b/tests/functional/syntax/test_constants.py index ffd2f1faa0..7089dee3bb 100644 --- a/tests/functional/syntax/test_constants.py +++ b/tests/functional/syntax/test_constants.py @@ -240,7 +240,7 @@ def test1(): @external @view def test(): - for i in range(CONST / 4): + for i: uint256 in range(CONST / 4): pass """, """ diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index a9c3ad5cab..66981a90de 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -15,7 +15,7 @@ """ @external def foo(): - for a[1] in range(10): + for a[1]: uint256 in range(10): pass """, StructureException, @@ -26,7 +26,7 @@ def foo(): """ @external def bar(): - for i in range(1,2,bound=0): + for i: uint256 in range(1,2,bound=0): pass """, StructureException, @@ -38,7 +38,7 @@ def bar(): @external def foo(): x: uint256 = 100 - for _ in range(10, bound=x): + for _: uint256 in range(10, bound=x): pass """, StateAccessViolation, @@ -49,7 +49,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=5): + for _: uint256 in range(10, 20, bound=5): pass """, StructureException, @@ -60,7 +60,7 @@ def foo(): """ @external def foo(): - for _ in range(10, 20, bound=0): + for _: uint256 in range(10, 20, bound=0): pass """, StructureException, @@ -72,7 +72,7 @@ def foo(): @external def bar(): x:uint256 = 1 - for i in range(x,x+1,bound=2,extra=3): + for i: uint256 in range(x,x+1,bound=2,extra=3): pass """, ArgumentException, @@ -83,7 +83,7 @@ def bar(): """ @external def bar(): - for i in range(0): + for i: uint256 in range(0): pass """, StructureException, @@ -95,7 +95,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x): + for i: uint256 in range(x): pass """, StateAccessViolation, @@ -107,7 +107,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(0, x): + for i: uint256 in range(0, x): pass """, StateAccessViolation, @@ -118,7 +118,7 @@ def bar(): """ @external def repeat(n: uint256) -> uint256: - for i in range(0, n * 10): + for i: uint256 in range(0, n * 10): pass return n """, @@ -131,7 +131,7 @@ def repeat(n: uint256) -> uint256: @external def bar(): x:uint256 = 1 - for i in range(0, x + 1): + for i: uint256 in range(0, x + 1): pass """, StateAccessViolation, @@ -142,7 +142,7 @@ def bar(): """ @external def bar(): - for i in range(2, 1): + for i: uint256 in range(2, 1): pass """, StructureException, @@ -154,7 +154,7 @@ def bar(): @external def bar(): x:uint256 = 1 - for i in range(x, x): + for i: uint256 in range(x, x): pass """, StateAccessViolation, @@ -166,7 +166,7 @@ def bar(): @external def foo(): x: int128 = 5 - for i in range(x, x + 10): + for i: int128 in range(x, x + 10): pass """, StateAccessViolation, @@ -177,7 +177,7 @@ def foo(): """ @external def repeat(n: uint256) -> uint256: - for i in range(n, 6): + for i: uint256 in range(n, 6): pass return x """, @@ -190,7 +190,7 @@ def repeat(n: uint256) -> uint256: @external def foo(x: int128): y: int128 = 7 - for i in range(x, x + y): + for i: int128 in range(x, x + y): pass """, StateAccessViolation, @@ -201,7 +201,7 @@ def foo(x: int128): """ @external def bar(x: uint256): - for i in range(3, x): + for i: uint256 in range(3, x): pass """, StateAccessViolation, @@ -215,12 +215,12 @@ def bar(x: uint256): @external def foo(): - for i in range(FOO, BAR): + for i: uint256 in range(FOO, BAR): pass """, TypeMismatch, - "Iterator values are of different types", - "range(FOO, BAR)", + "Given reference has type int128, expected uint256", + "FOO", ), ( """ @@ -228,12 +228,12 @@ def foo(): @external def foo(): - for i in range(10, bound=FOO): + for i: int128 in range(10, bound=FOO): pass """, StructureException, "Bound must be at least 1", - "-1", + "FOO", ), ] @@ -252,41 +252,41 @@ def test_range_fail(bad_code, error_type, message, source_code): with pytest.raises(error_type) as exc_info: compiler.compile_code(bad_code) assert message == exc_info.value.message - assert source_code == exc_info.value.args[1].node_source_code + assert source_code == exc_info.value.args[1].get_original_node().node_source_code valid_list = [ """ @external def foo(): - for i in range(10): + for i: uint256 in range(10): pass """, """ @external def foo(): - for i in range(10, 20): + for i: uint256 in range(10, 20): pass """, """ @external def foo(): x: int128 = 5 - for i in range(1, x, bound=4): + for i: int128 in range(1, x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(x, bound=4): + for i: int128 in range(x, bound=4): pass """, """ @external def foo(): x: int128 = 5 - for i in range(0, x, bound=4): + for i: int128 in range(0, x, bound=4): pass """, """ @@ -295,7 +295,7 @@ def kick(): nonpayable foos: Foo[3] @external def kick_foos(): - for foo in self.foos: + for foo: Foo in self.foos: foo.kick() """, ] diff --git a/tests/functional/syntax/test_list.py b/tests/functional/syntax/test_list.py index db41de5526..3936f8c220 100644 --- a/tests/functional/syntax/test_list.py +++ b/tests/functional/syntax/test_list.py @@ -306,7 +306,7 @@ def foo(): @external def foo(): x: DynArray[uint256, 3] = [1, 2, 3] - for i in [[], []]: + for i: DynArray[uint256, 3] in [[], []]: x = i """, ] diff --git a/tests/unit/ast/nodes/test_hex.py b/tests/unit/ast/nodes/test_hex.py index d413340083..a6bc3147e6 100644 --- a/tests/unit/ast/nodes/test_hex.py +++ b/tests/unit/ast/nodes/test_hex.py @@ -24,7 +24,7 @@ def foo(): """ @external def foo(): - for i in [0x6b175474e89094c44da98b954eedeac495271d0F]: + for i: address in [0x6b175474e89094c44da98b954eedeac495271d0F]: pass """, """ diff --git a/tests/unit/ast/test_annotate_and_optimize_ast.py b/tests/unit/ast/test_annotate_and_optimize_ast.py index 16ce6fe631..b202f6d8a3 100644 --- a/tests/unit/ast/test_annotate_and_optimize_ast.py +++ b/tests/unit/ast/test_annotate_and_optimize_ast.py @@ -28,10 +28,10 @@ def foo() -> int128: def get_contract_info(source_code): - _, class_types, reformatted_code = pre_parse(source_code) + _, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code) py_ast = python_ast.parse(reformatted_code) - annotate_python_ast(py_ast, reformatted_code, class_types) + annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types) return py_ast, reformatted_code diff --git a/tests/unit/ast/test_pre_parser.py b/tests/unit/ast/test_pre_parser.py index 682c13ca84..020e83627c 100644 --- a/tests/unit/ast/test_pre_parser.py +++ b/tests/unit/ast/test_pre_parser.py @@ -173,7 +173,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version): @pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples) def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version): mock_version("0.3.10") - settings, _, _ = pre_parse(code) + settings, _, _, _ = pre_parse(code) assert settings == pre_parse_settings diff --git a/tests/unit/compiler/asm/test_asm_optimizer.py b/tests/unit/compiler/asm/test_asm_optimizer.py index 44b823757c..b2851e908a 100644 --- a/tests/unit/compiler/asm/test_asm_optimizer.py +++ b/tests/unit/compiler/asm/test_asm_optimizer.py @@ -58,7 +58,7 @@ def ctor_only(): @internal def runtime_only(): - for i in range(10): + for i: uint256 in range(10): self.s += 1 @external diff --git a/tests/unit/compiler/test_source_map.py b/tests/unit/compiler/test_source_map.py index c9a152b09c..5b478dd2aa 100644 --- a/tests/unit/compiler/test_source_map.py +++ b/tests/unit/compiler/test_source_map.py @@ -6,7 +6,7 @@ @internal def _baz(a: int128) -> int128: b: int128 = a - for i in range(2, 5): + for i: int128 in range(2, 5): b *= i if b > 31337: break diff --git a/tests/unit/semantics/analysis/test_for_loop.py b/tests/unit/semantics/analysis/test_for_loop.py index e2c0f555af..607587cc28 100644 --- a/tests/unit/semantics/analysis/test_for_loop.py +++ b/tests/unit/semantics/analysis/test_for_loop.py @@ -22,7 +22,7 @@ def foo(): @internal def bar(): self.foo() - for i in self.a: + for i: uint256 in self.a: pass """ vyper_module = parse_to_ast(code) @@ -42,7 +42,7 @@ def foo(a: uint256[3]) -> uint256[3]: @internal def bar(): a: uint256[3] = [1,2,3] - for i in a: + for i: uint256 in a: self.foo(a) """ vyper_module = parse_to_ast(code) @@ -56,7 +56,7 @@ def test_modify_iterator(dummy_input_bundle): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.a[0] = 1 """ vyper_module = parse_to_ast(code) @@ -70,7 +70,7 @@ def test_bad_keywords(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, boundddd=10): + for i: uint256 in range(n, boundddd=10): x += i """ vyper_module = parse_to_ast(code) @@ -84,7 +84,7 @@ def test_bad_bound(dummy_input_bundle): @internal def bar(n: uint256): x: uint256 = 0 - for i in range(n, bound=n): + for i: uint256 in range(n, bound=n): x += i """ vyper_module = parse_to_ast(code) @@ -103,7 +103,7 @@ def foo(): @internal def bar(): - for i in self.a: + for i: uint256 in self.a: self.foo() """ vyper_module = parse_to_ast(code) @@ -126,7 +126,7 @@ def bar(): @internal def baz(): - for i in self.a: + for i: uint256 in self.a: self.bar() """ vyper_module = parse_to_ast(code) @@ -138,32 +138,32 @@ def baz(): """ @external def main(): - for j in range(3): + for j: uint256 in range(3): x: uint256 = j y: uint16 = j """, # GH issue 3212 """ @external def foo(): - for i in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1]: - for j in [1]: - a:uint256 = i - b:uint16 = i + for i: uint256 in [1]: + for j: uint256 in [1]: + a: uint256 = i + b: uint16 = i """, # GH issue 3374 """ @external def foo(): - for i in [1,2,3]: - for j in [1,2,3]: - b:uint256 = j + i - c:uint16 = i + for i: uint256 in [1,2,3]: + for j: uint256 in [1,2,3]: + b: uint256 = j + i + c: uint16 = i """, # GH issue 3374 ] diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark index 7889473b19..234e96e552 100644 --- a/vyper/ast/grammar.lark +++ b/vyper/ast/grammar.lark @@ -178,8 +178,7 @@ body: _NEWLINE _INDENT ([COMMENT] _NEWLINE | _stmt)+ _DEDENT cond_exec: _expr ":" body default_exec: body if_stmt: "if" cond_exec ("elif" cond_exec)* ["else" ":" default_exec] -// TODO: make this into a variable definition e.g. `for i: uint256 in range(0, 5): ...` -loop_variable: NAME [":" NAME] +loop_variable: NAME ":" type loop_iterator: _expr for_stmt: "for" loop_variable "in" loop_iterator ":" body diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index efab5117d4..7a8c7443b7 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -24,7 +24,15 @@ ) from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code -NODE_BASE_ATTRIBUTES = ("_children", "_depth", "_parent", "ast_type", "node_id", "_metadata") +NODE_BASE_ATTRIBUTES = ( + "_children", + "_depth", + "_parent", + "ast_type", + "node_id", + "_metadata", + "_original_node", +) NODE_SRC_ATTRIBUTES = ( "col_offset", "end_col_offset", @@ -257,6 +265,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict): self.set_parent(parent) self._children: set = set() self._metadata: NodeMetadata = NodeMetadata() + self._original_node = None for field_name in NODE_SRC_ATTRIBUTES: # when a source offset is not available, use the parent's source offset @@ -411,12 +420,16 @@ def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once assert "folded_value" not in self._metadata - # set the folded node's parent so that get_ancestor works - # this is mainly important for error messages. - node._parent = self._parent + # set the "original node" so that exceptions can point to the original + # node and not the folded node + node = copy.copy(node) + node._original_node = self self._metadata["folded_value"] = node + def get_original_node(self) -> "VyperNode": + return self._original_node or self + def _try_fold(self) -> "VyperNode": """ Attempt to constant-fold the content of a node, returning the result of @@ -1546,7 +1559,7 @@ class IfExp(ExprNode): class For(Stmt): - __slots__ = ("iter", "target", "body") + __slots__ = ("target", "iter", "body") _only_empty_fields = ("orelse",) diff --git a/vyper/ast/parse.py b/vyper/ast/parse.py index 38a9d31695..b657cf2245 100644 --- a/vyper/ast/parse.py +++ b/vyper/ast/parse.py @@ -54,7 +54,7 @@ def parse_to_ast_with_settings( """ if "\x00" in source_code: raise ParserException("No null bytes (\\x00) allowed in the source code.") - settings, class_types, reformatted_code = pre_parse(source_code) + settings, class_types, for_loop_annotations, reformatted_code = pre_parse(source_code) try: py_ast = python_ast.parse(reformatted_code) except SyntaxError as e: @@ -73,11 +73,15 @@ def parse_to_ast_with_settings( py_ast, source_code, class_types, + for_loop_annotations, source_id, module_path=module_path, resolved_path=resolved_path, ) + # postcondition: consumed all the for loop annotations + assert len(for_loop_annotations) == 0 + # Convert to Vyper AST. module = vy_ast.get_node(py_ast) assert isinstance(module, vy_ast.Module) # mypy hint @@ -113,11 +117,13 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]: class AnnotatingVisitor(python_ast.NodeTransformer): _source_code: str _modification_offsets: ModificationOffsets + _loop_var_annotations: dict[int, dict[str, Any]] def __init__( self, source_code: str, - modification_offsets: Optional[ModificationOffsets], + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, tokens: asttokens.ASTTokens, source_id: int, module_path: Optional[str] = None, @@ -127,11 +133,11 @@ def __init__( self._source_id = source_id self._module_path = module_path self._resolved_path = resolved_path - self._source_code: str = source_code + self._source_code = source_code + self._modification_offsets = modification_offsets + self._for_loop_annotations = for_loop_annotations + self.counter: int = 0 - self._modification_offsets = {} - if modification_offsets is not None: - self._modification_offsets = modification_offsets def generic_visit(self, node): """ @@ -213,6 +219,47 @@ def visit_ClassDef(self, node): node.ast_type = self._modification_offsets[(node.lineno, node.col_offset)] return node + def visit_For(self, node): + """ + Visit a For node, splicing in the loop variable annotation provided by + the pre-parser + """ + raw_annotation = self._for_loop_annotations.pop((node.lineno, node.col_offset)) + + if not raw_annotation: + # a common case for people migrating to 0.4.0, provide a more + # specific error message than "invalid type annotation" + raise SyntaxException( + "missing type annotation\n\n" + "(hint: did you mean something like " + f"`for {node.target.id}: uint256 in ...`?)\n", + self._source_code, + node.lineno, + node.col_offset, + ) + + try: + annotation = python_ast.parse(raw_annotation, mode="eval") + # annotate with token and source code information. `first_token` + # and `last_token` attributes are accessed in `generic_visit`. + tokens = asttokens.ASTTokens(raw_annotation) + tokens.mark_tokens(annotation) + except SyntaxError as e: + raise SyntaxException( + "invalid type annotation", self._source_code, node.lineno, node.col_offset + ) from e + + assert isinstance(annotation, python_ast.Expression) + annotation = annotation.body + + old_target = node.target + new_target = python_ast.AnnAssign(target=old_target, annotation=annotation, simple=1) + node.target = new_target + + self.generic_visit(node) + + return node + def visit_Expr(self, node): """ Convert the `Yield` node into a Vyper-specific node type. @@ -355,7 +402,8 @@ def visit_UnaryOp(self, node): def annotate_python_ast( parsed_ast: python_ast.AST, source_code: str, - modification_offsets: Optional[ModificationOffsets] = None, + modification_offsets: ModificationOffsets, + for_loop_annotations: dict, source_id: int = 0, module_path: Optional[str] = None, resolved_path: Optional[str] = None, @@ -369,6 +417,9 @@ def annotate_python_ast( The AST to be annotated and optimized. source_code : str The originating source code of the AST. + loop_var_annotations: dict, optional + A mapping of line numbers of `For` nodes to the type annotation of the iterator + extracted during pre-parsing. modification_offsets : dict, optional A mapping of class names to their original class types. @@ -381,6 +432,7 @@ def annotate_python_ast( visitor = AnnotatingVisitor( source_code, modification_offsets, + for_loop_annotations, tokens, source_id, module_path=module_path, diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py index b949a242bb..c7e6f3698f 100644 --- a/vyper/ast/pre_parser.py +++ b/vyper/ast/pre_parser.py @@ -1,3 +1,4 @@ +import enum import io import re from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize @@ -43,6 +44,64 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: ) +class ForParserState(enum.Enum): + NOT_RUNNING = enum.auto() + START_SOON = enum.auto() + RUNNING = enum.auto() + + +# a simple state machine which allows us to handle loop variable annotations +# (which are rejected by the python parser due to pep-526, so we scoop up the +# tokens between `:` and `in` and parse them and add them back in later). +class ForParser: + def __init__(self, code): + self._code = code + self.annotations = {} + self._current_annotation = None + + self._state = ForParserState.NOT_RUNNING + self._current_for_loop = None + + def consume(self, token): + # state machine: we can start slurping tokens soon + if token.type == NAME and token.string == "for": + # note: self._state should be NOT_RUNNING here, but we don't sanity + # check here as that should be an error the parser will handle. + self._state = ForParserState.START_SOON + self._current_for_loop = token.start + + if self._state == ForParserState.NOT_RUNNING: + return False + + # state machine: start slurping tokens + if token.type == OP and token.string == ":": + self._state = ForParserState.RUNNING + + # sanity check -- this should never really happen, but if it does, + # try to raise an exception which pinpoints the source. + if self._current_annotation is not None: + raise SyntaxException( + "for loop parse error", self._code, token.start[0], token.start[1] + ) + + self._current_annotation = [] + return True # do not add ":" to tokens. + + # state machine: end slurping tokens + if token.type == NAME and token.string == "in": + self._state = ForParserState.NOT_RUNNING + self.annotations[self._current_for_loop] = self._current_annotation or [] + self._current_annotation = None + return False + + if self._state != ForParserState.RUNNING: + return False + + # slurp the token + self._current_annotation.append(token) + return True + + # compound statements that are replaced with `class` # TODO remove enum in favor of flag VYPER_CLASS_TYPES = {"flag", "enum", "event", "interface", "struct"} @@ -51,7 +110,7 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None: VYPER_EXPRESSION_TYPES = {"log"} -def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: +def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]: """ Re-formats a vyper source string into a python source string and performs some validation. More specifically, @@ -60,9 +119,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: * Validates "@version" pragma against current compiler version * Prevents direct use of python "class" keyword * Prevents use of python semi-colon statement separator + * Extracts type annotation of for loop iterators into a separate dictionary Also returns a mapping of detected interface and struct names to their - respective vyper class types ("interface" or "struct"). + respective vyper class types ("interface" or "struct"), and a mapping of line numbers + of for loops to the type annotation of their iterators. Parameters ---------- @@ -71,21 +132,25 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: Returns ------- - dict - Mapping of offsets where source was modified. + Settings + Compilation settings based on the directives in the source code + ModificationOffsets + A mapping of class names to their original class types. + dict[tuple[int, int], str] + A mapping of line/column offsets of `For` nodes to the annotation of the for loop target str Reformatted python source string. """ result = [] modification_offsets: ModificationOffsets = {} settings = Settings() + for_parser = ForParser(code) try: code_bytes = code.encode("utf-8") token_list = list(tokenize(io.BytesIO(code_bytes).readline)) - for i in range(len(token_list)): - token = token_list[i] + for token in token_list: toks = [token] typ = token.type @@ -146,8 +211,18 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]: if (typ, string) == (OP, ";"): raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1]) - result.extend(toks) + + if not for_parser.consume(token): + result.extend(toks) + except TokenError as e: raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e - return settings, modification_offsets, untokenize(result).decode("utf-8") + for_loop_annotations = {} + for k, v in for_parser.annotations.items(): + v_source = untokenize(v) + # untokenize adds backslashes and whitespace, strip them. + v_source = v_source.replace("\\", "").strip() + for_loop_annotations[k] = v_source + + return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8") diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index c896fc7ef6..39d97c4abe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -2157,7 +2157,7 @@ def build_IR(self, expr, args, kwargs, context): z = x / 2.0 + 0.5 y: decimal = x - for i in range(256): + for i: uint256 in range(256): if z == y: break y = z diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py index bc29a79734..a47faefeb1 100644 --- a/vyper/codegen/stmt.py +++ b/vyper/codegen/stmt.py @@ -33,7 +33,7 @@ ) from vyper.semantics.types import DArrayT, MemberFunctionT from vyper.semantics.types.function import ContractFunctionT -from vyper.semantics.types.shortcuts import INT256_T, UINT256_T +from vyper.semantics.types.shortcuts import UINT256_T class Stmt: @@ -231,19 +231,17 @@ def parse_For(self): return self._parse_For_list() def _parse_For_range(self): - # TODO make sure type always gets annotated - if "type" in self.stmt.target._metadata: - iter_typ = self.stmt.target._metadata["type"] - else: - iter_typ = INT256_T + assert "type" in self.stmt.target.target._metadata + target_type = self.stmt.target.target._metadata["type"] # Get arg0 - for_iter: vy_ast.Call = self.stmt.iter - args_len = len(for_iter.args) + range_call: vy_ast.Call = self.stmt.iter + assert isinstance(range_call, vy_ast.Call) + args_len = len(range_call.args) if args_len == 1: - arg0, arg1 = (IRnode.from_list(0, typ=iter_typ), for_iter.args[0]) + arg0, arg1 = (IRnode.from_list(0, typ=target_type), range_call.args[0]) elif args_len == 2: - arg0, arg1 = for_iter.args + arg0, arg1 = range_call.args else: # pragma: nocover raise TypeCheckFailure("unreachable: bad # of arguments to range()") @@ -251,7 +249,7 @@ def _parse_For_range(self): start = Expr.parse_value_expr(arg0, self.context) end = Expr.parse_value_expr(arg1, self.context) kwargs = { - s.arg: Expr.parse_value_expr(s.value, self.context) for s in for_iter.keywords + s.arg: Expr.parse_value_expr(s.value, self.context) for s in range_call.keywords } if "bound" in kwargs: @@ -270,9 +268,9 @@ def _parse_For_range(self): if rounds_bound < 1: # pragma: nocover raise TypeCheckFailure("unreachable: unchecked 0 bound") - varname = self.stmt.target.id - i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=UINT256_T) - iptr = self.context.new_variable(varname, iter_typ) + varname = self.stmt.target.target.id + i = IRnode.from_list(self.context.fresh_varname("range_ix"), typ=target_type) + iptr = self.context.new_variable(varname, target_type) self.context.forvars[varname] = True @@ -297,11 +295,11 @@ def _parse_For_list(self): with self.context.range_scope(): iter_list = Expr(self.stmt.iter, self.context).ir_node - target_type = self.stmt.target._metadata["type"] + target_type = self.stmt.target.target._metadata["type"] assert target_type == iter_list.typ.value_type # user-supplied name for loop variable - varname = self.stmt.target.id + varname = self.stmt.target.target.id loop_var = IRnode.from_list( self.context.new_variable(varname, target_type), typ=target_type, location=MEMORY ) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index f216069eab..51f3fea14c 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -92,6 +92,10 @@ def __str__(self): node = value[1] if isinstance(value, tuple) else value node_msg = "" + if isinstance(node, vy_ast.VyperNode): + # folded AST nodes contain pointers to the original source + node = node.get_original_node() + try: source_annotation = annotate_source_code( # add trailing space because EOF exceptions point one char beyond the length diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 91fb2c21f0..169c71269d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,13 +1,11 @@ from typing import Optional from vyper import ast as vy_ast -from vyper.ast.metadata import NodeMetadata from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ExceptionList, FunctionDeclarationException, ImmutableViolation, - InvalidOperation, InvalidType, IteratorException, NonPayableViolation, @@ -40,7 +38,6 @@ EventT, FlagT, HashMapT, - IntegerT, SArrayT, StringT, StructT, @@ -347,8 +344,10 @@ def visit_Expr(self, node): self.expr_visitor.visit(node.value, fn_type) def visit_For(self, node): - if isinstance(node.iter, vy_ast.Subscript): - raise StructureException("Cannot iterate over a nested list", node.iter) + if not isinstance(node.target.target, vy_ast.Name): + raise StructureException("Invalid syntax for loop iterator", node.target.target) + + target_type = type_from_annotation(node.target.annotation, DataLocation.MEMORY) if isinstance(node.iter, vy_ast.Call): # iteration via range() @@ -356,7 +355,7 @@ def visit_For(self, node): raise IteratorException( "Cannot iterate over the result of a function call", node.iter ) - type_list = _analyse_range_call(node.iter) + _validate_range_call(node.iter) else: # iteration over a variable or literal list @@ -364,14 +363,10 @@ def visit_For(self, node): if isinstance(iter_val, vy_ast.List) and len(iter_val.elements) == 0: raise StructureException("For loop must have at least 1 iteration", node.iter) - type_list = [ - i.value_type - for i in get_possible_types_from_node(node.iter) - if isinstance(i, (DArrayT, SArrayT)) - ] - - if not type_list: - raise InvalidType("Not an iterable type", node.iter) + if not any( + isinstance(i, (DArrayT, SArrayT)) for i in get_possible_types_from_node(node.iter) + ): + raise InvalidType("Not an iterable type", node.iter) if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): # check for references to the iterated value within the body of the loop @@ -415,65 +410,28 @@ def visit_For(self, node): call_node, ) - if not isinstance(node.target, vy_ast.Name): - raise StructureException("Invalid syntax for loop iterator", node.target) + target_name = node.target.target.id + with self.namespace.enter_scope(): + self.namespace[target_name] = VarInfo( + target_type, modifiability=Modifiability.RUNTIME_CONSTANT + ) - for_loop_exceptions = [] - iter_name = node.target.id - for possible_target_type in type_list: - # type check the for loop body using each possible type for iterator value + for stmt in node.body: + self.visit(stmt) - with self.namespace.enter_scope(): - self.namespace[iter_name] = VarInfo( - possible_target_type, modifiability=Modifiability.RUNTIME_CONSTANT - ) + self.expr_visitor.visit(node.target.target, target_type) - try: - with NodeMetadata.enter_typechecker_speculation(): - for stmt in node.body: - self.visit(stmt) - - self.expr_visitor.visit(node.target, possible_target_type) - - if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)): - iter_type = get_exact_type_from_node(node.iter) - # note CMC 2023-10-23: slightly redundant with how type_list is computed - validate_expected_type(node.target, iter_type.value_type) - self.expr_visitor.visit(node.iter, iter_type) - if isinstance(node.iter, vy_ast.List): - len_ = len(node.iter.elements) - self.expr_visitor.visit(node.iter, SArrayT(possible_target_type, len_)) - if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": - for a in node.iter.args: - self.expr_visitor.visit(a, possible_target_type) - for a in node.iter.keywords: - if a.arg == "bound": - self.expr_visitor.visit(a.value, possible_target_type) - - except (TypeMismatch, InvalidOperation) as exc: - for_loop_exceptions.append(exc) - else: - # success -- do not enter error handling section - return - - # failed to find a good type. bail out - if len(set(str(i) for i in for_loop_exceptions)) == 1: - # if every attempt at type checking raised the same exception - raise for_loop_exceptions[0] - - # return an aggregate TypeMismatch that shows all possible exceptions - # depending on which type is used - types_str = [str(i) for i in type_list] - given_str = f"{', '.join(types_str[:1])} or {types_str[-1]}" - raise TypeMismatch( - f"Iterator value '{iter_name}' may be cast as {given_str}, " - "but type checking fails with all possible types:", - node, - *( - (f"Casting '{iter_name}' as {typ}: {exc.message}", exc.annotations[0]) - for typ, exc in zip(type_list, for_loop_exceptions) - ), - ) + if isinstance(node.iter, vy_ast.List): + len_ = len(node.iter.elements) + self.expr_visitor.visit(node.iter, SArrayT(target_type, len_)) + elif isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range": + args = node.iter.args + kwargs = [s.value for s in node.iter.keywords] + for arg in (*args, *kwargs): + self.expr_visitor.visit(arg, target_type) + else: + iter_type = get_exact_type_from_node(node.iter) + self.expr_visitor.visit(node.iter, iter_type) def visit_If(self, node): validate_expected_type(node.test, BoolT()) @@ -750,25 +708,18 @@ def visit_IfExp(self, node: vy_ast.IfExp, typ: VyperType) -> None: self.visit(node.orelse, typ) -def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: +def _validate_range_call(node: vy_ast.Call): """ Check that the arguments to a range() call are valid. :param node: call to range() :return: None """ + assert node.func.get("id") == "range" validate_call_args(node, (1, 2), kwargs=["bound"]) kwargs = {s.arg: s.value for s in node.keywords or []} start, end = (vy_ast.Int(value=0), node.args[0]) if len(node.args) == 1 else node.args start, end = [i.get_folded_value() if i.has_folded_value else i for i in (start, end)] - all_args = (start, end, *kwargs.values()) - for arg1 in all_args: - validate_expected_type(arg1, IntegerT.any()) - - type_list = get_common_types(*all_args) - if not type_list: - raise TypeMismatch("Iterator values are of different types", node) - if "bound" in kwargs: bound = kwargs["bound"] if bound.has_folded_value: @@ -787,5 +738,3 @@ def _analyse_range_call(node: vy_ast.Call) -> list[VyperType]: raise StateAccessViolation(error, arg) if end.value <= start.value: raise StructureException("End must be greater than start", end) - - return type_list