diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 872d8366e1..c8d7f7d6c4 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -1,4 +1,4 @@
-name: Artifacts
+name: Build artifacts
on:
workflow_dispatch:
@@ -6,10 +6,11 @@ on:
tag:
default: ''
push:
- tags:
- - '*'
branches:
- master
+ pull_request:
+ release:
+ types: [published] # releases and pre-releases (release candidates)
defaults:
run:
@@ -20,10 +21,10 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- os: [ubuntu-latest, macos-latest]
+ os: [ubuntu-20.04, macos-latest]
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
with:
# grab the commit passed in via `tag`, if any
ref: ${{ github.event.inputs.tag }}
@@ -31,18 +32,20 @@ jobs:
fetch-depth: 0
- name: Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Generate Binary
run: >-
- pip install . &&
+ pip install --no-binary pycryptodome --no-binary cbor2 . &&
pip install pyinstaller &&
make freeze
+
- name: Upload Artifact
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
path: dist/vyper.*
@@ -50,7 +53,7 @@ jobs:
runs-on: windows-latest
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v3
with:
# grab the commit passed in via `tag`, if any
ref: ${{ github.event.inputs.tag }}
@@ -58,9 +61,10 @@ jobs:
fetch-depth: 0
- name: Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Generate Binary
run: >-
@@ -69,6 +73,43 @@ jobs:
./make.cmd freeze
- name: Upload Artifact
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v3
with:
path: dist/vyper.*
+
+ publish-release-assets:
+ needs: [windows-build, unix-build]
+ if: ${{ github.event_name == 'release' }}
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/download-artifact@v3
+ with:
+ path: artifacts/
+
+ - name: Upload assets
+ # fun - artifacts are downloaded into "artifact/".
+ working-directory: artifacts/artifact
+ run: |
+ set -Eeuxo pipefail
+ for BIN_NAME in $(ls)
+ do
+ curl -L \
+ --no-progress-meter \
+ -X POST \
+ -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}"\
+ -H "Content-Type: application/octet-stream" \
+ "https://uploads.github.com/repos/${{ github.repository }}/releases/${{ github.event.release.id }}/assets?name=${BIN_NAME/+/%2B}" \
+ --data-binary "@${BIN_NAME}"
+ done
+
+ # check build success for pull requests
+ build-success:
+ if: always()
+ runs-on: ubuntu-latest
+ needs: [windows-build, unix-build]
+ steps:
+ - name: check that all builds succeeded
+ if: ${{ contains(needs.*.result, 'failure') }}
+ run: exit 1
diff --git a/.github/workflows/era-tester.yml b/.github/workflows/era-tester.yml
new file mode 100644
index 0000000000..3e0bb3e941
--- /dev/null
+++ b/.github/workflows/era-tester.yml
@@ -0,0 +1,117 @@
+name: Era compiler tester
+
+# run the matter labs compiler test to integrate their test cases
+# this is intended as a diagnostic / spot check to check that we
+# haven't seriously broken the compiler. but, it is not intended as
+# a requirement for merging since we may make changes to our IR
+# which break the downstream backend (at which point, downstream needs
+# to update, which we do not want to be blocked on).
+
+on: [push, pull_request]
+
+concurrency:
+ # cancel older, in-progress jobs from the same PR, same workflow.
+ # use run_id if the job is triggered by a push to ensure
+ # push-triggered jobs to not get canceled.
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ era-compiler-tester:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Get latest commit hash
+ run: |
+ echo "ERA_HASH=$( curl -u "u:${{ github.token }}" https://api.github.com/repos/matter-labs/era-compiler-tester/git/ref/heads/main | jq .object.sha | tr -d '"' )" >> $GITHUB_ENV
+ echo "ERA_VYPER_HASH=$( curl -u "u:${{ github.token }}" https://api.github.com/repos/matter-labs/era-compiler-vyper/git/ref/heads/main | jq .object.sha | tr -d '"' )" >> $GITHUB_ENV
+
+ - name: Checkout
+ uses: actions/checkout@v1
+
+ - name: Rust setup
+ uses: actions-rs/toolchain@v1
+ with:
+ toolchain: nightly-2022-11-03
+
+ - name: Set up Python ${{ matrix.python-version[0] }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version[0] }}
+ cache: "pip"
+
+ - name: Get cache
+ id: get-cache
+ uses: actions/cache@v3
+ with:
+ path: |
+ ~/.cargo/bin/
+ ~/.cargo/registry/index/
+ ~/.cargo/registry/cache/
+ ~/.cargo/git/db/
+ **/target
+ **/target-llvm
+ **/compiler_tester
+ **/llvm
+ **/era-compiler-tester
+ key: ${{ runner.os }}-${{ env.ERA_HASH }}-${{ env.ERA_VYPER_HASH }}
+
+ - name: Initialize repository and install dependencies
+ if: steps.get-cache.outputs.cache-hit != 'true'
+ run: |
+ git clone --depth 1 https://github.com/matter-labs/era-compiler-tester.git
+ cd era-compiler-tester
+ sed -i 's/ssh:\/\/git@/https:\/\//g' .gitmodules
+ git submodule init
+ git submodule update
+ sudo apt install cmake ninja-build clang-13 lld-13 parallel pkg-config lld
+ cargo install compiler-llvm-builder
+ zkevm-llvm clone && zkevm-llvm build
+ cargo build --release
+
+ - name: Save cache
+ uses: actions/cache/save@v3
+ if: steps.get-cache.outputs.cache-hit != 'true'
+ with:
+ path: |
+ ~/.cargo/bin/
+ ~/.cargo/registry/index/
+ ~/.cargo/registry/cache/
+ ~/.cargo/git/db/
+ **/target
+ **/target-llvm
+ **/compiler_tester
+ **/llvm
+ **/era-compiler-tester
+ key: ${{ runner.os }}-${{ env.ERA_HASH }}-${{ env.ERA_VYPER_HASH }}
+
+ - name: Build Vyper
+ run: |
+ set -Eeuxo pipefail
+ pip install .
+ echo "VYPER_VERSION=$(vyper --version | cut -f1 -d'+')" >> $GITHUB_ENV
+
+ - name: Install Vyper
+ run: |
+ mkdir era-compiler-tester/vyper-bin
+ cp $(which vyper) era-compiler-tester/vyper-bin/vyper-${{ env.VYPER_VERSION }}
+
+ - name: Run tester (fast)
+ # Run era tester with no LLVM optimizations
+ continue-on-error: true
+ if: ${{ github.ref != 'refs/heads/master' }}
+ run: |
+ cd era-compiler-tester
+ cargo run --release --bin compiler-tester -- --path=tests/vyper/ --mode="M0B0 ${{ env.VYPER_VERSION }}"
+
+ - name: Run tester (slow)
+ # Run era tester across the LLVM optimization matrix
+ continue-on-error: true
+ if: ${{ github.ref == 'refs/heads/master' }}
+ run: |
+ cd era-compiler-tester
+ cargo run --release --bin compiler-tester -- --path=tests/vyper/ --mode="M*B* ${{ env.VYPER_VERSION }}"
+
+ - name: Mark as success
+ run: |
+ exit 0
diff --git a/.github/workflows/ghcr.yml b/.github/workflows/ghcr.yml
new file mode 100644
index 0000000000..a35a22e278
--- /dev/null
+++ b/.github/workflows/ghcr.yml
@@ -0,0 +1,75 @@
+name: Deploy docker image to ghcr
+
+# Deploy docker image to ghcr on pushes to master and all releases/tags.
+# Note releases to docker hub are managed separately in another process
+# (github sends webhooks to docker hub which triggers the build there).
+# This workflow is an alternative form of retention for docker images
+# which also allows us to tag and retain every single commit to master.
+
+on:
+ push:
+ branches:
+ - master
+ release:
+ types: [released]
+
+env:
+ REGISTRY: ghcr.io
+ IMAGE_NAME: ${{ github.repository }}
+
+jobs:
+ deploy-ghcr:
+
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ # need to fetch unshallow so that setuptools_scm can infer the version
+ fetch-depth: 0
+
+ - uses: actions/setup-python@v4
+ name: Install python
+ with:
+ python-version: "3.11"
+ cache: "pip"
+
+ - name: Generate vyper/version.py
+ run: |
+ pip install .
+ echo "VYPER_VERSION=$(PYTHONPATH=. python vyper/cli/vyper_compile.py --version)" >> "$GITHUB_ENV"
+
+ - name: generate tag suffix
+ if: ${{ github.event_name != 'release' }}
+ run: echo "VERSION_SUFFIX=-dev" >> "$GITHUB_ENV"
+
+ - name: Docker meta
+ id: meta
+ uses: docker/metadata-action@v4
+ with:
+ images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
+ tags: |
+ type=ref,event=tag
+ type=raw,value=${{ env.VYPER_VERSION }}${{ env.VERSION_SUFFIX }}
+ type=raw,value=dev,enable=${{ github.ref == 'refs/heads/master' }}
+ type=raw,value=latest,enable=${{ github.event_name == 'release' }}
+
+
+ - name: Login to ghcr.io
+ uses: docker/login-action@v2
+ with:
+ registry: ${{ env.REGISTRY }}
+ username: ${{ github.actor }}
+ password: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Build and push
+ uses: docker/build-push-action@v4
+ with:
+ context: .
+ push: true
+ tags: ${{ steps.meta.outputs.tags }}
+ labels: ${{ steps.meta.outputs.labels }}
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index ecc2fa5e1f..f268942e7d 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -1,11 +1,11 @@
-# This workflows will upload a Python Package using Twine when a release is created
+# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
-name: Publish
+name: Publish to PyPI
on:
release:
- types: [released]
+ types: [published] # releases and pre-releases (release candidates)
jobs:
@@ -16,9 +16,9 @@ jobs:
- uses: actions/checkout@v2
- name: Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v4
with:
- python-version: '3.x'
+ python-version: "3.11"
- name: Install dependencies
run: |
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f8fcb54f7d..fd78e2fff8 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,4 +1,4 @@
-name: Test
+name: Run test suite
on: [push, pull_request]
@@ -17,10 +17,11 @@ jobs:
steps:
- uses: actions/checkout@v1
- - name: Set up Python 3.10
- uses: actions/setup-python@v1
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Install Dependencies
run: pip install .[lint]
@@ -42,10 +43,11 @@ jobs:
steps:
- uses: actions/checkout@v1
- - name: Set up Python 3.10
- uses: actions/setup-python@v1
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Install Tox
run: pip install tox
@@ -59,10 +61,11 @@ jobs:
steps:
- uses: actions/checkout@v1
- - name: Set up Python 3.10
- uses: actions/setup-python@v1
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Install Tox
run: pip install tox
@@ -75,11 +78,18 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [["3.10", "310", ["3.11", "311"]]]
- # run in default (optimized) and --no-optimize mode
- flag: ["core", "no-opt"]
-
- name: py${{ matrix.python-version[1] }}-${{ matrix.flag }}
+ python-version: [["3.11", "311"]]
+ # run in modes: --optimize [gas, none, codesize]
+ opt-mode: ["gas", "none", "codesize"]
+ debug: [true, false]
+ # run across other python versions.# we don't really need to run all
+ # modes across all python versions - one is enough
+ include:
+ - python-version: ["3.10", "310"]
+ opt-mode: gas
+ debug: false
+
+ name: py${{ matrix.python-version[1] }}-opt-${{ matrix.opt-mode }}${{ matrix.debug && '-debug' || '' }}
steps:
- uses: actions/checkout@v1
@@ -88,12 +98,13 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version[0] }}
+ cache: "pip"
- name: Install Tox
run: pip install tox
- name: Run Tox
- run: TOXENV=py${{ matrix.python-version[1] }}-${{ matrix.flag }} tox -r -- --reruns 10 --reruns-delay 1 -r aR tests/
+ run: TOXENV=py${{ matrix.python-version[1] }} tox -r -- --optimize ${{ matrix.opt-mode }} ${{ matrix.debug && '--enable-compiler-debug-mode' || '' }} --reruns 10 --reruns-delay 1 -r aR tests/
- name: Upload Coverage
uses: codecov/codecov-action@v1
@@ -126,10 +137,11 @@ jobs:
steps:
- uses: actions/checkout@v1
- - name: Set up Python 3.10
- uses: actions/setup-python@v1
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Install Tox
run: pip install tox
@@ -138,7 +150,7 @@ jobs:
# NOTE: if the tests get poorly distributed, run this and commit the resulting `.test_durations` file to the `vyper-test-durations` repo.
# `TOXENV=fuzzing tox -r -- --store-durations --reruns 10 --reruns-delay 1 -r aR tests/`
- name: Fetch test-durations
- run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/ac71e77863d7f4e7e7cd19a93cf50a8c39de4845/test_durations" -o .test_durations
+ run: curl --location "https://raw.githubusercontent.com/vyperlang/vyper-test-durations/5982755ee8459f771f2e8622427c36494646e1dd/test_durations" -o .test_durations
- name: Run Tox
run: TOXENV=fuzzing tox -r -- --splits 60 --group ${{ matrix.group }} --splitting-algorithm least_duration --reruns 10 --reruns-delay 1 -r aR tests/
@@ -167,10 +179,11 @@ jobs:
steps:
- uses: actions/checkout@v1
- - name: Set up Python 3.10
- uses: actions/setup-python@v1
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v4
with:
- python-version: "3.10"
+ python-version: "3.11"
+ cache: "pip"
- name: Install Tox
run: pip install tox
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 739e977c96..4b416a4414 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -11,7 +11,7 @@ repos:
- id: black
name: black
-- repo: https://gitlab.com/pycqa/flake8
+- repo: https://github.com/PyCQA/flake8
rev: 3.9.2
hooks:
- id: flake8
diff --git a/Dockerfile b/Dockerfile
index c2245ee981..bc5bb607d6 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,4 +1,4 @@
-FROM python:3.10-slim
+FROM python:3.11-slim
# Specify label-schema specific arguments and labels.
ARG BUILD_DATE
diff --git a/Makefile b/Makefile
index daa1c2bfc9..645b800e79 100644
--- a/Makefile
+++ b/Makefile
@@ -43,7 +43,7 @@ freeze: clean init
echo Generating binary...
export OS="$$(uname -s | tr A-Z a-z)" && \
export VERSION="$$(PYTHONPATH=. python vyper/cli/vyper_compile.py --version)" && \
- pyinstaller --clean --onefile vyper/cli/vyper_compile.py --name "vyper.$${VERSION}.$${OS}" --add-data vyper:vyper
+ pyinstaller --target-architecture=universal2 --clean --onefile vyper/cli/vyper_compile.py --name "vyper.$${VERSION}.$${OS}" --add-data vyper:vyper
clean: clean-build clean-docs clean-pyc clean-test
diff --git a/README.md b/README.md
index f17e693bf5..bad929956d 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,16 @@
+**Vyper compiler security audit competition starts 14th September with $150k worth of bounties.** [See the competition on CodeHawks](https://www.codehawks.com/contests/cll5rujmw0001js08menkj7hc) and find [more details in this blog post](https://mirror.xyz/0xBA41A04A14aeaEec79e2D694B21ba5Ab610982f1/WTZ3l3MLhTz9P4avq6JqipN5d4HJNiUY-d8zT0pfmXg).
-[](https://github.com/vyperlang/vyper/actions)
+[](https://github.com/vyperlang/vyper/actions/workflows/test.yml)
[](http://vyper.readthedocs.io/en/latest/?badge=latest "ReadTheDocs")
[](https://discord.gg/6tw7PTM7C2)
[](https://pypi.org/project/vyper "PyPI")
-[](https://hub.docker.com/r/vyperlang/vyper "DockerHub")
+[](https://hub.docker.com/r/vyperlang/vyper "DockerHub")
[](https://codecov.io/gh/vyperlang/vyper "Codecov")
-[](https://lgtm.com/projects/g/vyperlang/vyper/context:python)
+[](https://github.com/vyperlang/vyper/actions/workflows/codeql.yml)
# Getting Started
See [Installing Vyper](http://vyper.readthedocs.io/en/latest/installing-vyper.html) to install vyper.
diff --git a/docs/built-in-functions.rst b/docs/built-in-functions.rst
index ab3ebb3fff..bfaa8fdd5e 100644
--- a/docs/built-in-functions.rst
+++ b/docs/built-in-functions.rst
@@ -106,6 +106,11 @@ Bitwise Operations
>>> ExampleContract.foo(2, 8)
512
+.. note::
+
+ This function has been deprecated from version 0.3.8 onwards. Please use the ``<<`` and ``>>`` operators instead.
+
+
Chain Interaction
=================
@@ -201,11 +206,11 @@ Vyper has three built-ins for contract creation; all three contract creation bui
.. note::
- To properly deploy a blueprint contract, special deploy bytecode must be used. Deploying blueprint contracts is generally out of scope of this article, but the following preamble, prepended to regular deploy bytecode (output of ``vyper -f bytecode``), should deploy the blueprint in an ordinary contract creation transaction: ``deploy_preamble = "61" + + "3d81600a3d39f3"``. To see an example of this, please see `the setup code for testing create_from_blueprint `_.
+ To properly deploy a blueprint contract, special deploy bytecode must be used. The output of ``vyper -f blueprint_bytecode`` will produce bytecode which deploys an ERC-5202 compatible blueprint.
.. warning::
- It is recommended to deploy blueprints with the ERC5202 preamble ``0xfe7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided to enable this pattern:
+ It is recommended to deploy blueprints with the ERC-5202 preamble ``0xFE7100`` to guard them from being called as regular contracts. This is particularly important for factories where the constructor has side effects (including ``SELFDESTRUCT``!), as those could get executed by *anybody* calling the blueprint contract directly. The ``code_offset=`` kwarg is provided to enable this pattern:
.. code-block:: python
@@ -214,13 +219,13 @@ Vyper has three built-ins for contract creation; all three contract creation bui
# `blueprint` is a blueprint contract with some known preamble b"abcd..."
return create_from_blueprint(blueprint, code_offset=)
-.. py:function:: raw_call(to: address, data: Bytes, max_outsize: int = 0, gas: uint256 = gasLeft, value: uint256 = 0, is_delegate_call: bool = False, is_static_call: bool = False, revert_on_failure: bool = True) -> Bytes[max_outsize]
+.. py:function:: raw_call(to: address, data: Bytes, max_outsize: uint256 = 0, gas: uint256 = gasLeft, value: uint256 = 0, is_delegate_call: bool = False, is_static_call: bool = False, revert_on_failure: bool = True) -> Bytes[max_outsize]
Call to the specified Ethereum address.
* ``to``: Destination address to call to
* ``data``: Data to send to the destination address
- * ``max_outsize``: Maximum length of the bytes array returned from the call. If the returned call data exceeds this length, only this number of bytes is returned.
+ * ``max_outsize``: Maximum length of the bytes array returned from the call. If the returned call data exceeds this length, only this number of bytes is returned. (Optional, default ``0``)
* ``gas``: The amount of gas to attach to the call. If not set, all remaining gas is forwarded.
* ``value``: The wei value to send to the address (Optional, default ``0``)
* ``is_delegate_call``: If ``True``, the call will be sent as ``DELEGATECALL`` (Optional, default ``False``)
@@ -294,6 +299,10 @@ Vyper has three built-ins for contract creation; all three contract creation bui
This method deletes the contract from the blockchain. All non-ether assets associated with this contract are "burned" and the contract is no longer accessible.
+ .. note::
+
+ This function has been deprecated from version 0.3.8 onwards. The underlying opcode will eventually undergo breaking changes, and its use is not recommended.
+
.. code-block:: python
@external
@@ -370,7 +379,11 @@ Cryptography
* ``s``: second 32 bytes of signature
* ``v``: final 1 byte of signature
- Returns the associated address, or ``0`` on error.
+ Returns the associated address, or ``empty(address)`` on error.
+
+ .. note::
+
+ Prior to Vyper ``0.3.10``, the ``ecrecover`` function could return an undefined (possibly nonzero) value for invalid inputs to ``ecrecover``. For more information, please see `GHSA-f5x6-7qgp-jhf3 `_.
.. code-block:: python
@@ -564,6 +577,24 @@ Math
>>> ExampleContract.foo(3.1337)
4
+.. py:function:: epsilon(typename) -> Any
+
+ Returns the smallest non-zero value for a decimal type.
+
+ * ``typename``: Name of the decimal type (currently only ``decimal``)
+
+ .. code-block:: python
+
+ @external
+ @view
+ def foo() -> decimal:
+ return epsilon(decimal)
+
+ .. code-block:: python
+
+ >>> ExampleContract.foo()
+ Decimal('1E-10')
+
.. py:function:: floor(value: decimal) -> int256
Round a decimal down to the nearest integer.
diff --git a/docs/compiling-a-contract.rst b/docs/compiling-a-contract.rst
index 4a03347536..6d1cdf98d7 100644
--- a/docs/compiling-a-contract.rst
+++ b/docs/compiling-a-contract.rst
@@ -99,6 +99,11 @@ See :ref:`searching_for_imports` for more information on Vyper's import system.
Online Compilers
================
+Try VyperLang!
+-----------------
+
+`Try VyperLang! `_ is a JupterHub instance hosted by the Vyper team as a sandbox for developing and testing contracts in Vyper. It requires github for login, and supports deployment via the browser.
+
Remix IDE
---------
@@ -108,23 +113,51 @@ Remix IDE
While the Vyper version of the Remix IDE compiler is updated on a regular basis, it might be a bit behind the latest version found in the master branch of the repository. Make sure the byte code matches the output from your local compiler.
+.. _optimization-mode:
+
+Compiler Optimization Modes
+===========================
+
+The vyper CLI tool accepts an optimization mode ``"none"``, ``"codesize"``, or ``"gas"`` (default). It can be set using the ``--optimize`` flag. For example, invoking ``vyper --optimize codesize MyContract.vy`` will compile the contract, optimizing for code size. As a rough summary of the differences between gas and codesize mode, in gas optimized mode, the compiler will try to generate bytecode which minimizes gas (up to a point), including:
+
+* using a sparse selector table which optimizes for gas over codesize
+* inlining some constants, and
+* trying to unroll some loops, especially for data copies.
+
+In codesize optimized mode, the compiler will try hard to minimize codesize by
+
+* using a dense selector table
+* out-lining code, and
+* using more loops for data copies.
+
+
+.. _evm-version:
Setting the Target EVM Version
==============================
-When you compile your contract code, you can specify the Ethereum Virtual Machine version to compile for, to avoid particular features or behaviours.
+When you compile your contract code, you can specify the target Ethereum Virtual Machine version to compile for, to access or avoid particular features. You can specify the version either with a source code pragma or as a compiler option. It is recommended to use the compiler option when you want flexibility (for instance, ease of deploying across different chains), and the source code pragma when you want bytecode reproducibility (for instance, when verifying code on a block explorer).
+
+.. note::
+ If the evm version specified by the compiler options conflicts with the source code pragma, an exception will be raised and compilation will not continue.
+
+For instance, the adding the following pragma to a contract indicates that it should be compiled for the "shanghai" fork of the EVM.
+
+.. code-block:: python
+
+ #pragma evm-version shanghai
.. warning::
- Compiling for the wrong EVM version can result in wrong, strange and failing behaviour. Please ensure, especially if running a private chain, that you use matching EVM versions.
+ Compiling for the wrong EVM version can result in wrong, strange, or failing behavior. Please ensure, especially if running a private chain, that you use matching EVM versions.
-When compiling via ``vyper``, include the ``--evm-version`` flag:
+When compiling via the ``vyper`` CLI, you can specify the EVM version option using the ``--evm-version`` flag:
::
$ vyper --evm-version [VERSION]
-When using the JSON interface, include the ``"evmVersion"`` key within the ``"settings"`` field:
+When using the JSON interface, you can include the ``"evmVersion"`` key within the ``"settings"`` field:
.. code-block:: javascript
@@ -140,24 +173,32 @@ Target Options
The following is a list of supported EVM versions, and changes in the compiler introduced with each version. Backward compatibility is not guaranteed between each version.
-.. py:attribute:: byzantium
+.. py:attribute:: istanbul
- - The oldest EVM version supported by Vyper.
+ - The ``CHAINID`` opcode is accessible via ``chain.id``
+ - The ``SELFBALANCE`` opcode is used for calls to ``self.balance``
+ - Gas estimates changed for ``SLOAD`` and ``BALANCE``
-.. py:attribute:: constantinople
+.. py:attribute:: berlin
- - The ``EXTCODEHASH`` opcode is accessible via ``address.codehash``
- - ``shift`` makes use of ``SHL``/``SHR`` opcodes.
+ - Gas estimates changed for ``EXTCODESIZE``, ``EXTCODECOPY``, ``EXTCODEHASH``, ``SLOAD``, ``SSTORE``, ``CALL``, ``CALLCODE``, ``DELEGATECALL`` and ``STATICCALL``
+ - Functions marked with ``@nonreentrant`` are protected with different values (3 and 2) than contracts targeting pre-berlin.
+ - ``BASEFEE`` is accessible via ``block.basefee``
-.. py:attribute:: petersburg
+.. py:attribute:: paris
- - The compiler behaves the same way as with constantinople.
+ - ``block.difficulty`` is deprecated in favor of its new alias, ``block.prevrandao``.
+
+.. py:attribute:: shanghai (default)
+
+ - The ``PUSH0`` opcode is automatically generated by the compiler instead of ``PUSH1 0``
+
+.. py:attribute:: cancun (experimental)
+
+ - The ``transient`` keyword allows declaration of variables which live in transient storage
+ - Functions marked with ``@nonreentrant`` are protected with TLOAD/TSTORE instead of SLOAD/SSTORE
-.. py:attribute:: istanbul (default)
- - The ``CHAINID`` opcode is accessible via ``chain.id``
- - The ``SELFBALANCE`` opcode is used for calls to ``self.balance``
- - Gas estimates changed for ``SLOAD`` and ``BALANCE``
Compiler Input and Output JSON Description
@@ -204,10 +245,11 @@ The following example describes the expected input format of ``vyper-json``. Com
},
// Optional
"settings": {
- "evmVersion": "istanbul", // EVM version to compile for. Can be byzantium, constantinople, petersburg or istanbul.
- // optional, whether or not optimizations are turned on
- // defaults to true
- "optimize": true,
+ "evmVersion": "shanghai", // EVM version to compile for. Can be istanbul, berlin, paris, shanghai (default) or cancun (experimental!).
+ // optional, optimization mode
+ // defaults to "gas". can be one of "gas", "codesize", "none",
+ // false and true (the last two are for backwards compatibility).
+ "optimize": "gas",
// optional, whether or not the bytecode should include Vyper's signature
// defaults to true
"bytecodeMetadata": true,
diff --git a/docs/control-structures.rst b/docs/control-structures.rst
index a89f36f7cc..fc8a472ff6 100644
--- a/docs/control-structures.rst
+++ b/docs/control-structures.rst
@@ -36,8 +36,15 @@ External functions (marked with the ``@external`` decorator) are a part of the c
def add_seven(a: int128) -> int128:
return a + 7
+ @external
+ def add_seven_with_overloading(a: uint256, b: uint256 = 3):
+ return a + b
+
A Vyper contract cannot call directly between two external functions. If you must do this, you can use an :ref:`interface `.
+.. note::
+ For external functions with default arguments like ``def my_function(x: uint256, b: uint256 = 1)`` the Vyper compiler will generate ``N+1`` overloaded function selectors based on ``N`` default arguments.
+
.. _structure-functions-internal:
Internal Functions
@@ -48,13 +55,15 @@ Internal functions (marked with the ``@internal`` decorator) are only accessible
.. code-block:: python
@internal
- def _times_two(amount: uint256) -> uint256:
- return amount * 2
+ def _times_two(amount: uint256, two: uint256 = 2) -> uint256:
+ return amount * two
@external
def calculate(amount: uint256) -> uint256:
return self._times_two(amount)
+.. note::
+ Since calling an ``internal`` function is realized by jumping to its entry label, the internal function dispatcher ensures the correctness of the jumps. Please note that for ``internal`` functions which use more than one default parameter, Vyper versions ``>=0.3.8`` are strongly recommended due to the security advisory `GHSA-ph9x-4vc9-m39g `_.
Mutability
----------
diff --git a/docs/installing-vyper.rst b/docs/installing-vyper.rst
index 2e2d51bd6e..fb2849708d 100644
--- a/docs/installing-vyper.rst
+++ b/docs/installing-vyper.rst
@@ -76,9 +76,13 @@ Each tagged version of vyper is uploaded to `pypi s/`/``/g
+ to convert links to nice rst links:
+ :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/pull\/)(\d+)/(`#\2 <\1\2>`_)/g
+ ex. in: https://github.com/vyperlang/vyper/pull/3373
+ ex. out: (`#3373 `_)
+ for advisory links:
+ :'<,'>s/\v(https:\/\/github.com\/vyperlang\/vyper\/security\/advisories\/)([-A-Za-z0-9]+)/(`\2 <\1\2>`_)/g
+
+..
+ v0.3.10 ("Black Adder")
+ ***********************
+
+v0.3.10rc1
+**********
+
+Date released: 2023-09-06
+=========================
+
+v0.3.10 is a performance focused release. It adds a ``codesize`` optimization mode (`#3493 `_), adds new vyper-specific ``#pragma`` directives (`#3493 `_), uses Cancun's ``MCOPY`` opcode for some compiler generated code (`#3483 `_), and generates selector tables which now feature O(1) performance (`#3496 `_).
+
+Breaking changes:
+-----------------
+
+- add runtime code layout to initcode (`#3584 `_)
+- drop evm versions through istanbul (`#3470 `_)
+- remove vyper signature from runtime (`#3471 `_)
+
+Non-breaking changes and improvements:
+--------------------------------------
+
+- O(1) selector tables (`#3496 `_)
+- implement bound= in ranges (`#3537 `_, `#3551 `_)
+- add optimization mode to vyper compiler (`#3493 `_)
+- improve batch copy performance (`#3483 `_, `#3499 `_, `#3525 `_)
+
+Notable fixes:
+--------------
+
+- fix ``ecrecover()`` behavior when signature is invalid (`GHSA-f5x6-7qgp-jhf3 `_, `#3586 `_)
+- fix: order of evaluation for some builtins (`#3583 `_, `#3587 `_)
+- fix: pycryptodome for arm builds (`#3485 `_)
+- let params of internal functions be mutable (`#3473 `_)
+- typechecking of folded builtins in (`#3490 `_)
+- update tload/tstore opcodes per latest 1153 EIP spec (`#3484 `_)
+- fix: raw_call type when max_outsize=0 is set (`#3572 `_)
+- fix: implements check for indexed event arguments (`#3570 `_)
+
+Other docs updates, chores and fixes:
+-------------------------------------
+
+- relax restrictions on internal function signatures (`#3573 `_)
+- note on security advisory in release notes for versions ``0.2.15``, ``0.2.16``, and ``0.3.0`` (`#3553 `_)
+- fix: yanked version in release notes (`#3545 `_)
+- update release notes on yanked versions (`#3547 `_)
+- improve error message for conflicting methods IDs (`#3491 `_)
+- document epsilon builtin (`#3552 `_)
+- relax version pragma parsing (`#3511 `_)
+- fix: issue with finding installed packages in editable mode (`#3510 `_)
+- add note on security advisory for ``ecrecover`` in docs (`#3539 `_)
+- add ``asm`` option to cli help (`#3585 `_)
+- add message to error map for repeat range check (`#3542 `_)
+- fix: public constant arrays (`#3536 `_)
+
+
+v0.3.9 ("Common Adder")
+***********************
+
+Date released: 2023-05-29
+
+This is a patch release fix for v0.3.8. @bout3fiddy discovered a codesize regression for blueprint contracts in v0.3.8 which is fixed in this release. @bout3fiddy also discovered a runtime performance (gas) regression for default functions in v0.3.8 which is fixed in this release.
+
+Fixes:
+
+- initcode codesize blowup (`#3450 `_)
+- add back global calldatasize check for contracts with default fn (`#3463 `_)
+
+
+v0.3.8
+******
+
+Date released: 2023-05-23
+
+Non-breaking changes and improvements:
+
+- ``transient`` storage keyword (`#3373 `_)
+- ternary operators (`#3398 `_)
+- ``raw_revert()`` builtin (`#3136 `_)
+- shift operators (`#3019 `_)
+- make ``send()`` gas stipend configurable (`#3158 `_)
+- use new ``push0`` opcode (`#3361 `_)
+- python 3.11 support (`#3129 `_)
+- drop support for python 3.8 and 3.9 (`#3325 `_)
+- build for ``aarch64`` (`#2687 `_)
+
+Note that with the addition of ``push0`` opcode, ``shanghai`` is now the default compilation target for vyper. When deploying to a chain which does not support ``shanghai``, it is recommended to set ``--evm-version`` to ``paris``, otherwise it could result in hard-to-debug errors.
+
+Major refactoring PRs:
+
+- refactor front-end type system (`#2974 `_)
+- merge front-end and codegen type systems (`#3182 `_)
+- simplify ``GlobalContext`` (`#3209 `_)
+- remove ``FunctionSignature`` (`#3390 `_)
+
+Notable fixes:
+
+- assignment when rhs is complex type and references lhs (`#3410 `_)
+- uninitialized immutable values (`#3409 `_)
+- success value when mixing ``max_outsize=0`` and ``revert_on_failure=False`` (`GHSA-w9g2-3w7p-72g9 `_)
+- block certain kinds of storage allocator overflows (`GHSA-mgv8-gggw-mrg6 `_)
+- store-before-load when a dynarray appears on both sides of an assignment (`GHSA-3p37-3636-q8wv `_)
+- bounds check for loops of the form ``for i in range(x, x+N)`` (`GHSA-6r8q-pfpv-7cgj `_)
+- alignment of call-site posargs and kwargs for internal functions (`GHSA-ph9x-4vc9-m39g `_)
+- batch nonpayable check for default functions calldatasize < 4 (`#3104 `_, `#3408 `_, cf. `GHSA-vxmm-cwh2-q762 `_)
+
+Other docs updates, chores and fixes:
+
+- call graph stability (`#3370 `_)
+- fix ``vyper-serve`` output (`#3338 `_)
+- add ``custom:`` natspec tags (`#3403 `_)
+- add missing pc maps to ``vyper_json`` output (`#3333 `_)
+- fix constructor context for internal functions (`#3388 `_)
+- add deprecation warning for ``selfdestruct`` usage (`#3372 `_)
+- add bytecode metadata option to vyper-json (`#3117 `_)
+- fix compiler panic when a ``break`` is outside of a loop (`#3177 `_)
+- fix complex arguments to builtin functions (`#3167 `_)
+- add support for all types in ABI imports (`#3154 `_)
+- disable uadd operator (`#3174 `_)
+- block bitwise ops on decimals (`#3219 `_)
+- raise ``UNREACHABLE`` (`#3194 `_)
+- allow enum as mapping key (`#3256 `_)
+- block boolean ``not`` operator on numeric types (`#3231 `_)
+- enforce that loop's iterators are valid names (`#3242 `_)
+- fix typechecker hotspot (`#3318 `_)
+- rewrite typechecker journal to handle nested commits (`#3375 `_)
+- fix missing pc map for empty functions (`#3202 `_)
+- guard against iterating over empty list in for loop (`#3197 `_)
+- skip enum members during constant folding (`#3235 `_)
+- bitwise ``not`` constant folding (`#3222 `_)
+- allow accessing members of constant address (`#3261 `_)
+- guard against decorators in interface (`#3266 `_)
+- fix bounds for decimals in some builtins (`#3283 `_)
+- length of literal empty bytestrings (`#3276 `_)
+- block ``empty()`` for HashMaps (`#3303 `_)
+- fix type inference for empty lists (`#3377 `_)
+- disallow logging from ``pure``, ``view`` functions (`#3424 `_)
+- improve optimizer rules for comparison operators (`#3412 `_)
+- deploy to ghcr on push (`#3435 `_)
+- add note on return value bounds in interfaces (`#3205 `_)
+- index ``id`` param in ``URI`` event of ``ERC1155ownable`` (`#3203 `_)
+- add missing ``asset`` function to ``ERC4626`` built-in interface (`#3295 `_)
+- clarify ``skip_contract_check=True`` can result in undefined behavior (`#3386 `_)
+- add ``custom`` NatSpec tag to docs (`#3404 `_)
+- fix ``uint256_addmod`` doc (`#3300 `_)
+- document optional kwargs for external calls (`#3122 `_)
+- remove ``slice()`` length documentation caveats (`#3152 `_)
+- fix docs of ``blockhash`` to reflect revert behaviour (`#3168 `_)
+- improvements to compiler error messages (`#3121 `_, `#3134 `_, `#3312 `_, `#3304 `_, `#3240 `_, `#3264 `_, `#3343 `_, `#3307 `_, `#3313 `_ and `#3215 `_)
+
+These are really just the highlights, as many other bugfixes, docs updates and refactoring (over 150 pull requests!) made it into this release! For the full list, please see the `changelog `_. Special thanks to contributions from @tserg, @trocher, @z80dev, @emc415 and @benber86 in this release!
+
+New Contributors:
+
+- @omahs made their first contribution in (`#3128 `_)
+- @ObiajuluM made their first contribution in (`#3124 `_)
+- @trocher made their first contribution in (`#3134 `_)
+- @ozmium22 made their first contribution in (`#3149 `_)
+- @ToonVanHove made their first contribution in (`#3168 `_)
+- @emc415 made their first contribution in (`#3158 `_)
+- @lgtm-com made their first contribution in (`#3147 `_)
+- @tdurieux made their first contribution in (`#3224 `_)
+- @victor-ego made their first contribution in (`#3263 `_)
+- @miohtama made their first contribution in (`#3257 `_)
+- @kelvinfan001 made their first contribution in (`#2687 `_)
+
+
v0.3.7
******
Date released: 2022-09-26
-## Breaking changes:
+Breaking changes:
- chore: drop python 3.7 support (`#3071 `_)
- fix: relax check for statically sized calldata (`#3090 `_)
-## Non-breaking changes and improvements:
+Non-breaking changes and improvements:
-- fix: assert description in Crowdfund.finalize() (`#3058 `_)
+- fix: assert description in ``Crowdfund.finalize()`` (`#3058 `_)
- fix: change mutability of example ERC721 interface (`#3076 `_)
- chore: improve error message for non-checksummed address literal (`#3065 `_)
-- feat: isqrt built-in (`#3074 `_) (`#3069 `_)
-- feat: add `block.prevrandao` as alias for `block.difficulty` (`#3085 `_)
-- feat: epsilon builtin (`#3057 `_)
+- feat: ``isqrt()`` builtin (`#3074 `_) (`#3069 `_)
+- feat: add ``block.prevrandao`` as alias for ``block.difficulty`` (`#3085 `_)
+- feat: ``epsilon()`` builtin (`#3057 `_)
- feat: extend ecrecover signature to accept additional parameter types (`#3084 `_)
- feat: allow constant and immutable variables to be declared public (`#3024 `_)
- feat: optionally disable metadata in bytecode (`#3107 `_)
-## Bugfixes:
+Bugfixes:
- fix: empty nested dynamic arrays (`#3061 `_)
- fix: foldable builtin default args in imports (`#3079 `_) (`#3077 `_)
-## Additional changes and improvements:
+Additional changes and improvements:
- doc: update broken links in SECURITY.md (`#3095 `_)
- chore: update discord link in docs (`#3031 `_)
@@ -40,7 +218,7 @@ Date released: 2022-09-26
- chore: migrate lark grammar (`#3082 `_)
- chore: loosen and upgrade semantic version (`#3106 `_)
-# New Contributors
+New Contributors
- @emilianobonassi made their first contribution in `#3107 `_
- @unparalleled-js made their first contribution in `#3106 `_
@@ -65,6 +243,7 @@ Bugfixes:
v0.3.5
******
+**THIS RELEASE HAS BEEN PULLED**
Date released: 2022-08-05
@@ -213,6 +392,7 @@ Special thanks to @skellet0r for some major features in this release!
v0.3.0
*******
+⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_.
Date released: 2021-10-04
@@ -245,6 +425,7 @@ Special thanks to contributions from @skellet0r and @benjyz for this release!
v0.2.16
*******
+⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_.
Date released: 2021-08-27
@@ -269,6 +450,7 @@ Special thanks to contributions from @skellet0r, @sambacha and @milancermak for
v0.2.15
*******
+⚠️ A critical security vulnerability has been discovered in this version and we strongly recommend using version `0.3.1 `_ or higher. For more information, please see the Security Advisory `GHSA-5824-cm3x-3c38 `_.
Date released: 23-07-2021
@@ -281,6 +463,7 @@ Fixes:
v0.2.14
*******
+**THIS RELEASE HAS BEEN PULLED**
Date released: 20-07-2021
@@ -399,6 +582,7 @@ Fixes:
v0.2.6
******
+**THIS RELEASE HAS BEEN PULLED**
Date released: 10-10-2020
@@ -652,7 +836,7 @@ The following VIPs were implemented for Beta 13:
- Add ``vyper-json`` compilation mode (VIP `#1520 `_)
- Environment variables and constants can now be used as default parameters (VIP `#1525 `_)
-- Require unitialized memory be set on creation (VIP `#1493 `_)
+- Require uninitialized memory be set on creation (VIP `#1493 `_)
Some of the bug and stability fixes:
diff --git a/docs/resources.rst b/docs/resources.rst
index 295a104fcf..7f0d0600a9 100644
--- a/docs/resources.rst
+++ b/docs/resources.rst
@@ -9,8 +9,8 @@ examples, courses and other learning material.
General
-------
-- `Ape Academy - Learn how to build vyper projects by ApeWorX`__
-- `More Vyper by Example by Smart Contract Engineer`__
+- `Ape Academy - Learn how to build vyper projects `__ by ApeWorX
+- `More Vyper by Example `__ by Smart Contract Engineer
- `Vyper cheat Sheet `__
- `Vyper Hub for development `__
- `Vyper greatest hits smart contract examples `__
@@ -44,4 +44,4 @@ Unmaintained
These resources have not been updated for a while, but may still offer interesting content.
- `Awesome Vyper curated resources `__
-- `Brownie - Python framework for developing smart contracts (deprecated) `__
\ No newline at end of file
+- `Brownie - Python framework for developing smart contracts (deprecated) `__
diff --git a/docs/structure-of-a-contract.rst b/docs/structure-of-a-contract.rst
index 8eb2c1da78..d2c5d48d96 100644
--- a/docs/structure-of-a-contract.rst
+++ b/docs/structure-of-a-contract.rst
@@ -9,16 +9,47 @@ This section provides a quick overview of the types of data present within a con
.. _structure-versions:
-Version Pragma
+Pragmas
==============
-Vyper supports a version pragma to ensure that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax.
+Vyper supports several source code directives to control compiler modes and help with build reproducibility.
+
+Version Pragma
+--------------
+
+The version pragma ensures that a contract is only compiled by the intended compiler version, or range of versions. Version strings use `NPM `_ style syntax. Starting from v0.4.0 and up, version strings will use `PEP440 version specifiers _`.
+
+As of 0.3.10, the recommended way to specify the version pragma is as follows:
.. code-block:: python
- # @version ^0.2.0
+ #pragma version ^0.3.0
+
+The following declaration is equivalent, and, prior to 0.3.10, was the only supported method to specify the compiler version:
+
+.. code-block:: python
+
+ # @version ^0.3.0
+
+
+In the above examples, the contract will only compile with Vyper versions ``0.3.x``.
+
+Optimization Mode
+-----------------
+
+The optimization mode can be one of ``"none"``, ``"codesize"``, or ``"gas"`` (default). For example, adding the following line to a contract will cause it to try to optimize for codesize:
+
+.. code-block:: python
+
+ #pragma optimize codesize
+
+The optimization mode can also be set as a compiler option, which is documented in :ref:`optimization-mode`. If the compiler option conflicts with the source code pragma, an exception will be raised and compilation will not continue.
+
+EVM Version
+-----------------
+
+The EVM version can be set with the ``evm-version`` pragma, which is documented in :ref:`evm-version`.
-In the above example, the contract only compiles with Vyper versions ``0.2.x``.
.. _structure-state-variables:
diff --git a/docs/types.rst b/docs/types.rst
index 5e82c6ecba..d669e6946d 100644
--- a/docs/types.rst
+++ b/docs/types.rst
@@ -115,6 +115,23 @@ Operator Description
``x`` and ``y`` must be of the same type.
+Shifts
+^^^^^^^^^^^^^^^^^
+
+============= ======================
+Operator Description
+============= ======================
+``x << y`` Left shift
+``x >> y`` Right shift
+============= ======================
+
+Shifting is only available for 256-bit wide types. That is, ``x`` must be ``int256``, and ``y`` can be any unsigned integer. The right shift for ``int256`` compiles to a signed right shift (EVM ``SAR`` instruction).
+
+
+.. note::
+ While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``.
+
+
.. index:: ! uint, ! uintN, ! unsigned integer
Unsigned Integer (N bit)
@@ -188,6 +205,24 @@ Operator Description
.. note::
The Bitwise ``not`` operator is currently only available for ``uint256`` type.
+Shifts
+^^^^^^^^^^^^^^^^^
+
+============= ======================
+Operator Description
+============= ======================
+``x << y`` Left shift
+``x >> y`` Right shift
+============= ======================
+
+Shifting is only available for 256-bit wide types. That is, ``x`` must be ``uint256``, and ``y`` can be any unsigned integer. The right shift for ``uint256`` compiles to a signed right shift (EVM ``SHR`` instruction).
+
+
+.. note::
+ While at runtime shifts are unchecked (that is, they can be for any number of bits), to prevent common mistakes, the compiler is stricter at compile-time and will prevent out of bounds shifts. For instance, at runtime, ``1 << 257`` will evaluate to ``0``, while that expression at compile-time will raise an ``OverflowException``.
+
+
+
Decimals
--------
@@ -485,6 +520,9 @@ A two dimensional list can be declared with ``_name: _ValueType[inner_size][oute
# Returning the value in row 0 column 4 (in this case 14)
return exampleList2D[0][4]
+.. note::
+ Defining an array in storage whose size is significantly larger than ``2**64`` can result in security vulnerabilities due to risk of overflow.
+
.. index:: !dynarrays
Dynamic Arrays
@@ -526,6 +564,10 @@ Dynamic arrays represent bounded arrays whose length can be modified at runtime,
In the ABI, they are represented as ``_Type[]``. For instance, ``DynArray[int128, 3]`` gets represented as ``int128[]``, and ``DynArray[DynArray[int128, 3], 3]`` gets represented as ``int128[][]``.
+.. note::
+ Defining a dynamic array in storage whose size is significantly larger than ``2**64`` can result in security vulnerabilities due to risk of overflow.
+
+
.. _types-struct:
Structs
diff --git a/setup.cfg b/setup.cfg
index d18ffe2ac7..dd4a32a3ac 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,7 +31,6 @@ addopts = -n auto
--cov-report html
--cov-report xml
--cov=vyper
- --hypothesis-show-statistics
python_files = test_*.py
testpaths = tests
markers =
diff --git a/setup.py b/setup.py
index 0966a8e31a..40efb436c5 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@
import re
import subprocess
-from setuptools import find_packages, setup
+from setuptools import setup
extras_require = {
"test": [
@@ -14,8 +14,8 @@
"pytest-xdist>=2.5,<3.0",
"pytest-split>=0.7.0,<1.0",
"pytest-rerunfailures>=10.2,<11",
- "eth-tester[py-evm]>=0.8.0b3,<0.9",
- "py-evm>=0.6.1a2,<0.7",
+ "eth-tester[py-evm]>=0.9.0b1,<0.10",
+ "py-evm>=0.7.0a1,<0.8",
"web3==6.0.0",
"tox>=3.15,<4.0",
"lark==1.1.2",
@@ -28,7 +28,7 @@
"flake8-bugbear==20.1.4",
"flake8-use-fstring==1.1",
"isort==5.9.3",
- "mypy==0.910",
+ "mypy==0.982",
],
"docs": ["recommonmark", "sphinx>=6.0,<7.0", "sphinx_rtd_theme>=1.2,<1.3"],
"dev": ["ipython", "pre-commit", "pyinstaller", "twine"],
@@ -88,17 +88,18 @@ def _global_version(version):
license="Apache License 2.0",
keywords="ethereum evm smart contract language",
include_package_data=True,
- packages=find_packages(exclude=("tests", "docs")),
+ packages=["vyper"],
python_requires=">=3.10,<4",
py_modules=["vyper"],
install_requires=[
+ "cbor2>=5.4.6,<6",
"asttokens>=2.0.5,<3",
"pycryptodome>=3.5.1,<4",
- "semantic-version>=2.10,<3",
+ "packaging>=23.1,<24",
"importlib-metadata",
"wheel",
],
- setup_requires=["pytest-runner", "setuptools_scm"],
+ setup_requires=["pytest-runner", "setuptools_scm>=7.1.0,<8.0.0"],
tests_require=extras_require["test"],
extras_require=extras_require,
entry_points={
diff --git a/tests/ast/nodes/test_evaluate_binop_decimal.py b/tests/ast/nodes/test_evaluate_binop_decimal.py
index c6c69626b8..3c8ba0888c 100644
--- a/tests/ast/nodes/test_evaluate_binop_decimal.py
+++ b/tests/ast/nodes/test_evaluate_binop_decimal.py
@@ -13,7 +13,7 @@
@pytest.mark.fuzzing
-@settings(max_examples=50, deadline=1000)
+@settings(max_examples=50, deadline=None)
@given(left=st_decimals, right=st_decimals)
@example(left=Decimal("0.9999999999"), right=Decimal("0.0000000001"))
@example(left=Decimal("0.0000000001"), right=Decimal("0.9999999999"))
@@ -52,7 +52,7 @@ def test_binop_pow():
@pytest.mark.fuzzing
-@settings(max_examples=50, deadline=1000)
+@settings(max_examples=50, deadline=None)
@given(
values=st.lists(st_decimals, min_size=2, max_size=10),
ops=st.lists(st.sampled_from("+-*/%"), min_size=11, max_size=11),
diff --git a/tests/ast/test_metadata_journal.py b/tests/ast/test_metadata_journal.py
new file mode 100644
index 0000000000..34830409fc
--- /dev/null
+++ b/tests/ast/test_metadata_journal.py
@@ -0,0 +1,82 @@
+from vyper.ast.metadata import NodeMetadata
+from vyper.exceptions import VyperException
+
+
+def test_metadata_journal_basic():
+ m = NodeMetadata()
+
+ m["x"] = 1
+ assert m["x"] == 1
+
+
+def test_metadata_journal_commit():
+ m = NodeMetadata()
+
+ with m.enter_typechecker_speculation():
+ m["x"] = 1
+
+ assert m["x"] == 1
+
+
+def test_metadata_journal_exception():
+ m = NodeMetadata()
+
+ m["x"] = 1
+ try:
+ with m.enter_typechecker_speculation():
+ m["x"] = 2
+ m["x"] = 3
+
+ assert m["x"] == 3
+ raise VyperException("dummy exception")
+
+ except VyperException:
+ pass
+
+ # rollback upon exception
+ assert m["x"] == 1
+
+
+def test_metadata_journal_rollback_inner():
+ m = NodeMetadata()
+
+ m["x"] = 1
+ with m.enter_typechecker_speculation():
+ m["x"] = 2
+
+ try:
+ with m.enter_typechecker_speculation():
+ m["x"] = 3
+ m["x"] = 4 # test multiple writes
+
+ assert m["x"] == 4
+ raise VyperException("dummy exception")
+
+ except VyperException:
+ pass
+
+ assert m["x"] == 2
+
+
+def test_metadata_journal_rollback_outer():
+ m = NodeMetadata()
+
+ m["x"] = 1
+ try:
+ with m.enter_typechecker_speculation():
+ m["x"] = 2
+
+ with m.enter_typechecker_speculation():
+ m["x"] = 3
+ m["x"] = 4 # test multiple writes
+
+ assert m["x"] == 4
+
+ m["x"] = 5
+
+ raise VyperException("dummy exception")
+
+ except VyperException:
+ pass
+
+ assert m["x"] == 1
diff --git a/tests/ast/test_natspec.py b/tests/ast/test_natspec.py
index 2e9980b8d7..c2133468aa 100644
--- a/tests/ast/test_natspec.py
+++ b/tests/ast/test_natspec.py
@@ -24,6 +24,7 @@ def doesEat(food: String[30], qty: uint256) -> bool:
@param food The name of a food to evaluate (in English)
@param qty The number of food items to evaluate
@return True if Bugs will eat it, False otherwise
+ @custom:my-custom-tag hello, world!
'''
return True
"""
@@ -51,6 +52,7 @@ def doesEat(food: String[30], qty: uint256) -> bool:
"qty": "The number of food items to evaluate",
},
"returns": {"_0": "True if Bugs will eat it, False otherwise"},
+ "custom:my-custom-tag": "hello, world!",
}
},
"title": "A simulator for Bug Bunny, the most famous Rabbit",
diff --git a/tests/ast/test_pre_parser.py b/tests/ast/test_pre_parser.py
index 8501bb8749..5427532c16 100644
--- a/tests/ast/test_pre_parser.py
+++ b/tests/ast/test_pre_parser.py
@@ -1,6 +1,7 @@
import pytest
-from vyper.ast.pre_parser import validate_version_pragma
+from vyper.ast.pre_parser import pre_parse, validate_version_pragma
+from vyper.compiler.settings import OptimizationLevel, Settings
from vyper.exceptions import VersionException
SRC_LINE = (1, 0) # Dummy source line
@@ -20,16 +21,9 @@ def set_version(version):
"0.1.1",
">0.0.1",
"^0.1.0",
- "<=1.0.0 >=0.1.0",
- "0.1.0 - 1.0.0",
- "~0.1.0",
- "0.1",
- "0",
- "*",
- "x",
- "0.x",
- "0.1.x",
- "0.2.0 || 0.1.1",
+ "<=1.0.0,>=0.1.0",
+ # "0.1.0 - 1.0.0",
+ "~=0.1.0",
]
invalid_versions = [
"0.1.0",
@@ -43,7 +37,6 @@ def set_version(version):
"1.x",
"0.2.x",
"0.2.0 || 0.1.3",
- "==0.1.1",
"abc",
]
@@ -51,14 +44,14 @@ def set_version(version):
@pytest.mark.parametrize("file_version", valid_versions)
def test_valid_version_pragma(file_version, mock_version):
mock_version(COMPILER_VERSION)
- validate_version_pragma(f" @version {file_version}", (SRC_LINE))
+ validate_version_pragma(f"{file_version}", (SRC_LINE))
@pytest.mark.parametrize("file_version", invalid_versions)
def test_invalid_version_pragma(file_version, mock_version):
mock_version(COMPILER_VERSION)
with pytest.raises(VersionException):
- validate_version_pragma(f" @version {file_version}", (SRC_LINE))
+ validate_version_pragma(f"{file_version}", (SRC_LINE))
prerelease_valid_versions = [
@@ -69,9 +62,10 @@ def test_invalid_version_pragma(file_version, mock_version):
"<0.1.1-rc.1",
">0.1.1a1",
">0.1.1-alpha.1",
- "0.1.1a9 - 0.1.1-rc.10",
+ ">=0.1.1a9,<=0.1.1-rc.10",
"<0.1.1b8",
"<0.1.1rc1",
+ "<0.2.0",
]
prerelease_invalid_versions = [
">0.1.1-beta.9",
@@ -79,30 +73,93 @@ def test_invalid_version_pragma(file_version, mock_version):
"0.1.1b8",
"0.1.1rc2",
"0.1.1-rc.9 - 0.1.1-rc.10",
- "<0.2.0",
- pytest.param(
- "<0.1.1b1",
- marks=pytest.mark.xfail(
- reason="https://github.com/rbarrois/python-semanticversion/issues/100"
- ),
- ),
- pytest.param(
- "<0.1.1a9",
- marks=pytest.mark.xfail(
- reason="https://github.com/rbarrois/python-semanticversion/issues/100"
- ),
- ),
+ "<0.1.1b1",
+ "<0.1.1a9",
]
@pytest.mark.parametrize("file_version", prerelease_valid_versions)
def test_prerelease_valid_version_pragma(file_version, mock_version):
mock_version(PRERELEASE_COMPILER_VERSION)
- validate_version_pragma(f" @version {file_version}", (SRC_LINE))
+ validate_version_pragma(file_version, (SRC_LINE))
@pytest.mark.parametrize("file_version", prerelease_invalid_versions)
def test_prerelease_invalid_version_pragma(file_version, mock_version):
mock_version(PRERELEASE_COMPILER_VERSION)
with pytest.raises(VersionException):
- validate_version_pragma(f" @version {file_version}", (SRC_LINE))
+ validate_version_pragma(file_version, (SRC_LINE))
+
+
+pragma_examples = [
+ (
+ """
+ """,
+ Settings(),
+ ),
+ (
+ """
+ #pragma optimize codesize
+ """,
+ Settings(optimize=OptimizationLevel.CODESIZE),
+ ),
+ (
+ """
+ #pragma optimize none
+ """,
+ Settings(optimize=OptimizationLevel.NONE),
+ ),
+ (
+ """
+ #pragma optimize gas
+ """,
+ Settings(optimize=OptimizationLevel.GAS),
+ ),
+ (
+ """
+ #pragma version 0.3.10
+ """,
+ Settings(compiler_version="0.3.10"),
+ ),
+ (
+ """
+ #pragma evm-version shanghai
+ """,
+ Settings(evm_version="shanghai"),
+ ),
+ (
+ """
+ #pragma optimize codesize
+ #pragma evm-version shanghai
+ """,
+ Settings(evm_version="shanghai", optimize=OptimizationLevel.GAS),
+ ),
+ (
+ """
+ #pragma version 0.3.10
+ #pragma evm-version shanghai
+ """,
+ Settings(evm_version="shanghai", compiler_version="0.3.10"),
+ ),
+ (
+ """
+ #pragma version 0.3.10
+ #pragma optimize gas
+ """,
+ Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS),
+ ),
+ (
+ """
+ #pragma version 0.3.10
+ #pragma evm-version shanghai
+ #pragma optimize gas
+ """,
+ Settings(compiler_version="0.3.10", optimize=OptimizationLevel.GAS, evm_version="shanghai"),
+ ),
+]
+
+
+@pytest.mark.parametrize("code, expected_pragmas", pragma_examples)
+def parse_pragmas(code, expected_pragmas):
+ pragmas, _, _ = pre_parse(code)
+ assert pragmas == expected_pragmas
diff --git a/tests/base_conftest.py b/tests/base_conftest.py
index 4c3d0136bb..81e8dedc36 100644
--- a/tests/base_conftest.py
+++ b/tests/base_conftest.py
@@ -1,3 +1,5 @@
+import json
+
import pytest
import web3.exceptions
from eth_tester import EthereumTester, PyEVMBackend
@@ -10,6 +12,7 @@
from vyper import compiler
from vyper.ast.grammar import parse_vyper_source
+from vyper.compiler.settings import Settings
class VyperMethod:
@@ -31,7 +34,7 @@ def __prepared_function(self, *args, **kwargs):
if x.get("name") == self._function.function_identifier
].pop()
# To make tests faster just supply some high gas value.
- modifier_dict.update({"gas": fn_abi.get("gas", 0) + 50000})
+ modifier_dict.update({"gas": fn_abi.get("gas", 0) + 500000})
elif len(kwargs) == 1:
modifier, modifier_dict = kwargs.popitem()
if modifier not in self.ALLOWED_MODIFIERS:
@@ -109,16 +112,20 @@ def w3(tester):
return w3
-def _get_contract(w3, source_code, no_optimize, *args, **kwargs):
+def _get_contract(w3, source_code, optimize, *args, override_opt_level=None, **kwargs):
+ settings = Settings()
+ settings.evm_version = kwargs.pop("evm_version", None)
+ settings.optimize = override_opt_level or optimize
out = compiler.compile_code(
source_code,
- ["abi", "bytecode"],
+ # test that metadata gets generated
+ ["abi", "bytecode", "metadata"],
+ settings=settings,
interface_codes=kwargs.pop("interface_codes", None),
- no_optimize=no_optimize,
- evm_version=kwargs.pop("evm_version", None),
show_gas_estimates=True, # Enable gas estimates for testing
)
parse_vyper_source(source_code) # Test grammar.
+ json.dumps(out["metadata"]) # test metadata is json serializable
abi = out["abi"]
bytecode = out["bytecode"]
value = kwargs.pop("value_in_eth", 0) * 10**18 # Handle deploying with an eth value.
@@ -131,13 +138,15 @@ def _get_contract(w3, source_code, no_optimize, *args, **kwargs):
return w3.eth.contract(address, abi=abi, bytecode=bytecode, ContractFactoryClass=VyperContract)
-def _deploy_blueprint_for(w3, source_code, no_optimize, initcode_prefix=b"", **kwargs):
+def _deploy_blueprint_for(w3, source_code, optimize, initcode_prefix=b"", **kwargs):
+ settings = Settings()
+ settings.evm_version = kwargs.pop("evm_version", None)
+ settings.optimize = optimize
out = compiler.compile_code(
source_code,
["abi", "bytecode"],
interface_codes=kwargs.pop("interface_codes", None),
- no_optimize=no_optimize,
- evm_version=kwargs.pop("evm_version", None),
+ settings=settings,
show_gas_estimates=True, # Enable gas estimates for testing
)
parse_vyper_source(source_code) # Test grammar.
@@ -169,17 +178,17 @@ def factory(address):
@pytest.fixture(scope="module")
-def deploy_blueprint_for(w3, no_optimize):
+def deploy_blueprint_for(w3, optimize):
def deploy_blueprint_for(source_code, *args, **kwargs):
- return _deploy_blueprint_for(w3, source_code, no_optimize, *args, **kwargs)
+ return _deploy_blueprint_for(w3, source_code, optimize, *args, **kwargs)
return deploy_blueprint_for
@pytest.fixture(scope="module")
-def get_contract(w3, no_optimize):
+def get_contract(w3, optimize):
def get_contract(source_code, *args, **kwargs):
- return _get_contract(w3, source_code, no_optimize, *args, **kwargs)
+ return _get_contract(w3, source_code, optimize, *args, **kwargs)
return get_contract
diff --git a/tests/builtins/folding/test_bitwise.py b/tests/builtins/folding/test_bitwise.py
index 9be0b6817d..d28e482589 100644
--- a/tests/builtins/folding/test_bitwise.py
+++ b/tests/builtins/folding/test_bitwise.py
@@ -3,21 +3,27 @@
from hypothesis import strategies as st
from vyper import ast as vy_ast
-from vyper.builtins import functions as vy_fn
+from vyper.exceptions import InvalidType, OverflowException
+from vyper.semantics.analysis.utils import validate_expected_type
+from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
+from vyper.utils import unsigned_to_signed
st_uint256 = st.integers(min_value=0, max_value=2**256 - 1)
+st_sint256 = st.integers(min_value=-(2**255), max_value=2**255 - 1)
+
@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
-@given(a=st_uint256, b=st_uint256)
@pytest.mark.parametrize("op", ["&", "|", "^"])
-def test_bitwise_and_or(get_contract, a, b, op):
+@given(a=st_uint256, b=st_uint256)
+def test_bitwise_ops(get_contract, a, b, op):
source = f"""
@external
def foo(a: uint256, b: uint256) -> uint256:
return a {op} b
"""
+
contract = get_contract(source)
vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
@@ -29,35 +35,75 @@ def foo(a: uint256, b: uint256) -> uint256:
@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
-@given(value=st_uint256)
-def test_bitwise_not(get_contract, value):
- source = """
+@pytest.mark.parametrize("op", ["<<", ">>"])
+@given(a=st_uint256, b=st.integers(min_value=0, max_value=256))
+def test_bitwise_shift_unsigned(get_contract, a, b, op):
+ source = f"""
@external
-def foo(a: uint256) -> uint256:
- return ~a
+def foo(a: uint256, b: uint256) -> uint256:
+ return a {op} b
"""
contract = get_contract(source)
- vyper_ast = vy_ast.parse_to_ast(f"~{value}")
+ vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
- new_node = old_node.evaluate()
- assert contract.foo(value) == new_node.value
+ try:
+ new_node = old_node.evaluate()
+ # force bounds check, no-op because validate_numeric_bounds
+ # already does this, but leave in for hygiene (in case
+ # more types are added).
+ validate_expected_type(new_node, UINT256_T)
+ # compile time behavior does not match runtime behavior.
+ # compile-time will throw on OOB, runtime will wrap.
+ except OverflowException: # here: check the wrapped value matches runtime
+ assert op == "<<"
+ assert contract.foo(a, b) == (a << b) % (2**256)
+ else:
+ assert contract.foo(a, b) == new_node.value
+
+
+@pytest.mark.fuzzing
+@settings(max_examples=50, deadline=1000)
+@pytest.mark.parametrize("op", ["<<", ">>"])
+@given(a=st_sint256, b=st.integers(min_value=0, max_value=256))
+def test_bitwise_shift_signed(get_contract, a, b, op):
+ source = f"""
+@external
+def foo(a: int256, b: uint256) -> int256:
+ return a {op} b
+ """
+ contract = get_contract(source)
+
+ vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
+ old_node = vyper_ast.body[0].value
+
+ try:
+ new_node = old_node.evaluate()
+ validate_expected_type(new_node, INT256_T) # force bounds check
+ # compile time behavior does not match runtime behavior.
+ # compile-time will throw on OOB, runtime will wrap.
+ except (InvalidType, OverflowException):
+ # check the wrapped value matches runtime
+ assert op == "<<"
+ assert contract.foo(a, b) == unsigned_to_signed((a << b) % (2**256), 256)
+ else:
+ assert contract.foo(a, b) == new_node.value
@pytest.mark.fuzzing
@settings(max_examples=50, deadline=1000)
-@given(value=st_uint256, steps=st.integers(min_value=-256, max_value=256))
-def test_shift(get_contract, value, steps):
+@given(value=st_uint256)
+def test_bitwise_not(get_contract, value):
source = """
@external
-def foo(a: uint256, b: int128) -> uint256:
- return shift(a, b)
+def foo(a: uint256) -> uint256:
+ return ~a
"""
contract = get_contract(source)
- vyper_ast = vy_ast.parse_to_ast(f"shift({value}, {steps})")
+ vyper_ast = vy_ast.parse_to_ast(f"~{value}")
old_node = vyper_ast.body[0].value
- new_node = vy_fn.Shift().evaluate(old_node)
+ new_node = old_node.evaluate()
- assert contract.foo(value, steps) == new_node.value
+ assert contract.foo(value) == new_node.value
diff --git a/tests/cli/outputs/test_storage_layout_overrides.py b/tests/cli/outputs/test_storage_layout_overrides.py
index ff86b0cdf1..94e0faeb37 100644
--- a/tests/cli/outputs/test_storage_layout_overrides.py
+++ b/tests/cli/outputs/test_storage_layout_overrides.py
@@ -95,6 +95,21 @@ def test_simple_collision():
)
+def test_overflow():
+ code = """
+x: uint256[2]
+ """
+
+ storage_layout_override = {"x": {"slot": 2**256 - 1, "type": "uint256[2]"}}
+
+ with pytest.raises(
+ StorageLayoutException, match=f"Invalid storage slot for var x, out of bounds: {2**256}\n"
+ ):
+ compile_code(
+ code, output_formats=["layout"], storage_layout_override=storage_layout_override
+ )
+
+
def test_incomplete_overrides():
code = """
name: public(String[64])
diff --git a/tests/cli/vyper_compile/test_compile_files.py b/tests/cli/vyper_compile/test_compile_files.py
index 796976ae0e..31cf622658 100644
--- a/tests/cli/vyper_compile/test_compile_files.py
+++ b/tests/cli/vyper_compile/test_compile_files.py
@@ -28,29 +28,3 @@ def test_combined_json_keys(tmp_path):
def test_invalid_root_path():
with pytest.raises(FileNotFoundError):
compile_files([], [], root_folder="path/that/does/not/exist")
-
-
-def test_evm_versions(tmp_path):
- # should compile differently because of SELFBALANCE
- code = """
-@external
-def foo() -> uint256:
- return self.balance
-"""
-
- bar_path = tmp_path.joinpath("bar.vy")
- with bar_path.open("w") as fp:
- fp.write(code)
-
- byzantium_bytecode = compile_files(
- [bar_path], output_formats=["bytecode"], evm_version="byzantium"
- )[str(bar_path)]["bytecode"]
- istanbul_bytecode = compile_files(
- [bar_path], output_formats=["bytecode"], evm_version="istanbul"
- )[str(bar_path)]["bytecode"]
-
- assert byzantium_bytecode != istanbul_bytecode
-
- # SELFBALANCE opcode is 0x47
- assert "47" not in byzantium_bytecode
- assert "47" in istanbul_bytecode
diff --git a/tests/cli/vyper_json/test_compile_from_input_dict.py b/tests/cli/vyper_json/test_compile_from_input_dict.py
index a5a31a522b..a6d0a23100 100644
--- a/tests/cli/vyper_json/test_compile_from_input_dict.py
+++ b/tests/cli/vyper_json/test_compile_from_input_dict.py
@@ -130,12 +130,3 @@ def test_relative_import_paths():
input_json["sources"]["contracts/potato/baz/potato.vy"] = {"content": """from . import baz"""}
input_json["sources"]["contracts/potato/footato.vy"] = {"content": """from baz import baz"""}
compile_from_input_dict(input_json)
-
-
-def test_evm_version():
- # should compile differently because of SELFBALANCE
- input_json = deepcopy(INPUT_JSON)
- input_json["settings"]["evmVersion"] = "byzantium"
- compiled = compile_from_input_dict(input_json)
- input_json["settings"]["evmVersion"] = "istanbul"
- assert compiled != compile_from_input_dict(input_json)
diff --git a/tests/cli/vyper_json/test_get_settings.py b/tests/cli/vyper_json/test_get_settings.py
index ca60d2cf5a..bbe5dab113 100644
--- a/tests/cli/vyper_json/test_get_settings.py
+++ b/tests/cli/vyper_json/test_get_settings.py
@@ -3,7 +3,6 @@
import pytest
from vyper.cli.vyper_json import get_evm_version
-from vyper.evm.opcodes import DEFAULT_EVM_VERSION
from vyper.exceptions import JSONError
@@ -12,16 +11,22 @@ def test_unknown_evm():
get_evm_version({"settings": {"evmVersion": "foo"}})
-@pytest.mark.parametrize("evm_version", ["homestead", "tangerineWhistle", "spuriousDragon"])
+@pytest.mark.parametrize(
+ "evm_version",
+ [
+ "homestead",
+ "tangerineWhistle",
+ "spuriousDragon",
+ "byzantium",
+ "constantinople",
+ "petersburg",
+ ],
+)
def test_early_evm(evm_version):
with pytest.raises(JSONError):
get_evm_version({"settings": {"evmVersion": evm_version}})
-@pytest.mark.parametrize("evm_version", ["byzantium", "constantinople", "petersburg"])
+@pytest.mark.parametrize("evm_version", ["istanbul", "berlin", "paris", "shanghai", "cancun"])
def test_valid_evm(evm_version):
assert evm_version == get_evm_version({"settings": {"evmVersion": evm_version}})
-
-
-def test_default_evm():
- assert get_evm_version({}) == DEFAULT_EVM_VERSION
diff --git a/tests/cli/vyper_json/test_parse_args_vyperjson.py b/tests/cli/vyper_json/test_parse_args_vyperjson.py
index 08da5f1888..11e527843a 100644
--- a/tests/cli/vyper_json/test_parse_args_vyperjson.py
+++ b/tests/cli/vyper_json/test_parse_args_vyperjson.py
@@ -57,7 +57,7 @@ def test_to_stdout(tmp_path, capfd):
_parse_args([path.absolute().as_posix()])
out, _ = capfd.readouterr()
output_json = json.loads(out)
- assert _no_errors(output_json)
+ assert _no_errors(output_json), (INPUT_JSON, output_json)
assert "contracts/foo.vy" in output_json["sources"]
assert "contracts/bar.vy" in output_json["sources"]
@@ -71,7 +71,7 @@ def test_to_file(tmp_path):
assert output_path.exists()
with output_path.open() as fp:
output_json = json.load(fp)
- assert _no_errors(output_json)
+ assert _no_errors(output_json), (INPUT_JSON, output_json)
assert "contracts/foo.vy" in output_json["sources"]
assert "contracts/bar.vy" in output_json["sources"]
diff --git a/tests/compiler/__init__.py b/tests/compiler/__init__.py
index e69de29bb2..35a11f851b 100644
--- a/tests/compiler/__init__.py
+++ b/tests/compiler/__init__.py
@@ -0,0 +1,2 @@
+# prevent module name collision between tests/compiler/test_pre_parser.py
+# and tests/ast/test_pre_parser.py
diff --git a/tests/compiler/asm/test_asm_optimizer.py b/tests/compiler/asm/test_asm_optimizer.py
new file mode 100644
index 0000000000..47b70a8c70
--- /dev/null
+++ b/tests/compiler/asm/test_asm_optimizer.py
@@ -0,0 +1,103 @@
+import pytest
+
+from vyper.compiler.phases import CompilerData
+from vyper.compiler.settings import OptimizationLevel, Settings
+
+codes = [
+ """
+s: uint256
+
+@internal
+def ctor_only():
+ self.s = 1
+
+@internal
+def runtime_only():
+ self.s = 2
+
+@external
+def bar():
+ self.runtime_only()
+
+@external
+def __init__():
+ self.ctor_only()
+ """,
+ # code with nested function in it
+ """
+s: uint256
+
+@internal
+def runtime_only():
+ self.s = 1
+
+@internal
+def foo():
+ self.runtime_only()
+
+@internal
+def ctor_only():
+ self.s += 1
+
+@external
+def bar():
+ self.foo()
+
+@external
+def __init__():
+ self.ctor_only()
+ """,
+ # code with loop in it, these are harder for dead code eliminator
+ """
+s: uint256
+
+@internal
+def ctor_only():
+ self.s = 1
+
+@internal
+def runtime_only():
+ for i in range(10):
+ self.s += 1
+
+@external
+def bar():
+ self.runtime_only()
+
+@external
+def __init__():
+ self.ctor_only()
+ """,
+]
+
+
+@pytest.mark.parametrize("code", codes)
+def test_dead_code_eliminator(code):
+ c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.NONE))
+ initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
+ runtime_asm = c.assembly_runtime
+
+ ctor_only_label = "_sym_internal_ctor_only___"
+ runtime_only_label = "_sym_internal_runtime_only___"
+
+ # qux reachable from unoptimized initcode, foo not reachable.
+ assert ctor_only_label + "_deploy" in initcode_asm
+ assert runtime_only_label + "_deploy" not in initcode_asm
+
+ # all labels should be in unoptimized runtime asm
+ for s in (ctor_only_label, runtime_only_label):
+ assert s + "_runtime" in runtime_asm
+
+ c = CompilerData(code, settings=Settings(optimize=OptimizationLevel.GAS))
+ initcode_asm = [i for i in c.assembly if not isinstance(i, list)]
+ runtime_asm = c.assembly_runtime
+
+ # ctor only label should not be in runtime code
+ for instr in runtime_asm:
+ if isinstance(instr, str):
+ assert not instr.startswith(ctor_only_label), instr
+
+ # runtime only label should not be in initcode asm
+ for instr in initcode_asm:
+ if isinstance(instr, str):
+ assert not instr.startswith(runtime_only_label), instr
diff --git a/tests/compiler/ir/test_compile_ir.py b/tests/compiler/ir/test_compile_ir.py
index 91007da33a..706c31e0f2 100644
--- a/tests/compiler/ir/test_compile_ir.py
+++ b/tests/compiler/ir/test_compile_ir.py
@@ -68,4 +68,4 @@ def test_pc_debugger():
debugger_ir = ["seq", ["mstore", 0, 32], ["pc_debugger"]]
ir_nodes = IRnode.from_list(debugger_ir)
_, line_number_map = compile_ir.assembly_to_evm(compile_ir.compile_to_assembly(ir_nodes))
- assert line_number_map["pc_breakpoints"][0] == 5
+ assert line_number_map["pc_breakpoints"][0] == 4
diff --git a/tests/compiler/ir/test_optimize_ir.py b/tests/compiler/ir/test_optimize_ir.py
index bac0e18d65..1466166501 100644
--- a/tests/compiler/ir/test_optimize_ir.py
+++ b/tests/compiler/ir/test_optimize_ir.py
@@ -57,6 +57,14 @@
(["le", 0, "x"], [1]),
(["le", 0, ["sload", 0]], None), # no-op
(["ge", "x", 0], [1]),
+ (["le", "x", "x"], [1]),
+ (["ge", "x", "x"], [1]),
+ (["sle", "x", "x"], [1]),
+ (["sge", "x", "x"], [1]),
+ (["lt", "x", "x"], [0]),
+ (["gt", "x", "x"], [0]),
+ (["slt", "x", "x"], [0]),
+ (["sgt", "x", "x"], [0]),
# boundary conditions
(["slt", "x", -(2**255)], [0]),
(["sle", "x", -(2**255)], ["eq", "x", -(2**255)]),
@@ -135,7 +143,9 @@
(["sub", "x", 0], ["x"]),
(["sub", "x", "x"], [0]),
(["sub", ["sload", 0], ["sload", 0]], None),
- (["sub", ["callvalue"], ["callvalue"]], None),
+ (["sub", ["callvalue"], ["callvalue"]], [0]),
+ (["sub", ["msize"], ["msize"]], None),
+ (["sub", ["gas"], ["gas"]], None),
(["sub", -1, ["sload", 0]], ["not", ["sload", 0]]),
(["mul", "x", 1], ["x"]),
(["div", "x", 1], ["x"]),
@@ -202,7 +212,9 @@
(["eq", -1, ["add", -(2**255), 2**255 - 1]], [1]), # test compile-time wrapping
(["eq", -2, ["add", 2**256 - 1, 2**256 - 1]], [1]), # test compile-time wrapping
(["eq", "x", "x"], [1]),
- (["eq", "callvalue", "callvalue"], None),
+ (["eq", "gas", "gas"], None),
+ (["eq", "msize", "msize"], None),
+ (["eq", "callvalue", "callvalue"], [1]),
(["ne", "x", "x"], [0]),
]
@@ -253,3 +265,10 @@ def test_ir_optimizer(ir):
def test_static_assertions(ir, assert_compile_failed):
ir = IRnode.from_list(ir)
assert_compile_failed(lambda: optimizer.optimize(ir), StaticAssertionException)
+
+
+def test_operator_set_values():
+ # some sanity checks
+ assert optimizer.COMPARISON_OPS == {"lt", "gt", "le", "ge", "slt", "sgt", "sle", "sge"}
+ assert optimizer.STRICT_COMPARISON_OPS == {"lt", "gt", "slt", "sgt"}
+ assert optimizer.UNSTRICT_COMPARISON_OPS == {"le", "ge", "sle", "sge"}
diff --git a/tests/compiler/test_bytecode_runtime.py b/tests/compiler/test_bytecode_runtime.py
index 86eff70a50..9519b03772 100644
--- a/tests/compiler/test_bytecode_runtime.py
+++ b/tests/compiler/test_bytecode_runtime.py
@@ -1,14 +1,135 @@
-import vyper
+import cbor2
+import pytest
+import vyper
+from vyper.compiler.settings import OptimizationLevel, Settings
-def test_bytecode_runtime():
- code = """
+simple_contract_code = """
@external
def a() -> bool:
return True
- """
+"""
+
+many_functions = """
+@external
+def foo1():
+ pass
+
+@external
+def foo2():
+ pass
+
+@external
+def foo3():
+ pass
+
+@external
+def foo4():
+ pass
+
+@external
+def foo5():
+ pass
+"""
+
+has_immutables = """
+A_GOOD_PRIME: public(immutable(uint256))
+
+@external
+def __init__():
+ A_GOOD_PRIME = 967
+"""
+
+
+def _parse_cbor_metadata(initcode):
+ metadata_ofst = int.from_bytes(initcode[-2:], "big")
+ metadata = cbor2.loads(initcode[-metadata_ofst:-2])
+ return metadata
- out = vyper.compile_code(code, ["bytecode_runtime", "bytecode"])
+
+def test_bytecode_runtime():
+ out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"])
assert len(out["bytecode"]) > len(out["bytecode_runtime"])
- assert out["bytecode_runtime"][2:] in out["bytecode"][2:]
+ assert out["bytecode_runtime"].removeprefix("0x") in out["bytecode"].removeprefix("0x")
+
+
+def test_bytecode_signature():
+ out = vyper.compile_code(simple_contract_code, ["bytecode_runtime", "bytecode"])
+
+ runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
+ initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))
+
+ metadata = _parse_cbor_metadata(initcode)
+ runtime_len, data_section_lengths, immutables_len, compiler = metadata
+
+ assert runtime_len == len(runtime_code)
+ assert data_section_lengths == []
+ assert immutables_len == 0
+ assert compiler == {"vyper": list(vyper.version.version_tuple)}
+
+
+def test_bytecode_signature_dense_jumptable():
+ settings = Settings(optimize=OptimizationLevel.CODESIZE)
+
+ out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings)
+
+ runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
+ initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))
+
+ metadata = _parse_cbor_metadata(initcode)
+ runtime_len, data_section_lengths, immutables_len, compiler = metadata
+
+ assert runtime_len == len(runtime_code)
+ assert data_section_lengths == [5, 35]
+ assert immutables_len == 0
+ assert compiler == {"vyper": list(vyper.version.version_tuple)}
+
+
+def test_bytecode_signature_sparse_jumptable():
+ settings = Settings(optimize=OptimizationLevel.GAS)
+
+ out = vyper.compile_code(many_functions, ["bytecode_runtime", "bytecode"], settings=settings)
+
+ runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
+ initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))
+
+ metadata = _parse_cbor_metadata(initcode)
+ runtime_len, data_section_lengths, immutables_len, compiler = metadata
+
+ assert runtime_len == len(runtime_code)
+ assert data_section_lengths == [8]
+ assert immutables_len == 0
+ assert compiler == {"vyper": list(vyper.version.version_tuple)}
+
+
+def test_bytecode_signature_immutables():
+ out = vyper.compile_code(has_immutables, ["bytecode_runtime", "bytecode"])
+
+ runtime_code = bytes.fromhex(out["bytecode_runtime"].removeprefix("0x"))
+ initcode = bytes.fromhex(out["bytecode"].removeprefix("0x"))
+
+ metadata = _parse_cbor_metadata(initcode)
+ runtime_len, data_section_lengths, immutables_len, compiler = metadata
+
+ assert runtime_len == len(runtime_code)
+ assert data_section_lengths == []
+ assert immutables_len == 32
+ assert compiler == {"vyper": list(vyper.version.version_tuple)}
+
+
+# check that deployed bytecode actually matches the cbor metadata
+@pytest.mark.parametrize("code", [simple_contract_code, has_immutables, many_functions])
+def test_bytecode_signature_deployed(code, get_contract, w3):
+ c = get_contract(code)
+ deployed_code = w3.eth.get_code(c.address)
+
+ initcode = c._classic_contract.bytecode
+
+ metadata = _parse_cbor_metadata(initcode)
+ runtime_len, data_section_lengths, immutables_len, compiler = metadata
+
+ assert compiler == {"vyper": list(vyper.version.version_tuple)}
+
+ # runtime_len includes data sections but not immutables
+ assert len(deployed_code) == runtime_len + immutables_len
diff --git a/tests/compiler/test_default_settings.py b/tests/compiler/test_default_settings.py
new file mode 100644
index 0000000000..ca05170b61
--- /dev/null
+++ b/tests/compiler/test_default_settings.py
@@ -0,0 +1,27 @@
+from vyper.codegen import core
+from vyper.compiler.phases import CompilerData
+from vyper.compiler.settings import OptimizationLevel, _is_debug_mode
+
+
+def test_default_settings():
+ source_code = ""
+ compiler_data = CompilerData(source_code)
+ _ = compiler_data.vyper_module # force settings to be computed
+
+ assert compiler_data.settings.optimize == OptimizationLevel.GAS
+
+
+def test_default_opt_level():
+ assert OptimizationLevel.default() == OptimizationLevel.GAS
+
+
+def test_codegen_opt_level():
+ assert core._opt_level == OptimizationLevel.GAS
+ assert core._opt_gas() is True
+ assert core._opt_none() is False
+ assert core._opt_codesize() is False
+
+
+def test_debug_mode(pytestconfig):
+ debug_mode = pytestconfig.getoption("enable_compiler_debug_mode")
+ assert _is_debug_mode() == debug_mode
diff --git a/tests/compiler/test_opcodes.py b/tests/compiler/test_opcodes.py
index 67ea10c311..20f45ced6b 100644
--- a/tests/compiler/test_opcodes.py
+++ b/tests/compiler/test_opcodes.py
@@ -8,9 +8,11 @@
@pytest.fixture(params=list(opcodes.EVM_VERSIONS))
def evm_version(request):
default = opcodes.active_evm_version
- opcodes.active_evm_version = opcodes.EVM_VERSIONS[request.param]
- yield request.param
- opcodes.active_evm_version = default
+ try:
+ opcodes.active_evm_version = opcodes.EVM_VERSIONS[request.param]
+ yield request.param
+ finally:
+ opcodes.active_evm_version = default
def test_opcodes():
@@ -35,24 +37,30 @@ def test_version_check(evm_version):
assert opcodes.version_check(begin=evm_version)
assert opcodes.version_check(end=evm_version)
assert opcodes.version_check(begin=evm_version, end=evm_version)
- if evm_version not in ("byzantium", "atlantis"):
- assert not opcodes.version_check(end="byzantium")
+ if evm_version not in ("istanbul"):
+ assert not opcodes.version_check(end="istanbul")
istanbul_check = opcodes.version_check(begin="istanbul")
assert istanbul_check == (opcodes.EVM_VERSIONS[evm_version] >= opcodes.EVM_VERSIONS["istanbul"])
def test_get_opcodes(evm_version):
- op = opcodes.get_opcodes()
- if evm_version in ("paris", "berlin"):
- assert "CHAINID" in op
- assert op["SLOAD"][-1] == 2100
- elif evm_version == "istanbul":
- assert "CHAINID" in op
- assert op["SLOAD"][-1] == 800
+ ops = opcodes.get_opcodes()
+
+ assert "CHAINID" in ops
+ assert ops["CREATE2"][-1] == 32000
+
+ if evm_version in ("london", "berlin", "paris", "shanghai", "cancun"):
+ assert ops["SLOAD"][-1] == 2100
else:
- assert "CHAINID" not in op
- assert op["SLOAD"][-1] == 200
- if evm_version in ("byzantium", "atlantis"):
- assert "CREATE2" not in op
+ assert evm_version == "istanbul"
+ assert ops["SLOAD"][-1] == 800
+
+ if evm_version in ("shanghai", "cancun"):
+ assert "PUSH0" in ops
+
+ if evm_version in ("cancun",):
+ for op in ("TLOAD", "TSTORE", "MCOPY"):
+ assert op in ops
else:
- assert op["CREATE2"][-1] == 32000
+ for op in ("TLOAD", "TSTORE", "MCOPY"):
+ assert op not in ops
diff --git a/tests/compiler/test_pre_parser.py b/tests/compiler/test_pre_parser.py
index 4b747bb7d1..1761e74bad 100644
--- a/tests/compiler/test_pre_parser.py
+++ b/tests/compiler/test_pre_parser.py
@@ -1,6 +1,8 @@
-from pytest import raises
+import pytest
-from vyper.exceptions import SyntaxException
+from vyper.compiler import compile_code
+from vyper.compiler.settings import OptimizationLevel, Settings
+from vyper.exceptions import StructureException, SyntaxException
def test_semicolon_prohibited(get_contract):
@@ -10,7 +12,7 @@ def test() -> int128:
return a + b
"""
- with raises(SyntaxException):
+ with pytest.raises(SyntaxException):
get_contract(code)
@@ -70,6 +72,57 @@ def test():
assert get_contract(code)
+def test_version_pragma2(get_contract):
+ # new, `#pragma` way of doing things
+ from vyper import __version__
+
+ installed_version = ".".join(__version__.split(".")[:3])
+
+ code = f"""
+#pragma version {installed_version}
+
+@external
+def test():
+ pass
+ """
+ assert get_contract(code)
+
+
+def test_evm_version_check(assert_compile_failed):
+ code = """
+#pragma evm-version berlin
+ """
+ assert compile_code(code, settings=Settings(evm_version=None)) is not None
+ assert compile_code(code, settings=Settings(evm_version="berlin")) is not None
+ # should fail if compile options indicate different evm version
+ # from source pragma
+ with pytest.raises(StructureException):
+ compile_code(code, settings=Settings(evm_version="shanghai"))
+
+
+def test_optimization_mode_check():
+ code = """
+#pragma optimize codesize
+ """
+ assert compile_code(code, settings=Settings(optimize=None))
+ # should fail if compile options indicate different optimization mode
+ # from source pragma
+ with pytest.raises(StructureException):
+ compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS))
+ with pytest.raises(StructureException):
+ compile_code(code, settings=Settings(optimize=OptimizationLevel.NONE))
+
+
+def test_optimization_mode_check_none():
+ code = """
+#pragma optimize none
+ """
+ assert compile_code(code, settings=Settings(optimize=None))
+ # "none" conflicts with "gas"
+ with pytest.raises(StructureException):
+ compile_code(code, settings=Settings(optimize=OptimizationLevel.GAS))
+
+
def test_version_empty_version(assert_compile_failed, get_contract):
code = """
#@version
@@ -110,5 +163,5 @@ def foo():
convert(
"""
- with raises(SyntaxException):
+ with pytest.raises(SyntaxException):
get_contract(code)
diff --git a/tests/compiler/test_sha3_32.py b/tests/compiler/test_sha3_32.py
index 9fbdf6f000..e1cbf9c843 100644
--- a/tests/compiler/test_sha3_32.py
+++ b/tests/compiler/test_sha3_32.py
@@ -1,9 +1,12 @@
from vyper.codegen.ir_node import IRnode
+from vyper.evm.opcodes import version_check
from vyper.ir import compile_ir, optimizer
def test_sha3_32():
ir = ["sha3_32", 0]
evm = ["PUSH1", 0, "PUSH1", 0, "MSTORE", "PUSH1", 32, "PUSH1", 0, "SHA3"]
+ if version_check(begin="shanghai"):
+ evm = ["PUSH0", "PUSH0", "MSTORE", "PUSH1", 32, "PUSH0", "SHA3"]
assert compile_ir.compile_to_assembly(IRnode.from_list(ir)) == evm
assert compile_ir.compile_to_assembly(optimizer.optimize(IRnode.from_list(ir))) == evm
diff --git a/tests/conftest.py b/tests/conftest.py
index e1d0996767..d519ca3100 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,6 +10,7 @@
from vyper import compiler
from vyper.codegen.ir_node import IRnode
+from vyper.compiler.settings import OptimizationLevel, _set_debug_mode
from vyper.ir import compile_ir, optimizer
from .base_conftest import VyperContract, _get_contract, zero_gas_price_strategy
@@ -36,12 +37,26 @@ def set_evm_verbose_logging():
def pytest_addoption(parser):
- parser.addoption("--no-optimize", action="store_true", help="disable asm and IR optimizations")
+ parser.addoption(
+ "--optimize",
+ choices=["codesize", "gas", "none"],
+ default="gas",
+ help="change optimization mode",
+ )
+ parser.addoption("--enable-compiler-debug-mode", action="store_true")
@pytest.fixture(scope="module")
-def no_optimize(pytestconfig):
- return pytestconfig.getoption("no_optimize")
+def optimize(pytestconfig):
+ flag = pytestconfig.getoption("optimize")
+ return OptimizationLevel.from_string(flag)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def debug(pytestconfig):
+ debug = pytestconfig.getoption("enable_compiler_debug_mode")
+ assert isinstance(debug, bool)
+ _set_debug_mode(debug)
@pytest.fixture
@@ -58,13 +73,13 @@ def bytes_helper(str, length):
@pytest.fixture
-def get_contract_from_ir(w3, no_optimize):
+def get_contract_from_ir(w3, optimize):
def ir_compiler(ir, *args, **kwargs):
ir = IRnode.from_list(ir)
- if not no_optimize:
+ if optimize != OptimizationLevel.NONE:
ir = optimizer.optimize(ir)
bytecode, _ = compile_ir.assembly_to_evm(
- compile_ir.compile_to_assembly(ir, no_optimize=no_optimize)
+ compile_ir.compile_to_assembly(ir, optimize=optimize)
)
abi = kwargs.get("abi") or []
c = w3.eth.contract(abi=abi, bytecode=bytecode)
@@ -80,7 +95,7 @@ def ir_compiler(ir, *args, **kwargs):
@pytest.fixture(scope="module")
-def get_contract_module(no_optimize):
+def get_contract_module(optimize):
"""
This fixture is used for Hypothesis tests to ensure that
the same contract is called over multiple runs of the test.
@@ -93,7 +108,7 @@ def get_contract_module(no_optimize):
w3.eth.set_gas_price_strategy(zero_gas_price_strategy)
def get_contract_module(source_code, *args, **kwargs):
- return _get_contract(w3, source_code, no_optimize, *args, **kwargs)
+ return _get_contract(w3, source_code, optimize, *args, **kwargs)
return get_contract_module
@@ -138,9 +153,9 @@ def set_decorator_to_contract_function(w3, tester, contract, source_code, func):
@pytest.fixture
-def get_contract_with_gas_estimation(tester, w3, no_optimize):
+def get_contract_with_gas_estimation(tester, w3, optimize):
def get_contract_with_gas_estimation(source_code, *args, **kwargs):
- contract = _get_contract(w3, source_code, no_optimize, *args, **kwargs)
+ contract = _get_contract(w3, source_code, optimize, *args, **kwargs)
for abi_ in contract._classic_contract.functions.abi:
if abi_["type"] == "function":
set_decorator_to_contract_function(w3, tester, contract, source_code, abi_["name"])
@@ -150,9 +165,9 @@ def get_contract_with_gas_estimation(source_code, *args, **kwargs):
@pytest.fixture
-def get_contract_with_gas_estimation_for_constants(w3, no_optimize):
+def get_contract_with_gas_estimation_for_constants(w3, optimize):
def get_contract_with_gas_estimation_for_constants(source_code, *args, **kwargs):
- return _get_contract(w3, source_code, no_optimize, *args, **kwargs)
+ return _get_contract(w3, source_code, optimize, *args, **kwargs)
return get_contract_with_gas_estimation_for_constants
@@ -192,3 +207,38 @@ def _f(_addr, _salt, _initcode):
return keccak(prefix + addr + salt + keccak(initcode))[12:]
return _f
+
+
+@pytest.fixture
+def side_effects_contract(get_contract):
+ def generate(ret_type):
+ """
+ Generates a Vyper contract with an external `foo()` function, which
+ returns the specified return value of the specified return type, for
+ testing side effects using the `assert_side_effects_invoked` fixture.
+ """
+ code = f"""
+counter: public(uint256)
+
+@external
+def foo(s: {ret_type}) -> {ret_type}:
+ self.counter += 1
+ return s
+ """
+ contract = get_contract(code)
+ return contract
+
+ return generate
+
+
+@pytest.fixture
+def assert_side_effects_invoked():
+ def assert_side_effects_invoked(side_effects_contract, side_effects_trigger, n=1):
+ start_value = side_effects_contract.counter()
+
+ side_effects_trigger()
+
+ end_value = side_effects_contract.counter()
+ assert end_value == start_value + n
+
+ return assert_side_effects_invoked
diff --git a/tests/examples/factory/test_factory.py b/tests/examples/factory/test_factory.py
index 15becc05f1..0c5cf61b04 100644
--- a/tests/examples/factory/test_factory.py
+++ b/tests/examples/factory/test_factory.py
@@ -2,6 +2,7 @@
from eth_utils import keccak
import vyper
+from vyper.compiler.settings import Settings
@pytest.fixture
@@ -30,12 +31,12 @@ def create_exchange(token, factory):
@pytest.fixture
-def factory(get_contract, no_optimize):
+def factory(get_contract, optimize):
with open("examples/factory/Exchange.vy") as f:
code = f.read()
exchange_interface = vyper.compile_code(
- code, output_formats=["bytecode_runtime"], no_optimize=no_optimize
+ code, output_formats=["bytecode_runtime"], settings=Settings(optimize=optimize)
)
exchange_deployed_bytecode = exchange_interface["bytecode_runtime"]
diff --git a/tests/functional/semantics/analysis/test_cyclic_function_calls.py b/tests/functional/semantics/analysis/test_cyclic_function_calls.py
index 086f8ed08c..2a09bd5ed5 100644
--- a/tests/functional/semantics/analysis/test_cyclic_function_calls.py
+++ b/tests/functional/semantics/analysis/test_cyclic_function_calls.py
@@ -6,6 +6,18 @@
from vyper.semantics.analysis.module import ModuleAnalyzer
+def test_self_function_call(namespace):
+ code = """
+@internal
+def foo():
+ self.foo()
+ """
+ vyper_module = parse_to_ast(code)
+ with namespace.enter_scope():
+ with pytest.raises(CallViolation):
+ ModuleAnalyzer(vyper_module, {}, namespace)
+
+
def test_cyclic_function_call(namespace):
code = """
@internal
diff --git a/tests/functional/semantics/analysis/test_for_loop.py b/tests/functional/semantics/analysis/test_for_loop.py
index 13f309181f..0d61a8f8f8 100644
--- a/tests/functional/semantics/analysis/test_for_loop.py
+++ b/tests/functional/semantics/analysis/test_for_loop.py
@@ -1,7 +1,12 @@
import pytest
from vyper.ast import parse_to_ast
-from vyper.exceptions import ImmutableViolation
+from vyper.exceptions import (
+ ArgumentException,
+ ImmutableViolation,
+ StateAccessViolation,
+ TypeMismatch,
+)
from vyper.semantics.analysis import validate_semantics
@@ -59,6 +64,34 @@ def bar():
validate_semantics(vyper_module, {})
+def test_bad_keywords(namespace):
+ code = """
+
+@internal
+def bar(n: uint256):
+ x: uint256 = 0
+ for i in range(n, boundddd=10):
+ x += i
+ """
+ vyper_module = parse_to_ast(code)
+ with pytest.raises(ArgumentException):
+ validate_semantics(vyper_module, {})
+
+
+def test_bad_bound(namespace):
+ code = """
+
+@internal
+def bar(n: uint256):
+ x: uint256 = 0
+ for i in range(n, bound=n):
+ x += i
+ """
+ vyper_module = parse_to_ast(code)
+ with pytest.raises(StateAccessViolation):
+ validate_semantics(vyper_module, {})
+
+
def test_modify_iterator_function_call(namespace):
code = """
@@ -99,3 +132,44 @@ def baz():
vyper_module = parse_to_ast(code)
with pytest.raises(ImmutableViolation):
validate_semantics(vyper_module, {})
+
+
+iterator_inference_codes = [
+ """
+@external
+def main():
+ for j in range(3):
+ x: uint256 = j
+ y: uint16 = j
+ """, # GH issue 3212
+ """
+@external
+def foo():
+ for i in [1]:
+ a:uint256 = i
+ b:uint16 = i
+ """, # GH issue 3374
+ """
+@external
+def foo():
+ for i in [1]:
+ for j in [1]:
+ a:uint256 = i
+ b:uint16 = i
+ """, # GH issue 3374
+ """
+@external
+def foo():
+ for i in [1,2,3]:
+ for j in [1,2,3]:
+ b:uint256 = j + i
+ c:uint16 = i
+ """, # GH issue 3374
+]
+
+
+@pytest.mark.parametrize("code", iterator_inference_codes)
+def test_iterator_type_inference_checker(namespace, code):
+ vyper_module = parse_to_ast(code)
+ with pytest.raises(TypeMismatch):
+ validate_semantics(vyper_module, {})
diff --git a/tests/functional/semantics/test_namespace.py b/tests/functional/semantics/test_namespace.py
index 5cb942ca0d..10fca3a090 100644
--- a/tests/functional/semantics/test_namespace.py
+++ b/tests/functional/semantics/test_namespace.py
@@ -3,7 +3,7 @@
from vyper.exceptions import CompilerPanic, NamespaceCollision, UndeclaredDefinition
from vyper.semantics import environment
from vyper.semantics.namespace import get_namespace
-from vyper.semantics.types import get_types
+from vyper.semantics.types import PRIMITIVE_TYPES
def test_get_namespace():
@@ -22,13 +22,13 @@ def test_builtin_context_manager(namespace):
def test_builtin_types(namespace):
- for key, value in get_types().items():
+ for key, value in PRIMITIVE_TYPES.items():
assert namespace[key] == value
def test_builtin_types_persist_after_clear(namespace):
namespace.clear()
- for key, value in get_types().items():
+ for key, value in PRIMITIVE_TYPES.items():
assert namespace[key] == value
diff --git a/tests/functional/semantics/types/test_type_from_abi.py b/tests/functional/semantics/types/test_type_from_abi.py
index c1606d237e..be021f6904 100644
--- a/tests/functional/semantics/types/test_type_from_abi.py
+++ b/tests/functional/semantics/types/test_type_from_abi.py
@@ -1,7 +1,7 @@
import pytest
from vyper.exceptions import UnknownType
-from vyper.semantics.types import SArrayT, get_primitive_types
+from vyper.semantics.types import PRIMITIVE_TYPES, SArrayT
from vyper.semantics.types.utils import type_from_abi
BASE_TYPES = ["int128", "uint256", "bool", "address", "bytes32"]
@@ -9,7 +9,7 @@
@pytest.mark.parametrize("type_str", BASE_TYPES)
def test_base_types(type_str):
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
type_t = type_from_abi({"type": type_str})
assert base_t == type_t
@@ -17,7 +17,7 @@ def test_base_types(type_str):
@pytest.mark.parametrize("type_str", BASE_TYPES)
def test_base_types_as_arrays(type_str):
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
type_t = type_from_abi({"type": f"{type_str}[3]"})
assert type_t == SArrayT(base_t, 3)
@@ -25,7 +25,7 @@ def test_base_types_as_arrays(type_str):
@pytest.mark.parametrize("type_str", BASE_TYPES)
def test_base_types_as_multidimensional_arrays(type_str):
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
type_t = type_from_abi({"type": f"{type_str}[3][5]"})
diff --git a/tests/functional/semantics/types/test_type_from_annotation.py b/tests/functional/semantics/types/test_type_from_annotation.py
index c308c1264b..16a31cc651 100644
--- a/tests/functional/semantics/types/test_type_from_annotation.py
+++ b/tests/functional/semantics/types/test_type_from_annotation.py
@@ -6,7 +6,8 @@
StructureException,
UndeclaredDefinition,
)
-from vyper.semantics.types import HashMapT, SArrayT, get_primitive_types
+from vyper.semantics.data_locations import DataLocation
+from vyper.semantics.types import PRIMITIVE_TYPES, HashMapT, SArrayT
from vyper.semantics.types.utils import type_from_annotation
BASE_TYPES = ["int128", "uint256", "bool", "address", "bytes32"]
@@ -14,9 +15,10 @@
@pytest.mark.parametrize("type_str", BASE_TYPES)
-def test_base_types(build_node, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_base_types(build_node, type_str, location):
node = build_node(type_str)
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
ann_t = type_from_annotation(node)
@@ -24,9 +26,10 @@ def test_base_types(build_node, type_str):
@pytest.mark.parametrize("type_str", BYTESTRING_TYPES)
-def test_array_value_types(build_node, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_array_value_types(build_node, type_str, location):
node = build_node(f"{type_str}[1]")
- base_t = get_primitive_types()[type_str](1)
+ base_t = PRIMITIVE_TYPES[type_str](1)
ann_t = type_from_annotation(node)
@@ -34,9 +37,10 @@ def test_array_value_types(build_node, type_str):
@pytest.mark.parametrize("type_str", BASE_TYPES)
-def test_base_types_as_arrays(build_node, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_base_types_as_arrays(build_node, type_str, location):
node = build_node(f"{type_str}[3]")
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
ann_t = type_from_annotation(node)
@@ -44,7 +48,8 @@ def test_base_types_as_arrays(build_node, type_str):
@pytest.mark.parametrize("type_str", BYTESTRING_TYPES)
-def test_array_value_types_as_arrays(build_node, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_array_value_types_as_arrays(build_node, type_str, location):
node = build_node(f"{type_str}[1][1]")
with pytest.raises(StructureException):
@@ -52,9 +57,10 @@ def test_array_value_types_as_arrays(build_node, type_str):
@pytest.mark.parametrize("type_str", BASE_TYPES)
-def test_base_types_as_multidimensional_arrays(build_node, namespace, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_base_types_as_multidimensional_arrays(build_node, namespace, type_str, location):
node = build_node(f"{type_str}[3][5]")
- base_t = get_primitive_types()[type_str]
+ base_t = PRIMITIVE_TYPES[type_str]
ann_t = type_from_annotation(node)
@@ -63,7 +69,8 @@ def test_base_types_as_multidimensional_arrays(build_node, namespace, type_str):
@pytest.mark.parametrize("type_str", ["int128", "String"])
@pytest.mark.parametrize("idx", ["0", "-1", "0x00", "'1'", "foo", "[1]", "(1,)"])
-def test_invalid_index(build_node, idx, type_str):
+@pytest.mark.parametrize("location", iter(DataLocation))
+def test_invalid_index(build_node, idx, type_str, location):
node = build_node(f"{type_str}[{idx}]")
with pytest.raises(
(ArrayIndexException, InvalidType, StructureException, UndeclaredDefinition)
@@ -75,9 +82,9 @@ def test_invalid_index(build_node, idx, type_str):
@pytest.mark.parametrize("type_str2", BASE_TYPES)
def test_mapping(build_node, type_str, type_str2):
node = build_node(f"HashMap[{type_str}, {type_str2}]")
- types = get_primitive_types()
+ types = PRIMITIVE_TYPES
- ann_t = type_from_annotation(node)
+ ann_t = type_from_annotation(node, DataLocation.STORAGE)
k_t = types[type_str]
v_t = types[type_str2]
@@ -89,9 +96,9 @@ def test_mapping(build_node, type_str, type_str2):
@pytest.mark.parametrize("type_str2", BASE_TYPES)
def test_multidimensional_mapping(build_node, type_str, type_str2):
node = build_node(f"HashMap[{type_str}, HashMap[{type_str}, {type_str2}]]")
- types = get_primitive_types()
+ types = PRIMITIVE_TYPES
- ann_t = type_from_annotation(node)
+ ann_t = type_from_annotation(node, DataLocation.STORAGE)
k_t = types[type_str]
v_t = types[type_str2]
diff --git a/tests/functional/test_storage_slots.py b/tests/functional/test_storage_slots.py
index 4653625468..d390fe9a39 100644
--- a/tests/functional/test_storage_slots.py
+++ b/tests/functional/test_storage_slots.py
@@ -1,3 +1,7 @@
+import pytest
+
+from vyper.exceptions import StorageLayoutException
+
code = """
struct StructOne:
@@ -97,3 +101,15 @@ def test_reentrancy_lock(get_contract):
assert [c.foo(0, i) for i in range(3)] == [987, 654, 321]
assert [c.foo(1, i) for i in range(3)] == [123, 456, 789]
assert c.h(0) == 123456789
+
+
+def test_allocator_overflow(get_contract):
+ code = """
+x: uint256
+y: uint256[max_value(uint256)]
+ """
+ with pytest.raises(
+ StorageLayoutException,
+ match=f"Invalid storage slot for var y, tried to allocate slots 1 through {2**256}\n",
+ ):
+ get_contract(code)
diff --git a/tests/grammar/test_grammar.py b/tests/grammar/test_grammar.py
index 7e220b58ae..d665ca2544 100644
--- a/tests/grammar/test_grammar.py
+++ b/tests/grammar/test_grammar.py
@@ -106,5 +106,6 @@ def has_no_docstrings(c):
@hypothesis.settings(deadline=400, max_examples=500, suppress_health_check=(HealthCheck.too_slow,))
def test_grammar_bruteforce(code):
if utf8_encodable(code):
- tree = parse_to_ast(pre_parse(code + "\n")[1])
+ _, _, reformatted_code = pre_parse(code + "\n")
+ tree = parse_to_ast(reformatted_code)
assert isinstance(tree, Module)
diff --git a/tests/parser/ast_utils/test_ast_dict.py b/tests/parser/ast_utils/test_ast_dict.py
index 214af50f9f..f483d0cbe8 100644
--- a/tests/parser/ast_utils/test_ast_dict.py
+++ b/tests/parser/ast_utils/test_ast_dict.py
@@ -73,6 +73,7 @@ def test_basic_ast():
"is_constant": False,
"is_immutable": False,
"is_public": False,
+ "is_transient": False,
}
diff --git a/tests/parser/exceptions/test_instantiation_exception.py b/tests/parser/exceptions/test_instantiation_exception.py
new file mode 100644
index 0000000000..0d641f154a
--- /dev/null
+++ b/tests/parser/exceptions/test_instantiation_exception.py
@@ -0,0 +1,81 @@
+import pytest
+
+from vyper.exceptions import InstantiationException
+
+invalid_list = [
+ """
+event Foo:
+ a: uint256
+
+@external
+def foo() -> Foo:
+ return Foo(2)
+ """,
+ """
+event Foo:
+ a: uint256
+
+@external
+def foo() -> (uint256, Foo):
+ return 1, Foo(2)
+ """,
+ """
+a: HashMap[uint256, uint256]
+
+@external
+def foo() -> HashMap[uint256, uint256]:
+ return self.a
+ """,
+ """
+event Foo:
+ a: uint256
+
+@external
+def foo(x: Foo):
+ pass
+ """,
+ """
+@external
+def foo(x: HashMap[uint256, uint256]):
+ pass
+ """,
+ """
+event Foo:
+ a: uint256
+
+foo: Foo
+ """,
+ """
+event Foo:
+ a: uint256
+
+@external
+def foo():
+ f: Foo = Foo(1)
+ pass
+ """,
+ """
+event Foo:
+ a: uint256
+
+b: HashMap[uint256, Foo]
+ """,
+ """
+event Foo:
+ a: uint256
+
+b: HashMap[Foo, uint256]
+ """,
+ """
+b: immutable(HashMap[uint256, uint256])
+
+@external
+def __init__():
+ b = empty(HashMap[uint256, uint256])
+ """,
+]
+
+
+@pytest.mark.parametrize("bad_code", invalid_list)
+def test_instantiation_exception(bad_code, get_contract, assert_compile_failed):
+ assert_compile_failed(lambda: get_contract(bad_code), InstantiationException)
diff --git a/tests/parser/exceptions/test_invalid_reference.py b/tests/parser/exceptions/test_invalid_reference.py
index 3aec6028e4..fe315e5cbf 100644
--- a/tests/parser/exceptions/test_invalid_reference.py
+++ b/tests/parser/exceptions/test_invalid_reference.py
@@ -37,6 +37,24 @@ def foo():
def foo():
int128 = 5
""",
+ """
+a: public(constant(uint256)) = 1
+
+@external
+def foo():
+ b: uint256 = self.a
+ """,
+ """
+a: public(immutable(uint256))
+
+@external
+def __init__():
+ a = 123
+
+@external
+def foo():
+ b: uint256 = self.a
+ """,
]
diff --git a/tests/parser/exceptions/test_structure_exception.py b/tests/parser/exceptions/test_structure_exception.py
index 08794b75f2..97ac2b139d 100644
--- a/tests/parser/exceptions/test_structure_exception.py
+++ b/tests/parser/exceptions/test_structure_exception.py
@@ -56,9 +56,26 @@ def double_nonreentrant():
""",
"""
@external
-@nonreentrant("B")
-@nonreentrant("C")
-def double_nonreentrant():
+@nonreentrant(" ")
+def invalid_nonreentrant_key():
+ pass
+ """,
+ """
+@external
+@nonreentrant("")
+def invalid_nonreentrant_key():
+ pass
+ """,
+ """
+@external
+@nonreentrant("123")
+def invalid_nonreentrant_key():
+ pass
+ """,
+ """
+@external
+@nonreentrant("!123abcd")
+def invalid_nonreentrant_key():
pass
""",
"""
diff --git a/tests/parser/exceptions/test_syntax_exception.py b/tests/parser/exceptions/test_syntax_exception.py
index 46c98670e3..9ab9b6c677 100644
--- a/tests/parser/exceptions/test_syntax_exception.py
+++ b/tests/parser/exceptions/test_syntax_exception.py
@@ -1,7 +1,5 @@
import pytest
-from pytest import raises
-from vyper import compiler
from vyper.exceptions import SyntaxException
fail_list = [
@@ -79,10 +77,18 @@ def foo():
def foo():
x: uint256 = +1 # test UAdd ast blocked
""",
+ """
+@internal
+def f(a:uint256,/): # test posonlyargs blocked
+ return
+
+@external
+def g():
+ self.f()
+ """,
]
@pytest.mark.parametrize("bad_code", fail_list)
-def test_syntax_exception(bad_code):
- with raises(SyntaxException):
- compiler.compile_code(bad_code)
+def test_syntax_exception(assert_compile_failed, get_contract, bad_code):
+ assert_compile_failed(lambda: get_contract(bad_code), SyntaxException)
diff --git a/tests/parser/features/decorators/test_nonreentrant.py b/tests/parser/features/decorators/test_nonreentrant.py
index 0577313b88..9e74019250 100644
--- a/tests/parser/features/decorators/test_nonreentrant.py
+++ b/tests/parser/features/decorators/test_nonreentrant.py
@@ -3,6 +3,8 @@
from vyper.exceptions import FunctionDeclarationException
+# TODO test functions in this module across all evm versions
+# once we have cancun support.
def test_nonreentrant_decorator(get_contract, assert_tx_failed):
calling_contract_code = """
interface SpecialContract:
@@ -140,7 +142,7 @@ def set_callback(c: address):
@external
@payable
-@nonreentrant('default')
+@nonreentrant("lock")
def protected_function(val: String[100], do_callback: bool) -> uint256:
self.special_value = val
_amount: uint256 = msg.value
@@ -164,7 +166,7 @@ def unprotected_function(val: String[100], do_callback: bool):
@external
@payable
-@nonreentrant('default')
+@nonreentrant("lock")
def __default__():
pass
"""
diff --git a/tests/parser/features/decorators/test_payable.py b/tests/parser/features/decorators/test_payable.py
index 906ae330c0..55c60236f4 100644
--- a/tests/parser/features/decorators/test_payable.py
+++ b/tests/parser/features/decorators/test_payable.py
@@ -372,3 +372,24 @@ def __default__():
assert_tx_failed(
lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": "0x12345678"})
)
+
+
+def test_batch_nonpayable(get_contract, w3, assert_tx_failed):
+ code = """
+@external
+def foo() -> bool:
+ return True
+
+@external
+def __default__():
+ pass
+ """
+
+ c = get_contract(code)
+ w3.eth.send_transaction({"to": c.address, "value": 0, "data": "0x12345678"})
+ data = bytes([1, 2, 3, 4])
+ for i in range(5):
+ calldata = "0x" + data[:i].hex()
+ assert_tx_failed(
+ lambda: w3.eth.send_transaction({"to": c.address, "value": 100, "data": calldata})
+ )
diff --git a/tests/parser/features/decorators/test_private.py b/tests/parser/features/decorators/test_private.py
index 31822492bd..51e6d90ee1 100644
--- a/tests/parser/features/decorators/test_private.py
+++ b/tests/parser/features/decorators/test_private.py
@@ -433,7 +433,7 @@ def i_am_me() -> bool:
return msg.sender == self._whoami()
@external
-@view
+@nonpayable
def whoami() -> address:
log Addr(self._whoami())
return self._whoami()
diff --git a/tests/parser/features/iteration/test_repeater.py b/tests/parser/features/iteration/test_for_range.py
similarity index 81%
rename from tests/parser/features/iteration/test_repeater.py
rename to tests/parser/features/iteration/test_for_range.py
index 3c95882d2d..395dd28231 100644
--- a/tests/parser/features/iteration/test_repeater.py
+++ b/tests/parser/features/iteration/test_for_range.py
@@ -14,6 +14,23 @@ def repeat(z: int128) -> int128:
assert c.repeat(9) == 54
+def test_range_bound(get_contract, assert_tx_failed):
+ code = """
+@external
+def repeat(n: uint256) -> uint256:
+ x: uint256 = 0
+ for i in range(n, bound=6):
+ x += i
+ return x
+ """
+ c = get_contract(code)
+ for n in range(7):
+ assert c.repeat(n) == sum(range(n))
+
+ # check codegen inserts assertion for n greater than bound
+ assert_tx_failed(lambda: c.repeat(7))
+
+
def test_digit_reverser(get_contract_with_gas_estimation):
digit_reverser = """
@external
@@ -128,6 +145,45 @@ def foo(a: {typ}) -> {typ}:
assert c.foo(100) == 31337
+# test that we can get to the upper range of an integer
+@pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"])
+def test_for_range_edge(get_contract, typ):
+ code = f"""
+@external
+def test():
+ found: bool = False
+ x: {typ} = max_value({typ})
+ for i in range(x, x + 1):
+ if i == max_value({typ}):
+ found = True
+
+ assert found
+
+ found = False
+ x = max_value({typ}) - 1
+ for i in range(x, x + 2):
+ if i == max_value({typ}):
+ found = True
+
+ assert found
+ """
+ c = get_contract(code)
+ c.test()
+
+
+@pytest.mark.parametrize("typ", ["uint8", "int128", "uint256"])
+def test_for_range_oob_check(get_contract, assert_tx_failed, typ):
+ code = f"""
+@external
+def test():
+ x: {typ} = max_value({typ})
+ for i in range(x, x+2):
+ pass
+ """
+ c = get_contract(code)
+ assert_tx_failed(lambda: c.test())
+
+
@pytest.mark.parametrize("typ", ["int128", "uint256"])
def test_return_inside_nested_repeater(get_contract, typ):
code = f"""
diff --git a/tests/parser/features/test_assignment.py b/tests/parser/features/test_assignment.py
index 0093c9b361..583a927b41 100644
--- a/tests/parser/features/test_assignment.py
+++ b/tests/parser/features/test_assignment.py
@@ -39,7 +39,118 @@ def augmod(x: int128, y: int128) -> int128:
print("Passed aug-assignment test")
-def test_invalid_assign(assert_compile_failed, get_contract_with_gas_estimation):
+@pytest.mark.parametrize(
+ "typ,in_val,out_val",
+ [
+ ("uint256", 77, 123),
+ ("uint256[3]", [1, 2, 3], [4, 5, 6]),
+ ("DynArray[uint256, 3]", [1, 2, 3], [4, 5, 6]),
+ ("Bytes[5]", b"vyper", b"conda"),
+ ],
+)
+def test_internal_assign(get_contract_with_gas_estimation, typ, in_val, out_val):
+ code = f"""
+@internal
+def foo(x: {typ}) -> {typ}:
+ x = {out_val}
+ return x
+
+@external
+def bar(x: {typ}) -> {typ}:
+ return self.foo(x)
+ """
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.bar(in_val) == out_val
+
+
+def test_internal_assign_struct(get_contract_with_gas_estimation):
+ code = """
+enum Bar:
+ BAD
+ BAK
+ BAZ
+
+struct Foo:
+ a: uint256
+ b: DynArray[Bar, 3]
+ c: String[5]
+
+@internal
+def foo(x: Foo) -> Foo:
+ x = Foo({a: 789, b: [Bar.BAZ, Bar.BAK, Bar.BAD], c: \"conda\"})
+ return x
+
+@external
+def bar(x: Foo) -> Foo:
+ return self.foo(x)
+ """
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.bar((123, [1, 2, 4], "vyper")) == (789, [4, 2, 1], "conda")
+
+
+def test_internal_assign_struct_member(get_contract_with_gas_estimation):
+ code = """
+enum Bar:
+ BAD
+ BAK
+ BAZ
+
+struct Foo:
+ a: uint256
+ b: DynArray[Bar, 3]
+ c: String[5]
+
+@internal
+def foo(x: Foo) -> Foo:
+ x.a = 789
+ x.b.pop()
+ return x
+
+@external
+def bar(x: Foo) -> Foo:
+ return self.foo(x)
+ """
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.bar((123, [1, 2, 4], "vyper")) == (789, [1, 2], "vyper")
+
+
+def test_internal_augassign(get_contract_with_gas_estimation):
+ code = """
+@internal
+def foo(x: int128) -> int128:
+ x += 77
+ return x
+
+@external
+def bar(x: int128) -> int128:
+ return self.foo(x)
+ """
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.bar(123) == 200
+
+
+@pytest.mark.parametrize("typ", ["DynArray[uint256, 3]", "uint256[3]"])
+def test_internal_augassign_arrays(get_contract_with_gas_estimation, typ):
+ code = f"""
+@internal
+def foo(x: {typ}) -> {typ}:
+ x[1] += 77
+ return x
+
+@external
+def bar(x: {typ}) -> {typ}:
+ return self.foo(x)
+ """
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.bar([1, 2, 3]) == [1, 79, 3]
+
+
+def test_invalid_external_assign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
@@ -48,7 +159,7 @@ def foo(x: int128):
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), ImmutableViolation)
-def test_invalid_augassign(assert_compile_failed, get_contract_with_gas_estimation):
+def test_invalid_external_augassign(assert_compile_failed, get_contract_with_gas_estimation):
code = """
@external
def foo(x: int128):
@@ -255,3 +366,79 @@ def foo():
ret : bool = self.bar()
"""
assert_compile_failed(lambda: get_contract_with_gas_estimation(code), InvalidType)
+
+ # GH issue 2418
+
+
+overlap_codes = [
+ """
+@external
+def bug(xs: uint256[2]) -> uint256[2]:
+ # Initial value
+ ys: uint256[2] = xs
+ ys = [ys[1], ys[0]]
+ return ys
+ """,
+ """
+foo: uint256[2]
+@external
+def bug(xs: uint256[2]) -> uint256[2]:
+ # Initial value
+ self.foo = xs
+ self.foo = [self.foo[1], self.foo[0]]
+ return self.foo
+ """,
+ # TODO add transient tests when it's available
+]
+
+
+@pytest.mark.parametrize("code", overlap_codes)
+def test_assign_rhs_lhs_overlap(get_contract, code):
+ c = get_contract(code)
+
+ assert c.bug([1, 2]) == [2, 1]
+
+
+def test_assign_rhs_lhs_partial_overlap(get_contract):
+ # GH issue 2418, generalize when lhs is not only dependency of rhs.
+ code = """
+@external
+def bug(xs: uint256[2]) -> uint256[2]:
+ # Initial value
+ ys: uint256[2] = xs
+ ys = [xs[1], ys[0]]
+ return ys
+ """
+ c = get_contract(code)
+
+ assert c.bug([1, 2]) == [2, 1]
+
+
+def test_assign_rhs_lhs_overlap_dynarray(get_contract):
+ # GH issue 2418, generalize to dynarrays
+ code = """
+@external
+def bug(xs: DynArray[uint256, 2]) -> DynArray[uint256, 2]:
+ ys: DynArray[uint256, 2] = xs
+ ys = [ys[1], ys[0]]
+ return ys
+ """
+ c = get_contract(code)
+ assert c.bug([1, 2]) == [2, 1]
+
+
+def test_assign_rhs_lhs_overlap_struct(get_contract):
+ # GH issue 2418, generalize to structs
+ code = """
+struct Point:
+ x: uint256
+ y: uint256
+
+@external
+def bug(p: Point) -> Point:
+ t: Point = p
+ t = Point({x: t.y, y: t.x})
+ return t
+ """
+ c = get_contract(code)
+ assert c.bug((1, 2)) == (2, 1)
diff --git a/tests/parser/features/test_comparison.py b/tests/parser/features/test_comparison.py
index 1c2f287c10..5a86ffb4b8 100644
--- a/tests/parser/features/test_comparison.py
+++ b/tests/parser/features/test_comparison.py
@@ -4,7 +4,7 @@
def test_3034_verbatim(get_contract):
- # test issue #3034 exactly
+ # test GH issue 3034 exactly
code = """
@view
@external
diff --git a/tests/parser/features/test_immutable.py b/tests/parser/features/test_immutable.py
index bb01b3fc07..47f7fc748e 100644
--- a/tests/parser/features/test_immutable.py
+++ b/tests/parser/features/test_immutable.py
@@ -1,5 +1,7 @@
import pytest
+from vyper.compiler.settings import OptimizationLevel
+
@pytest.mark.parametrize(
"typ,value",
@@ -239,3 +241,140 @@ def get_immutable() -> uint256:
c = get_contract(code, n)
assert c.get_immutable() == n + 2
+
+
+# GH issue 3101
+def test_immutables_initialized(get_contract):
+ dummy_code = """
+@external
+def foo() -> uint256:
+ return 1
+ """
+ dummy_contract = get_contract(dummy_code)
+
+ code = """
+a: public(immutable(uint256))
+b: public(uint256)
+
+@payable
+@external
+def __init__(to_copy: address):
+ c: address = create_copy_of(to_copy)
+ self.b = a
+ a = 12
+ """
+ c = get_contract(code, dummy_contract.address)
+
+ assert c.b() == 0
+
+
+# GH issue 3101, take 2
+def test_immutables_initialized2(get_contract, get_contract_from_ir):
+ dummy_contract = get_contract_from_ir(
+ ["deploy", 0, ["seq"] + ["invalid"] * 600, 0], optimize=OptimizationLevel.NONE
+ )
+
+ # rekt because immutables section extends past allocated memory
+ code = """
+a0: immutable(uint256[10])
+a: public(immutable(uint256))
+b: public(uint256)
+
+@payable
+@external
+def __init__(to_copy: address):
+ c: address = create_copy_of(to_copy)
+ self.b = a
+ a = 12
+ a0 = empty(uint256[10])
+ """
+ c = get_contract(code, dummy_contract.address)
+
+ assert c.b() == 0
+
+
+# GH issue 3292
+def test_internal_functions_called_by_ctor_location(get_contract):
+ code = """
+d: uint256
+x: immutable(uint256)
+
+@external
+def __init__():
+ self.d = 1
+ x = 2
+ self.a()
+
+@external
+def test() -> uint256:
+ return self.d
+
+@internal
+def a():
+ self.d = x
+ """
+ c = get_contract(code)
+ assert c.test() == 2
+
+
+# GH issue 3292, extended to nested internal functions
+def test_nested_internal_function_immutables(get_contract):
+ code = """
+d: public(uint256)
+x: public(immutable(uint256))
+
+@external
+def __init__():
+ self.d = 1
+ x = 2
+ self.a()
+
+@internal
+def a():
+ self.b()
+
+@internal
+def b():
+ self.d = x
+ """
+ c = get_contract(code)
+ assert c.x() == 2
+ assert c.d() == 2
+
+
+# GH issue 3292, test immutable read from both ctor and runtime
+def test_immutable_read_ctor_and_runtime(get_contract):
+ code = """
+d: public(uint256)
+x: public(immutable(uint256))
+
+@external
+def __init__():
+ self.d = 1
+ x = 2
+ self.a()
+
+@internal
+def a():
+ self.d = x
+
+@external
+def thrash():
+ self.d += 5
+
+@external
+def fix():
+ self.a()
+ """
+ c = get_contract(code)
+ assert c.x() == 2
+ assert c.d() == 2
+
+ c.thrash(transact={})
+
+ assert c.x() == 2
+ assert c.d() == 2 + 5
+
+ c.fix(transact={})
+ assert c.x() == 2
+ assert c.d() == 2
diff --git a/tests/parser/features/test_init.py b/tests/parser/features/test_init.py
index feeabe311a..83bcbc95ea 100644
--- a/tests/parser/features/test_init.py
+++ b/tests/parser/features/test_init.py
@@ -53,3 +53,29 @@ def baz() -> uint8:
n = 256
assert_compile_failed(lambda: get_contract(code, n))
+
+
+# GH issue 3206
+def test_nested_internal_call_from_ctor(get_contract):
+ code = """
+x: uint256
+
+@external
+def __init__():
+ self.a()
+
+@internal
+def a():
+ self.x += 1
+ self.b()
+
+@internal
+def b():
+ self.x += 2
+
+@external
+def test() -> uint256:
+ return self.x
+ """
+ c = get_contract(code)
+ assert c.test() == 3
diff --git a/tests/parser/features/test_internal_call.py b/tests/parser/features/test_internal_call.py
index f576dc5ee5..d7a41acbc0 100644
--- a/tests/parser/features/test_internal_call.py
+++ b/tests/parser/features/test_internal_call.py
@@ -1,6 +1,9 @@
+import string
from decimal import Decimal
+import hypothesis.strategies as st
import pytest
+from hypothesis import given, settings
from vyper.compiler import compile_code
from vyper.exceptions import ArgumentException, CallViolation
@@ -642,3 +645,62 @@ def bar() -> String[6]:
c = get_contract_with_gas_estimation(contract)
assert c.bar() == "hello"
+
+
+# TODO probably want to refactor these into general test utils
+st_uint256 = st.integers(min_value=0, max_value=2**256 - 1)
+st_string65 = st.text(max_size=65, alphabet=string.printable)
+st_bytes65 = st.binary(max_size=65)
+st_sarray3 = st.lists(st_uint256, min_size=3, max_size=3)
+st_darray3 = st.lists(st_uint256, max_size=3)
+
+internal_call_kwargs_cases = [
+ ("uint256", st_uint256),
+ ("String[65]", st_string65),
+ ("Bytes[65]", st_bytes65),
+ ("uint256[3]", st_sarray3),
+ ("DynArray[uint256, 3]", st_darray3),
+]
+
+
+@pytest.mark.parametrize("typ1,strategy1", internal_call_kwargs_cases)
+@pytest.mark.parametrize("typ2,strategy2", internal_call_kwargs_cases)
+def test_internal_call_kwargs(get_contract, typ1, strategy1, typ2, strategy2):
+ # GHSA-ph9x-4vc9-m39g
+
+ @given(kwarg1=strategy1, default1=strategy1, kwarg2=strategy2, default2=strategy2)
+ @settings(deadline=None, max_examples=5) # len(cases) * len(cases) * 5 * 5
+ def fuzz(kwarg1, kwarg2, default1, default2):
+ code = f"""
+@internal
+def foo(a: {typ1} = {repr(default1)}, b: {typ2} = {repr(default2)}) -> ({typ1}, {typ2}):
+ return a, b
+
+@external
+def test0() -> ({typ1}, {typ2}):
+ return self.foo()
+
+@external
+def test1() -> ({typ1}, {typ2}):
+ return self.foo({repr(kwarg1)})
+
+@external
+def test2() -> ({typ1}, {typ2}):
+ return self.foo({repr(kwarg1)}, {repr(kwarg2)})
+
+@external
+def test3(x1: {typ1}) -> ({typ1}, {typ2}):
+ return self.foo(x1)
+
+@external
+def test4(x1: {typ1}, x2: {typ2}) -> ({typ1}, {typ2}):
+ return self.foo(x1, x2)
+ """
+ c = get_contract(code)
+ assert c.test0() == [default1, default2]
+ assert c.test1() == [kwarg1, default2]
+ assert c.test2() == [kwarg1, kwarg2]
+ assert c.test3(kwarg1) == [kwarg1, default2]
+ assert c.test4(kwarg1, kwarg2) == [kwarg1, kwarg2]
+
+ fuzz()
diff --git a/tests/parser/features/test_string_map_keys.py b/tests/parser/features/test_string_map_keys.py
new file mode 100644
index 0000000000..c52bd72821
--- /dev/null
+++ b/tests/parser/features/test_string_map_keys.py
@@ -0,0 +1,25 @@
+def test_string_map_keys(get_contract):
+ code = """
+f:HashMap[String[1], bool]
+@external
+def test() -> bool:
+ a:String[1] = "a"
+ b:String[1] = "b"
+ self.f[a] = True
+ return self.f[b] # should return False
+ """
+ c = get_contract(code)
+ c.test()
+ assert c.test() is False
+
+
+def test_string_map_keys_literals(get_contract):
+ code = """
+f:HashMap[String[1], bool]
+@external
+def test() -> bool:
+ self.f["a"] = True
+ return self.f["b"] # should return False
+ """
+ c = get_contract(code)
+ assert c.test() is False
diff --git a/tests/parser/features/test_ternary.py b/tests/parser/features/test_ternary.py
new file mode 100644
index 0000000000..c5480286c8
--- /dev/null
+++ b/tests/parser/features/test_ternary.py
@@ -0,0 +1,280 @@
+import pytest
+
+simple_cases = [
+ (
+ """
+@external
+def foo(t: bool, x: uint256, y: uint256) -> uint256:
+ return x if t else y
+ """,
+ (1, 2),
+ ),
+ ( # literal test
+ """
+@external
+def foo(_t: bool, x: uint256, y: uint256) -> uint256:
+ return x if {test} else y
+ """,
+ (1, 2),
+ ),
+ ( # literal body
+ """
+@external
+def foo(t: bool, _x: uint256, y: uint256) -> uint256:
+ return {x} if t else y
+ """,
+ (1, 2),
+ ),
+ ( # literal orelse
+ """
+@external
+def foo(t: bool, x: uint256, _y: uint256) -> uint256:
+ return x if t else {y}
+ """,
+ (1, 2),
+ ),
+ ( # literal body/orelse
+ """
+@external
+def foo(t: bool, _x: uint256, _y: uint256) -> uint256:
+ return {x} if t else {y}
+ """,
+ (1, 2),
+ ),
+ ( # literal everything
+ """
+@external
+def foo(_t: bool, _x: uint256, _y: uint256) -> uint256:
+ return {x} if {test} else {y}
+ """,
+ (1, 2),
+ ),
+ ( # body/orelse in storage and memory
+ """
+s: uint256
+@external
+def foo(t: bool, x: uint256, y: uint256) -> uint256:
+ self.s = x
+ return self.s if t else y
+ """,
+ (1, 2),
+ ),
+ ( # body/orelse in memory and storage
+ """
+s: uint256
+@external
+def foo(t: bool, x: uint256, y: uint256) -> uint256:
+ self.s = x
+ return self.s if t else y
+ """,
+ (1, 2),
+ ),
+ ( # body/orelse in memory and constant
+ """
+S: constant(uint256) = {y}
+@external
+def foo(t: bool, x: uint256, _y: uint256) -> uint256:
+ return x if t else S
+ """,
+ (1, 2),
+ ),
+ ( # dynarray
+ """
+@external
+def foo(t: bool, x: DynArray[uint256, 3], y: DynArray[uint256, 3]) -> DynArray[uint256, 3]:
+ return x if t else y
+ """,
+ ([], [1]),
+ ),
+ ( # variable + literal dynarray
+ """
+@external
+def foo(t: bool, x: DynArray[uint256, 3], _y: DynArray[uint256, 3]) -> DynArray[uint256, 3]:
+ return x if t else {y}
+ """,
+ ([], [1]),
+ ),
+ ( # literal + variable dynarray
+ """
+@external
+def foo(t: bool, _x: DynArray[uint256, 3], y: DynArray[uint256, 3]) -> DynArray[uint256, 3]:
+ return {x} if t else y
+ """,
+ ([], [1]),
+ ),
+ ( # storage dynarray
+ """
+s: DynArray[uint256, 3]
+@external
+def foo(t: bool, x: DynArray[uint256, 3], y: DynArray[uint256, 3]) -> DynArray[uint256, 3]:
+ self.s = y
+ return x if t else self.s
+ """,
+ ([], [1]),
+ ),
+ ( # static array
+ """
+@external
+def foo(t: bool, x: uint256[1], y: uint256[1]) -> uint256[1]:
+ return x if t else y
+ """,
+ ([2], [1]),
+ ),
+ ( # static array literal
+ """
+@external
+def foo(t: bool, x: uint256[1], _y: uint256[1]) -> uint256[1]:
+ return x if t else {y}
+ """,
+ ([2], [1]),
+ ),
+ ( # strings
+ """
+@external
+def foo(t: bool, x: String[10], y: String[10]) -> String[10]:
+ return x if t else y
+ """,
+ ("hello", "world"),
+ ),
+ ( # string literal
+ """
+@external
+def foo(t: bool, x: String[10], _y: String[10]) -> String[10]:
+ return x if t else {y}
+ """,
+ ("hello", "world"),
+ ),
+ ( # bytes
+ """
+@external
+def foo(t: bool, x: Bytes[10], y: Bytes[10]) -> Bytes[10]:
+ return x if t else y
+ """,
+ (b"hello", b"world"),
+ ),
+]
+
+
+@pytest.mark.parametrize("code,inputs", simple_cases)
+@pytest.mark.parametrize("test", [True, False])
+def test_ternary_simple(get_contract, code, test, inputs):
+ x, y = inputs
+ # note: repr to escape strings
+ code = code.format(test=test, x=repr(x), y=repr(y))
+ c = get_contract(code)
+ # careful with order of precedence of `assert` and `if/else` in python!
+ assert c.foo(test, x, y) == (x if test else y)
+
+
+tuple_codes = [
+ """
+@external
+def foo(t: bool, x: uint256, y: uint256) -> (uint256, uint256):
+ return (x, y) if t else (y, x)
+ """,
+ """
+s: uint256
+@external
+def foo(t: bool, x: uint256, y: uint256) -> (uint256, uint256):
+ self.s = x
+ return (self.s, y) if t else (y, self.s)
+ """,
+]
+
+
+@pytest.mark.parametrize("code", tuple_codes)
+@pytest.mark.parametrize("test", [True, False])
+def test_ternary_tuple(get_contract, code, test):
+ c = get_contract(code)
+
+ x, y = 1, 2
+ assert c.foo(test, x, y) == ([x, y] if test else [y, x])
+
+
+@pytest.mark.parametrize("test", [True, False])
+def test_ternary_immutable(get_contract, test):
+ code = """
+IMM: public(immutable(uint256))
+@external
+def __init__(test: bool):
+ IMM = 1 if test else 2
+ """
+ c = get_contract(code, test)
+
+ assert c.IMM() == (1 if test else 2)
+
+
+@pytest.mark.parametrize("test", [True, False])
+@pytest.mark.parametrize("x", list(range(8)))
+@pytest.mark.parametrize("y", list(range(8)))
+def test_complex_ternary_expression(get_contract, test, x, y):
+ code = """
+@external
+def foo(t: bool, x: uint256, y: uint256) -> uint256:
+ return (x * y) if (t and True) else (x + y + convert(t, uint256))
+ """
+ c = get_contract(code)
+
+ assert c.foo(test, x, y) == ((x * y) if (test and True) else (x + y + int(test)))
+
+
+@pytest.mark.parametrize("test", [True, False])
+@pytest.mark.parametrize("x", list(range(8)))
+@pytest.mark.parametrize("y", list(range(8)))
+def test_ternary_precedence(get_contract, test, x, y):
+ code = """
+@external
+def foo(t: bool, x: uint256, y: uint256) -> uint256:
+ return x * y if t else x + y + convert(t, uint256)
+ """
+ c = get_contract(code)
+
+ assert c.foo(test, x, y) == (x * y if test else x + y + int(test))
+
+
+@pytest.mark.parametrize("test1", [True, False])
+@pytest.mark.parametrize("test2", [True, False])
+def test_nested_ternary(get_contract, test1, test2):
+ code = """
+@external
+def foo(t1: bool, t2: bool, x: uint256, y: uint256, z: uint256) -> uint256:
+ return x if t1 else y if t2 else z
+ """
+ c = get_contract(code)
+
+ x, y, z = 1, 2, 3
+ assert c.foo(test1, test2, x, y, z) == (x if test1 else y if test2 else z)
+
+
+@pytest.mark.parametrize("test", [True, False])
+def test_ternary_side_effects(get_contract, test):
+ code = """
+track_taint_x: public(uint256)
+track_taint_y: public(uint256)
+foo_retval: public(uint256)
+
+@internal
+def x() -> uint256:
+ self.track_taint_x += 1
+ return 5
+
+@internal
+def y() -> uint256:
+ self.track_taint_y += 1
+ return 7
+
+@external
+def foo(t: bool):
+ self.foo_retval = self.x() if t else self.y()
+ """
+ c = get_contract(code)
+
+ c.foo(test, transact={})
+ assert c.foo_retval() == (5 if test else 7)
+
+ if test:
+ assert c.track_taint_x() == 1
+ assert c.track_taint_y() == 0
+ else:
+ assert c.track_taint_x() == 0
+ assert c.track_taint_y() == 1
diff --git a/tests/parser/features/test_transient.py b/tests/parser/features/test_transient.py
new file mode 100644
index 0000000000..718f5ae314
--- /dev/null
+++ b/tests/parser/features/test_transient.py
@@ -0,0 +1,62 @@
+import pytest
+
+from vyper.compiler import compile_code
+from vyper.compiler.settings import Settings
+from vyper.evm.opcodes import EVM_VERSIONS
+from vyper.exceptions import StructureException
+
+post_cancun = {k: v for k, v in EVM_VERSIONS.items() if v >= EVM_VERSIONS["cancun"]}
+
+
+@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS.keys()))
+def test_transient_blocked(evm_version):
+ # test transient is blocked on pre-cancun and compiles post-cancun
+ code = """
+my_map: transient(HashMap[address, uint256])
+ """
+ settings = Settings(evm_version=evm_version)
+ if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["cancun"]:
+ assert compile_code(code, settings=settings) is not None
+ else:
+ with pytest.raises(StructureException):
+ compile_code(code, settings=settings)
+
+
+@pytest.mark.parametrize("evm_version", list(post_cancun.keys()))
+def test_transient_compiles(evm_version):
+ # test transient keyword at least generates TLOAD/TSTORE opcodes
+ settings = Settings(evm_version=evm_version)
+ getter_code = """
+my_map: public(transient(HashMap[address, uint256]))
+ """
+ t = compile_code(getter_code, settings=settings, output_formats=["opcodes_runtime"])
+ t = t["opcodes_runtime"].split(" ")
+
+ assert "TLOAD" in t
+ assert "TSTORE" not in t
+
+ setter_code = """
+my_map: transient(HashMap[address, uint256])
+
+@external
+def setter(k: address, v: uint256):
+ self.my_map[k] = v
+ """
+ t = compile_code(setter_code, settings=settings, output_formats=["opcodes_runtime"])
+ t = t["opcodes_runtime"].split(" ")
+
+ assert "TLOAD" not in t
+ assert "TSTORE" in t
+
+ getter_setter_code = """
+my_map: public(transient(HashMap[address, uint256]))
+
+@external
+def setter(k: address, v: uint256):
+ self.my_map[k] = v
+ """
+ t = compile_code(getter_setter_code, settings=settings, output_formats=["opcodes_runtime"])
+ t = t["opcodes_runtime"].split(" ")
+
+ assert "TLOAD" in t
+ assert "TSTORE" in t
diff --git a/tests/parser/functions/test_addmod.py b/tests/parser/functions/test_addmod.py
new file mode 100644
index 0000000000..b3135660bb
--- /dev/null
+++ b/tests/parser/functions/test_addmod.py
@@ -0,0 +1,89 @@
+def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation):
+ uint256_code = """
+@external
+def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
+ return uint256_addmod(x, y, z)
+ """
+
+ c = get_contract_with_gas_estimation(uint256_code)
+
+ assert c._uint256_addmod(1, 2, 2) == 1
+ assert c._uint256_addmod(32, 2, 32) == 2
+ assert c._uint256_addmod((2**256) - 1, 0, 2) == 1
+ assert c._uint256_addmod(2**255, 2**255, 6) == 4
+ assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0))
+
+
+def test_uint256_addmod_ext_call(
+ w3, side_effects_contract, assert_side_effects_invoked, get_contract
+):
+ code = """
+@external
+def foo(f: Foo) -> uint256:
+ return uint256_addmod(32, 2, f.foo(32))
+
+interface Foo:
+ def foo(x: uint256) -> uint256: payable
+ """
+
+ c1 = side_effects_contract("uint256")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == 2
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_uint256_addmod_internal_call(get_contract_with_gas_estimation):
+ code = """
+@external
+def foo() -> uint256:
+ return uint256_addmod(self.a(), self.b(), self.c())
+
+@internal
+def a() -> uint256:
+ return 32
+
+@internal
+def b() -> uint256:
+ return 2
+
+@internal
+def c() -> uint256:
+ return 32
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo() == 2
+
+
+def test_uint256_addmod_evaluation_order(get_contract_with_gas_estimation):
+ code = """
+a: uint256
+
+@external
+def foo1() -> uint256:
+ self.a = 0
+ return uint256_addmod(self.a, 1, self.bar())
+
+@external
+def foo2() -> uint256:
+ self.a = 0
+ return uint256_addmod(self.a, self.bar(), 3)
+
+@external
+def foo3() -> uint256:
+ self.a = 0
+ return uint256_addmod(1, self.a, self.bar())
+
+@internal
+def bar() -> uint256:
+ self.a = 1
+ return 2
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo1() == 1
+ assert c.foo2() == 2
+ assert c.foo3() == 1
diff --git a/tests/parser/functions/test_as_wei_value.py b/tests/parser/functions/test_as_wei_value.py
new file mode 100644
index 0000000000..bab0aed616
--- /dev/null
+++ b/tests/parser/functions/test_as_wei_value.py
@@ -0,0 +1,31 @@
+def test_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
+ code = """
+@external
+def foo(a: Foo) -> uint256:
+ return as_wei_value(a.foo(7), "ether")
+
+interface Foo:
+ def foo(x: uint8) -> uint8: nonpayable
+ """
+
+ c1 = side_effects_contract("uint8")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == w3.to_wei(7, "ether")
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_internal_call(w3, get_contract_with_gas_estimation):
+ code = """
+@external
+def foo() -> uint256:
+ return as_wei_value(self.bar(), "ether")
+
+@internal
+def bar() -> uint8:
+ return 7
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo() == w3.to_wei(7, "ether")
diff --git a/tests/parser/functions/test_bitwise.py b/tests/parser/functions/test_bitwise.py
index 171ba7daeb..3ba74034ac 100644
--- a/tests/parser/functions/test_bitwise.py
+++ b/tests/parser/functions/test_bitwise.py
@@ -1,8 +1,8 @@
import pytest
from vyper.compiler import compile_code
-from vyper.evm.opcodes import EVM_VERSIONS
-from vyper.exceptions import InvalidLiteral, TypeMismatch
+from vyper.exceptions import InvalidLiteral, InvalidOperation, TypeMismatch
+from vyper.utils import unsigned_to_signed
code = """
@external
@@ -22,73 +22,55 @@ def _bitwise_not(x: uint256) -> uint256:
return ~x
@external
-def _shift(x: uint256, y: int128) -> uint256:
- return shift(x, y)
+def _shl(x: uint256, y: uint256) -> uint256:
+ return x << y
@external
-def _negatedShift(x: uint256, y: int128) -> uint256:
- return shift(x, -y)
+def _shr(x: uint256, y: uint256) -> uint256:
+ return x >> y
"""
-@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
-def test_bitwise_opcodes(evm_version):
- opcodes = compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"]
- if evm_version in ("byzantium", "atlantis"):
- assert "SHL" not in opcodes
- assert "SHR" not in opcodes
- else:
- assert "SHL" in opcodes
- assert "SHR" in opcodes
+def test_bitwise_opcodes():
+ opcodes = compile_code(code, ["opcodes"])["opcodes"]
+ assert "SHL" in opcodes
+ assert "SHR" in opcodes
-@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
-def test_test_bitwise(get_contract_with_gas_estimation, evm_version):
- c = get_contract_with_gas_estimation(code, evm_version=evm_version)
+def test_test_bitwise(get_contract_with_gas_estimation):
+ c = get_contract_with_gas_estimation(code)
x = 126416208461208640982146408124
y = 7128468721412412459
assert c._bitwise_and(x, y) == (x & y)
assert c._bitwise_or(x, y) == (x | y)
assert c._bitwise_xor(x, y) == (x ^ y)
assert c._bitwise_not(x) == 2**256 - 1 - x
- assert c._shift(x, 3) == x * 8
- assert c._shift(x, 255) == 0
- assert c._shift(y, 255) == 2**255
- assert c._shift(x, 256) == 0
- assert c._shift(x, 0) == x
- assert c._shift(x, -1) == x // 2
- assert c._shift(x, -3) == x // 8
- assert c._shift(x, -256) == 0
- assert c._negatedShift(x, -3) == x * 8
- assert c._negatedShift(x, -255) == 0
- assert c._negatedShift(y, -255) == 2**255
- assert c._negatedShift(x, -256) == 0
- assert c._negatedShift(x, -0) == x
- assert c._negatedShift(x, 1) == x // 2
- assert c._negatedShift(x, 3) == x // 8
- assert c._negatedShift(x, 256) == 0
-
-
-POST_BYZANTIUM = [k for (k, v) in EVM_VERSIONS.items() if v > 0]
-
-
-@pytest.mark.parametrize("evm_version", POST_BYZANTIUM)
-def test_signed_shift(get_contract_with_gas_estimation, evm_version):
+
+ for t in (x, y):
+ for s in (0, 1, 3, 255, 256):
+ assert c._shr(t, s) == t >> s
+ assert c._shl(t, s) == (t << s) % (2**256)
+
+
+def test_signed_shift(get_contract_with_gas_estimation):
code = """
@external
-def _signedShift(x: int256, y: int128) -> int256:
- return shift(x, y)
+def _sar(x: int256, y: uint256) -> int256:
+ return x >> y
+
+@external
+def _shl(x: int256, y: uint256) -> int256:
+ return x << y
"""
- c = get_contract_with_gas_estimation(code, evm_version=evm_version)
+ c = get_contract_with_gas_estimation(code)
x = 126416208461208640982146408124
y = 7128468721412412459
cases = [x, y, -x, -y]
for t in cases:
- assert c._signedShift(t, 0) == t >> 0
- assert c._signedShift(t, -1) == t >> 1
- assert c._signedShift(t, -3) == t >> 3
- assert c._signedShift(t, -256) == t >> 256
+ for s in (0, 1, 3, 255, 256):
+ assert c._sar(t, s) == t >> s
+ assert c._shl(t, s) == unsigned_to_signed((t << s) % (2**256), 256)
def test_precedence(get_contract):
@@ -111,45 +93,74 @@ def baz(a: uint256, b: uint256, c: uint256) -> (uint256, uint256):
assert tuple(c.baz(1, 6, 14)) == (1 + 8 | ~6 & 14 * 2, (1 + 8 | ~6) & 14 * 2) == (25, 24)
-@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
-def test_literals(get_contract, evm_version):
+def test_literals(get_contract):
code = """
@external
-def left(x: uint256) -> uint256:
- return shift(x, -3)
+def _shr(x: uint256) -> uint256:
+ return x >> 3
@external
-def right(x: uint256) -> uint256:
- return shift(x, 3)
+def _shl(x: uint256) -> uint256:
+ return x << 3
"""
- c = get_contract(code, evm_version=evm_version)
- assert c.left(80) == 10
- assert c.right(80) == 640
+ c = get_contract(code)
+ assert c._shr(80) == 10
+ assert c._shl(80) == 640
fail_list = [
(
+ # cannot shift non-uint256/int256 argument
"""
@external
-def foo(x: uint8, y: int128) -> uint256:
- return shift(x, y)
+def foo(x: uint8, y: uint8) -> uint8:
+ return x << y
+ """,
+ InvalidOperation,
+ ),
+ (
+ # cannot shift non-uint256/int256 argument
+ """
+@external
+def foo(x: int8, y: uint8) -> int8:
+ return x << y
+ """,
+ InvalidOperation,
+ ),
+ (
+ # cannot shift by non-uint bits
+ """
+@external
+def foo(x: uint256, y: int128) -> uint256:
+ return x << y
""",
TypeMismatch,
),
(
+ # cannot left shift by more than 256 bits
+ """
+@external
+def foo() -> uint256:
+ return 2 << 257
+ """,
+ InvalidLiteral,
+ ),
+ (
+ # cannot shift by negative amount
"""
@external
def foo() -> uint256:
- return shift(2, 257)
+ return 2 << -1
""",
InvalidLiteral,
),
(
+ # cannot shift by negative amount
"""
@external
def foo() -> uint256:
- return shift(2, -257)
+ return 2 << -1
""",
InvalidLiteral,
),
diff --git a/tests/parser/functions/test_ceil.py b/tests/parser/functions/test_ceil.py
index a9bcf62da2..daa9cb7c1b 100644
--- a/tests/parser/functions/test_ceil.py
+++ b/tests/parser/functions/test_ceil.py
@@ -104,3 +104,37 @@ def ceil_param(p: decimal) -> int256:
assert c.fou() == -3
assert c.ceil_param(Decimal("-0.5")) == 0
assert c.ceil_param(Decimal("-7777777.7777777")) == -7777777
+
+
+def test_ceil_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
+ code = """
+@external
+def foo(a: Foo) -> int256:
+ return ceil(a.foo(2.5))
+
+interface Foo:
+ def foo(x: decimal) -> decimal: payable
+ """
+
+ c1 = side_effects_contract("decimal")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == 3
+
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_ceil_internal_call(get_contract_with_gas_estimation):
+ code = """
+@external
+def foo() -> int256:
+ return ceil(self.bar())
+
+@internal
+def bar() -> decimal:
+ return 2.5
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo() == 3
diff --git a/tests/parser/functions/test_convert.py b/tests/parser/functions/test_convert.py
index 89fe18de31..b5ce613235 100644
--- a/tests/parser/functions/test_convert.py
+++ b/tests/parser/functions/test_convert.py
@@ -22,7 +22,7 @@
BASE_TYPES = set(IntegerT.all()) | set(BytesM_T.all()) | {DecimalT(), AddressT(), BoolT()}
-TEST_TYPES = BASE_TYPES | {BytesT(32)}
+TEST_TYPES = BASE_TYPES | {BytesT(32)} | {StringT(32)}
ZERO_ADDRESS = "0x0000000000000000000000000000000000000000"
@@ -163,6 +163,17 @@ def _cases_for_Bytes(typ):
# would not need this if we tested all Bytes[1]...Bytes[32] types.
for i in range(32):
ret.extend(_cases_for_bytes(BytesM_T(i + 1)))
+
+ ret.append(b"")
+ return uniq(ret)
+
+
+def _cases_for_String(typ):
+ ret = []
+ # would not need this if we tested all Bytes[1]...Bytes[32] types.
+ for i in range(32):
+ ret.extend([str(c, "utf-8") for c in _cases_for_bytes(BytesM_T(i + 1))])
+ ret.append("")
return uniq(ret)
@@ -176,6 +187,8 @@ def interesting_cases_for_type(typ):
return _cases_for_bytes(typ)
if isinstance(typ, BytesT):
return _cases_for_Bytes(typ)
+ if isinstance(typ, StringT):
+ return _cases_for_String(typ)
if isinstance(typ, BoolT):
return _cases_for_bool(typ)
if isinstance(typ, AddressT):
diff --git a/tests/parser/functions/test_create_functions.py b/tests/parser/functions/test_create_functions.py
index 857173df7e..fa7729d98e 100644
--- a/tests/parser/functions/test_create_functions.py
+++ b/tests/parser/functions/test_create_functions.py
@@ -3,6 +3,9 @@
from eth.codecs import abi
from hexbytes import HexBytes
+import vyper.ir.compile_ir as compile_ir
+from vyper.codegen.ir_node import IRnode
+from vyper.compiler.settings import OptimizationLevel
from vyper.utils import EIP_170_LIMIT, checksum_encode, keccak256
@@ -224,15 +227,25 @@ def test(code_ofst: uint256) -> address:
return create_from_blueprint(BLUEPRINT, code_offset=code_ofst)
"""
- # use a bunch of JUMPDEST + STOP instructions as blueprint code
- # (as any STOP instruction returns valid code, split up with
- # jumpdests as optimization fence)
initcode_len = 100
- f = get_contract_from_ir(["deploy", 0, ["seq"] + ["jumpdest", "stop"] * (initcode_len // 2), 0])
- blueprint_code = w3.eth.get_code(f.address)
- print(blueprint_code)
- d = get_contract(deployer_code, f.address)
+ # deploy a blueprint contract whose contained initcode contains only
+ # zeroes (so no matter which offset, create_from_blueprint will
+ # return empty code)
+ ir = IRnode.from_list(["deploy", 0, ["seq"] + ["stop"] * initcode_len, 0])
+ bytecode, _ = compile_ir.assembly_to_evm(
+ compile_ir.compile_to_assembly(ir, optimize=OptimizationLevel.NONE)
+ )
+ # manually deploy the bytecode
+ c = w3.eth.contract(abi=[], bytecode=bytecode)
+ deploy_transaction = c.constructor()
+ tx_info = {"from": w3.eth.accounts[0], "value": 0, "gasPrice": 0}
+ tx_hash = deploy_transaction.transact(tx_info)
+ blueprint_address = w3.eth.get_transaction_receipt(tx_hash)["contractAddress"]
+ blueprint_code = w3.eth.get_code(blueprint_address)
+ print("BLUEPRINT CODE:", blueprint_code)
+
+ d = get_contract(deployer_code, blueprint_address)
# deploy with code_ofst=0 fine
d.test(0)
@@ -418,3 +431,212 @@ def test2(target: address, salt: bytes32) -> address:
# test2 = c.test2(b"\x01", salt)
# assert HexBytes(test2) == create2_address_of(c.address, salt, vyper_initcode(b"\x01"))
# assert_tx_failed(lambda: c.test2(bytecode, salt))
+
+
+# XXX: these various tests to check the msize allocator for
+# create_copy_of and create_from_blueprint depend on calling convention
+# and variables writing to memory. think of ways to make more robust to
+# changes in calling convention and memory layout
+@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
+def test_create_from_blueprint_complex_value(
+ get_contract, deploy_blueprint_for, w3, blueprint_prefix
+):
+ # check msize allocator does not get trampled by value= kwarg
+ code = """
+var: uint256
+
+@external
+@payable
+def __init__(x: uint256):
+ self.var = x
+
+@external
+def foo()-> uint256:
+ return self.var
+ """
+
+ prefix_len = len(blueprint_prefix)
+
+ some_constant = b"\00" * 31 + b"\x0c"
+
+ deployer_code = f"""
+created_address: public(address)
+x: constant(Bytes[32]) = {some_constant}
+
+@internal
+def foo() -> uint256:
+ g:uint256 = 42
+ return 3
+
+@external
+@payable
+def test(target: address):
+ self.created_address = create_from_blueprint(
+ target,
+ x,
+ code_offset={prefix_len},
+ value=self.foo(),
+ raw_args=True
+ )
+ """
+
+ foo_contract = get_contract(code, 12)
+ expected_runtime_code = w3.eth.get_code(foo_contract.address)
+
+ f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)
+
+ d = get_contract(deployer_code)
+
+ d.test(f.address, transact={"value": 3})
+
+ test = FooContract(d.created_address())
+ assert w3.eth.get_code(test.address) == expected_runtime_code
+ assert test.foo() == 12
+
+
+@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
+def test_create_from_blueprint_complex_salt_raw_args(
+ get_contract, deploy_blueprint_for, w3, blueprint_prefix
+):
+ # test msize allocator does not get trampled by salt= kwarg
+ code = """
+var: uint256
+
+@external
+@payable
+def __init__(x: uint256):
+ self.var = x
+
+@external
+def foo()-> uint256:
+ return self.var
+ """
+
+ some_constant = b"\00" * 31 + b"\x0c"
+ prefix_len = len(blueprint_prefix)
+
+ deployer_code = f"""
+created_address: public(address)
+
+x: constant(Bytes[32]) = {some_constant}
+salt: constant(bytes32) = keccak256("kebab")
+
+@internal
+def foo() -> bytes32:
+ g:uint256 = 42
+ return salt
+
+@external
+@payable
+def test(target: address):
+ self.created_address = create_from_blueprint(
+ target,
+ x,
+ code_offset={prefix_len},
+ salt=self.foo(),
+ raw_args= True
+ )
+ """
+
+ foo_contract = get_contract(code, 12)
+ expected_runtime_code = w3.eth.get_code(foo_contract.address)
+
+ f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)
+
+ d = get_contract(deployer_code)
+
+ d.test(f.address, transact={})
+
+ test = FooContract(d.created_address())
+ assert w3.eth.get_code(test.address) == expected_runtime_code
+ assert test.foo() == 12
+
+
+@pytest.mark.parametrize("blueprint_prefix", [b"", b"\xfe", b"\xfe\71\x00"])
+def test_create_from_blueprint_complex_salt_no_constructor_args(
+ get_contract, deploy_blueprint_for, w3, blueprint_prefix
+):
+ # test msize allocator does not get trampled by salt= kwarg
+ code = """
+var: uint256
+
+@external
+@payable
+def __init__():
+ self.var = 12
+
+@external
+def foo()-> uint256:
+ return self.var
+ """
+
+ prefix_len = len(blueprint_prefix)
+ deployer_code = f"""
+created_address: public(address)
+
+salt: constant(bytes32) = keccak256("kebab")
+
+@external
+@payable
+def test(target: address):
+ self.created_address = create_from_blueprint(
+ target,
+ code_offset={prefix_len},
+ salt=keccak256(_abi_encode(target))
+ )
+ """
+
+ foo_contract = get_contract(code)
+ expected_runtime_code = w3.eth.get_code(foo_contract.address)
+
+ f, FooContract = deploy_blueprint_for(code, initcode_prefix=blueprint_prefix)
+
+ d = get_contract(deployer_code)
+
+ d.test(f.address, transact={})
+
+ test = FooContract(d.created_address())
+ assert w3.eth.get_code(test.address) == expected_runtime_code
+ assert test.foo() == 12
+
+
+def test_create_copy_of_complex_kwargs(get_contract, w3):
+ # test msize allocator does not get trampled by salt= kwarg
+ complex_salt = """
+created_address: public(address)
+
+@external
+def test(target: address) -> address:
+ self.created_address = create_copy_of(
+ target,
+ salt=keccak256(_abi_encode(target))
+ )
+ return self.created_address
+
+ """
+
+ c = get_contract(complex_salt)
+ bytecode = w3.eth.get_code(c.address)
+ c.test(c.address, transact={})
+ test1 = c.created_address()
+ assert w3.eth.get_code(test1) == bytecode
+
+ # test msize allocator does not get trampled by value= kwarg
+ complex_value = """
+created_address: public(address)
+
+@external
+@payable
+def test(target: address) -> address:
+ value: uint256 = 2
+ self.created_address = create_copy_of(target, value = [2,2,2][value])
+ return self.created_address
+
+ """
+
+ c = get_contract(complex_value)
+ bytecode = w3.eth.get_code(c.address)
+
+ c.test(c.address, transact={"value": 2})
+ test1 = c.created_address()
+ assert w3.eth.get_code(test1) == bytecode
diff --git a/tests/parser/functions/test_default_function.py b/tests/parser/functions/test_default_function.py
index 13ddb7ee1b..4ad68697ac 100644
--- a/tests/parser/functions/test_default_function.py
+++ b/tests/parser/functions/test_default_function.py
@@ -100,7 +100,9 @@ def __default__():
assert_compile_failed(lambda: get_contract_with_gas_estimation(code))
-def test_zero_method_id(w3, get_logs, get_contract_with_gas_estimation):
+def test_zero_method_id(w3, get_logs, get_contract, assert_tx_failed):
+ # test a method with 0x00000000 selector,
+ # expects at least 36 bytes of calldata.
code = """
event Sent:
sig: uint256
@@ -116,18 +118,108 @@ def blockHashAskewLimitary(v: uint256) -> uint256:
def __default__():
log Sent(1)
"""
-
- c = get_contract_with_gas_estimation(code)
+ c = get_contract(code)
assert c.blockHashAskewLimitary(0) == 7
- logs = get_logs(w3.eth.send_transaction({"to": c.address, "value": 0}), c, "Sent")
- assert 1 == logs[0].args.sig
+ def _call_with_bytes(hexstr):
+ # call our special contract and return the logged value
+ logs = get_logs(
+ w3.eth.send_transaction({"to": c.address, "value": 0, "data": hexstr}), c, "Sent"
+ )
+ return logs[0].args.sig
- logs = get_logs(
- # call blockHashAskewLimitary
- w3.eth.send_transaction({"to": c.address, "value": 0, "data": "0x" + "00" * 36}),
- c,
- "Sent",
- )
- assert 2 == logs[0].args.sig
+ assert 1 == _call_with_bytes("0x")
+
+ # call blockHashAskewLimitary with proper calldata
+ assert 2 == _call_with_bytes("0x" + "00" * 36)
+
+ # call blockHashAskewLimitary with extra trailing bytes in calldata
+ assert 2 == _call_with_bytes("0x" + "00" * 37)
+
+ for i in range(4):
+ # less than 4 bytes of calldata doesn't match the 0 selector and goes to default
+ assert 1 == _call_with_bytes("0x" + "00" * i)
+
+ for i in range(4, 36):
+ # match the full 4 selector bytes, but revert due to malformed (short) calldata
+ assert_tx_failed(lambda: _call_with_bytes("0x" + "00" * i))
+
+
+def test_another_zero_method_id(w3, get_logs, get_contract, assert_tx_failed):
+ # test another zero method id but which only expects 4 bytes of calldata
+ code = """
+event Sent:
+ sig: uint256
+
+@external
+@payable
+# function selector: 0x00000000
+def wycpnbqcyf() -> uint256:
+ log Sent(2)
+ return 7
+
+@external
+def __default__():
+ log Sent(1)
+ """
+ c = get_contract(code)
+
+ assert c.wycpnbqcyf() == 7
+
+ def _call_with_bytes(hexstr):
+ # call our special contract and return the logged value
+ logs = get_logs(
+ w3.eth.send_transaction({"to": c.address, "value": 0, "data": hexstr}), c, "Sent"
+ )
+ return logs[0].args.sig
+
+ assert 1 == _call_with_bytes("0x")
+
+ # call wycpnbqcyf
+ assert 2 == _call_with_bytes("0x" + "00" * 4)
+
+ # too many bytes ok
+ assert 2 == _call_with_bytes("0x" + "00" * 5)
+
+ # "right" method id but by accident - not enough bytes.
+ for i in range(4):
+ assert 1 == _call_with_bytes("0x" + "00" * i)
+
+
+def test_partial_selector_match_trailing_zeroes(w3, get_logs, get_contract):
+ code = """
+event Sent:
+ sig: uint256
+
+@external
+@payable
+# function selector: 0xd88e0b00
+def fow() -> uint256:
+ log Sent(2)
+ return 7
+
+@external
+def __default__():
+ log Sent(1)
+ """
+ c = get_contract(code)
+
+ # sanity check - we can call c.fow()
+ assert c.fow() == 7
+
+ def _call_with_bytes(hexstr):
+ # call our special contract and return the logged value
+ logs = get_logs(
+ w3.eth.send_transaction({"to": c.address, "value": 0, "data": hexstr}), c, "Sent"
+ )
+ return logs[0].args.sig
+
+ # check we can call default function
+ assert 1 == _call_with_bytes("0x")
+
+ # check fow() selector is 0xd88e0b00
+ assert 2 == _call_with_bytes("0xd88e0b00")
+
+ # check calling d88e0b with no trailing zero goes to fallback instead of reverting
+ assert 1 == _call_with_bytes("0xd88e0b")
diff --git a/tests/parser/functions/test_ec.py b/tests/parser/functions/test_ec.py
index be0f6f7ed2..e1d9e3d2ee 100644
--- a/tests/parser/functions/test_ec.py
+++ b/tests/parser/functions/test_ec.py
@@ -45,6 +45,57 @@ def _ecadd3(x: uint256[2], y: uint256[2]) -> uint256[2]:
assert c._ecadd3(G1, negative_G1) == [0, 0]
+def test_ecadd_internal_call(get_contract_with_gas_estimation):
+ code = """
+@internal
+def a() -> uint256[2]:
+ return [1, 2]
+
+@external
+def foo() -> uint256[2]:
+ return ecadd([1, 2], self.a())
+ """
+ c = get_contract_with_gas_estimation(code)
+ assert c.foo() == G1_times_two
+
+
+def test_ecadd_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
+ code = """
+interface Foo:
+ def foo(x: uint256[2]) -> uint256[2]: payable
+
+@external
+def foo(a: Foo) -> uint256[2]:
+ return ecadd([1, 2], a.foo([1, 2]))
+ """
+ c1 = side_effects_contract("uint256[2]")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == G1_times_two
+
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_ecadd_evaluation_order(get_contract_with_gas_estimation):
+ code = """
+x: uint256[2]
+
+@internal
+def bar() -> uint256[2]:
+ self.x = ecadd([1, 2], [1, 2])
+ return [1, 2]
+
+@external
+def foo() -> bool:
+ self.x = [1, 2]
+ a: uint256[2] = ecadd([1, 2], [1, 2])
+ b: uint256[2] = ecadd(self.x, self.bar())
+ return a[0] == b[0] and a[1] == b[1]
+ """
+ c = get_contract_with_gas_estimation(code)
+ assert c.foo() is True
+
+
def test_ecmul(get_contract_with_gas_estimation):
ecmuller = """
x3: uint256[2]
@@ -74,3 +125,54 @@ def _ecmul3(x: uint256[2], y: uint256) -> uint256[2]:
assert c._ecmul(G1, 3) == G1_times_three
assert c._ecmul(G1, curve_order - 1) == negative_G1
assert c._ecmul(G1, curve_order) == [0, 0]
+
+
+def test_ecmul_internal_call(get_contract_with_gas_estimation):
+ code = """
+@internal
+def a() -> uint256:
+ return 3
+
+@external
+def foo() -> uint256[2]:
+ return ecmul([1, 2], self.a())
+ """
+ c = get_contract_with_gas_estimation(code)
+ assert c.foo() == G1_times_three
+
+
+def test_ecmul_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
+ code = """
+interface Foo:
+ def foo(x: uint256) -> uint256: payable
+
+@external
+def foo(a: Foo) -> uint256[2]:
+ return ecmul([1, 2], a.foo(3))
+ """
+ c1 = side_effects_contract("uint256")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == G1_times_three
+
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_ecmul_evaluation_order(get_contract_with_gas_estimation):
+ code = """
+x: uint256[2]
+
+@internal
+def bar() -> uint256:
+ self.x = ecmul([1, 2], 3)
+ return 3
+
+@external
+def foo() -> bool:
+ self.x = [1, 2]
+ a: uint256[2] = ecmul([1, 2], 3)
+ b: uint256[2] = ecmul(self.x, self.bar())
+ return a[0] == b[0] and a[1] == b[1]
+ """
+ c = get_contract_with_gas_estimation(code)
+ assert c.foo() is True
diff --git a/tests/parser/functions/test_ecrecover.py b/tests/parser/functions/test_ecrecover.py
index 77e9655b3e..8571948c3d 100644
--- a/tests/parser/functions/test_ecrecover.py
+++ b/tests/parser/functions/test_ecrecover.py
@@ -40,3 +40,48 @@ def test_ecrecover_uints2() -> address:
assert c.test_ecrecover_uints2() == local_account.address
print("Passed ecrecover test")
+
+
+def test_invalid_signature(get_contract):
+ code = """
+dummies: HashMap[address, HashMap[address, uint256]]
+
+@external
+def test_ecrecover(hash: bytes32, v: uint8, r: uint256) -> address:
+ # read from hashmap to put garbage in 0 memory location
+ s: uint256 = self.dummies[msg.sender][msg.sender]
+ return ecrecover(hash, v, r, s)
+ """
+ c = get_contract(code)
+ hash_ = bytes(i for i in range(32))
+ v = 0 # invalid v! ecrecover precompile will not write to output buffer
+ r = 0
+ # note web3.py decoding of 0x000..00 address is None.
+ assert c.test_ecrecover(hash_, v, r) is None
+
+
+# slightly more subtle example: get_v() stomps memory location 0,
+# so this tests that the output buffer stays clean during ecrecover()
+# builtin execution.
+def test_invalid_signature2(get_contract):
+ code = """
+
+owner: immutable(address)
+
+@external
+def __init__():
+ owner = 0x7E5F4552091A69125d5DfCb7b8C2659029395Bdf
+
+@internal
+def get_v() -> uint256:
+ assert owner == owner # force a dload to write at index 0 of memory
+ return 21
+
+@payable
+@external
+def test_ecrecover() -> bool:
+ assert ecrecover(empty(bytes32), self.get_v(), 0, 0) == empty(address)
+ return True
+ """
+ c = get_contract(code)
+ assert c.test_ecrecover() is True
diff --git a/tests/parser/functions/test_empty.py b/tests/parser/functions/test_empty.py
index 93f9dd7739..c3627785dc 100644
--- a/tests/parser/functions/test_empty.py
+++ b/tests/parser/functions/test_empty.py
@@ -1,6 +1,6 @@
import pytest
-from vyper.exceptions import TypeMismatch
+from vyper.exceptions import InstantiationException, TypeMismatch
@pytest.mark.parametrize(
@@ -711,4 +711,4 @@ def test():
],
)
def test_invalid_types(contract, get_contract, assert_compile_failed):
- assert_compile_failed(lambda: get_contract(contract), TypeMismatch)
+ assert_compile_failed(lambda: get_contract(contract), InstantiationException)
diff --git a/tests/parser/functions/test_floor.py b/tests/parser/functions/test_floor.py
index dc53545ac3..d2fd993785 100644
--- a/tests/parser/functions/test_floor.py
+++ b/tests/parser/functions/test_floor.py
@@ -108,3 +108,37 @@ def floor_param(p: decimal) -> int256:
assert c.fou() == -4
assert c.floor_param(Decimal("-5.6")) == -6
assert c.floor_param(Decimal("-0.0000000001")) == -1
+
+
+def test_floor_ext_call(w3, side_effects_contract, assert_side_effects_invoked, get_contract):
+ code = """
+@external
+def foo(a: Foo) -> int256:
+ return floor(a.foo(2.5))
+
+interface Foo:
+ def foo(x: decimal) -> decimal: nonpayable
+ """
+
+ c1 = side_effects_contract("decimal")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == 2
+
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_floor_internal_call(get_contract_with_gas_estimation):
+ code = """
+@external
+def foo() -> int256:
+ return floor(self.bar())
+
+@internal
+def bar() -> decimal:
+ return 2.5
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo() == 2
diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py
index d07e2b13ae..c16e188cfd 100644
--- a/tests/parser/functions/test_interfaces.py
+++ b/tests/parser/functions/test_interfaces.py
@@ -67,7 +67,6 @@ def test_basic_interface_implements(assert_compile_failed):
implements: ERC20
-
@external
def test() -> bool:
return True
@@ -146,6 +145,7 @@ def bar() -> uint256:
)
+# check that event types match
def test_malformed_event(assert_compile_failed):
interface_code = """
event Foo:
@@ -173,6 +173,64 @@ def bar() -> uint256:
)
+# check that event non-indexed arg needs to match interface
+def test_malformed_events_indexed(assert_compile_failed):
+ interface_code = """
+event Foo:
+ a: uint256
+ """
+
+ interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}}
+
+ not_implemented_code = """
+import a as FooBarInterface
+
+implements: FooBarInterface
+
+# a should not be indexed
+event Foo:
+ a: indexed(uint256)
+
+@external
+def bar() -> uint256:
+ return 1
+ """
+
+ assert_compile_failed(
+ lambda: compile_code(not_implemented_code, interface_codes=interface_codes),
+ InterfaceViolation,
+ )
+
+
+# check that event indexed arg needs to match interface
+def test_malformed_events_indexed2(assert_compile_failed):
+ interface_code = """
+event Foo:
+ a: indexed(uint256)
+ """
+
+ interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}}
+
+ not_implemented_code = """
+import a as FooBarInterface
+
+implements: FooBarInterface
+
+# a should be indexed
+event Foo:
+ a: uint256
+
+@external
+def bar() -> uint256:
+ return 1
+ """
+
+ assert_compile_failed(
+ lambda: compile_code(not_implemented_code, interface_codes=interface_codes),
+ InterfaceViolation,
+ )
+
+
VALID_IMPORT_CODE = [
# import statement, import path without suffix
("import a as Foo", "a"),
@@ -321,6 +379,23 @@ def test():
assert erc20.balanceOf(sender) == 1000
+def test_address_member(w3, get_contract):
+ code = """
+interface Foo:
+ def foo(): payable
+
+f: Foo
+
+@external
+def test(addr: address):
+ self.f = Foo(addr)
+ assert self.f.address == addr
+ """
+ c = get_contract(code)
+ for address in w3.eth.accounts:
+ c.test(address)
+
+
# test data returned from external interface gets clamped
@pytest.mark.parametrize("typ", ("int128", "uint8"))
def test_external_interface_int_clampers(get_contract, assert_tx_failed, typ):
diff --git a/tests/parser/functions/test_mulmod.py b/tests/parser/functions/test_mulmod.py
new file mode 100644
index 0000000000..96477897b9
--- /dev/null
+++ b/tests/parser/functions/test_mulmod.py
@@ -0,0 +1,107 @@
+def test_uint256_mulmod(assert_tx_failed, get_contract_with_gas_estimation):
+ uint256_code = """
+@external
+def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256:
+ return uint256_mulmod(x, y, z)
+ """
+
+ c = get_contract_with_gas_estimation(uint256_code)
+
+ assert c._uint256_mulmod(3, 1, 2) == 1
+ assert c._uint256_mulmod(200, 3, 601) == 600
+ assert c._uint256_mulmod(2**255, 1, 3) == 2
+ assert c._uint256_mulmod(2**255, 2, 6) == 4
+ assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0))
+
+
+def test_uint256_mulmod_complex(get_contract_with_gas_estimation):
+ modexper = """
+@external
+def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256:
+ o: uint256 = 1
+ for i in range(256):
+ o = uint256_mulmod(o, o, modulus)
+ if exponent & shift(1, 255 - i) != 0:
+ o = uint256_mulmod(o, base, modulus)
+ return o
+ """
+
+ c = get_contract_with_gas_estimation(modexper)
+ assert c.exponential(3, 5, 100) == 43
+ assert c.exponential(2, 997, 997) == 2
+
+
+def test_uint256_mulmod_ext_call(
+ w3, side_effects_contract, assert_side_effects_invoked, get_contract
+):
+ code = """
+@external
+def foo(f: Foo) -> uint256:
+ return uint256_mulmod(200, 3, f.foo(601))
+
+interface Foo:
+ def foo(x: uint256) -> uint256: nonpayable
+ """
+
+ c1 = side_effects_contract("uint256")
+ c2 = get_contract(code)
+
+ assert c2.foo(c1.address) == 600
+
+ assert_side_effects_invoked(c1, lambda: c2.foo(c1.address, transact={}))
+
+
+def test_uint256_mulmod_internal_call(get_contract_with_gas_estimation):
+ code = """
+@external
+def foo() -> uint256:
+ return uint256_mulmod(self.a(), self.b(), self.c())
+
+@internal
+def a() -> uint256:
+ return 200
+
+@internal
+def b() -> uint256:
+ return 3
+
+@internal
+def c() -> uint256:
+ return 601
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo() == 600
+
+
+def test_uint256_mulmod_evaluation_order(get_contract_with_gas_estimation):
+ code = """
+a: uint256
+
+@external
+def foo1() -> uint256:
+ self.a = 1
+ return uint256_mulmod(self.a, 2, self.bar())
+
+@external
+def foo2() -> uint256:
+ self.a = 1
+ return uint256_mulmod(self.bar(), self.a, 2)
+
+@external
+def foo3() -> uint256:
+ self.a = 1
+ return uint256_mulmod(2, self.a, self.bar())
+
+@internal
+def bar() -> uint256:
+ self.a = 7
+ return 5
+ """
+
+ c = get_contract_with_gas_estimation(code)
+
+ assert c.foo1() == 2
+ assert c.foo2() == 1
+ assert c.foo3() == 2
diff --git a/tests/parser/functions/test_raw_call.py b/tests/parser/functions/test_raw_call.py
index 4fbff331ed..81efe64a18 100644
--- a/tests/parser/functions/test_raw_call.py
+++ b/tests/parser/functions/test_raw_call.py
@@ -1,6 +1,7 @@
import pytest
from hexbytes import HexBytes
+from vyper import compile_code
from vyper.builtins.functions import eip1167_bytecode
from vyper.exceptions import ArgumentException, InvalidType, StateAccessViolation
@@ -260,6 +261,68 @@ def __default__():
w3.eth.send_transaction({"to": caller.address, "data": sig})
+# check max_outsize=0 does same thing as not setting max_outsize.
+# compile to bytecode and compare bytecode directly.
+def test_max_outsize_0():
+ code1 = """
+@external
+def test_raw_call(_target: address):
+ raw_call(_target, method_id("foo()"))
+ """
+ code2 = """
+@external
+def test_raw_call(_target: address):
+ raw_call(_target, method_id("foo()"), max_outsize=0)
+ """
+ output1 = compile_code(code1, ["bytecode", "bytecode_runtime"])
+ output2 = compile_code(code2, ["bytecode", "bytecode_runtime"])
+ assert output1 == output2
+
+
+# check max_outsize=0 does same thing as not setting max_outsize,
+# this time with revert_on_failure set to False
+def test_max_outsize_0_no_revert_on_failure():
+ code1 = """
+@external
+def test_raw_call(_target: address) -> bool:
+ # compile raw_call both ways, with revert_on_failure
+ a: bool = raw_call(_target, method_id("foo()"), revert_on_failure=False)
+ return a
+ """
+ # same code, but with max_outsize=0
+ code2 = """
+@external
+def test_raw_call(_target: address) -> bool:
+ a: bool = raw_call(_target, method_id("foo()"), max_outsize=0, revert_on_failure=False)
+ return a
+ """
+ output1 = compile_code(code1, ["bytecode", "bytecode_runtime"])
+ output2 = compile_code(code2, ["bytecode", "bytecode_runtime"])
+ assert output1 == output2
+
+
+# test functionality of max_outsize=0
+def test_max_outsize_0_call(get_contract):
+ target_source = """
+@external
+@payable
+def bar() -> uint256:
+ return 123
+ """
+
+ caller_source = """
+@external
+@payable
+def foo(_addr: address) -> bool:
+ success: bool = raw_call(_addr, method_id("bar()"), max_outsize=0, revert_on_failure=False)
+ return success
+ """
+
+ target = get_contract(target_source)
+ caller = get_contract(caller_source)
+ assert caller.foo(target.address) is True
+
+
def test_static_call_fails_nonpayable(get_contract, assert_tx_failed):
target_source = """
baz: int128
@@ -296,8 +359,10 @@ def test_checkable_raw_call(get_contract, assert_tx_failed):
def fail1(should_raise: bool):
if should_raise:
raise "fail"
+
# test both paths for raw_call -
# they are different depending if callee has or doesn't have returntype
+# (fail2 fails because of staticcall)
@external
def fail2(should_raise: bool) -> int128:
if should_raise:
@@ -320,6 +385,7 @@ def foo(_addr: address, should_raise: bool) -> uint256:
)
assert success == (not should_raise)
return 1
+
@external
@view
def bar(_addr: address, should_raise: bool) -> uint256:
@@ -334,6 +400,19 @@ def bar(_addr: address, should_raise: bool) -> uint256:
)
assert success == (not should_raise)
return 2
+
+# test max_outsize not set case
+@external
+@nonpayable
+def baz(_addr: address, should_raise: bool) -> uint256:
+ success: bool = True
+ success = raw_call(
+ _addr,
+ _abi_encode(should_raise, method_id=method_id("fail1(bool)")),
+ revert_on_failure=False,
+ )
+ assert success == (not should_raise)
+ return 3
"""
target = get_contract(target_source)
@@ -343,6 +422,166 @@ def bar(_addr: address, should_raise: bool) -> uint256:
assert caller.foo(target.address, False) == 1
assert caller.bar(target.address, True) == 2
assert caller.bar(target.address, False) == 2
+ assert caller.baz(target.address, True) == 3
+ assert caller.baz(target.address, False) == 3
+
+
+# XXX: these test_raw_call_clean_mem* tests depend on variables and
+# calling convention writing to memory. think of ways to make more
+# robust to changes to calling convention and memory layout.
+
+
+def test_raw_call_msg_data_clean_mem(get_contract):
+ # test msize uses clean memory and does not get overwritten by
+ # any raw_call() arguments
+ code = """
+identity: constant(address) = 0x0000000000000000000000000000000000000004
+
+@external
+def foo():
+ pass
+
+@internal
+@view
+def get_address()->address:
+ a:uint256 = 121 # 0x79
+ return identity
+@external
+def bar(f: uint256, u: uint256) -> Bytes[100]:
+ # embed an internal call in the calculation of address
+ a: Bytes[100] = raw_call(self.get_address(), msg.data, max_outsize=100)
+ return a
+ """
+
+ c = get_contract(code)
+ assert (
+ c.bar(1, 2).hex() == "ae42e951"
+ "0000000000000000000000000000000000000000000000000000000000000001"
+ "0000000000000000000000000000000000000000000000000000000000000002"
+ )
+
+
+def test_raw_call_clean_mem2(get_contract):
+ # test msize uses clean memory and does not get overwritten by
+ # any raw_call() arguments, another way
+ code = """
+buf: Bytes[100]
+
+@external
+def bar(f: uint256, g: uint256, h: uint256) -> Bytes[100]:
+ # embed a memory modifying expression in the calculation of address
+ self.buf = raw_call(
+ [0x0000000000000000000000000000000000000004,][f-1],
+ msg.data,
+ max_outsize=100
+ )
+ return self.buf
+ """
+ c = get_contract(code)
+
+ assert (
+ c.bar(1, 2, 3).hex() == "9309b76e"
+ "0000000000000000000000000000000000000000000000000000000000000001"
+ "0000000000000000000000000000000000000000000000000000000000000002"
+ "0000000000000000000000000000000000000000000000000000000000000003"
+ )
+
+
+def test_raw_call_clean_mem3(get_contract):
+ # test msize uses clean memory and does not get overwritten by
+ # any raw_call() arguments, and also test order of evaluation for
+ # scope_multi
+ code = """
+buf: Bytes[100]
+canary: String[32]
+
+@internal
+def bar() -> address:
+ self.canary = "bar"
+ return 0x0000000000000000000000000000000000000004
+
+@internal
+def goo() -> uint256:
+ self.canary = "goo"
+ return 0
+
+@external
+def foo() -> String[32]:
+ self.buf = raw_call(self.bar(), msg.data, value = self.goo(), max_outsize=100)
+ return self.canary
+ """
+ c = get_contract(code)
+ assert c.foo() == "goo"
+
+
+def test_raw_call_clean_mem_kwargs_value(get_contract):
+ # test msize uses clean memory and does not get overwritten by
+ # any raw_call() kwargs
+ code = """
+buf: Bytes[100]
+
+# add a dummy function to trigger memory expansion in the selector table routine
+@external
+def foo():
+ pass
+
+@internal
+def _value() -> uint256:
+ x: uint256 = 1
+ return x
+
+@external
+def bar(f: uint256) -> Bytes[100]:
+ # embed a memory modifying expression in the calculation of address
+ self.buf = raw_call(
+ 0x0000000000000000000000000000000000000004,
+ msg.data,
+ max_outsize=100,
+ value=self._value()
+ )
+ return self.buf
+ """
+ c = get_contract(code, value=1)
+
+ assert (
+ c.bar(13).hex() == "0423a132"
+ "000000000000000000000000000000000000000000000000000000000000000d"
+ )
+
+
+def test_raw_call_clean_mem_kwargs_gas(get_contract):
+ # test msize uses clean memory and does not get overwritten by
+ # any raw_call() kwargs
+ code = """
+buf: Bytes[100]
+
+# add a dummy function to trigger memory expansion in the selector table routine
+@external
+def foo():
+ pass
+
+@internal
+def _gas() -> uint256:
+ x: uint256 = msg.gas
+ return x
+
+@external
+def bar(f: uint256) -> Bytes[100]:
+ # embed a memory modifying expression in the calculation of address
+ self.buf = raw_call(
+ 0x0000000000000000000000000000000000000004,
+ msg.data,
+ max_outsize=100,
+ gas=self._gas()
+ )
+ return self.buf
+ """
+ c = get_contract(code, value=1)
+
+ assert (
+ c.bar(15).hex() == "0423a132"
+ "000000000000000000000000000000000000000000000000000000000000000f"
+ )
uncompilable_code = [
diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py
index 11d834bf42..6229b47921 100644
--- a/tests/parser/functions/test_slice.py
+++ b/tests/parser/functions/test_slice.py
@@ -1,6 +1,9 @@
+import hypothesis.strategies as st
import pytest
+from hypothesis import given, settings
-from vyper.exceptions import ArgumentException
+from vyper.compiler.settings import OptimizationLevel
+from vyper.exceptions import ArgumentException, TypeMismatch
_fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)]
@@ -9,14 +12,6 @@ def _generate_bytes(length):
return bytes(list(range(length)))
-# good numbers to try
-_fun_numbers = [0, 1, 5, 31, 32, 33, 64, 99, 100, 101]
-
-
-# [b"", b"\x01", b"\x02"...]
-_bytes_examples = [_generate_bytes(i) for i in _fun_numbers if i <= 100]
-
-
def test_basic_slice(get_contract_with_gas_estimation):
code = """
@external
@@ -31,75 +26,87 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]:
assert x == b"klmnopqrst", x
-@pytest.mark.parametrize("bytesdata", _bytes_examples)
-@pytest.mark.parametrize("start", _fun_numbers)
+# note: optimization boundaries at 32, 64 and 320 depending on mode
+_draw_1024 = st.integers(min_value=0, max_value=1024)
+_draw_1024_1 = st.integers(min_value=1, max_value=1024)
+_bytes_1024 = st.binary(min_size=0, max_size=1024)
+
+
@pytest.mark.parametrize("literal_start", (True, False))
-@pytest.mark.parametrize("length", _fun_numbers)
@pytest.mark.parametrize("literal_length", (True, False))
+@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
+@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
+@settings(max_examples=100, deadline=None)
@pytest.mark.fuzzing
def test_slice_immutable(
get_contract,
assert_compile_failed,
assert_tx_failed,
+ opt_level,
bytesdata,
start,
literal_start,
length,
literal_length,
+ length_bound,
):
_start = start if literal_start else "start"
_length = length if literal_length else "length"
code = f"""
-IMMUTABLE_BYTES: immutable(Bytes[100])
-IMMUTABLE_SLICE: immutable(Bytes[100])
+IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
+IMMUTABLE_SLICE: immutable(Bytes[{length_bound}])
@external
-def __init__(inp: Bytes[100], start: uint256, length: uint256):
+def __init__(inp: Bytes[{length_bound}], start: uint256, length: uint256):
IMMUTABLE_BYTES = inp
IMMUTABLE_SLICE = slice(IMMUTABLE_BYTES, {_start}, {_length})
@external
-def do_splice() -> Bytes[100]:
+def do_splice() -> Bytes[{length_bound}]:
return IMMUTABLE_SLICE
"""
+ def _get_contract():
+ return get_contract(code, bytesdata, start, length, override_opt_level=opt_level)
+
if (
- (start + length > 100 and literal_start and literal_length)
- or (literal_length and length > 100)
- or (literal_start and start > 100)
+ (start + length > length_bound and literal_start and literal_length)
+ or (literal_length and length > length_bound)
+ or (literal_start and start > length_bound)
or (literal_length and length < 1)
):
- assert_compile_failed(
- lambda: get_contract(code, bytesdata, start, length), ArgumentException
- )
- elif start + length > len(bytesdata):
- assert_tx_failed(lambda: get_contract(code, bytesdata, start, length))
+ assert_compile_failed(lambda: _get_contract(), ArgumentException)
+ elif start + length > len(bytesdata) or (len(bytesdata) > length_bound):
+ # deploy fail
+ assert_tx_failed(lambda: _get_contract())
else:
- c = get_contract(code, bytesdata, start, length)
+ c = _get_contract()
assert c.do_splice() == bytesdata[start : start + length]
@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code"))
-@pytest.mark.parametrize("bytesdata", _bytes_examples)
-@pytest.mark.parametrize("start", _fun_numbers)
@pytest.mark.parametrize("literal_start", (True, False))
-@pytest.mark.parametrize("length", _fun_numbers)
@pytest.mark.parametrize("literal_length", (True, False))
+@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
+@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
+@settings(max_examples=100, deadline=None)
@pytest.mark.fuzzing
def test_slice_bytes(
get_contract,
assert_compile_failed,
assert_tx_failed,
+ opt_level,
location,
bytesdata,
start,
literal_start,
length,
literal_length,
+ length_bound,
):
if location == "memory":
- spliced_code = "foo: Bytes[100] = inp"
+ spliced_code = f"foo: Bytes[{length_bound}] = inp"
foo = "foo"
elif location == "storage":
spliced_code = "self.foo = inp"
@@ -120,31 +127,38 @@ def test_slice_bytes(
_length = length if literal_length else "length"
code = f"""
-foo: Bytes[100]
-IMMUTABLE_BYTES: immutable(Bytes[100])
+foo: Bytes[{length_bound}]
+IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
@external
-def __init__(foo: Bytes[100]):
+def __init__(foo: Bytes[{length_bound}]):
IMMUTABLE_BYTES = foo
@external
-def do_slice(inp: Bytes[100], start: uint256, length: uint256) -> Bytes[100]:
+def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]:
{spliced_code}
return slice({foo}, {_start}, {_length})
"""
- length_bound = len(bytesdata) if location == "literal" else 100
+ def _get_contract():
+ return get_contract(code, bytesdata, override_opt_level=opt_level)
+
+ data_length = len(bytesdata) if location == "literal" else length_bound
if (
- (start + length > length_bound and literal_start and literal_length)
- or (literal_length and length > length_bound)
- or (literal_start and start > length_bound)
+ (start + length > data_length and literal_start and literal_length)
+ or (literal_length and length > data_length)
+ or (location == "literal" and len(bytesdata) > length_bound)
+ or (literal_start and start > data_length)
or (literal_length and length < 1)
):
- assert_compile_failed(lambda: get_contract(code, bytesdata), ArgumentException)
+ assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch))
+ elif len(bytesdata) > data_length:
+ # deploy fail
+ assert_tx_failed(lambda: _get_contract())
elif start + length > len(bytesdata):
- c = get_contract(code, bytesdata)
+ c = _get_contract()
assert_tx_failed(lambda: c.do_slice(bytesdata, start, length))
else:
- c = get_contract(code, bytesdata)
+ c = _get_contract()
assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code
diff --git a/tests/parser/globals/test_getters.py b/tests/parser/globals/test_getters.py
index 59c91cbeef..5eac074ef6 100644
--- a/tests/parser/globals/test_getters.py
+++ b/tests/parser/globals/test_getters.py
@@ -35,6 +35,7 @@ def test_getter_code(get_contract_with_gas_estimation_for_constants):
c: public(constant(uint256)) = 1
d: public(immutable(uint256))
e: public(immutable(uint256[2]))
+f: public(constant(uint256[2])) = [3, 7]
@external
def __init__():
@@ -68,6 +69,7 @@ def __init__():
assert c.c() == 1
assert c.d() == 1729
assert c.e(0) == 2
+ assert [c.f(i) for i in range(2)] == [3, 7]
def test_getter_mutability(get_contract):
diff --git a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py
index 6f2246c6c0..68a07178bb 100644
--- a/tests/parser/parser_utils/test_annotate_and_optimize_ast.py
+++ b/tests/parser/parser_utils/test_annotate_and_optimize_ast.py
@@ -29,7 +29,7 @@ def foo() -> int128:
def get_contract_info(source_code):
- class_types, reformatted_code = pre_parse(source_code)
+ _, class_types, reformatted_code = pre_parse(source_code)
py_ast = python_ast.parse(reformatted_code)
annotate_python_ast(py_ast, reformatted_code, class_types)
diff --git a/tests/parser/syntax/test_addmulmod.py b/tests/parser/syntax/test_addmulmod.py
new file mode 100644
index 0000000000..ddff4d3e01
--- /dev/null
+++ b/tests/parser/syntax/test_addmulmod.py
@@ -0,0 +1,27 @@
+import pytest
+
+from vyper.exceptions import InvalidType
+
+fail_list = [
+ ( # bad AST nodes given as arguments
+ """
+@external
+def foo() -> uint256:
+ return uint256_addmod(1.1, 1.2, 3.0)
+ """,
+ InvalidType,
+ ),
+ ( # bad AST nodes given as arguments
+ """
+@external
+def foo() -> uint256:
+ return uint256_mulmod(1.1, 1.2, 3.0)
+ """,
+ InvalidType,
+ ),
+]
+
+
+@pytest.mark.parametrize("code,exc", fail_list)
+def test_add_mod_fail(assert_compile_failed, get_contract, code, exc):
+ assert_compile_failed(lambda: get_contract(code), exc)
diff --git a/tests/parser/syntax/test_address_code.py b/tests/parser/syntax/test_address_code.py
index 25fe1be0b4..70ba5cbbf7 100644
--- a/tests/parser/syntax/test_address_code.py
+++ b/tests/parser/syntax/test_address_code.py
@@ -5,6 +5,7 @@
from web3 import Web3
from vyper import compiler
+from vyper.compiler.settings import Settings
from vyper.exceptions import NamespaceCollision, StructureException, VyperException
# For reproducibility, use precompiled data of `hello: public(uint256)` using vyper 0.3.1
@@ -161,7 +162,7 @@ def test_address_code_compile_success(code: str):
compiler.compile_code(code)
-def test_address_code_self_success(get_contract, no_optimize: bool):
+def test_address_code_self_success(get_contract, optimize):
code = """
code_deployment: public(Bytes[32])
@@ -174,8 +175,9 @@ def code_runtime() -> Bytes[32]:
return slice(self.code, 0, 32)
"""
contract = get_contract(code)
+ settings = Settings(optimize=optimize)
code_compiled = compiler.compile_code(
- code, output_formats=["bytecode", "bytecode_runtime"], no_optimize=no_optimize
+ code, output_formats=["bytecode", "bytecode_runtime"], settings=settings
)
assert contract.code_deployment() == bytes.fromhex(code_compiled["bytecode"][2:])[:32]
assert contract.code_runtime() == bytes.fromhex(code_compiled["bytecode_runtime"][2:])[:32]
diff --git a/tests/parser/syntax/test_byte_string.py b/tests/parser/syntax/test_byte_string.py
deleted file mode 100644
index 90bbe197b0..0000000000
--- a/tests/parser/syntax/test_byte_string.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import pytest
-
-from vyper import compiler
-
-valid_list = [
- """
-@external
-def foo() -> String[10]:
- return "badminton"
- """,
- """
-@external
-def foo():
- x: String[11] = "¡très bien!"
- """,
- """
-@external
-def test() -> String[100]:
- return "hello world!"
- """,
-]
-
-
-@pytest.mark.parametrize("good_code", valid_list)
-def test_byte_string_success(good_code):
- assert compiler.compile_code(good_code) is not None
diff --git a/tests/parser/syntax/test_bytes.py b/tests/parser/syntax/test_bytes.py
index d732ec3ea5..a7fb7e77ce 100644
--- a/tests/parser/syntax/test_bytes.py
+++ b/tests/parser/syntax/test_bytes.py
@@ -1,7 +1,13 @@
import pytest
from vyper import compiler
-from vyper.exceptions import InvalidOperation, InvalidType, SyntaxException, TypeMismatch
+from vyper.exceptions import (
+ InvalidOperation,
+ InvalidType,
+ StructureException,
+ SyntaxException,
+ TypeMismatch,
+)
fail_list = [
(
@@ -77,6 +83,14 @@ def test() -> Bytes[1]:
""",
SyntaxException,
),
+ (
+ """
+@external
+def foo():
+ a: Bytes = b"abc"
+ """,
+ StructureException,
+ ),
]
diff --git a/tests/parser/syntax/test_chainid.py b/tests/parser/syntax/test_chainid.py
index eb2ed36325..2b6e08cbc4 100644
--- a/tests/parser/syntax/test_chainid.py
+++ b/tests/parser/syntax/test_chainid.py
@@ -1,8 +1,9 @@
import pytest
from vyper import compiler
+from vyper.compiler.settings import Settings
from vyper.evm.opcodes import EVM_VERSIONS
-from vyper.exceptions import EvmVersionException, InvalidType, TypeMismatch
+from vyper.exceptions import InvalidType, TypeMismatch
@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
@@ -12,12 +13,9 @@ def test_evm_version(evm_version):
def foo():
a: uint256 = chain.id
"""
+ settings = Settings(evm_version=evm_version)
- if EVM_VERSIONS[evm_version] < 2:
- with pytest.raises(EvmVersionException):
- compiler.compile_code(code, evm_version=evm_version)
- else:
- compiler.compile_code(code, evm_version=evm_version)
+ assert compiler.compile_code(code, settings=settings) is not None
fail_list = [
@@ -71,10 +69,10 @@ def foo(inp: Bytes[10]) -> Bytes[3]:
def test_chain_fail(bad_code):
if isinstance(bad_code, tuple):
with pytest.raises(bad_code[1]):
- compiler.compile_code(bad_code[0], evm_version="istanbul")
+ compiler.compile_code(bad_code[0])
else:
with pytest.raises(TypeMismatch):
- compiler.compile_code(bad_code, evm_version="istanbul")
+ compiler.compile_code(bad_code)
valid_list = [
@@ -95,7 +93,7 @@ def check_chain_id(c: uint256) -> bool:
@pytest.mark.parametrize("good_code", valid_list)
def test_chain_success(good_code):
- assert compiler.compile_code(good_code, evm_version="istanbul") is not None
+ assert compiler.compile_code(good_code) is not None
def test_chainid_operation(get_contract_with_gas_estimation):
@@ -105,5 +103,5 @@ def test_chainid_operation(get_contract_with_gas_estimation):
def get_chain_id() -> uint256:
return chain.id
"""
- c = get_contract_with_gas_estimation(code, evm_version="istanbul")
+ c = get_contract_with_gas_estimation(code)
assert c.get_chain_id() == 131277322940537 # Default value of py-evm
diff --git a/tests/parser/syntax/test_codehash.py b/tests/parser/syntax/test_codehash.py
index 8c1e430d32..5074d14636 100644
--- a/tests/parser/syntax/test_codehash.py
+++ b/tests/parser/syntax/test_codehash.py
@@ -1,13 +1,13 @@
import pytest
from vyper.compiler import compile_code
+from vyper.compiler.settings import Settings
from vyper.evm.opcodes import EVM_VERSIONS
-from vyper.exceptions import EvmVersionException
from vyper.utils import keccak256
@pytest.mark.parametrize("evm_version", list(EVM_VERSIONS))
-def test_get_extcodehash(get_contract, evm_version, no_optimize):
+def test_get_extcodehash(get_contract, evm_version, optimize):
code = """
a: address
@@ -32,15 +32,8 @@ def foo3() -> bytes32:
def foo4() -> bytes32:
return self.a.codehash
"""
-
- if evm_version in ("byzantium", "atlantis"):
- with pytest.raises(EvmVersionException):
- compile_code(code, evm_version=evm_version)
- return
-
- compiled = compile_code(
- code, ["bytecode_runtime"], evm_version=evm_version, no_optimize=no_optimize
- )
+ settings = Settings(evm_version=evm_version, optimize=optimize)
+ compiled = compile_code(code, ["bytecode_runtime"], settings=settings)
bytecode = bytes.fromhex(compiled["bytecode_runtime"][2:])
hash_ = keccak256(bytecode)
diff --git a/tests/parser/syntax/test_constants.py b/tests/parser/syntax/test_constants.py
index 15546672c6..ffd2f1faa0 100644
--- a/tests/parser/syntax/test_constants.py
+++ b/tests/parser/syntax/test_constants.py
@@ -69,6 +69,46 @@
"""
VAL: constant(Bytes[4]) = b"t"
VAL: uint256
+ """,
+ NamespaceCollision,
+ ),
+ # global with same type and name
+ (
+ """
+VAL: constant(uint256) = 1
+VAL: uint256
+ """,
+ NamespaceCollision,
+ ),
+ # global with same type and name, different order
+ (
+ """
+VAL: uint256
+VAL: constant(uint256) = 1
+ """,
+ NamespaceCollision,
+ ),
+ # global with same type and name
+ (
+ """
+VAL: immutable(uint256)
+VAL: uint256
+
+@external
+def __init__():
+ VAL = 1
+ """,
+ NamespaceCollision,
+ ),
+ # global with same type and name, different order
+ (
+ """
+VAL: uint256
+VAL: immutable(uint256)
+
+@external
+def __init__():
+ VAL = 1
""",
NamespaceCollision,
),
diff --git a/tests/parser/syntax/test_dynamic_array.py b/tests/parser/syntax/test_dynamic_array.py
index e7dc2d1183..0c23bf67da 100644
--- a/tests/parser/syntax/test_dynamic_array.py
+++ b/tests/parser/syntax/test_dynamic_array.py
@@ -16,6 +16,14 @@
""",
StructureException,
),
+ (
+ """
+@external
+def foo():
+ a: DynArray = [1, 2, 3]
+ """,
+ StructureException,
+ ),
]
diff --git a/tests/parser/syntax/test_enum.py b/tests/parser/syntax/test_enum.py
index c13f27f497..9bb74fb675 100644
--- a/tests/parser/syntax/test_enum.py
+++ b/tests/parser/syntax/test_enum.py
@@ -110,6 +110,20 @@ def foo() -> Roles:
""",
UnknownAttribute,
),
+ (
+ """
+enum A:
+ a
+enum B:
+ a
+ b
+
+@internal
+def foo():
+ a: A = B.b
+ """,
+ TypeMismatch,
+ ),
]
diff --git a/tests/parser/syntax/test_for_range.py b/tests/parser/syntax/test_for_range.py
index b2a9491058..e6f35c1d2d 100644
--- a/tests/parser/syntax/test_for_range.py
+++ b/tests/parser/syntax/test_for_range.py
@@ -12,7 +12,26 @@ def foo():
pass
""",
StructureException,
- )
+ ),
+ (
+ """
+@external
+def bar():
+ for i in range(1,2,bound=2):
+ pass
+ """,
+ StructureException,
+ ),
+ (
+ """
+@external
+def bar():
+ x:uint256 = 1
+ for i in range(x,x+1,bound=2):
+ pass
+ """,
+ StructureException,
+ ),
]
diff --git a/tests/parser/syntax/test_interfaces.py b/tests/parser/syntax/test_interfaces.py
index 8f49352876..498f1363d8 100644
--- a/tests/parser/syntax/test_interfaces.py
+++ b/tests/parser/syntax/test_interfaces.py
@@ -3,6 +3,7 @@
from vyper import compiler
from vyper.exceptions import (
ArgumentException,
+ InterfaceViolation,
InvalidReference,
InvalidType,
StructureException,
@@ -106,6 +107,109 @@ def foo(): nonpayable
""",
StructureException,
),
+ (
+ """
+from vyper.interfaces import ERC20
+
+interface A:
+ def f(): view
+
+@internal
+def foo():
+ a: ERC20 = A(empty(address))
+ """,
+ TypeMismatch,
+ ),
+ (
+ """
+interface A:
+ def f(a: uint256): view
+
+implements: A
+
+@external
+@nonpayable
+def f(a: uint256): # visibility is nonpayable instead of view
+ pass
+ """,
+ InterfaceViolation,
+ ),
+ (
+ # `receiver` of `Transfer` event should be indexed
+ """
+from vyper.interfaces import ERC20
+
+implements: ERC20
+
+event Transfer:
+ sender: indexed(address)
+ receiver: address
+ value: uint256
+
+event Approval:
+ owner: indexed(address)
+ spender: indexed(address)
+ value: uint256
+
+name: public(String[32])
+symbol: public(String[32])
+decimals: public(uint8)
+balanceOf: public(HashMap[address, uint256])
+allowance: public(HashMap[address, HashMap[address, uint256]])
+totalSupply: public(uint256)
+
+@external
+def transfer(_to : address, _value : uint256) -> bool:
+ return True
+
+@external
+def transferFrom(_from : address, _to : address, _value : uint256) -> bool:
+ return True
+
+@external
+def approve(_spender : address, _value : uint256) -> bool:
+ return True
+ """,
+ InterfaceViolation,
+ ),
+ (
+ # `value` of `Transfer` event should not be indexed
+ """
+from vyper.interfaces import ERC20
+
+implements: ERC20
+
+event Transfer:
+ sender: indexed(address)
+ receiver: indexed(address)
+ value: indexed(uint256)
+
+event Approval:
+ owner: indexed(address)
+ spender: indexed(address)
+ value: uint256
+
+name: public(String[32])
+symbol: public(String[32])
+decimals: public(uint8)
+balanceOf: public(HashMap[address, uint256])
+allowance: public(HashMap[address, HashMap[address, uint256]])
+totalSupply: public(uint256)
+
+@external
+def transfer(_to : address, _value : uint256) -> bool:
+ return True
+
+@external
+def transferFrom(_from : address, _to : address, _value : uint256) -> bool:
+ return True
+
+@external
+def approve(_spender : address, _value : uint256) -> bool:
+ return True
+ """,
+ InterfaceViolation,
+ ),
]
@@ -248,6 +352,20 @@ def foo() -> uint256: view
def __init__(x: uint256):
foo = x
""",
+ # no namespace collision of interface after storage variable
+ """
+a: constant(uint256) = 1
+
+interface A:
+ def f(a: uint128): view
+ """,
+ # no namespace collision of storage variable after interface
+ """
+interface A:
+ def f(a: uint256): view
+
+a: constant(uint128) = 1
+ """,
]
diff --git a/tests/parser/syntax/test_invalids.py b/tests/parser/syntax/test_invalids.py
index 3c51075e60..33478fcff1 100644
--- a/tests/parser/syntax/test_invalids.py
+++ b/tests/parser/syntax/test_invalids.py
@@ -363,6 +363,13 @@ def a():
UnknownAttribute,
)
+must_fail(
+ """
+a: HashMap
+""",
+ StructureException,
+)
+
@pytest.mark.parametrize("bad_code,exception_type", fail_list)
def test_compilation_fails_with_exception(bad_code, exception_type):
diff --git a/tests/parser/syntax/test_list.py b/tests/parser/syntax/test_list.py
index 6d941fa2df..3f81b911c8 100644
--- a/tests/parser/syntax/test_list.py
+++ b/tests/parser/syntax/test_list.py
@@ -302,6 +302,12 @@ def foo():
def foo():
self.b[0] = 7.0
""",
+ """
+@external
+def foo():
+ for i in [[], []]:
+ pass
+ """,
]
diff --git a/tests/parser/syntax/test_logging.py b/tests/parser/syntax/test_logging.py
index 39573642c0..2dd21e7a92 100644
--- a/tests/parser/syntax/test_logging.py
+++ b/tests/parser/syntax/test_logging.py
@@ -1,7 +1,7 @@
import pytest
from vyper import compiler
-from vyper.exceptions import InvalidType, TypeMismatch
+from vyper.exceptions import InvalidType, StructureException, TypeMismatch
fail_list = [
"""
@@ -45,3 +45,19 @@ def test_logging_fail(bad_code):
else:
with pytest.raises(TypeMismatch):
compiler.compile_code(bad_code)
+
+
+@pytest.mark.parametrize("mutability", ["@pure", "@view"])
+@pytest.mark.parametrize("visibility", ["@internal", "@external"])
+def test_logging_from_non_mutable(mutability, visibility):
+ code = f"""
+event Test:
+ n: uint256
+
+{visibility}
+{mutability}
+def test():
+ log Test(1)
+ """
+ with pytest.raises(StructureException):
+ compiler.compile_code(code)
diff --git a/tests/parser/syntax/test_public.py b/tests/parser/syntax/test_public.py
index fd0058cab8..68575ebd41 100644
--- a/tests/parser/syntax/test_public.py
+++ b/tests/parser/syntax/test_public.py
@@ -23,6 +23,27 @@ def __init__():
def foo() -> int128:
return self.x / self.y / self.z
""",
+ # expansion of public user-defined struct
+ """
+struct Foo:
+ a: uint256
+
+x: public(HashMap[uint256, Foo])
+ """,
+ # expansion of public user-defined enum
+ """
+enum Foo:
+ BAR
+
+x: public(HashMap[uint256, Foo])
+ """,
+ # expansion of public user-defined interface
+ """
+interface Foo:
+ def bar(): nonpayable
+
+x: public(HashMap[uint256, Foo])
+ """,
]
diff --git a/tests/parser/syntax/test_self_balance.py b/tests/parser/syntax/test_self_balance.py
index 949cdde324..63db58e347 100644
--- a/tests/parser/syntax/test_self_balance.py
+++ b/tests/parser/syntax/test_self_balance.py
@@ -1,6 +1,7 @@
import pytest
from vyper import compiler
+from vyper.compiler.settings import Settings
from vyper.evm.opcodes import EVM_VERSIONS
@@ -18,7 +19,8 @@ def get_balance() -> uint256:
def __default__():
pass
"""
- opcodes = compiler.compile_code(code, ["opcodes"], evm_version=evm_version)["opcodes"]
+ settings = Settings(evm_version=evm_version)
+ opcodes = compiler.compile_code(code, ["opcodes"], settings=settings)["opcodes"]
if EVM_VERSIONS[evm_version] >= EVM_VERSIONS["istanbul"]:
assert "SELFBALANCE" in opcodes
else:
diff --git a/tests/parser/syntax/test_string.py b/tests/parser/syntax/test_string.py
index 5b5b9e5bb5..6252011bd9 100644
--- a/tests/parser/syntax/test_string.py
+++ b/tests/parser/syntax/test_string.py
@@ -1,6 +1,7 @@
import pytest
from vyper import compiler
+from vyper.exceptions import StructureException
valid_list = [
"""
@@ -27,9 +28,31 @@ def foo() -> bool:
y: String[12] = "test"
return x != y
""",
+ """
+@external
+def test() -> String[100]:
+ return "hello world!"
+ """,
]
@pytest.mark.parametrize("good_code", valid_list)
def test_string_success(good_code):
assert compiler.compile_code(good_code) is not None
+
+
+invalid_list = [
+ (
+ """
+@external
+def foo():
+ a: String = "abc"
+ """,
+ StructureException,
+ )
+]
+
+
+@pytest.mark.parametrize("bad_code,exc", invalid_list)
+def test_string_fail(assert_compile_failed, get_contract_with_gas_estimation, bad_code, exc):
+ assert_compile_failed(lambda: get_contract_with_gas_estimation(bad_code), exc)
diff --git a/tests/parser/syntax/test_structs.py b/tests/parser/syntax/test_structs.py
index 757c46c4b3..b30f7e6098 100644
--- a/tests/parser/syntax/test_structs.py
+++ b/tests/parser/syntax/test_structs.py
@@ -2,6 +2,7 @@
from vyper import compiler
from vyper.exceptions import (
+ InstantiationException,
InvalidType,
StructureException,
TypeMismatch,
@@ -297,7 +298,7 @@ def foo():
a: HashMap[int128, C]
b: int128
""",
- StructureException,
+ InstantiationException,
),
"""
struct C1:
@@ -429,6 +430,16 @@ def foo():
""",
StructureException,
),
+ (
+ """
+event Foo:
+ a: uint256
+
+struct Bar:
+ a: Foo
+ """,
+ InstantiationException,
+ ),
]
diff --git a/tests/parser/syntax/test_ternary.py b/tests/parser/syntax/test_ternary.py
new file mode 100644
index 0000000000..325be3e43b
--- /dev/null
+++ b/tests/parser/syntax/test_ternary.py
@@ -0,0 +1,143 @@
+import pytest
+
+from vyper.compiler import compile_code
+from vyper.exceptions import InvalidType, TypeMismatch
+
+good_list = [
+ # basic test
+ """
+@external
+def foo(a: uint256, b: uint256) -> uint256:
+ return a if a > b else b
+ """,
+ """
+@external
+def foo():
+ a: bool = (True if True else True) or True
+ """,
+ # different locations:
+ """
+b: uint256
+
+@external
+def foo(x: uint256) -> uint256:
+ return x if x > self.b else self.b
+ """,
+ # different kinds of test exprs
+ """
+@external
+def foo(x: uint256, t: bool) -> uint256:
+ return x if t else 1
+ """,
+ """
+@external
+def foo(x: uint256) -> uint256:
+ return x if True else 1
+ """,
+ """
+@external
+def foo(x: uint256) -> uint256:
+ return x if False else 1
+ """,
+ # more complex types
+ """
+@external
+def foo(t: bool) -> DynArray[uint256, 1]:
+ return [2] if t else [1]
+ """,
+ # TODO: get this working, depends #3377
+ # """
+ # @external
+ # def foo(t: bool) -> DynArray[uint256, 1]:
+ # return [] if t else [1]
+ # """,
+ """
+@external
+def foo(t: bool) -> (uint256, uint256):
+ a: uint256 = 0
+ b: uint256 = 1
+ return (a, b) if t else (b, a)
+ """,
+]
+
+
+@pytest.mark.parametrize("code", good_list)
+def test_ternary_good(code):
+ assert compile_code(code) is not None
+
+
+fail_list = [
+ ( # bad test type
+ """
+@external
+def foo() -> uint256:
+ return 1 if 1 else 2
+ """,
+ InvalidType,
+ ),
+ ( # bad test type: constant
+ """
+TEST: constant(uint256) = 1
+@external
+def foo() -> uint256:
+ return 1 if TEST else 2
+ """,
+ InvalidType,
+ ),
+ ( # bad test type: variable
+ """
+TEST: constant(uint256) = 1
+@external
+def foo(t: uint256) -> uint256:
+ return 1 if t else 2
+ """,
+ TypeMismatch,
+ ),
+ ( # mismatched body and orelse: literal
+ """
+@external
+def foo() -> uint256:
+ return 1 if True else 2.0
+ """,
+ TypeMismatch,
+ ),
+ ( # mismatched body and orelse: literal and known type
+ """
+T: constant(uint256) = 1
+@external
+def foo() -> uint256:
+ return T if True else 2.0
+ """,
+ TypeMismatch,
+ ),
+ ( # mismatched body and orelse: both variable
+ """
+@external
+def foo(x: uint256, y: uint8) -> uint256:
+ return x if True else y
+ """,
+ TypeMismatch,
+ ),
+ ( # mismatched tuple types
+ """
+@external
+def foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256):
+ return (a, b) if True else (a, b, c)
+ """,
+ TypeMismatch,
+ ),
+ ( # mismatched tuple types - other direction
+ """
+@external
+def foo(a: uint256, b: uint256, c: uint256) -> (uint256, uint256):
+ return (a, b, c) if True else (a, b)
+ """,
+ TypeMismatch,
+ ),
+]
+
+
+@pytest.mark.parametrize("code,exc", fail_list)
+def test_functions_call_fail(code, exc):
+ with pytest.raises(exc):
+ compile_code(code)
diff --git a/tests/parser/syntax/utils/test_function_names.py b/tests/parser/syntax/utils/test_function_names.py
index 90e185558c..5489a4f6a0 100644
--- a/tests/parser/syntax/utils/test_function_names.py
+++ b/tests/parser/syntax/utils/test_function_names.py
@@ -23,6 +23,22 @@ def wei(i: int128) -> int128:
temp_var : int128 = i
return temp_var1
""",
+ # collision between getter and external function
+ """
+foo: public(uint256)
+
+@external
+def foo():
+ pass
+ """,
+ # collision between getter and external function, reverse order
+ """
+@external
+def foo():
+ pass
+
+foo: public(uint256)
+ """,
]
@@ -77,6 +93,30 @@ def append():
def foo():
self.append()
""",
+ # "method id" collisions between internal functions are allowed
+ """
+@internal
+@view
+def gfah():
+ pass
+
+@internal
+@view
+def eexo():
+ pass
+ """,
+ # "method id" collisions between internal+external functions are allowed
+ """
+@internal
+@view
+def gfah():
+ pass
+
+@external
+@view
+def eexo():
+ pass
+ """,
]
diff --git a/tests/parser/test_call_graph_stability.py b/tests/parser/test_call_graph_stability.py
new file mode 100644
index 0000000000..a6193610e2
--- /dev/null
+++ b/tests/parser/test_call_graph_stability.py
@@ -0,0 +1,75 @@
+import random
+import string
+
+import hypothesis.strategies as st
+import pytest
+from hypothesis import given, settings
+
+import vyper.ast as vy_ast
+from vyper.ast.identifiers import RESERVED_KEYWORDS
+from vyper.compiler.phases import CompilerData
+
+
+def _valid_identifier(attr):
+ return attr not in RESERVED_KEYWORDS
+
+
+# random names for functions
+@settings(max_examples=20, deadline=None)
+@given(
+ st.lists(
+ st.tuples(
+ st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
+ st.text(alphabet=string.ascii_lowercase, min_size=1).filter(_valid_identifier),
+ ),
+ unique_by=lambda x: x[1], # unique on function name
+ min_size=1,
+ max_size=10,
+ )
+)
+@pytest.mark.fuzzing
+def test_call_graph_stability_fuzz(funcs):
+ def generate_func_def(mutability, func_name, i):
+ return f"""
+@internal
+{mutability}
+def {func_name}() -> uint256:
+ return {i}
+ """
+
+ func_defs = "\n".join(generate_func_def(m, s, i) for i, (m, s) in enumerate(funcs))
+
+ for _ in range(10):
+ func_names = [f for (_, f) in funcs]
+ random.shuffle(func_names)
+
+ self_calls = "\n".join(f" self.{f}()" for f in func_names)
+ code = f"""
+{func_defs}
+
+@external
+def foo():
+{self_calls}
+ """
+ t = CompilerData(code)
+
+ # check the .called_functions data structure on foo() directly
+ foo = t.vyper_module_folded.get_children(vy_ast.FunctionDef, filters={"name": "foo"})[0]
+ foo_t = foo._metadata["type"]
+ assert [f.name for f in foo_t.called_functions] == func_names
+
+ # now for sanity, ensure the order that the function definitions appear
+ # in the IR is the same as the order of the calls
+ sigs = t.function_signatures
+ del sigs["foo"]
+ ir = t.ir_runtime
+ ir_funcs = []
+ # search for function labels
+ for d in ir.args: # currently: (seq ... (seq (label foo ...)) ...)
+ if d.value == "seq" and d.args[0].value == "label":
+ r = d.args[0].args[0].value
+ if isinstance(r, str) and r.startswith("internal"):
+ ir_funcs.append(r)
+ assert ir_funcs == [
+ f._ir_info.internal_function_label(is_ctor_context=False) for f in sigs.values()
+ ]
diff --git a/tests/parser/test_selector_table.py b/tests/parser/test_selector_table.py
new file mode 100644
index 0000000000..3ac50707c2
--- /dev/null
+++ b/tests/parser/test_selector_table.py
@@ -0,0 +1,629 @@
+import hypothesis.strategies as st
+import pytest
+from hypothesis import given, settings
+
+import vyper.utils as utils
+from vyper.codegen.jumptable_utils import (
+ generate_dense_jumptable_info,
+ generate_sparse_jumptable_buckets,
+)
+from vyper.compiler.settings import OptimizationLevel
+
+
+def test_dense_selector_table_empty_buckets(get_contract):
+ # some special combination of selectors which can result in
+ # some empty bucket being returned from _mk_buckets (that is,
+ # len(_mk_buckets(..., n_buckets)) != n_buckets
+ code = """
+@external
+def aX61QLPWF()->uint256:
+ return 1
+@external
+def aQHG0P2L1()->uint256:
+ return 2
+@external
+def a2G8ME94W()->uint256:
+ return 3
+@external
+def a0GNA21AY()->uint256:
+ return 4
+@external
+def a4U1XA4T5()->uint256:
+ return 5
+@external
+def aAYLMGOBZ()->uint256:
+ return 6
+@external
+def a0KXRLHKE()->uint256:
+ return 7
+@external
+def aDQS32HTR()->uint256:
+ return 8
+@external
+def aP4K6SA3S()->uint256:
+ return 9
+@external
+def aEB94ZP5S()->uint256:
+ return 10
+@external
+def aTOIMN0IM()->uint256:
+ return 11
+@external
+def aXV2N81OW()->uint256:
+ return 12
+@external
+def a66PP6Y5X()->uint256:
+ return 13
+@external
+def a5MWMTEWN()->uint256:
+ return 14
+@external
+def a5ZFST4Z8()->uint256:
+ return 15
+@external
+def aR13VXULX()->uint256:
+ return 16
+@external
+def aWITH917Y()->uint256:
+ return 17
+@external
+def a59NP6C5O()->uint256:
+ return 18
+@external
+def aJ02590EX()->uint256:
+ return 19
+@external
+def aUAXAAUQ8()->uint256:
+ return 20
+@external
+def aWR1XNC6J()->uint256:
+ return 21
+@external
+def aJABKZOKH()->uint256:
+ return 22
+@external
+def aO1TT0RJT()->uint256:
+ return 23
+@external
+def a41442IOK()->uint256:
+ return 24
+@external
+def aMVXV9FHQ()->uint256:
+ return 25
+@external
+def aNN0KJDZM()->uint256:
+ return 26
+@external
+def aOX965047()->uint256:
+ return 27
+@external
+def a575NX2J3()->uint256:
+ return 28
+@external
+def a16EN8O7W()->uint256:
+ return 29
+@external
+def aSZXLFF7O()->uint256:
+ return 30
+@external
+def aQKQCIPH9()->uint256:
+ return 31
+@external
+def aIP8021DL()->uint256:
+ return 32
+@external
+def aQAV0HSHX()->uint256:
+ return 33
+@external
+def aZVPAD745()->uint256:
+ return 34
+@external
+def aJYBSNST4()->uint256:
+ return 35
+@external
+def aQGWC4NYQ()->uint256:
+ return 36
+@external
+def aFMBB9CXJ()->uint256:
+ return 37
+@external
+def aYWM7ZUH1()->uint256:
+ return 38
+@external
+def aJAZONIX1()->uint256:
+ return 39
+@external
+def aQZ1HJK0H()->uint256:
+ return 40
+@external
+def aKIH9LOUB()->uint256:
+ return 41
+@external
+def aF4ZT80XL()->uint256:
+ return 42
+@external
+def aYQD8UKR5()->uint256:
+ return 43
+@external
+def aP6NCCAI4()->uint256:
+ return 44
+@external
+def aY92U2EAZ()->uint256:
+ return 45
+@external
+def aHMQ49D7P()->uint256:
+ return 46
+@external
+def aMC6YX8VF()->uint256:
+ return 47
+@external
+def a734X6YSI()->uint256:
+ return 48
+@external
+def aRXXPNSMU()->uint256:
+ return 49
+@external
+def aL5XKDTGT()->uint256:
+ return 50
+@external
+def a86V1Y18A()->uint256:
+ return 51
+@external
+def aAUM8PL5J()->uint256:
+ return 52
+@external
+def aBAEC1ERZ()->uint256:
+ return 53
+@external
+def a1U1VA3UE()->uint256:
+ return 54
+@external
+def aC9FGVAHC()->uint256:
+ return 55
+@external
+def aWN81WYJ3()->uint256:
+ return 56
+@external
+def a3KK1Y07J()->uint256:
+ return 57
+@external
+def aAZ6P6OSG()->uint256:
+ return 58
+@external
+def aWP5HCIB3()->uint256:
+ return 59
+@external
+def aVEK161C5()->uint256:
+ return 60
+@external
+def aY0Q3O519()->uint256:
+ return 61
+@external
+def aDHHHFIAE()->uint256:
+ return 62
+@external
+def aGSJBCZKQ()->uint256:
+ return 63
+@external
+def aZQQIUDHY()->uint256:
+ return 64
+@external
+def a12O9QDH5()->uint256:
+ return 65
+@external
+def aRQ1178XR()->uint256:
+ return 66
+@external
+def aDT25C832()->uint256:
+ return 67
+@external
+def aCSB01C4E()->uint256:
+ return 68
+@external
+def aYGBPKZSD()->uint256:
+ return 69
+@external
+def aP24N3EJ8()->uint256:
+ return 70
+@external
+def a531Y9X3C()->uint256:
+ return 71
+@external
+def a4727IKVS()->uint256:
+ return 72
+@external
+def a2EX1L2BS()->uint256:
+ return 73
+@external
+def a6145RN68()->uint256:
+ return 74
+@external
+def aDO1ZNX97()->uint256:
+ return 75
+@external
+def a3R28EU6M()->uint256:
+ return 76
+@external
+def a9BFC867L()->uint256:
+ return 77
+@external
+def aPL1MBGYC()->uint256:
+ return 78
+@external
+def aI6H11O48()->uint256:
+ return 79
+@external
+def aX0248DZY()->uint256:
+ return 80
+@external
+def aE4JBUJN4()->uint256:
+ return 81
+@external
+def aXBDB2ZBO()->uint256:
+ return 82
+@external
+def a7O7MYYHL()->uint256:
+ return 83
+@external
+def aERFF4PB6()->uint256:
+ return 84
+@external
+def aJCUBG6TJ()->uint256:
+ return 85
+@external
+def aQ5ELXM0F()->uint256:
+ return 86
+@external
+def aWDT9UQVV()->uint256:
+ return 87
+@external
+def a7UU40DJK()->uint256:
+ return 88
+@external
+def aH01IT5VS()->uint256:
+ return 89
+@external
+def aSKYTZ0FC()->uint256:
+ return 90
+@external
+def aNX5LYRAW()->uint256:
+ return 91
+@external
+def aUDKAOSGG()->uint256:
+ return 92
+@external
+def aZ86YGAAO()->uint256:
+ return 93
+@external
+def aIHWQGKLO()->uint256:
+ return 94
+@external
+def aKIKFLAR9()->uint256:
+ return 95
+@external
+def aCTPE0KRS()->uint256:
+ return 96
+@external
+def aAD75X00P()->uint256:
+ return 97
+@external
+def aDROUEF2F()->uint256:
+ return 98
+@external
+def a8CDIF6YN()->uint256:
+ return 99
+@external
+def aD2X7TM83()->uint256:
+ return 100
+@external
+def a3W5UUB4L()->uint256:
+ return 101
+@external
+def aG4MOBN4B()->uint256:
+ return 102
+@external
+def aPRS0MSG7()->uint256:
+ return 103
+@external
+def aKN3GHBUR()->uint256:
+ return 104
+@external
+def aGE435RHQ()->uint256:
+ return 105
+@external
+def a4E86BNFE()->uint256:
+ return 106
+@external
+def aYDG928YW()->uint256:
+ return 107
+@external
+def a2HFP5GQE()->uint256:
+ return 108
+@external
+def a5DPMVXKA()->uint256:
+ return 109
+@external
+def a3OFVC3DR()->uint256:
+ return 110
+@external
+def aK8F62DAN()->uint256:
+ return 111
+@external
+def aJS9EY3U6()->uint256:
+ return 112
+@external
+def aWW789JQH()->uint256:
+ return 113
+@external
+def a8AJJN3YR()->uint256:
+ return 114
+@external
+def a4D0MUIDU()->uint256:
+ return 115
+@external
+def a35W41JQR()->uint256:
+ return 116
+@external
+def a07DQOI1E()->uint256:
+ return 117
+@external
+def aFT43YNCT()->uint256:
+ return 118
+@external
+def a0E75I8X3()->uint256:
+ return 119
+@external
+def aT6NXIRO4()->uint256:
+ return 120
+@external
+def aXB2UBAKQ()->uint256:
+ return 121
+@external
+def aHWH55NW6()->uint256:
+ return 122
+@external
+def a7TCFE6C2()->uint256:
+ return 123
+@external
+def a8XYAM81I()->uint256:
+ return 124
+@external
+def aHQTQ4YBY()->uint256:
+ return 125
+@external
+def aGCZEHG6Y()->uint256:
+ return 126
+@external
+def a6LJTKIW0()->uint256:
+ return 127
+@external
+def aBDIXTD9S()->uint256:
+ return 128
+@external
+def aCB83G21P()->uint256:
+ return 129
+@external
+def aZC525N4K()->uint256:
+ return 130
+@external
+def a40LC94U6()->uint256:
+ return 131
+@external
+def a8X9TI93D()->uint256:
+ return 132
+@external
+def aGUG9CD8Y()->uint256:
+ return 133
+@external
+def a0LAERVAY()->uint256:
+ return 134
+@external
+def aXQ0UEX19()->uint256:
+ return 135
+@external
+def aKK9C7NE7()->uint256:
+ return 136
+@external
+def aS2APW8UE()->uint256:
+ return 137
+@external
+def a65NT07MM()->uint256:
+ return 138
+@external
+def aGRMT6ZW5()->uint256:
+ return 139
+@external
+def aILR4U1Z()->uint256:
+ return 140
+ """
+ c = get_contract(code)
+
+ assert c.aX61QLPWF() == 1 # will revert if the header section is misaligned
+
+
+@given(
+ n_methods=st.integers(min_value=1, max_value=100),
+ seed=st.integers(min_value=0, max_value=2**64 - 1),
+)
+@pytest.mark.fuzzing
+@settings(max_examples=10, deadline=None)
+def test_sparse_jumptable_probe_depth(n_methods, seed):
+ sigs = [f"foo{i + seed}()" for i in range(n_methods)]
+ _, buckets = generate_sparse_jumptable_buckets(sigs)
+ bucket_sizes = [len(bucket) for bucket in buckets.values()]
+
+ # generally bucket sizes should be bounded at around 4, but
+ # just test that they don't get really out of hand
+ assert max(bucket_sizes) <= 8
+
+ # generally mean bucket size should be around 1.6, here just
+ # test they don't get really out of hand
+ assert sum(bucket_sizes) / len(bucket_sizes) <= 4
+
+
+@given(
+ n_methods=st.integers(min_value=4, max_value=100),
+ seed=st.integers(min_value=0, max_value=2**64 - 1),
+)
+@pytest.mark.fuzzing
+@settings(max_examples=10, deadline=None)
+def test_dense_jumptable_bucket_size(n_methods, seed):
+ sigs = [f"foo{i + seed}()" for i in range(n_methods)]
+ n = len(sigs)
+ buckets = generate_dense_jumptable_info(sigs)
+ n_buckets = len(buckets)
+
+ # generally should be around 14 buckets per 100 methods, here
+ # we test they don't get really out of hand
+ assert n_buckets / n < 0.4 or n < 10
+
+
+@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
+# dense selector table packing boundaries at 256 and 65336
+@pytest.mark.parametrize("max_calldata_bytes", [255, 256, 65336])
+@settings(max_examples=5, deadline=None)
+@given(
+ seed=st.integers(min_value=0, max_value=2**64 - 1),
+ max_default_args=st.integers(min_value=0, max_value=4),
+ default_fn_mutability=st.sampled_from(["", "@pure", "@view", "@nonpayable", "@payable"]),
+)
+@pytest.mark.fuzzing
+def test_selector_table_fuzz(
+ max_calldata_bytes,
+ seed,
+ max_default_args,
+ opt_level,
+ default_fn_mutability,
+ w3,
+ get_contract,
+ assert_tx_failed,
+ get_logs,
+):
+ def abi_sig(calldata_words, i, n_default_args):
+ args = [] if not calldata_words else [f"uint256[{calldata_words}]"]
+ args.extend(["uint256"] * n_default_args)
+ argstr = ",".join(args)
+ return f"foo{seed + i}({argstr})"
+
+ def generate_func_def(mutability, calldata_words, i, n_default_args):
+ arglist = [] if not calldata_words else [f"x: uint256[{calldata_words}]"]
+ for j in range(n_default_args):
+ arglist.append(f"x{j}: uint256 = 0")
+ args = ", ".join(arglist)
+ _log_return = f"log _Return({i})" if mutability == "@payable" else ""
+
+ return f"""
+@external
+{mutability}
+def foo{seed + i}({args}) -> uint256:
+ {_log_return}
+ return {i}
+ """
+
+ @given(
+ methods=st.lists(
+ st.tuples(
+ st.sampled_from(["@pure", "@view", "@nonpayable", "@payable"]),
+ st.integers(min_value=0, max_value=max_calldata_bytes // 32),
+ # n bytes to strip from calldata
+ st.integers(min_value=1, max_value=4),
+ # n default args
+ st.integers(min_value=0, max_value=max_default_args),
+ ),
+ min_size=1,
+ max_size=100,
+ )
+ )
+ @settings(max_examples=25)
+ def _test(methods):
+ func_defs = "\n".join(
+ generate_func_def(m, s, i, d) for i, (m, s, _, d) in enumerate(methods)
+ )
+
+ if default_fn_mutability == "":
+ default_fn_code = ""
+ elif default_fn_mutability in ("@nonpayable", "@payable"):
+ default_fn_code = f"""
+@external
+{default_fn_mutability}
+def __default__():
+ log CalledDefault()
+ """
+ else:
+ # can't log from pure/view functions, just test that it returns
+ default_fn_code = """
+@external
+def __default__():
+ pass
+ """
+
+ code = f"""
+event CalledDefault:
+ pass
+
+event _Return:
+ val: uint256
+
+{func_defs}
+
+{default_fn_code}
+ """
+
+ c = get_contract(code, override_opt_level=opt_level)
+
+ for i, (mutability, n_calldata_words, n_strip_bytes, n_default_args) in enumerate(methods):
+ funcname = f"foo{seed + i}"
+ func = getattr(c, funcname)
+
+ for j in range(n_default_args + 1):
+ args = [[1] * n_calldata_words] if n_calldata_words else []
+ args.extend([1] * j)
+
+ # check the function returns as expected
+ assert func(*args) == i
+
+ method_id = utils.method_id(abi_sig(n_calldata_words, i, j))
+
+ argsdata = b"\x00" * (n_calldata_words * 32 + j * 32)
+
+ # do payable check
+ if mutability == "@payable":
+ tx = func(*args, transact={"value": 1})
+ (event,) = get_logs(tx, c, "_Return")
+ assert event.args.val == i
+ else:
+ hexstr = (method_id + argsdata).hex()
+ txdata = {"to": c.address, "data": hexstr, "value": 1}
+ assert_tx_failed(lambda: w3.eth.send_transaction(txdata))
+
+ # now do calldatasize check
+ # strip some bytes
+ calldata = (method_id + argsdata)[:-n_strip_bytes]
+ hexstr = calldata.hex()
+ tx_params = {"to": c.address, "data": hexstr}
+ if n_calldata_words == 0 and j == 0:
+ # no args, hit default function
+ if default_fn_mutability == "":
+ assert_tx_failed(lambda: w3.eth.send_transaction(tx_params))
+ elif default_fn_mutability == "@payable":
+ # we should be able to send eth to it
+ tx_params["value"] = 1
+ tx = w3.eth.send_transaction(tx_params)
+ logs = get_logs(tx, c, "CalledDefault")
+ assert len(logs) == 1
+ else:
+ tx = w3.eth.send_transaction(tx_params)
+
+ # note: can't emit logs from view/pure functions,
+ # so the logging is not tested.
+ if default_fn_mutability == "@nonpayable":
+ logs = get_logs(tx, c, "CalledDefault")
+ assert len(logs) == 1
+
+ # check default function reverts
+ tx_params["value"] = 1
+ assert_tx_failed(lambda: w3.eth.send_transaction(tx_params))
+ else:
+ assert_tx_failed(lambda: w3.eth.send_transaction(tx_params))
+
+ _test()
diff --git a/tests/parser/test_selector_table_stability.py b/tests/parser/test_selector_table_stability.py
new file mode 100644
index 0000000000..abc2c17b8f
--- /dev/null
+++ b/tests/parser/test_selector_table_stability.py
@@ -0,0 +1,53 @@
+from vyper.codegen.jumptable_utils import generate_sparse_jumptable_buckets
+from vyper.compiler import compile_code
+from vyper.compiler.settings import OptimizationLevel, Settings
+
+
+def test_dense_jumptable_stability():
+ function_names = [f"foo{i}" for i in range(30)]
+
+ code = "\n".join(f"@external\ndef {name}():\n pass" for name in function_names)
+
+ output = compile_code(code, ["asm"], settings=Settings(optimize=OptimizationLevel.CODESIZE))
+
+ # test that the selector table data is stable across different runs
+ # (tox should provide different PYTHONHASHSEEDs).
+ expected_asm = """{ DATA _sym_BUCKET_HEADERS b'\\x0bB' _sym_bucket_0 b'\\n' b'+\\x8d' _sym_bucket_1 b'\\x0c' b'\\x00\\x85' _sym_bucket_2 b'\\x08' } { DATA _sym_bucket_1 b'\\xd8\\xee\\xa1\\xe8' _sym_external_foo6___3639517672 b'\\x05' b'\\xd2\\x9e\\xe0\\xf9' _sym_external_foo0___3533627641 b'\\x05' b'\\x05\\xf1\\xe0_' _sym_external_foo2___99737695 b'\\x05' b'\\x91\\t\\xb4{' _sym_external_foo23___2433332347 b'\\x05' b'np3\\x7f' _sym_external_foo11___1852846975 b'\\x05' b'&\\xf5\\x96\\xf9' _sym_external_foo13___653629177 b'\\x05' b'\\x04ga\\xeb' _sym_external_foo14___73884139 b'\\x05' b'\\x89\\x06\\xad\\xc6' _sym_external_foo17___2298916294 b'\\x05' b'\\xe4%\\xac\\xd1' _sym_external_foo4___3827674321 b'\\x05' b'yj\\x01\\xac' _sym_external_foo7___2036990380 b'\\x05' b'\\xf1\\xe6K\\xe5' _sym_external_foo29___4058401765 b'\\x05' b'\\xd2\\x89X\\xb8' _sym_external_foo3___3532216504 b'\\x05' } { DATA _sym_bucket_2 b'\\x06p\\xffj' _sym_external_foo25___108068714 b'\\x05' b'\\x964\\x99I' _sym_external_foo24___2520029513 b'\\x05' b's\\x81\\xe7\\xc1' _sym_external_foo10___1937893313 b'\\x05' b'\\x85\\xad\\xc11' _sym_external_foo28___2242756913 b'\\x05' b'\\xfa"\\xb1\\xed' _sym_external_foo5___4196577773 b'\\x05' b'A\\xe7[\\x05' _sym_external_foo22___1105681157 b'\\x05' b'\\xd3\\x89U\\xe8' _sym_external_foo1___3548993000 b'\\x05' b'hL\\xf8\\xf3' _sym_external_foo20___1749874931 b'\\x05' } { DATA _sym_bucket_0 b'\\xee\\xd9\\x1d\\xe3' _sym_external_foo9___4007206371 b'\\x05' b'a\\xbc\\x1ch' _sym_external_foo16___1639717992 b'\\x05' b'\\xd3*\\xa7\\x0c' _sym_external_foo21___3542787852 b'\\x05' b'\\x18iG\\xd9' _sym_external_foo19___409552857 b'\\x05' b'\\n\\xf1\\xf9\\x7f' _sym_external_foo18___183630207 b'\\x05' b')\\xda\\xd7`' _sym_external_foo27___702207840 b'\\x05' b'2\\xf6\\xaa\\xda' _sym_external_foo12___855026394 b'\\x05' b'\\xbe\\xb5\\x05\\xf5' _sym_external_foo15___3199534581 b'\\x05' b'\\xfc\\xa7_\\xe6' _sym_external_foo8___4238827494 b'\\x05' b'\\x1b\\x12C8' _sym_external_foo26___454181688 b'\\x05' } }""" # noqa: E501
+ assert expected_asm in output["asm"]
+
+
+def test_sparse_jumptable_stability():
+ function_names = [f"foo{i}()" for i in range(30)]
+
+ # sparse jumptable is not as complicated in assembly.
+ # here just test the data structure is stable
+
+ n_buckets, buckets = generate_sparse_jumptable_buckets(function_names)
+ assert n_buckets == 33
+
+ # the buckets sorted by id are what go into the IR, check equality against
+ # expected:
+ assert sorted(buckets.items()) == [
+ (0, [4238827494, 1639717992]),
+ (1, [1852846975]),
+ (2, [1749874931]),
+ (3, [4007206371]),
+ (4, [2298916294]),
+ (7, [2036990380]),
+ (10, [3639517672, 73884139]),
+ (12, [3199534581]),
+ (13, [99737695]),
+ (14, [3548993000, 4196577773]),
+ (15, [454181688, 702207840]),
+ (16, [3533627641]),
+ (17, [108068714]),
+ (20, [1105681157]),
+ (21, [409552857, 3542787852]),
+ (22, [4058401765]),
+ (23, [2520029513, 2242756913]),
+ (24, [855026394, 183630207]),
+ (25, [3532216504, 653629177]),
+ (26, [1937893313]),
+ (28, [2433332347]),
+ (31, [3827674321]),
+ ]
diff --git a/tests/parser/types/numbers/test_sqrt.py b/tests/parser/types/numbers/test_sqrt.py
index 025a3868e9..df1ed0539c 100644
--- a/tests/parser/types/numbers/test_sqrt.py
+++ b/tests/parser/types/numbers/test_sqrt.py
@@ -143,8 +143,8 @@ def test_sqrt_bounds(sqrt_contract, value):
min_value=Decimal(0), max_value=Decimal(SizeLimits.MAX_INT128), places=DECIMAL_PLACES
)
)
-@hypothesis.example(Decimal(SizeLimits.MAX_INT128))
-@hypothesis.example(Decimal(0))
+@hypothesis.example(value=Decimal(SizeLimits.MAX_INT128))
+@hypothesis.example(value=Decimal(0))
@hypothesis.settings(deadline=1000)
def test_sqrt_valid_range(sqrt_contract, value):
vyper_sqrt = sqrt_contract.test(value)
@@ -159,8 +159,8 @@ def test_sqrt_valid_range(sqrt_contract, value):
)
)
@hypothesis.settings(deadline=400)
-@hypothesis.example(Decimal(SizeLimits.MIN_INT128))
-@hypothesis.example(Decimal("-1E10"))
+@hypothesis.example(value=Decimal(SizeLimits.MIN_INT128))
+@hypothesis.example(value=Decimal("-1E10"))
def test_sqrt_invalid_range(sqrt_contract, value):
with pytest.raises(TransactionFailed):
sqrt_contract.test(value)
diff --git a/tests/parser/types/numbers/test_unsigned_ints.py b/tests/parser/types/numbers/test_unsigned_ints.py
index 97a4097923..683684e6be 100644
--- a/tests/parser/types/numbers/test_unsigned_ints.py
+++ b/tests/parser/types/numbers/test_unsigned_ints.py
@@ -195,49 +195,6 @@ def foo(x: {typ}, y: {typ}) -> bool:
assert c.foo(x, y) is expected
-# TODO move to tests/parser/functions/test_mulmod.py and test_addmod.py
-def test_uint256_mod(assert_tx_failed, get_contract_with_gas_estimation):
- uint256_code = """
-@external
-def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
- return uint256_addmod(x, y, z)
-
-@external
-def _uint256_mulmod(x: uint256, y: uint256, z: uint256) -> uint256:
- return uint256_mulmod(x, y, z)
- """
-
- c = get_contract_with_gas_estimation(uint256_code)
-
- assert c._uint256_addmod(1, 2, 2) == 1
- assert c._uint256_addmod(32, 2, 32) == 2
- assert c._uint256_addmod((2**256) - 1, 0, 2) == 1
- assert c._uint256_addmod(2**255, 2**255, 6) == 4
- assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0))
- assert c._uint256_mulmod(3, 1, 2) == 1
- assert c._uint256_mulmod(200, 3, 601) == 600
- assert c._uint256_mulmod(2**255, 1, 3) == 2
- assert c._uint256_mulmod(2**255, 2, 6) == 4
- assert_tx_failed(lambda: c._uint256_mulmod(2, 2, 0))
-
-
-def test_uint256_modmul(get_contract_with_gas_estimation):
- modexper = """
-@external
-def exponential(base: uint256, exponent: uint256, modulus: uint256) -> uint256:
- o: uint256 = 1
- for i in range(256):
- o = uint256_mulmod(o, o, modulus)
- if exponent & shift(1, 255 - i) != 0:
- o = uint256_mulmod(o, base, modulus)
- return o
- """
-
- c = get_contract_with_gas_estimation(modexper)
- assert c.exponential(3, 5, 100) == 43
- assert c.exponential(2, 997, 997) == 2
-
-
@pytest.mark.parametrize("typ", types)
def test_uint_literal(get_contract, assert_compile_failed, typ):
lo, hi = typ.ast_bounds
diff --git a/tests/parser/types/test_bytes.py b/tests/parser/types/test_bytes.py
index 28602d61b1..01ec75d5c1 100644
--- a/tests/parser/types/test_bytes.py
+++ b/tests/parser/types/test_bytes.py
@@ -268,9 +268,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]:
y: uint256 = 0
x: uint256 = _value
for _ in range(8):
- y = shift(y, 8)
- y = y + (x & 255)
- x = shift(x, -8)
+ y = (y << 8) | (x & 255)
+ x >>= 8
return slice(convert(y, bytes32), 24, 8)
@external
diff --git a/tests/parser/types/test_bytes_zero_padding.py b/tests/parser/types/test_bytes_zero_padding.py
index 9bc774f12f..ee938fdffb 100644
--- a/tests/parser/types/test_bytes_zero_padding.py
+++ b/tests/parser/types/test_bytes_zero_padding.py
@@ -11,9 +11,8 @@ def to_little_endian_64(_value: uint256) -> Bytes[8]:
y: uint256 = 0
x: uint256 = _value
for _ in range(8):
- y = shift(y, 8)
- y = y + (x & 255)
- x = shift(x, -8)
+ y = (y << 8) | (x & 255)
+ x >>= 8
return slice(convert(y, bytes32), 24, 8)
@external
diff --git a/tests/parser/types/test_dynamic_array.py b/tests/parser/types/test_dynamic_array.py
index 04c0688245..9231d1979f 100644
--- a/tests/parser/types/test_dynamic_array.py
+++ b/tests/parser/types/test_dynamic_array.py
@@ -1543,7 +1543,7 @@ def bar(x: int128) -> DynArray[int128, 3]:
assert c.bar(7) == [7, 14]
-def test_nested_struct_of_lists(get_contract, assert_compile_failed, no_optimize):
+def test_nested_struct_of_lists(get_contract, assert_compile_failed, optimize):
code = """
struct nestedFoo:
a1: DynArray[DynArray[DynArray[uint256, 2], 2], 2]
@@ -1584,14 +1584,9 @@ def bar2() -> uint256:
newFoo.b1[1][0][0].a1[0][1][1] + \\
newFoo.b1[0][1][0].a1[0][0][0]
"""
-
- if no_optimize:
- # fails at assembly stage with too many stack variables
- assert_compile_failed(lambda: get_contract(code), Exception)
- else:
- c = get_contract(code)
- assert c.bar() == [[[3, 7], [7, 3]], [[7, 3], [0, 0]]]
- assert c.bar2() == 0
+ c = get_contract(code)
+ assert c.bar() == [[[3, 7], [7, 3]], [[7, 3], [0, 0]]]
+ assert c.bar2() == 0
def test_tuple_of_lists(get_contract):
@@ -1748,3 +1743,95 @@ def foo(i: uint256) -> {return_type}:
return MY_CONSTANT[i]
"""
assert_compile_failed(lambda: get_contract(code), TypeMismatch)
+
+
+dynarray_length_no_clobber_cases = [
+ # GHSA-3p37-3636-q8wv cases
+ """
+a: DynArray[uint256,3]
+
+@external
+def should_revert() -> DynArray[uint256,3]:
+ self.a = [1,2,3]
+ self.a = empty(DynArray[uint256,3])
+ self.a = [self.a[0], self.a[1], self.a[2]]
+
+ return self.a # if bug: returns [1,2,3]
+ """,
+ """
+@external
+def should_revert() -> DynArray[uint256,3]:
+ self.a()
+ return self.b() # if bug: returns [1,2,3]
+
+@internal
+def a():
+ a: uint256 = 0
+ b: uint256 = 1
+ c: uint256 = 2
+ d: uint256 = 3
+
+@internal
+def b() -> DynArray[uint256,3]:
+ a: DynArray[uint256,3] = empty(DynArray[uint256,3])
+ a = [a[0],a[1],a[2]]
+ return a
+ """,
+ """
+a: DynArray[uint256,4]
+
+@external
+def should_revert() -> DynArray[uint256,4]:
+ self.a = [1,2,3]
+ self.a = empty(DynArray[uint256,4])
+ self.a = [4, self.a[0]]
+
+ return self.a # if bug: return [4, 4]
+ """,
+ """
+@external
+def should_revert() -> DynArray[uint256,4]:
+ a: DynArray[uint256, 4] = [1,2,3]
+ a = []
+
+ a = [a.pop()] # if bug: return [1]
+
+ return a
+ """,
+ """
+@external
+def should_revert():
+ c: DynArray[uint256, 1] = []
+ c.append(c[0])
+ """,
+ """
+@external
+def should_revert():
+ c: DynArray[uint256, 1] = [1]
+ c[0] = c.pop()
+ """,
+ """
+@external
+def should_revert():
+ c: DynArray[DynArray[uint256, 1], 2] = [[]]
+ c[0] = c.pop()
+ """,
+ """
+a: DynArray[String[65],2]
+
+@external
+def should_revert() -> DynArray[String[65], 2]:
+ self.a = ["hello", "world"]
+ self.a = []
+ self.a = [self.a[0], self.a[1]]
+
+ return self.a # if bug: return ["hello", "world"]
+ """,
+]
+
+
+@pytest.mark.parametrize("code", dynarray_length_no_clobber_cases)
+def test_dynarray_length_no_clobber(get_contract, assert_tx_failed, code):
+ # check that length is not clobbered before dynarray data copy happens
+ c = get_contract(code)
+ assert_tx_failed(lambda: c.should_revert())
diff --git a/tests/parser/types/test_identifier_naming.py b/tests/parser/types/test_identifier_naming.py
index cdda221fd1..0a93329848 100755
--- a/tests/parser/types/test_identifier_naming.py
+++ b/tests/parser/types/test_identifier_naming.py
@@ -1,9 +1,9 @@
import pytest
+from vyper.ast.identifiers import RESERVED_KEYWORDS
from vyper.builtins.functions import BUILTIN_FUNCTIONS
from vyper.codegen.expr import ENVIRONMENT_VARIABLES
from vyper.exceptions import NamespaceCollision, StructureException, SyntaxException
-from vyper.semantics.namespace import RESERVED_KEYWORDS
from vyper.semantics.types.primitives import AddressT
ALL_RESERVED_KEYWORDS = BUILTIN_FUNCTIONS | RESERVED_KEYWORDS | ENVIRONMENT_VARIABLES
@@ -41,10 +41,8 @@ def test({constant}: int128):
)
-PYTHON_KEYWORDS = {"if", "for", "while", "pass", "def", "assert", "continue", "raise"}
-
SELF_NAMESPACE_MEMBERS = set(AddressT._type_members.keys())
-DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | PYTHON_KEYWORDS | RESERVED_KEYWORDS
+DISALLOWED_FN_NAMES = SELF_NAMESPACE_MEMBERS | RESERVED_KEYWORDS
ALLOWED_FN_NAMES = ALL_RESERVED_KEYWORDS - DISALLOWED_FN_NAMES
diff --git a/tests/parser/types/test_lists.py b/tests/parser/types/test_lists.py
index 0715eb3870..832b679e5e 100644
--- a/tests/parser/types/test_lists.py
+++ b/tests/parser/types/test_lists.py
@@ -676,6 +676,18 @@ def ix(i: uint256) -> {type}:
assert_tx_failed(lambda: c.ix(len(value) + 1))
+def test_nested_constant_list_accessor(get_contract):
+ code = """
+@external
+def foo() -> bool:
+ f: uint256 = 1
+ a: bool = 1 == [1,2,4][f] + -1
+ return a
+ """
+ c = get_contract(code)
+ assert c.foo() is True
+
+
# Would be nice to put this somewhere accessible, like in vyper.types or something
integer_types = ["uint8", "int128", "int256", "uint256"]
@@ -745,6 +757,23 @@ def ix(i: uint256) -> address:
assert_tx_failed(lambda: c.ix(len(some_good_address) + 1))
+def test_list_index_complex_expr(get_contract, assert_tx_failed):
+ # test subscripts where the index is not a literal
+ code = """
+@external
+def foo(xs: uint256[257], i: uint8) -> uint256:
+ return xs[i + 1]
+ """
+ c = get_contract(code)
+ xs = [i + 1 for i in range(257)]
+
+ for ix in range(255):
+ assert c.foo(xs, ix) == xs[ix + 1]
+
+ # safemath should fail for uint8: 255 + 1.
+ assert_tx_failed(lambda: c.foo(xs, 255))
+
+
@pytest.mark.parametrize(
"type,value",
[
diff --git a/tests/signatures/test_method_id_conflicts.py b/tests/signatures/test_method_id_conflicts.py
index 262348c12a..f3312efeab 100644
--- a/tests/signatures/test_method_id_conflicts.py
+++ b/tests/signatures/test_method_id_conflicts.py
@@ -48,24 +48,11 @@ def OwnerTransferV7b711143(a: uint256):
pass
""",
"""
-# check collision between private method IDs
-@internal
-@view
-def gfah(): pass
-
-@internal
-@view
-def eexo(): pass
- """,
- """
-# check collision between private and public IDs
-@internal
-@view
-def gfah(): pass
+# check collision with ID = 0x00000000
+wycpnbqcyf:public(uint256)
@external
-@view
-def eexo(): pass
+def randallsRevenge_ilxaotc(): pass
""",
]
diff --git a/tox.ini b/tox.ini
index 7fd09ff3f8..c949354dfe 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
[tox]
envlist =
- py{310,311}-{core,no-opt}
+ py{310,311}
lint
mypy
docs
@@ -8,8 +8,7 @@ envlist =
[testenv]
usedevelop = True
commands =
- core: pytest -m "not fuzzing" --showlocals {posargs:tests/}
- no-opt: pytest -m "not fuzzing" --showlocals --no-optimize {posargs:tests/}
+ pytest -m "not fuzzing" --showlocals {posargs:tests/}
basepython =
py310: python3.10
py311: python3.11
@@ -46,7 +45,7 @@ whitelist_externals = make
basepython = python3
extras = lint
commands =
- black -C -t py310 {toxinidir}/vyper {toxinidir}/tests {toxinidir}/setup.py
+ black -C -t py311 {toxinidir}/vyper {toxinidir}/tests {toxinidir}/setup.py
flake8 {toxinidir}/vyper {toxinidir}/tests
isort {toxinidir}/vyper {toxinidir}/tests {toxinidir}/setup.py
diff --git a/vyper/ast/__init__.py b/vyper/ast/__init__.py
index 5695ceab7c..e5b81f1e7f 100644
--- a/vyper/ast/__init__.py
+++ b/vyper/ast/__init__.py
@@ -6,7 +6,7 @@
from . import nodes, validation
from .natspec import parse_natspec
from .nodes import compare_nodes
-from .utils import ast_to_dict, parse_to_ast
+from .utils import ast_to_dict, parse_to_ast, parse_to_ast_with_settings
# adds vyper.ast.nodes classes into the local namespace
for name, obj in (
diff --git a/vyper/ast/expansion.py b/vyper/ast/expansion.py
index c5518be405..5471b971a4 100644
--- a/vyper/ast/expansion.py
+++ b/vyper/ast/expansion.py
@@ -2,6 +2,7 @@
from vyper import ast as vy_ast
from vyper.exceptions import CompilerPanic
+from vyper.semantics.types.function import ContractFunctionT
def expand_annotated_ast(vyper_module: vy_ast.Module) -> None:
@@ -48,7 +49,6 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
# the base return statement is an `Attribute` node, e.g. `self.`
# for each input type we wrap it in a `Subscript` to access a specific member
return_stmt = vy_ast.Attribute(value=vy_ast.Name(id="self"), attr=func_type.name)
- return_stmt._metadata["type"] = node._metadata["type"]
for i, type_ in enumerate(input_types):
if not isinstance(annotation, vy_ast.Subscript):
@@ -85,6 +85,10 @@ def generate_public_variable_getters(vyper_module: vy_ast.Module) -> None:
decorator_list=[vy_ast.Name(id="external"), vy_ast.Name(id="view")],
returns=return_node,
)
+
+ with vyper_module.namespace():
+ func_type = ContractFunctionT.from_FunctionDef(expanded)
+
expanded._metadata["type"] = func_type
return_node.set_parent(expanded)
vyper_module.add_to_body(expanded)
diff --git a/vyper/ast/grammar.lark b/vyper/ast/grammar.lark
index cf15e13a8c..ca9979b2a3 100644
--- a/vyper/ast/grammar.lark
+++ b/vyper/ast/grammar.lark
@@ -72,8 +72,8 @@ function_def: [decorators] function_sig ":" body
_EVENT_DECL: "event"
event_member: NAME ":" type
indexed_event_arg: NAME ":" "indexed" "(" type ")"
-event_body: _NEWLINE _INDENT ((event_member | indexed_event_arg) _NEWLINE)+ _DEDENT
// Events which use no args use a pass statement instead
+event_body: _NEWLINE _INDENT (((event_member | indexed_event_arg ) _NEWLINE)+ | _PASS _NEWLINE) _DEDENT
event_def: _EVENT_DECL NAME ":" ( event_body | _PASS )
// Enums
@@ -174,12 +174,15 @@ loop_variable: NAME [":" NAME]
loop_iterator: _expr
for_stmt: "for" loop_variable "in" loop_iterator ":" body
+// ternary operator
+ternary: _expr "if" _expr "else" _expr
// Expressions
_expr: operation
| dict
+ | ternary
-get_item: variable_access "[" _expr "]"
+get_item: (variable_access | list) "[" _expr "]"
get_attr: variable_access "." NAME
call: variable_access "(" [arguments] ")"
?variable_access: NAME -> get_var
diff --git a/vyper/ast/identifiers.py b/vyper/ast/identifiers.py
new file mode 100644
index 0000000000..985b04e5cd
--- /dev/null
+++ b/vyper/ast/identifiers.py
@@ -0,0 +1,111 @@
+import re
+
+from vyper.exceptions import StructureException
+
+
+def validate_identifier(attr, ast_node=None):
+ if not re.match("^[_a-zA-Z][a-zA-Z0-9_]*$", attr):
+ raise StructureException(f"'{attr}' contains invalid character(s)", ast_node)
+ if attr.lower() in RESERVED_KEYWORDS:
+ raise StructureException(f"'{attr}' is a reserved keyword", ast_node)
+
+
+# https://docs.python.org/3/reference/lexical_analysis.html#keywords
+# note we don't technically need to block all python reserved keywords,
+# but do it for hygiene
+_PYTHON_RESERVED_KEYWORDS = {
+ "False",
+ "None",
+ "True",
+ "and",
+ "as",
+ "assert",
+ "async",
+ "await",
+ "break",
+ "class",
+ "continue",
+ "def",
+ "del",
+ "elif",
+ "else",
+ "except",
+ "finally",
+ "for",
+ "from",
+ "global",
+ "if",
+ "import",
+ "in",
+ "is",
+ "lambda",
+ "nonlocal",
+ "not",
+ "or",
+ "pass",
+ "raise",
+ "return",
+ "try",
+ "while",
+ "with",
+ "yield",
+}
+_PYTHON_RESERVED_KEYWORDS = {s.lower() for s in _PYTHON_RESERVED_KEYWORDS}
+
+# Cannot be used for variable or member naming
+RESERVED_KEYWORDS = _PYTHON_RESERVED_KEYWORDS | {
+ # decorators
+ "public",
+ "external",
+ "nonpayable",
+ "constant",
+ "immutable",
+ "transient",
+ "internal",
+ "payable",
+ "nonreentrant",
+ # "class" keywords
+ "interface",
+ "struct",
+ "event",
+ "enum",
+ # EVM operations
+ "unreachable",
+ # special functions (no name mangling)
+ "init",
+ "_init_",
+ "___init___",
+ "____init____",
+ "default",
+ "_default_",
+ "___default___",
+ "____default____",
+ # more control flow and special operations
+ "range",
+ # more special operations
+ "indexed",
+ # denominations
+ "ether",
+ "wei",
+ "finney",
+ "szabo",
+ "shannon",
+ "lovelace",
+ "ada",
+ "babbage",
+ "gwei",
+ "kwei",
+ "mwei",
+ "twei",
+ "pwei",
+ # sentinal constant values
+ # TODO remove when these are removed from the language
+ "zero_address",
+ "empty_bytes32",
+ "max_int128",
+ "min_int128",
+ "max_decimal",
+ "min_decimal",
+ "max_uint256",
+ "zero_wei",
+}
diff --git a/vyper/ast/metadata.py b/vyper/ast/metadata.py
new file mode 100644
index 0000000000..0a419c3732
--- /dev/null
+++ b/vyper/ast/metadata.py
@@ -0,0 +1,83 @@
+import contextlib
+from typing import Any
+
+from vyper.exceptions import VyperException
+
+
+# a commit/rollback scheme for metadata caching. in the case that an
+# exception is thrown and caught during type checking (currently, only
+# during for loop iterator variable type inference), we can roll back
+# any state updates due to type checking.
+# this is implemented as a stack of changesets, because we need to
+# handle nested rollbacks in the case of nested for loops
+class _NodeMetadataJournal:
+ _NOT_FOUND = object()
+
+ def __init__(self):
+ self._node_updates: list[dict[tuple[int, str, Any], NodeMetadata]] = []
+
+ def register_update(self, metadata, k):
+ KEY = (id(metadata), k)
+ if KEY in self._node_updates[-1]:
+ return
+ prev = metadata.get(k, self._NOT_FOUND)
+ self._node_updates[-1][KEY] = (metadata, prev)
+
+ @contextlib.contextmanager
+ def enter(self):
+ self._node_updates.append({})
+ try:
+ yield
+ except VyperException as e:
+ # note: would be better to only catch typechecker exceptions here.
+ self._rollback_inner()
+ raise e from e
+ else:
+ self._commit_inner()
+
+ def _rollback_inner(self):
+ for (_, k), (metadata, prev) in self._node_updates[-1].items():
+ if prev is self._NOT_FOUND:
+ metadata.pop(k, None)
+ else:
+ metadata[k] = prev
+ self._pop_inner()
+
+ def _commit_inner(self):
+ inner = self._pop_inner()
+
+ if len(self._node_updates) == 0:
+ return
+
+ outer = self._node_updates[-1]
+
+ # register with previous frame in case inner gets commited
+ # but outer needs to be rolled back
+ for (_, k), (metadata, prev) in inner.items():
+ if (id(metadata), k) not in outer:
+ outer[(id(metadata), k)] = (metadata, prev)
+
+ def _pop_inner(self):
+ return self._node_updates.pop()
+
+
+class NodeMetadata(dict):
+ """
+ A data structure which allows for journaling.
+ """
+
+ _JOURNAL: _NodeMetadataJournal = _NodeMetadataJournal()
+
+ def __setitem__(self, k, v):
+ # if we are in a context where we need to journal, add
+ # this to the changeset.
+ if len(self._JOURNAL._node_updates) != 0:
+ self._JOURNAL.register_update(self, k)
+
+ super().__setitem__(k, v)
+
+ @classmethod
+ @contextlib.contextmanager
+ def enter_typechecker_speculation(cls):
+ with cls._JOURNAL.enter():
+ yield
diff --git a/vyper/ast/natspec.py b/vyper/ast/natspec.py
index e6f0fcd00b..c25fc423f8 100644
--- a/vyper/ast/natspec.py
+++ b/vyper/ast/natspec.py
@@ -88,7 +88,7 @@ def _parse_docstring(
tag, value = match.groups()
err_args = (source, *line_no.offset_to_line(start + match.start(1)))
- if tag not in SINGLE_FIELDS + PARAM_FIELDS:
+ if tag not in SINGLE_FIELDS + PARAM_FIELDS and not tag.startswith("custom:"):
raise NatSpecSyntaxException(f"Unknown NatSpec field '@{tag}'", *err_args)
if tag in invalid_fields:
raise NatSpecSyntaxException(
diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py
index cbbd019337..2497928035 100644
--- a/vyper/ast/nodes.py
+++ b/vyper/ast/nodes.py
@@ -1,10 +1,12 @@
import ast as python_ast
+import contextlib
import copy
import decimal
import operator
import sys
from typing import Any, Optional, Union
+from vyper.ast.metadata import NodeMetadata
from vyper.compiler.settings import VYPER_ERROR_CONTEXT_LINES, VYPER_ERROR_LINE_NUMBERS
from vyper.exceptions import (
ArgumentException,
@@ -254,7 +256,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
"""
self.set_parent(parent)
self._children: set = set()
- self._metadata: dict = {}
+ self._metadata: NodeMetadata = NodeMetadata()
for field_name in NODE_SRC_ATTRIBUTES:
# when a source offset is not available, use the parent's source offset
@@ -337,7 +339,7 @@ def __hash__(self):
def __eq__(self, other):
if not isinstance(other, type(self)):
return False
- if other.node_id != self.node_id:
+ if getattr(other, "node_id", None) != getattr(self, "node_id", None):
return False
for field_name in (i for i in self.get_fields() if i not in VyperNode.__slots__):
if getattr(self, field_name, None) != getattr(other, field_name, None):
@@ -663,6 +665,19 @@ def remove_from_body(self, node: VyperNode) -> None:
self.body.remove(node)
self._children.remove(node)
+ @contextlib.contextmanager
+ def namespace(self):
+ from vyper.semantics.namespace import get_namespace, override_global_namespace
+
+ # kludge implementation for backwards compatibility.
+ # TODO: replace with type_from_ast
+ try:
+ ns = self._metadata["namespace"]
+ except AttributeError:
+ ns = get_namespace()
+ with override_global_namespace(ns):
+ yield
+
class FunctionDef(TopLevel):
__slots__ = ("args", "returns", "decorator_list", "pos")
@@ -684,7 +699,7 @@ class DocStr(VyperNode):
class arguments(VyperNode):
__slots__ = ("args", "defaults", "default")
- _only_empty_fields = ("vararg", "kwonlyargs", "kwarg", "kw_defaults")
+ _only_empty_fields = ("posonlyargs", "vararg", "kwonlyargs", "kwarg", "kw_defaults")
class arg(VyperNode):
@@ -960,6 +975,12 @@ def evaluate(self) -> ExprNode:
if not isinstance(left, (Int, Decimal)):
raise UnfoldableNode("Node contains invalid field(s) for evaluation")
+ # this validation is performed to prevent the compiler from hanging
+ # on very large shifts and improve the error message for negative
+ # values.
+ if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256):
+ raise InvalidLiteral("Shift bits must be between 0 and 256", right)
+
value = self.op._op(left.value, right.value)
_validate_numeric_bounds(self, value)
return type(left).from_node(self, value=value)
@@ -1072,6 +1093,20 @@ class BitXor(Operator):
_op = operator.xor
+class LShift(Operator):
+ __slots__ = ()
+ _description = "bitwise left shift"
+ _pretty = "<<"
+ _op = operator.lshift
+
+
+class RShift(Operator):
+ __slots__ = ()
+ _description = "bitwise right shift"
+ _pretty = ">>"
+ _op = operator.rshift
+
+
class BoolOp(ExprNode):
__slots__ = ("op", "values")
@@ -1309,7 +1344,15 @@ class VariableDecl(VyperNode):
If true, indicates that the variable is an immutable variable.
"""
- __slots__ = ("target", "annotation", "value", "is_constant", "is_public", "is_immutable")
+ __slots__ = (
+ "target",
+ "annotation",
+ "value",
+ "is_constant",
+ "is_public",
+ "is_immutable",
+ "is_transient",
+ )
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -1317,6 +1360,7 @@ def __init__(self, *args, **kwargs):
self.is_constant = False
self.is_public = False
self.is_immutable = False
+ self.is_transient = False
def _check_args(annotation, call_name):
# do the same thing as `validate_call_args`
@@ -1334,9 +1378,10 @@ def _check_args(annotation, call_name):
# unwrap one layer
self.annotation = self.annotation.args[0]
- if self.annotation.get("func.id") in ("immutable", "constant"):
- _check_args(self.annotation, self.annotation.func.id)
- setattr(self, f"is_{self.annotation.func.id}", True)
+ func_id = self.annotation.get("func.id")
+ if func_id in ("immutable", "constant", "transient"):
+ _check_args(self.annotation, func_id)
+ setattr(self, f"is_{func_id}", True)
# unwrap one layer
self.annotation = self.annotation.args[0]
@@ -1409,6 +1454,10 @@ class If(Stmt):
__slots__ = ("test", "body", "orelse")
+class IfExp(ExprNode):
+ __slots__ = ("test", "body", "orelse")
+
+
class For(Stmt):
__slots__ = ("iter", "target", "body")
_only_empty_fields = ("orelse",)
diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi
index 942640b6e2..0d59a2fa63 100644
--- a/vyper/ast/nodes.pyi
+++ b/vyper/ast/nodes.pyi
@@ -4,6 +4,7 @@ from typing import Any, Optional, Sequence, Type, Union
from .natspec import parse_natspec as parse_natspec
from .utils import ast_to_dict as ast_to_dict
from .utils import parse_to_ast as parse_to_ast
+from .utils import parse_to_ast_with_settings as parse_to_ast_with_settings
NODE_BASE_ATTRIBUTES: Any
NODE_SRC_ATTRIBUTES: Any
@@ -61,6 +62,7 @@ class Module(TopLevel):
def replace_in_tree(self, old_node: VyperNode, new_node: VyperNode) -> None: ...
def add_to_body(self, node: VyperNode) -> None: ...
def remove_from_body(self, node: VyperNode) -> None: ...
+ def namespace(self) -> Any: ... # context manager
class FunctionDef(TopLevel):
args: arguments = ...
@@ -155,6 +157,8 @@ class Mult(VyperNode): ...
class Div(VyperNode): ...
class Mod(VyperNode): ...
class Pow(VyperNode): ...
+class LShift(VyperNode): ...
+class RShift(VyperNode): ...
class BitAnd(VyperNode): ...
class BitOr(VyperNode): ...
class BitXor(VyperNode): ...
@@ -236,6 +240,11 @@ class If(VyperNode):
body: list = ...
orelse: list = ...
+class IfExp(ExprNode):
+ test: ExprNode = ...
+ body: ExprNode = ...
+ orelse: ExprNode = ...
+
class For(VyperNode): ...
class Break(VyperNode): ...
class Continue(VyperNode): ...
diff --git a/vyper/ast/pre_parser.py b/vyper/ast/pre_parser.py
index f29150a5d3..0ead889787 100644
--- a/vyper/ast/pre_parser.py
+++ b/vyper/ast/pre_parser.py
@@ -1,27 +1,16 @@
import io
import re
from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize
-from typing import Tuple
-from semantic_version import NpmSpec, Version
+from packaging.specifiers import InvalidSpecifier, SpecifierSet
-from vyper.exceptions import SyntaxException, VersionException
-from vyper.typing import ModificationOffsets, ParserPosition
-
-VERSION_ALPHA_RE = re.compile(r"(?<=\d)a(?=\d)") # 0.1.0a17
-VERSION_BETA_RE = re.compile(r"(?<=\d)b(?=\d)") # 0.1.0b17
-VERSION_RC_RE = re.compile(r"(?<=\d)rc(?=\d)") # 0.1.0rc17
-
-
-def _convert_version_str(version_str: str) -> str:
- """
- Convert loose version (0.1.0b17) to strict version (0.1.0-beta.17)
- """
- version_str = re.sub(VERSION_ALPHA_RE, "-alpha.", version_str) # 0.1.0-alpha.17
- version_str = re.sub(VERSION_BETA_RE, "-beta.", version_str) # 0.1.0-beta.17
- version_str = re.sub(VERSION_RC_RE, "-rc.", version_str) # 0.1.0-rc.17
+from vyper.compiler.settings import OptimizationLevel, Settings
- return version_str
+# seems a bit early to be importing this but we want it to validate the
+# evm-version pragma
+from vyper.evm.opcodes import EVM_VERSIONS
+from vyper.exceptions import StructureException, SyntaxException, VersionException
+from vyper.typing import ModificationOffsets, ParserPosition
def validate_version_pragma(version_str: str, start: ParserPosition) -> None:
@@ -30,31 +19,26 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None:
"""
from vyper import __version__
- # NOTE: should be `x.y.z.*`
- installed_version = ".".join(__version__.split(".")[:3])
-
- version_arr = version_str.split("@version")
-
- raw_file_version = version_arr[1].strip()
- strict_file_version = _convert_version_str(raw_file_version)
- strict_compiler_version = Version(_convert_version_str(installed_version))
-
- if len(strict_file_version) == 0:
+ if len(version_str) == 0:
raise VersionException("Version specification cannot be empty", start)
+ # X.Y.Z or vX.Y.Z => ==X.Y.Z, ==vX.Y.Z
+ if re.match("[v0-9]", version_str):
+ version_str = "==" + version_str
+ # convert npm to pep440
+ version_str = re.sub("^\\^", "~=", version_str)
+
try:
- npm_spec = NpmSpec(strict_file_version)
- except ValueError:
+ spec = SpecifierSet(version_str)
+ except InvalidSpecifier:
raise VersionException(
- f'Version specification "{raw_file_version}" is not a valid NPM semantic '
- f"version specification",
- start,
+ f'Version specification "{version_str}" is not a valid PEP440 specifier', start
)
- if not npm_spec.match(strict_compiler_version):
+ if not spec.contains(__version__, prereleases=True):
raise VersionException(
- f'Version specification "{raw_file_version}" is not compatible '
- f'with compiler version "{installed_version}"',
+ f'Version specification "{version_str}" is not compatible '
+ f'with compiler version "{__version__}"',
start,
)
@@ -66,12 +50,12 @@ def validate_version_pragma(version_str: str, start: ParserPosition) -> None:
VYPER_EXPRESSION_TYPES = {"log"}
-def pre_parse(code: str) -> Tuple[ModificationOffsets, str]:
+def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, str]:
"""
Re-formats a vyper source string into a python source string and performs
some validation. More specifically,
- * Translates "interface", "struct" and "event" keywords into python "class" keyword
+ * Translates "interface", "struct", "enum, and "event" keywords into python "class" keyword
* Validates "@version" pragma against current compiler version
* Prevents direct use of python "class" keyword
* Prevents use of python semi-colon statement separator
@@ -93,6 +77,7 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]:
"""
result = []
modification_offsets: ModificationOffsets = {}
+ settings = Settings()
try:
code_bytes = code.encode("utf-8")
@@ -108,8 +93,39 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]:
end = token.end
line = token.line
- if typ == COMMENT and "@version" in string:
- validate_version_pragma(string[1:], start)
+ if typ == COMMENT:
+ contents = string[1:].strip()
+ if contents.startswith("@version"):
+ if settings.compiler_version is not None:
+ raise StructureException("compiler version specified twice!", start)
+ compiler_version = contents.removeprefix("@version ").strip()
+ validate_version_pragma(compiler_version, start)
+ settings.compiler_version = compiler_version
+
+ if contents.startswith("pragma "):
+ pragma = contents.removeprefix("pragma ").strip()
+ if pragma.startswith("version "):
+ if settings.compiler_version is not None:
+ raise StructureException("pragma version specified twice!", start)
+ compiler_version = pragma.removeprefix("version ").strip()
+ validate_version_pragma(compiler_version, start)
+ settings.compiler_version = compiler_version
+
+ if pragma.startswith("optimize "):
+ if settings.optimize is not None:
+ raise StructureException("pragma optimize specified twice!", start)
+ try:
+ mode = pragma.removeprefix("optimize").strip()
+ settings.optimize = OptimizationLevel.from_string(mode)
+ except ValueError:
+ raise StructureException(f"Invalid optimization mode `{mode}`", start)
+ if pragma.startswith("evm-version "):
+ if settings.evm_version is not None:
+ raise StructureException("pragma evm-version specified twice!", start)
+ evm_version = pragma.removeprefix("evm-version").strip()
+ if evm_version not in EVM_VERSIONS:
+ raise StructureException("Invalid evm version: `{evm_version}`", start)
+ settings.evm_version = evm_version
if typ == NAME and string in ("class", "yield"):
raise SyntaxException(
@@ -130,4 +146,4 @@ def pre_parse(code: str) -> Tuple[ModificationOffsets, str]:
except TokenError as e:
raise SyntaxException(e.args[0], code, e.args[1][0], e.args[1][1]) from e
- return modification_offsets, untokenize(result).decode("utf-8")
+ return settings, modification_offsets, untokenize(result).decode("utf-8")
diff --git a/vyper/ast/signatures/__init__.py b/vyper/ast/signatures/__init__.py
deleted file mode 100644
index aa9a5c66c6..0000000000
--- a/vyper/ast/signatures/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .function_signature import FrameInfo, FunctionSignature
diff --git a/vyper/ast/signatures/function_signature.py b/vyper/ast/signatures/function_signature.py
deleted file mode 100644
index 3f59b95f97..0000000000
--- a/vyper/ast/signatures/function_signature.py
+++ /dev/null
@@ -1,194 +0,0 @@
-from dataclasses import dataclass
-from functools import cached_property
-from typing import Dict, Optional, Tuple
-
-from vyper import ast as vy_ast
-from vyper.exceptions import CompilerPanic, StructureException
-from vyper.semantics.types import VyperType
-from vyper.utils import MemoryPositions, mkalphanum
-
-# dict from function names to signatures
-FunctionSignatures = Dict[str, "FunctionSignature"]
-
-
-@dataclass
-class FunctionArg:
- name: str
- typ: VyperType
- ast_source: vy_ast.VyperNode
-
-
-@dataclass
-class FrameInfo:
- frame_start: int
- frame_size: int
- frame_vars: Dict[str, Tuple[int, VyperType]]
-
- @property
- def mem_used(self):
- return self.frame_size + MemoryPositions.RESERVED_MEMORY
-
-
-# Function signature object
-# TODO: merge with ContractFunction type
-class FunctionSignature:
- def __init__(
- self,
- name,
- args,
- return_type,
- mutability,
- internal,
- nonreentrant_key,
- func_ast_code,
- is_from_json,
- ):
- self.name = name
- self.args = args
- self.return_type = return_type
- self.mutability = mutability
- self.internal = internal
- self.gas_estimate = None
- self.nonreentrant_key = nonreentrant_key
- self.func_ast_code = func_ast_code
- self.is_from_json = is_from_json
-
- self.set_default_args()
-
- # frame info is metadata that will be generated during codegen.
- self.frame_info: Optional[FrameInfo] = None
-
- def __str__(self):
- input_name = "def " + self.name + "(" + ",".join([str(arg.typ) for arg in self.args]) + ")"
- if self.return_type:
- return input_name + " -> " + str(self.return_type) + ":"
- return input_name + ":"
-
- def set_frame_info(self, frame_info):
- if self.frame_info is not None:
- raise CompilerPanic("sig.frame_info already set!")
- self.frame_info = frame_info
-
- @cached_property
- def _ir_identifier(self) -> str:
- # we could do a bit better than this but it just needs to be unique
- visibility = "internal" if self.internal else "external"
- argz = ",".join([str(arg.typ) for arg in self.args])
- ret = f"{visibility} {self.name} ({argz})"
- return mkalphanum(ret)
-
- # calculate the abi signature for a given set of kwargs
- def abi_signature_for_kwargs(self, kwargs):
- args = self.base_args + kwargs
- return self.name + "(" + ",".join([arg.typ.abi_type.selector_name() for arg in args]) + ")"
-
- @cached_property
- def base_signature(self):
- return self.abi_signature_for_kwargs([])
-
- @property
- # common entry point for external function with kwargs
- def external_function_base_entry_label(self):
- assert not self.internal
-
- return self._ir_identifier + "_common"
-
- @property
- def internal_function_label(self):
- assert self.internal, "why are you doing this"
-
- return self._ir_identifier
-
- @property
- def exit_sequence_label(self):
- return self._ir_identifier + "_cleanup"
-
- def set_default_args(self):
- """Split base from kwargs and set member data structures"""
-
- args = self.func_ast_code.args
-
- defaults = getattr(args, "defaults", [])
- num_base_args = len(args.args) - len(defaults)
-
- self.base_args = self.args[:num_base_args]
- self.default_args = self.args[num_base_args:]
-
- # Keep all the value to assign to default parameters.
- self.default_values = dict(zip([arg.name for arg in self.default_args], defaults))
-
- # Get a signature from a function definition
- @classmethod
- def from_definition(
- cls,
- func_ast, # vy_ast.FunctionDef
- global_ctx,
- interface_def=False,
- constant_override=False, # CMC 20210907 what does this do?
- is_from_json=False,
- ):
- name = func_ast.name
-
- args = []
- for arg in func_ast.args.args:
- argname = arg.arg
- argtyp = global_ctx.parse_type(arg.annotation)
-
- args.append(FunctionArg(argname, argtyp, arg))
-
- mutability = "nonpayable" # Assume nonpayable by default
- nonreentrant_key = None
- is_internal = None
-
- # Update function properties from decorators
- # NOTE: Can't import enums here because of circular import
- for dec in func_ast.decorator_list:
- if isinstance(dec, vy_ast.Name) and dec.id in ("payable", "view", "pure"):
- mutability = dec.id
- elif isinstance(dec, vy_ast.Name) and dec.id == "internal":
- is_internal = True
- elif isinstance(dec, vy_ast.Name) and dec.id == "external":
- is_internal = False
- elif isinstance(dec, vy_ast.Call) and dec.func.id == "nonreentrant":
- nonreentrant_key = dec.args[0].s
-
- if constant_override:
- # In case this override is abused, match previous behavior
- if mutability == "payable":
- raise StructureException(f"Function {name} cannot be both constant and payable.")
- mutability = "view"
-
- # Determine the return type and whether or not it's constant. Expects something
- # of the form:
- # def foo(): ...
- # def foo() -> int128: ...
- # If there is no return type, ie. it's of the form def foo(): ...
- # and NOT def foo() -> type: ..., then it's null
- return_type = None
- if func_ast.returns:
- return_type = global_ctx.parse_type(func_ast.returns)
- # sanity check: Output type must be canonicalizable
- assert return_type.abi_type.selector_name()
-
- return cls(
- name,
- args,
- return_type,
- mutability,
- is_internal,
- nonreentrant_key,
- func_ast,
- is_from_json,
- )
-
- @property
- def is_default_func(self):
- return self.name == "__default__"
-
- @property
- def is_init_func(self):
- return self.name == "__init__"
-
- @property
- def is_regular_function(self):
- return not self.is_default_func and not self.is_init_func
diff --git a/vyper/ast/utils.py b/vyper/ast/utils.py
index fc8aad227c..4e669385ab 100644
--- a/vyper/ast/utils.py
+++ b/vyper/ast/utils.py
@@ -1,18 +1,23 @@
import ast as python_ast
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
from vyper.ast import nodes as vy_ast
from vyper.ast.annotation import annotate_python_ast
from vyper.ast.pre_parser import pre_parse
+from vyper.compiler.settings import Settings
from vyper.exceptions import CompilerPanic, ParserException, SyntaxException
-def parse_to_ast(
+def parse_to_ast(*args: Any, **kwargs: Any) -> vy_ast.Module:
+ return parse_to_ast_with_settings(*args, **kwargs)[1]
+
+
+def parse_to_ast_with_settings(
source_code: str,
source_id: int = 0,
contract_name: Optional[str] = None,
add_fn_node: Optional[str] = None,
-) -> vy_ast.Module:
+) -> tuple[Settings, vy_ast.Module]:
"""
Parses a Vyper source string and generates basic Vyper AST nodes.
@@ -34,7 +39,7 @@ def parse_to_ast(
"""
if "\x00" in source_code:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
- class_types, reformatted_code = pre_parse(source_code)
+ settings, class_types, reformatted_code = pre_parse(source_code)
try:
py_ast = python_ast.parse(reformatted_code)
except SyntaxError as e:
@@ -51,7 +56,9 @@ def parse_to_ast(
annotate_python_ast(py_ast, source_code, class_types, source_id, contract_name)
# Convert to Vyper AST.
- return vy_ast.get_node(py_ast) # type: ignore
+ module = vy_ast.get_node(py_ast)
+ assert isinstance(module, vy_ast.Module) # mypy hint
+ return settings, module
def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
diff --git a/vyper/ast/validation.py b/vyper/ast/validation.py
index 7742d60c01..36a6a0484c 100644
--- a/vyper/ast/validation.py
+++ b/vyper/ast/validation.py
@@ -48,7 +48,7 @@ def validate_call_args(
arg_count = (arg_count[0], 2**64)
if arg_count[0] == arg_count[1]:
- arg_count == arg_count[0]
+ arg_count = arg_count[0]
if isinstance(node.func, vy_ast.Attribute):
msg = f" for call to '{node.func.attr}'"
diff --git a/vyper/builtins/_convert.py b/vyper/builtins/_convert.py
index 407a32f3e9..e09f5f3174 100644
--- a/vyper/builtins/_convert.py
+++ b/vyper/builtins/_convert.py
@@ -267,7 +267,7 @@ def _literal_decimal(expr, arg_typ, out_typ):
def to_bool(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, 32) # should we restrict to Bytes[1]?
- if isinstance(arg.typ, BytesT):
+ if isinstance(arg.typ, _BytestringT):
# no clamp. checks for any nonzero bytes.
arg = _bytes_to_num(arg, out_typ, signed=False)
@@ -453,13 +453,13 @@ def to_enum(expr, arg, out_typ):
def convert(expr, context):
- if len(expr.args) != 2:
- raise StructureException("The convert function expects two parameters.", expr)
+ assert len(expr.args) == 2, "bad typecheck: convert"
arg_ast = expr.args[0]
arg = Expr(arg_ast, context).ir_node
original_arg = arg
- out_typ = context.parse_type(expr.args[1])
+
+ out_typ = expr.args[1]._metadata["type"].typedef
if arg.typ._is_prim_word:
arg = unwrap_location(arg)
diff --git a/vyper/builtins/_utils.py b/vyper/builtins/_utils.py
index 77185f6e53..afc0987b6d 100644
--- a/vyper/builtins/_utils.py
+++ b/vyper/builtins/_utils.py
@@ -23,7 +23,7 @@ def generate_inline_function(code, variables, variables_2, memory_allocator):
# `ContractFunctionT` type to rely on the annotation visitors in semantics
# module.
ast_code.body[0]._metadata["type"] = ContractFunctionT(
- "sqrt_builtin", {}, 0, 0, None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE
+ "sqrt_builtin", [], [], None, FunctionVisibility.INTERNAL, StateMutability.NONPAYABLE
)
# The FunctionNodeVisitor's constructor performs semantic checks
# annotate the AST as side effects.
diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py
index f81fb20a64..95759372a6 100644
--- a/vyper/builtins/functions.py
+++ b/vyper/builtins/functions.py
@@ -5,7 +5,6 @@
from vyper import ast as vy_ast
from vyper.abi_types import ABI_Tuple
-from vyper.address_space import MEMORY, STORAGE
from vyper.ast.validation import validate_call_args
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
@@ -22,13 +21,14 @@
clamp_basetype,
clamp_nonzero,
copy_bytes,
+ dummy_node_for_type,
ensure_in_memory,
eval_once_check,
eval_seq,
get_bytearray_length,
- get_element_ptr,
get_type_for_exact_size,
ir_tuple_from_args,
+ make_setter,
needs_external_call_wrap,
promote_signed_int,
sar,
@@ -37,8 +37,9 @@
unwrap_location,
)
from vyper.codegen.expr import Expr
-from vyper.codegen.ir_node import Encoding
+from vyper.codegen.ir_node import Encoding, scope_multi
from vyper.codegen.keccak256_helper import keccak256_helper
+from vyper.evm.address_space import MEMORY, STORAGE
from vyper.exceptions import (
ArgumentException,
CompilerPanic,
@@ -148,15 +149,18 @@ def evaluate(self, node):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
- return IRnode.from_list(
- [
- "if",
- ["slt", args[0], 0],
- ["sdiv", ["sub", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR],
- ["sdiv", args[0], DECIMAL_DIVISOR],
- ],
- typ=INT256_T,
- )
+ arg = args[0]
+ with arg.cache_when_complex("arg") as (b1, arg):
+ ret = IRnode.from_list(
+ [
+ "if",
+ ["slt", arg, 0],
+ ["sdiv", ["sub", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR],
+ ["sdiv", arg, DECIMAL_DIVISOR],
+ ],
+ typ=INT256_T,
+ )
+ return b1.resolve(ret)
class Ceil(BuiltinFunction):
@@ -175,15 +179,18 @@ def evaluate(self, node):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
- return IRnode.from_list(
- [
- "if",
- ["slt", args[0], 0],
- ["sdiv", args[0], DECIMAL_DIVISOR],
- ["sdiv", ["add", args[0], DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR],
- ],
- typ=INT256_T,
- )
+ arg = args[0]
+ with arg.cache_when_complex("arg") as (b1, arg):
+ ret = IRnode.from_list(
+ [
+ "if",
+ ["slt", arg, 0],
+ ["sdiv", arg, DECIMAL_DIVISOR],
+ ["sdiv", ["add", arg, DECIMAL_DIVISOR - 1], DECIMAL_DIVISOR],
+ ],
+ typ=INT256_T,
+ )
+ return b1.resolve(ret)
class Convert(BuiltinFunction):
@@ -611,7 +618,7 @@ def infer_arg_types(self, node):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
assert len(args) == 1
- return keccak256_helper(expr, args[0], context)
+ return keccak256_helper(args[0], context)
def _make_sha256_call(inp_start, inp_len, out_start, out_len):
@@ -758,87 +765,68 @@ def infer_arg_types(self, node):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
- placeholder_node = IRnode.from_list(
- context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY
- )
+ input_buf = context.new_internal_variable(get_type_for_exact_size(128))
+ output_buf = context.new_internal_variable(get_type_for_exact_size(32))
return IRnode.from_list(
[
"seq",
- ["mstore", placeholder_node, args[0]],
- ["mstore", ["add", placeholder_node, 32], args[1]],
- ["mstore", ["add", placeholder_node, 64], args[2]],
- ["mstore", ["add", placeholder_node, 96], args[3]],
- [
- "pop",
- [
- "staticcall",
- ["gas"],
- 1,
- placeholder_node,
- 128,
- MemoryPositions.FREE_VAR_SPACE,
- 32,
- ],
- ],
- ["mload", MemoryPositions.FREE_VAR_SPACE],
+ # clear output memory first, ecrecover can return 0 bytes
+ ["mstore", output_buf, 0],
+ ["mstore", input_buf, args[0]],
+ ["mstore", input_buf + 32, args[1]],
+ ["mstore", input_buf + 64, args[2]],
+ ["mstore", input_buf + 96, args[3]],
+ ["staticcall", "gas", 1, input_buf, 128, output_buf, 32],
+ ["mload", output_buf],
],
typ=AddressT(),
)
-def _getelem(arg, ind):
- return unwrap_location(get_element_ptr(arg, IRnode.from_list(ind, typ=INT128_T)))
+class _ECArith(BuiltinFunction):
+ @process_inputs
+ def build_IR(self, expr, _args, kwargs, context):
+ args_tuple = ir_tuple_from_args(_args)
+ args_t = args_tuple.typ
+ input_buf = IRnode.from_list(
+ context.new_internal_variable(args_t), typ=args_t, location=MEMORY
+ )
+ ret_t = self._return_type
-class ECAdd(BuiltinFunction):
- _id = "ecadd"
- _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))]
- _return_type = SArrayT(UINT256_T, 2)
+ ret = ["seq"]
+ ret.append(make_setter(input_buf, args_tuple))
- @process_inputs
- def build_IR(self, expr, args, kwargs, context):
- placeholder_node = IRnode.from_list(
- context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY
- )
- o = IRnode.from_list(
+ output_buf = context.new_internal_variable(ret_t)
+
+ args_ofst = input_buf
+ args_len = args_t.memory_bytes_required
+ out_ofst = output_buf
+ out_len = ret_t.memory_bytes_required
+
+ ret.append(
[
- "seq",
- ["mstore", placeholder_node, _getelem(args[0], 0)],
- ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)],
- ["mstore", ["add", placeholder_node, 64], _getelem(args[1], 0)],
- ["mstore", ["add", placeholder_node, 96], _getelem(args[1], 1)],
- ["assert", ["staticcall", ["gas"], 6, placeholder_node, 128, placeholder_node, 64]],
- placeholder_node,
- ],
- typ=SArrayT(UINT256_T, 2),
- location=MEMORY,
+ "assert",
+ ["staticcall", ["gas"], self._precompile, args_ofst, args_len, out_ofst, out_len],
+ ]
)
- return o
+ ret.append(output_buf)
+
+ return IRnode.from_list(ret, typ=ret_t, location=MEMORY)
+
+
+class ECAdd(_ECArith):
+ _id = "ecadd"
+ _inputs = [("a", SArrayT(UINT256_T, 2)), ("b", SArrayT(UINT256_T, 2))]
+ _return_type = SArrayT(UINT256_T, 2)
+ _precompile = 0x6
-class ECMul(BuiltinFunction):
+class ECMul(_ECArith):
_id = "ecmul"
_inputs = [("point", SArrayT(UINT256_T, 2)), ("scalar", UINT256_T)]
_return_type = SArrayT(UINT256_T, 2)
-
- @process_inputs
- def build_IR(self, expr, args, kwargs, context):
- placeholder_node = IRnode.from_list(
- context.new_internal_variable(BytesT(128)), typ=BytesT(128), location=MEMORY
- )
- o = IRnode.from_list(
- [
- "seq",
- ["mstore", placeholder_node, _getelem(args[0], 0)],
- ["mstore", ["add", placeholder_node, 32], _getelem(args[0], 1)],
- ["mstore", ["add", placeholder_node, 64], args[1]],
- ["assert", ["staticcall", ["gas"], 7, placeholder_node, 96, placeholder_node, 64]],
- placeholder_node,
- ],
- typ=SArrayT(UINT256_T, 2),
- location=MEMORY,
- )
- return o
+ _precompile = 0x7
def _generic_element_getter(op):
@@ -1030,34 +1018,35 @@ def build_IR(self, expr, args, kwargs, context):
value = args[0]
denom_divisor = self.get_denomination(expr)
- if value.typ in (UINT256_T, UINT8_T):
- sub = [
- "with",
- "ans",
- ["mul", value, denom_divisor],
- [
- "seq",
+ with value.cache_when_complex("value") as (b1, value):
+ if value.typ in (UINT256_T, UINT8_T):
+ sub = [
+ "with",
+ "ans",
+ ["mul", value, denom_divisor],
[
- "assert",
- ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]],
+ "seq",
+ [
+ "assert",
+ ["or", ["eq", ["div", "ans", value], denom_divisor], ["iszero", value]],
+ ],
+ "ans",
],
- "ans",
- ],
- ]
- elif value.typ == INT128_T:
- # signed types do not require bounds checks because the
- # largest possible converted value will not overflow 2**256
- sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]]
- elif value.typ == DecimalT():
- sub = [
- "seq",
- ["assert", ["sgt", value, -1]],
- ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR],
- ]
- else:
- raise CompilerPanic(f"Unexpected type: {value.typ}")
+ ]
+ elif value.typ == INT128_T:
+ # signed types do not require bounds checks because the
+ # largest possible converted value will not overflow 2**256
+ sub = ["seq", ["assert", ["sgt", value, -1]], ["mul", value, denom_divisor]]
+ elif value.typ == DecimalT():
+ sub = [
+ "seq",
+ ["assert", ["sgt", value, -1]],
+ ["div", ["mul", value, denom_divisor], DECIMAL_DIVISOR],
+ ]
+ else:
+ raise CompilerPanic(f"Unexpected type: {value.typ}")
- return IRnode.from_list(sub, typ=UINT256_T)
+ return IRnode.from_list(b1.resolve(sub), typ=UINT256_T)
zero_value = IRnode.from_list(0, typ=UINT256_T)
@@ -1086,7 +1075,7 @@ def fetch_call_return(self, node):
revert_on_failure = kwargz.get("revert_on_failure")
revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True
- if outsize is None:
+ if outsize is None or outsize.value == 0:
if revert_on_failure:
return None
return BoolT()
@@ -1167,14 +1156,17 @@ def build_IR(self, expr, args, kwargs, context):
outsize,
]
- if delegate_call:
- call_op = ["delegatecall", gas, to, *common_call_args]
- elif static_call:
- call_op = ["staticcall", gas, to, *common_call_args]
- else:
- call_op = ["call", gas, to, value, *common_call_args]
+ gas, value = IRnode.from_list(gas), IRnode.from_list(value)
+ with scope_multi((to, value, gas), ("_to", "_value", "_gas")) as (b1, (to, value, gas)):
+ if delegate_call:
+ call_op = ["delegatecall", gas, to, *common_call_args]
+ elif static_call:
+ call_op = ["staticcall", gas, to, *common_call_args]
+ else:
+ call_op = ["call", gas, to, value, *common_call_args]
- call_ir += [call_op]
+ call_ir += [call_op]
+ call_ir = b1.resolve(call_ir)
# build sequence IR
if outsize:
@@ -1188,7 +1180,9 @@ def build_IR(self, expr, args, kwargs, context):
if revert_on_failure:
typ = bytes_ty
+ # check the call success flag, and store returndata in memory
ret_ir = ["seq", check_external_call(call_ir), store_output_size]
+ return IRnode.from_list(ret_ir, typ=typ, location=MEMORY)
else:
typ = TupleT([bool_ty, bytes_ty])
ret_ir = [
@@ -1198,16 +1192,22 @@ def build_IR(self, expr, args, kwargs, context):
IRnode.from_list(call_ir, typ=bool_ty),
IRnode.from_list(store_output_size, typ=bytes_ty, location=MEMORY),
]
+ # return an IR tuple of call success flag and returndata pointer
+ return IRnode.from_list(ret_ir, typ=typ)
+
+ # max_outsize is 0.
+
+ if not revert_on_failure:
+ # return call flag as stack item
+ typ = bool_ty
+ return IRnode.from_list(call_ir, typ=typ)
else:
- if revert_on_failure:
- typ = None
- ret_ir = check_external_call(call_ir)
- else:
- typ = bool_ty
- ret_ir = call_ir
+ # check the call success flag and don't return anything
+ ret_ir = check_external_call(call_ir)
+ return IRnode.from_list(ret_ir, typ=None)
- return IRnode.from_list(ret_ir, typ=typ, location=MEMORY)
+ raise CompilerPanic("unreachable!")
class Send(BuiltinFunction):
@@ -1222,7 +1222,9 @@ def build_IR(self, expr, args, kwargs, context):
to, value = args
gas = kwargs["gas"]
context.check_is_not_constant("send ether", expr)
- return IRnode.from_list(["assert", ["call", gas, to, value, 0, 0, 0, 0]])
+ return IRnode.from_list(
+ ["assert", ["call", gas, to, value, 0, 0, 0, 0]], error_msg="send failed"
+ )
class SelfDestruct(BuiltinFunction):
@@ -1230,9 +1232,14 @@ class SelfDestruct(BuiltinFunction):
_inputs = [("to", AddressT())]
_return_type = None
_is_terminus = True
+ _warned = False
@process_inputs
def build_IR(self, expr, args, kwargs, context):
+ if not self._warned:
+ vyper_warn("`selfdestruct` is deprecated! The opcode is no longer recommended for use.")
+ self._warned = True
+
context.check_is_not_constant("selfdestruct", expr)
return IRnode.from_list(
["seq", eval_once_check(_freshname("selfdestruct")), ["selfdestruct", args[0]]]
@@ -1338,7 +1345,7 @@ def evaluate(self, node):
validate_call_args(node, 2)
for arg in node.args:
- if not isinstance(arg, vy_ast.Num):
+ if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
if arg.value < 0 or arg.value >= 2**256:
raise InvalidLiteral("Value out of range for uint256", arg)
@@ -1364,7 +1371,7 @@ def evaluate(self, node):
validate_call_args(node, 2)
for arg in node.args:
- if not isinstance(arg, vy_ast.Num):
+ if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
if arg.value < 0 or arg.value >= 2**256:
raise InvalidLiteral("Value out of range for uint256", arg)
@@ -1390,7 +1397,7 @@ def evaluate(self, node):
validate_call_args(node, 2)
for arg in node.args:
- if not isinstance(arg, vy_ast.Num):
+ if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
if arg.value < 0 or arg.value >= 2**256:
raise InvalidLiteral("Value out of range for uint256", arg)
@@ -1415,7 +1422,7 @@ def evaluate(self, node):
self.__class__._warned = True
validate_call_args(node, 1)
- if not isinstance(node.args[0], vy_ast.Num):
+ if not isinstance(node.args[0], vy_ast.Int):
raise UnfoldableNode
value = node.args[0].value
@@ -1432,12 +1439,17 @@ def build_IR(self, expr, args, kwargs, context):
class Shift(BuiltinFunction):
_id = "shift"
- _inputs = [("x", (UINT256_T, INT256_T)), ("_shift", IntegerT.any())]
+ _inputs = [("x", (UINT256_T, INT256_T)), ("_shift_bits", IntegerT.any())]
_return_type = UINT256_T
+ _warned = False
def evaluate(self, node):
+ if not self.__class__._warned:
+ vyper_warn("`shift()` is deprecated! Please use the << or >> operator instead.")
+ self.__class__._warned = True
+
validate_call_args(node, 2)
- if [i for i in node.args if not isinstance(i, vy_ast.Num)]:
+ if [i for i in node.args if not isinstance(i, vy_ast.Int)]:
raise UnfoldableNode
value, shift = [i.value for i in node.args]
if value < 0 or value >= 2**256:
@@ -1485,10 +1497,10 @@ class _AddMulMod(BuiltinFunction):
def evaluate(self, node):
validate_call_args(node, 3)
- if isinstance(node.args[2], vy_ast.Num) and node.args[2].value == 0:
+ if isinstance(node.args[2], vy_ast.Int) and node.args[2].value == 0:
raise ZeroDivisionException("Modulo by 0", node.args[2])
for arg in node.args:
- if not isinstance(arg, vy_ast.Num):
+ if not isinstance(arg, vy_ast.Int):
raise UnfoldableNode
if arg.value < 0 or arg.value >= 2**256:
raise InvalidLiteral("Value out of range for uint256", arg)
@@ -1498,9 +1510,14 @@ def evaluate(self, node):
@process_inputs
def build_IR(self, expr, args, kwargs, context):
- return IRnode.from_list(
- ["seq", ["assert", args[2]], [self._opcode, args[0], args[1], args[2]]], typ=UINT256_T
- )
+ x, y, z = args
+ with x.cache_when_complex("x") as (b1, x):
+ with y.cache_when_complex("y") as (b2, y):
+ with z.cache_when_complex("z") as (b3, z):
+ ret = IRnode.from_list(
+ ["seq", ["assert", z], [self._opcode, x, y, z]], typ=UINT256_T
+ )
+ return b1.resolve(b2.resolve(b3.resolve(ret)))
class AddMod(_AddMulMod):
@@ -1576,13 +1593,15 @@ def build_IR(self, expr, context):
# CREATE* functions
+CREATE2_SENTINEL = dummy_node_for_type(BYTES32_T)
+
# create helper functions
# generates CREATE op sequence + zero check for result
-def _create_ir(value, buf, length, salt=None, checked=True):
+def _create_ir(value, buf, length, salt, checked=True):
args = [value, buf, length]
create_op = "create"
- if salt is not None:
+ if salt is not CREATE2_SENTINEL:
create_op = "create2"
args.append(salt)
@@ -1593,7 +1612,9 @@ def _create_ir(value, buf, length, salt=None, checked=True):
if not checked:
return ret
- return clamp_nonzero(ret)
+ ret = clamp_nonzero(ret)
+ ret.set_error_msg(f"{create_op} failed")
+ return ret
# calculate the gas used by create for a given number of bytes
@@ -1698,8 +1719,9 @@ def build_IR(self, expr, args, kwargs, context):
context.check_is_not_constant("use {self._id}", expr)
should_use_create2 = "salt" in [kwarg.arg for kwarg in expr.keywords]
+
if not should_use_create2:
- kwargs["salt"] = None
+ kwargs["salt"] = CREATE2_SENTINEL
ir_builder = self._build_create_IR(expr, args, context, **kwargs)
@@ -1779,17 +1801,23 @@ def _add_gas_estimate(self, args, should_use_create2):
def _build_create_IR(self, expr, args, context, value, salt):
target = args[0]
- with target.cache_when_complex("create_target") as (b1, target):
+ # something we can pass to scope_multi
+ with scope_multi(
+ (target, value, salt), ("create_target", "create_value", "create_salt")
+ ) as (b1, (target, value, salt)):
codesize = IRnode.from_list(["extcodesize", target])
msize = IRnode.from_list(["msize"])
- with codesize.cache_when_complex("target_codesize") as (
+ with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as (
b2,
- codesize,
- ), msize.cache_when_complex("mem_ofst") as (b3, mem_ofst):
+ (codesize, mem_ofst),
+ ):
ir = ["seq"]
# make sure there is actually code at the target
- ir.append(["assert", codesize])
+ check_codesize = ["assert", codesize]
+ ir.append(
+ IRnode.from_list(check_codesize, error_msg="empty target (create_copy_of)")
+ )
# store the preamble at msize + 22 (zero padding)
preamble, preamble_len = _create_preamble(codesize)
@@ -1806,7 +1834,7 @@ def _build_create_IR(self, expr, args, context, value, salt):
ir.append(_create_ir(value, buf, buf_len, salt))
- return b1.resolve(b2.resolve(b3.resolve(ir)))
+ return b1.resolve(b2.resolve(ir))
class CreateFromBlueprint(_CreateBase):
@@ -1859,17 +1887,18 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar
# (since the abi encoder could write to fresh memory).
# it would be good to not require the memory copy, but need
# to evaluate memory safety.
- with target.cache_when_complex("create_target") as (b1, target), argslen.cache_when_complex(
- "encoded_args_len"
- ) as (b2, encoded_args_len), code_offset.cache_when_complex("code_ofst") as (b3, codeofst):
- codesize = IRnode.from_list(["sub", ["extcodesize", target], codeofst])
+ with scope_multi(
+ (target, value, salt, argslen, code_offset),
+ ("create_target", "create_value", "create_salt", "encoded_args_len", "code_offset"),
+ ) as (b1, (target, value, salt, encoded_args_len, code_offset)):
+ codesize = IRnode.from_list(["sub", ["extcodesize", target], code_offset])
# copy code to memory starting from msize. we are clobbering
# unused memory so it's safe.
msize = IRnode.from_list(["msize"], location=MEMORY)
- with codesize.cache_when_complex("target_codesize") as (
- b4,
- codesize,
- ), msize.cache_when_complex("mem_ofst") as (b5, mem_ofst):
+ with scope_multi((codesize, msize), ("target_codesize", "mem_ofst")) as (
+ b2,
+ (codesize, mem_ofst),
+ ):
ir = ["seq"]
# make sure there is code at the target, and that
@@ -1879,12 +1908,17 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar
# (code_ofst == (extcodesize target) would be empty
# initcode, which we disallow for hygiene reasons -
# same as `create_copy_of` on an empty target).
- ir.append(["assert", ["sgt", codesize, 0]])
+ check_codesize = ["assert", ["sgt", codesize, 0]]
+ ir.append(
+ IRnode.from_list(
+ check_codesize, error_msg="empty target (create_from_blueprint)"
+ )
+ )
# copy the target code into memory.
# layout starting from mem_ofst:
# 00...00 (22 0's) | preamble | bytecode
- ir.append(["extcodecopy", target, mem_ofst, codeofst, codesize])
+ ir.append(["extcodecopy", target, mem_ofst, code_offset, codesize])
ir.append(copy_bytes(add_ofst(mem_ofst, codesize), argbuf, encoded_args_len, bufsz))
@@ -1899,7 +1933,7 @@ def _build_create_IR(self, expr, args, context, value, salt, code_offset, raw_ar
ir.append(_create_ir(value, mem_ofst, length, salt))
- return b1.resolve(b2.resolve(b3.resolve(b4.resolve(b5.resolve(ir)))))
+ return b1.resolve(b2.resolve(ir))
class _UnsafeMath(BuiltinFunction):
diff --git a/vyper/builtins/interfaces/ERC721.py b/vyper/builtins/interfaces/ERC721.py
index 29ef5f4c26..8dea4e4976 100644
--- a/vyper/builtins/interfaces/ERC721.py
+++ b/vyper/builtins/interfaces/ERC721.py
@@ -2,18 +2,18 @@
# Events
event Transfer:
- _from: address
- _to: address
- _tokenId: uint256
+ _from: indexed(address)
+ _to: indexed(address)
+ _tokenId: indexed(uint256)
event Approval:
- _owner: address
- _approved: address
- _tokenId: uint256
+ _owner: indexed(address)
+ _approved: indexed(address)
+ _tokenId: indexed(uint256)
event ApprovalForAll:
- _owner: address
- _operator: address
+ _owner: indexed(address)
+ _operator: indexed(address)
_approved: bool
# Functions
diff --git a/vyper/cli/vyper_compile.py b/vyper/cli/vyper_compile.py
index 09ddd971ee..bdd01eebbe 100755
--- a/vyper/cli/vyper_compile.py
+++ b/vyper/cli/vyper_compile.py
@@ -5,13 +5,18 @@
import warnings
from collections import OrderedDict
from pathlib import Path
-from typing import Dict, Iterable, Iterator, Set, TypeVar
+from typing import Dict, Iterable, Iterator, Optional, Set, TypeVar
import vyper
import vyper.codegen.ir_node as ir_node
from vyper.cli import vyper_json
from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path
-from vyper.compiler.settings import VYPER_TRACEBACK_LIMIT
+from vyper.compiler.settings import (
+ VYPER_TRACEBACK_LIMIT,
+ OptimizationLevel,
+ Settings,
+ _set_debug_mode,
+)
from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS
from vyper.typing import ContractCodes, ContractPath, OutputFormats
@@ -20,6 +25,7 @@
format_options_help = """Format to print, one or more of:
bytecode (default) - Deployable bytecode
bytecode_runtime - Bytecode at runtime
+blueprint_bytecode - Deployment bytecode for an ERC-5202 compatible blueprint
abi - ABI in JSON format
abi_python - ABI in python format
source_map - Vyper source map
@@ -35,9 +41,9 @@
opcodes_runtime - List of runtime opcodes as a string
ir - Intermediate representation in list format
ir_json - Intermediate representation in JSON format
+ir_runtime - Intermediate representation of runtime bytecode in list format
+asm - Output the EVM assembly of the deployable bytecode
hex-ir - Output IR and assembly constants in hex instead of decimal
-no-optimize - Do not optimize (don't use this for production code)
-no-bytecode-metadata - Do not add metadata to bytecode
"""
combined_json_outputs = [
@@ -100,12 +106,18 @@ def _parse_args(argv):
)
parser.add_argument(
"--evm-version",
- help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION})",
+ help=f"Select desired EVM version (default {DEFAULT_EVM_VERSION}). "
+ "note: cancun support is EXPERIMENTAL",
choices=list(EVM_VERSIONS),
- default=DEFAULT_EVM_VERSION,
dest="evm_version",
)
parser.add_argument("--no-optimize", help="Do not optimize", action="store_true")
+ parser.add_argument(
+ "--optimize",
+ help="Optimization flag (defaults to 'gas')",
+ choices=["gas", "codesize", "none"],
+ )
+ parser.add_argument("--debug", help="Compile in debug mode", action="store_true")
parser.add_argument(
"--no-bytecode-metadata", help="Do not add metadata to bytecode", action="store_true"
)
@@ -151,13 +163,31 @@ def _parse_args(argv):
output_formats = tuple(uniq(args.format.split(",")))
+ if args.debug:
+ _set_debug_mode(True)
+
+ if args.no_optimize and args.optimize:
+ raise ValueError("Cannot use `--no-optimize` and `--optimize` at the same time!")
+
+ settings = Settings()
+
+ if args.no_optimize:
+ settings.optimize = OptimizationLevel.NONE
+ elif args.optimize is not None:
+ settings.optimize = OptimizationLevel.from_string(args.optimize)
+
+ if args.evm_version:
+ settings.evm_version = args.evm_version
+
+ if args.verbose:
+ print(f"cli specified: `{settings}`", file=sys.stderr)
+
compiled = compile_files(
args.input_files,
output_formats,
args.root_folder,
args.show_gas_estimates,
- args.evm_version,
- args.no_optimize,
+ settings,
args.storage_layout,
args.no_bytecode_metadata,
)
@@ -251,9 +281,8 @@ def compile_files(
output_formats: OutputFormats,
root_folder: str = ".",
show_gas_estimates: bool = False,
- evm_version: str = DEFAULT_EVM_VERSION,
- no_optimize: bool = False,
- storage_layout: Iterable[str] = None,
+ settings: Optional[Settings] = None,
+ storage_layout: Optional[Iterable[str]] = None,
no_bytecode_metadata: bool = False,
) -> OrderedDict:
root_path = Path(root_folder).resolve()
@@ -294,8 +323,7 @@ def compile_files(
final_formats,
exc_handler=exc_handler,
interface_codes=get_interface_codes(root_path, contract_sources),
- evm_version=evm_version,
- no_optimize=no_optimize,
+ settings=settings,
storage_layouts=storage_layouts,
show_gas_estimates=show_gas_estimates,
no_bytecode_metadata=no_bytecode_metadata,
diff --git a/vyper/cli/vyper_ir.py b/vyper/cli/vyper_ir.py
index 6831f39473..1f90badcaa 100755
--- a/vyper/cli/vyper_ir.py
+++ b/vyper/cli/vyper_ir.py
@@ -55,7 +55,7 @@ def compile_to_ir(input_file, output_formats, show_gas_estimates=False):
compiler_data["asm"] = asm
if "bytecode" in output_formats:
- (bytecode, _srcmap) = compile_ir.assembly_to_evm(asm)
+ bytecode, _ = compile_ir.assembly_to_evm(asm)
compiler_data["bytecode"] = "0x" + bytecode.hex()
return compiler_data
diff --git a/vyper/cli/vyper_json.py b/vyper/cli/vyper_json.py
index aa6cf1c2f5..4a1c91550e 100755
--- a/vyper/cli/vyper_json.py
+++ b/vyper/cli/vyper_json.py
@@ -5,11 +5,12 @@
import sys
import warnings
from pathlib import Path
-from typing import Any, Callable, Dict, Hashable, List, Tuple, Union
+from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union
import vyper
from vyper.cli.utils import extract_file_interface_imports, get_interface_file_path
-from vyper.evm.opcodes import DEFAULT_EVM_VERSION, EVM_VERSIONS
+from vyper.compiler.settings import OptimizationLevel, Settings
+from vyper.evm.opcodes import EVM_VERSIONS
from vyper.exceptions import JSONError
from vyper.typing import ContractCodes, ContractPath
from vyper.utils import keccak256
@@ -144,13 +145,23 @@ def _standardize_path(path_str: str) -> str:
return path.as_posix()
-def get_evm_version(input_dict: Dict) -> str:
+def get_evm_version(input_dict: Dict) -> Optional[str]:
if "settings" not in input_dict:
- return DEFAULT_EVM_VERSION
-
- evm_version = input_dict["settings"].get("evmVersion", DEFAULT_EVM_VERSION)
- if evm_version in ("homestead", "tangerineWhistle", "spuriousDragon"):
- raise JSONError("Vyper does not support pre-byzantium EVM versions")
+ return None
+
+ # TODO: move this validation somewhere it can be reused more easily
+ evm_version = input_dict["settings"].get("evmVersion")
+ if evm_version is None:
+ return None
+
+ if evm_version in (
+ "homestead",
+ "tangerineWhistle",
+ "spuriousDragon",
+ "byzantium",
+ "constantinople",
+ ):
+ raise JSONError("Vyper does not support pre-istanbul EVM versions")
if evm_version not in EVM_VERSIONS:
raise JSONError(f"Unknown EVM version - '{evm_version}'")
@@ -354,7 +365,21 @@ def compile_from_input_dict(
raise JSONError(f"Invalid language '{input_dict['language']}' - Only Vyper is supported.")
evm_version = get_evm_version(input_dict)
- no_optimize = not input_dict["settings"].get("optimize", True)
+
+ optimize = input_dict["settings"].get("optimize")
+ if isinstance(optimize, bool):
+ # bool optimization level for backwards compatibility
+ warnings.warn(
+ "optimize: is deprecated! please use one of 'gas', 'codesize', 'none'."
+ )
+ optimize = OptimizationLevel.default() if optimize else OptimizationLevel.NONE
+ elif isinstance(optimize, str):
+ optimize = OptimizationLevel.from_string(optimize)
+ else:
+ assert optimize is None
+
+ settings = Settings(evm_version=evm_version, optimize=optimize)
+
no_bytecode_metadata = not input_dict["settings"].get("bytecodeMetadata", True)
contract_sources: ContractCodes = get_input_dict_contracts(input_dict)
@@ -377,8 +402,7 @@ def compile_from_input_dict(
output_formats[contract_path],
interface_codes=interface_codes,
initial_id=id_,
- no_optimize=no_optimize,
- evm_version=evm_version,
+ settings=settings,
no_bytecode_metadata=no_bytecode_metadata,
)
except Exception as exc:
@@ -429,7 +453,7 @@ def format_to_output_dict(compiler_data: Dict) -> Dict:
if "source_map" in data:
evm["sourceMap"] = data["source_map"]["pc_pos_map_compressed"]
if "source_map_full" in data:
- evm["sourceMapFull"] = data["source_map"]
+ evm["sourceMapFull"] = data["source_map_full"]
return output_dict
diff --git a/vyper/codegen/abi_encoder.py b/vyper/codegen/abi_encoder.py
index e76fbf2f64..66d61a9c16 100644
--- a/vyper/codegen/abi_encoder.py
+++ b/vyper/codegen/abi_encoder.py
@@ -1,4 +1,3 @@
-from vyper.address_space import MEMORY
from vyper.codegen.core import (
STORE,
add_ofst,
@@ -9,6 +8,7 @@
zero_pad,
)
from vyper.codegen.ir_node import IRnode
+from vyper.evm.address_space import MEMORY
from vyper.exceptions import CompilerPanic
from vyper.semantics.types import DArrayT, SArrayT, _BytestringT
from vyper.semantics.types.shortcuts import UINT256_T
diff --git a/vyper/codegen/arithmetic.py b/vyper/codegen/arithmetic.py
index eb2df20922..f14069384a 100644
--- a/vyper/codegen/arithmetic.py
+++ b/vyper/codegen/arithmetic.py
@@ -10,7 +10,6 @@
is_numeric_type,
)
from vyper.codegen.ir_node import IRnode
-from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, TypeCheckFailure, UnimplementedException
@@ -243,10 +242,7 @@ def safe_mul(x, y):
# in the above sdiv check, if (r==-1 and l==-2**255),
# -2**255 / -1 will return -2**255.
# need to check: not (r == -1 and l == -2**255)
- if version_check(begin="constantinople"):
- upper_bound = ["shl", 255, 1]
- else:
- upper_bound = -(2**255)
+ upper_bound = ["shl", 255, 1]
check_x = ["ne", x, upper_bound]
check_y = ["ne", ["not", y], 0]
@@ -301,10 +297,7 @@ def safe_div(x, y):
with res.cache_when_complex("res") as (b1, res):
# TODO: refactor this condition / push some things into the optimizer
if typ.is_signed and typ.bits == 256:
- if version_check(begin="constantinople"):
- upper_bound = ["shl", 255, 1]
- else:
- upper_bound = -(2**255)
+ upper_bound = ["shl", 255, 1]
if not x.is_literal and not y.is_literal:
ok = ["or", ["ne", y, ["not", 0]], ["ne", x, upper_bound]]
diff --git a/vyper/codegen/context.py b/vyper/codegen/context.py
index 2efba960de..5b79f293bd 100644
--- a/vyper/codegen/context.py
+++ b/vyper/codegen/context.py
@@ -3,8 +3,8 @@
from dataclasses import dataclass
from typing import Any, Optional
-from vyper.address_space import MEMORY, AddrSpace
from vyper.codegen.ir_node import Encoding
+from vyper.evm.address_space import MEMORY, AddrSpace
from vyper.exceptions import CompilerPanic, StateAccessViolation
from vyper.semantics.types import VyperType
@@ -28,8 +28,12 @@ class VariableRecord:
defined_at: Any = None
is_internal: bool = False
is_immutable: bool = False
+ is_transient: bool = False
data_offset: Optional[int] = None
+ def __hash__(self):
+ return hash(id(self))
+
def __post_init__(self):
if self.blockscopes is None:
self.blockscopes = []
@@ -47,10 +51,10 @@ def __init__(
global_ctx,
memory_allocator,
vars_=None,
- sigs=None,
forvars=None,
constancy=Constancy.Mutable,
- sig=None,
+ func_t=None,
+ is_ctor_context=False,
):
# In-memory variables, in the form (name, memory location, type)
self.vars = vars_ or {}
@@ -58,9 +62,6 @@ def __init__(
# Global variables, in the form (name, storage location, type)
self.globals = global_ctx.variables
- # ABI objects, in the form {classname: ABI JSON}
- self.sigs = sigs or {"self": {}}
-
# Variables defined in for loops, e.g. for i in range(6): ...
self.forvars = forvars or {}
@@ -68,6 +69,7 @@ def __init__(
self.constancy = constancy
# Whether body is currently in an assert statement
+ # XXX: dead, never set to True
self.in_assertion = False
# Whether we are currently parsing a range expression
@@ -76,8 +78,8 @@ def __init__(
# store global context
self.global_ctx = global_ctx
- # full function signature
- self.sig = sig
+ # full function type
+ self.func_t = func_t
# Active scopes
self._scopes = set()
@@ -89,6 +91,9 @@ def __init__(
self._internal_var_iter = 0
self._scope_id_iter = 0
+ # either the constructor, or called from the constructor
+ self.is_ctor_context = is_ctor_context
+
def is_constant(self):
return self.constancy is Constancy.Constant or self.in_assertion or self.in_range_expr
@@ -99,15 +104,15 @@ def check_is_not_constant(self, err, expr):
# convenience propreties
@property
def is_payable(self):
- return self.sig.mutability == "payable"
+ return self.func_t.is_payable
@property
def is_internal(self):
- return self.sig.internal
+ return self.func_t.is_internal
@property
def return_type(self):
- return self.sig.return_type
+ return self.func_t.return_type
#
# Context Managers
@@ -240,40 +245,9 @@ def new_internal_variable(self, typ: VyperType) -> int:
var_size = typ.memory_bytes_required
return self._new_variable(name, typ, var_size, True)
- def parse_type(self, ast_node):
- return self.global_ctx.parse_type(ast_node)
-
def lookup_var(self, varname):
return self.vars[varname]
- def lookup_internal_function(self, method_name, args_ir, ast_source):
- # TODO is this the right module for me?
- """
- Using a list of args, find the internal method to use, and
- the kwargs which need to be filled in by the compiler
- """
-
- sig = self.sigs["self"].get(method_name, None)
-
- def _check(cond, s="Unreachable"):
- if not cond:
- raise CompilerPanic(s)
-
- # these should have been caught during type checking; sanity check
- _check(sig is not None)
- _check(sig.internal)
- _check(len(sig.base_args) <= len(args_ir) <= len(sig.args))
- # more sanity check, that the types match
- # _check(all(l.typ == r.typ for (l, r) in zip(args_ir, sig.args))
-
- num_provided_kwargs = len(args_ir) - len(sig.base_args)
- num_kwargs = len(sig.default_args)
- kwargs_needed = num_kwargs - num_provided_kwargs
-
- kw_vals = list(sig.default_values.values())[:kwargs_needed]
-
- return sig, kw_vals
-
# Pretty print constancy for error messages
def pp_constancy(self):
if self.in_assertion:
diff --git a/vyper/codegen/core.py b/vyper/codegen/core.py
index 1cc601413c..e1d3ea12b4 100644
--- a/vyper/codegen/core.py
+++ b/vyper/codegen/core.py
@@ -1,6 +1,10 @@
+import contextlib
+from typing import Generator
+
from vyper import ast as vy_ast
-from vyper.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE
from vyper.codegen.ir_node import Encoding, IRnode
+from vyper.compiler.settings import OptimizationLevel
+from vyper.evm.address_space import CALLDATA, DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure, TypeMismatch
from vyper.semantics.types import (
@@ -20,13 +24,7 @@
from vyper.semantics.types.shortcuts import BYTES32_T, INT256_T, UINT256_T
from vyper.semantics.types.subscriptable import SArrayT
from vyper.semantics.types.user import EnumT
-from vyper.utils import (
- GAS_CALLDATACOPY_WORD,
- GAS_CODECOPY_WORD,
- GAS_IDENTITY,
- GAS_IDENTITYWORD,
- ceil32,
-)
+from vyper.utils import GAS_COPY_WORD, GAS_IDENTITY, GAS_IDENTITYWORD, ceil32
DYNAMIC_ARRAY_OVERHEAD = 1
@@ -91,12 +89,16 @@ def _identity_gas_bound(num_bytes):
return GAS_IDENTITY + GAS_IDENTITYWORD * (ceil32(num_bytes) // 32)
+def _mcopy_gas_bound(num_bytes):
+ return GAS_COPY_WORD * ceil32(num_bytes) // 32
+
+
def _calldatacopy_gas_bound(num_bytes):
- return GAS_CALLDATACOPY_WORD * ceil32(num_bytes) // 32
+ return GAS_COPY_WORD * ceil32(num_bytes) // 32
def _codecopy_gas_bound(num_bytes):
- return GAS_CODECOPY_WORD * ceil32(num_bytes) // 32
+ return GAS_COPY_WORD * ceil32(num_bytes) // 32
# Copy byte array word-for-word (including layout)
@@ -108,23 +110,33 @@ def make_byte_array_copier(dst, src):
_check_assign_bytes(dst, src)
# TODO: remove this branch, copy_bytes and get_bytearray_length should handle
- if src.value == "~empty":
+ if src.value == "~empty" or src.typ.maxlen == 0:
# set length word to 0.
return STORE(dst, 0)
with src.cache_when_complex("src") as (b1, src):
- with get_bytearray_length(src).cache_when_complex("len") as (b2, len_):
- max_bytes = src.typ.maxlen
+ has_storage = STORAGE in (src.location, dst.location)
+ is_memory_copy = dst.location == src.location == MEMORY
+ batch_uses_identity = is_memory_copy and not version_check(begin="cancun")
+ if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity):
+ # it's cheaper to run two load/stores instead of copy_bytes
ret = ["seq"]
- # store length
+ # store length word
+ len_ = get_bytearray_length(src)
ret.append(STORE(dst, len_))
- dst = bytes_data_ptr(dst)
- src = bytes_data_ptr(src)
+ # store the single data word.
+ dst_data_ptr = bytes_data_ptr(dst)
+ src_data_ptr = bytes_data_ptr(src)
+ ret.append(STORE(dst_data_ptr, LOAD(src_data_ptr)))
+ return b1.resolve(ret)
- ret.append(copy_bytes(dst, src, len_, max_bytes))
- return b1.resolve(b2.resolve(ret))
+ # batch copy the bytearray (including length word) using copy_bytes
+ len_ = add_ofst(get_bytearray_length(src), 32)
+ max_bytes = src.typ.maxlen + 32
+ ret = copy_bytes(dst, src, len_, max_bytes)
+ return b1.resolve(ret)
def bytes_data_ptr(ptr):
@@ -148,25 +160,34 @@ def _dynarray_make_setter(dst, src):
if src.value == "~empty":
return IRnode.from_list(STORE(dst, 0))
+ # copy contents of src dynarray to dst.
+ # note that in case src and dst refer to the same dynarray,
+ # in order for get_element_ptr oob checks on the src dynarray
+ # to work, we need to wait until after the data is copied
+ # before we clobber the length word.
+
if src.value == "multi":
ret = ["seq"]
# handle literals
- # write the length word
- store_length = STORE(dst, len(src.args))
- ann = None
- if src.annotation is not None:
- ann = f"len({src.annotation})"
- store_length = IRnode.from_list(store_length, annotation=ann)
- ret.append(store_length)
-
+ # copy each item
n_items = len(src.args)
+
for i in range(n_items):
k = IRnode.from_list(i, typ=UINT256_T)
dst_i = get_element_ptr(dst, k, array_bounds_check=False)
src_i = get_element_ptr(src, k, array_bounds_check=False)
ret.append(make_setter(dst_i, src_i))
+ # write the length word after data is copied
+ store_length = STORE(dst, n_items)
+ ann = None
+ if src.annotation is not None:
+ ann = f"len({src.annotation})"
+ store_length = IRnode.from_list(store_length, annotation=ann)
+
+ ret.append(store_length)
+
return ret
with src.cache_when_complex("darray_src") as (b1, src):
@@ -190,8 +211,6 @@ def _dynarray_make_setter(dst, src):
with get_dyn_array_count(src).cache_when_complex("darray_count") as (b2, count):
ret = ["seq"]
- ret.append(STORE(dst, count))
-
if should_loop:
i = IRnode.from_list(_freshname("copy_darray_ix"), typ=UINT256_T)
@@ -202,16 +221,17 @@ def _dynarray_make_setter(dst, src):
loop_body.annotation = f"{dst}[i] = {src}[i]"
ret.append(["repeat", i, 0, count, src.typ.count, loop_body])
+ # write the length word after data is copied
+ ret.append(STORE(dst, count))
else:
element_size = src.typ.value_type.memory_bytes_required
- # number of elements * size of element in bytes
- n_bytes = _mul(count, element_size)
- max_bytes = src.typ.count * element_size
+ # number of elements * size of element in bytes + length word
+ n_bytes = add_ofst(_mul(count, element_size), 32)
+ max_bytes = 32 + src.typ.count * element_size
- src_ = dynarray_data_ptr(src)
- dst_ = dynarray_data_ptr(dst)
- ret.append(copy_bytes(dst_, src_, n_bytes, max_bytes))
+ # batch copy the entire dynarray, including length word
+ ret.append(copy_bytes(dst, src, n_bytes, max_bytes))
return b1.resolve(b2.resolve(ret))
@@ -247,7 +267,6 @@ def copy_bytes(dst, src, length, length_bound):
assert src.is_pointer and dst.is_pointer
# fast code for common case where num bytes is small
- # TODO expand this for more cases where num words is less than ~8
if length_bound <= 32:
copy_op = STORE(dst, LOAD(src))
ret = IRnode.from_list(copy_op, annotation=annotation)
@@ -257,8 +276,12 @@ def copy_bytes(dst, src, length, length_bound):
# special cases: batch copy to memory
# TODO: iloadbytes
if src.location == MEMORY:
- copy_op = ["staticcall", "gas", 4, src, length, dst, length]
- gas_bound = _identity_gas_bound(length_bound)
+ if version_check(begin="cancun"):
+ copy_op = ["mcopy", dst, src, length]
+ gas_bound = _mcopy_gas_bound(length_bound)
+ else:
+ copy_op = ["staticcall", "gas", 4, src, length, dst, length]
+ gas_bound = _identity_gas_bound(length_bound)
elif src.location == CALLDATA:
copy_op = ["calldatacopy", dst, src, length]
gas_bound = _calldatacopy_gas_bound(length_bound)
@@ -336,12 +359,14 @@ def append_dyn_array(darray_node, elem_node):
with len_.cache_when_complex("old_darray_len") as (b2, len_):
assertion = ["assert", ["lt", len_, darray_node.typ.count]]
ret.append(IRnode.from_list(assertion, error_msg=f"{darray_node.typ} bounds check"))
- ret.append(STORE(darray_node, ["add", len_, 1]))
# NOTE: typechecks elem_node
# NOTE skip array bounds check bc we already asserted len two lines up
ret.append(
make_setter(get_element_ptr(darray_node, len_, array_bounds_check=False), elem_node)
)
+
+ # store new length
+ ret.append(STORE(darray_node, ["add", len_, 1]))
return IRnode.from_list(b1.resolve(b2.resolve(ret)))
@@ -354,6 +379,7 @@ def pop_dyn_array(darray_node, return_popped_item):
new_len = IRnode.from_list(["sub", old_len, 1], typ=UINT256_T)
with new_len.cache_when_complex("new_len") as (b2, new_len):
+ # store new length
ret.append(STORE(darray_node, new_len))
# NOTE skip array bounds check bc we already asserted len two lines up
@@ -364,6 +390,7 @@ def pop_dyn_array(darray_node, return_popped_item):
location = popped_item.location
else:
typ, location = None, None
+
return IRnode.from_list(b1.resolve(b2.resolve(ret)), typ=typ, location=location)
@@ -512,6 +539,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check):
# an array index, and the clamp will throw an error.
# NOTE: there are optimization rules for this when ix or bound is literal
ix = clamp("lt", ix, bound)
+ ix.set_error_msg(f"{parent.typ} bounds check")
if parent.encoding == Encoding.ABI:
if parent.location == STORAGE:
@@ -546,10 +574,10 @@ def _get_element_ptr_mapping(parent, key):
key = unwrap_location(key)
# TODO when is key None?
- if key is None or parent.location != STORAGE:
- raise TypeCheckFailure(f"bad dereference on mapping {parent}[{key}]")
+ if key is None or parent.location not in (STORAGE, TRANSIENT):
+ raise TypeCheckFailure("bad dereference on mapping {parent}[{key}]")
- return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=STORAGE)
+ return IRnode.from_list(["sha3_64", parent, key], typ=subtype, location=parent.location)
# Take a value representing a memory or storage location, and descend down to
@@ -861,6 +889,38 @@ def make_setter(left, right):
return _complex_make_setter(left, right)
+_opt_level = OptimizationLevel.GAS
+
+
+@contextlib.contextmanager
+def anchor_opt_level(new_level: OptimizationLevel) -> Generator:
+ """
+ Set the global optimization level variable for the duration of this
+ context manager.
+ """
+ assert isinstance(new_level, OptimizationLevel)
+
+ global _opt_level
+ try:
+ tmp = _opt_level
+ _opt_level = new_level
+ yield
+ finally:
+ _opt_level = tmp
+
+
+def _opt_codesize():
+ return _opt_level == OptimizationLevel.CODESIZE
+
+
+def _opt_gas():
+ return _opt_level == OptimizationLevel.GAS
+
+
+def _opt_none():
+ return _opt_level == OptimizationLevel.NONE
+
+
def _complex_make_setter(left, right):
if right.value == "~empty" and left.location == MEMORY:
# optimized memzero
@@ -876,11 +936,69 @@ def _complex_make_setter(left, right):
assert is_tuple_like(left.typ)
keys = left.typ.tuple_keys()
- # if len(keyz) == 0:
- # return IRnode.from_list(["pass"])
+ if left.is_pointer and right.is_pointer and right.encoding == Encoding.VYPER:
+ # both left and right are pointers, see if we want to batch copy
+ # instead of unrolling the loop.
+ assert left.encoding == Encoding.VYPER
+ len_ = left.typ.memory_bytes_required
+
+ has_storage = STORAGE in (left.location, right.location)
+ if has_storage:
+ if _opt_codesize():
+ # assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes,
+ # sstore(add (dst ofst), (sload (add (src ofst)))) is 16 bytes,
+ # whereas loop overhead is 16-17 bytes.
+ base_cost = 3
+ if left._optimized.is_literal:
+ # code size is smaller since add is performed at compile-time
+ base_cost += 1
+ if right._optimized.is_literal:
+ base_cost += 1
+ # the formula is a heuristic, but it works.
+ # (CMC 2023-07-14 could get more detailed for PUSH1 vs
+ # PUSH2 etc but not worried about that too much now,
+ # it's probably better to add a proper unroll rule in the
+ # optimizer.)
+ should_batch_copy = len_ >= 32 * base_cost
+ elif _opt_gas():
+ # kind of arbitrary, but cut off when code used > ~160 bytes
+ should_batch_copy = len_ >= 32 * 10
+ else:
+ assert _opt_none()
+ # don't care, just generate the most readable version
+ should_batch_copy = True
+ else:
+ # find a cutoff for memory copy where identity is cheaper
+ # than unrolled mloads/mstores
+ # if MCOPY is available, mcopy is *always* better (except in
+ # the 1 word case, but that is already handled by copy_bytes).
+ if right.location == MEMORY and _opt_gas() and not version_check(begin="cancun"):
+ # cost for 0th word - (mstore dst (mload src))
+ base_unroll_cost = 12
+ nth_word_cost = base_unroll_cost
+ if not left._optimized.is_literal:
+ # (mstore (add N dst) (mload src))
+ nth_word_cost += 6
+ if not right._optimized.is_literal:
+ # (mstore dst (mload (add N src)))
+ nth_word_cost += 6
+
+ identity_base_cost = 115 # staticcall 4 gas dst len src len
+
+ n_words = ceil32(len_) // 32
+ should_batch_copy = (
+ base_unroll_cost + (nth_word_cost * (n_words - 1)) >= identity_base_cost
+ )
+
+ # calldata to memory, code to memory, cancun, or codesize -
+ # batch copy is always better.
+ else:
+ should_batch_copy = True
- # general case
- # TODO use copy_bytes when the generated code is above a certain size
+ if should_batch_copy:
+ return copy_bytes(left, right, len_, len_)
+
+ # general case, unroll
with left.cache_when_complex("_L") as (b1, left), right.cache_when_complex("_R") as (b2, right):
for k in keys:
l_i = get_element_ptr(left, k, array_bounds_check=False)
@@ -915,7 +1033,6 @@ def eval_seq(ir_node):
return None
-# TODO move return checks to vyper/semantics/validation
def is_return_from_function(node):
if isinstance(node, vy_ast.Expr) and node.get("value.func.id") in (
"raw_revert",
@@ -927,6 +1044,8 @@ def is_return_from_function(node):
return False
+# TODO this is almost certainly duplicated with check_terminus_node
+# in vyper/semantics/analysis/local.py
def check_single_exit(fn_node):
_check_return_body(fn_node, fn_node.body)
for node in fn_node.get_descendants(vy_ast.If):
@@ -981,23 +1100,16 @@ def zero_pad(bytez_placeholder):
# convenience rewrites for shr/sar/shl
def shr(bits, x):
- if version_check(begin="constantinople"):
- return ["shr", bits, x]
- return ["div", x, ["exp", 2, bits]]
+ return ["shr", bits, x]
# convenience rewrites for shr/sar/shl
def shl(bits, x):
- if version_check(begin="constantinople"):
- return ["shl", bits, x]
- return ["mul", x, ["exp", 2, bits]]
+ return ["shl", bits, x]
def sar(bits, x):
- if version_check(begin="constantinople"):
- return ["sar", bits, x]
-
- raise NotImplementedError("no SAR emulation for pre-constantinople EVM")
+ return ["sar", bits, x]
def clamp_bytestring(ir_node):
diff --git a/vyper/codegen/events.py b/vyper/codegen/events.py
index 9508a869ea..30a1b1e591 100644
--- a/vyper/codegen/events.py
+++ b/vyper/codegen/events.py
@@ -15,7 +15,7 @@ def _encode_log_topics(expr, event_id, arg_nodes, context):
value = unwrap_location(arg)
elif isinstance(arg.typ, _BytestringT):
- value = keccak256_helper(expr, arg, context=context)
+ value = keccak256_helper(arg, context=context)
else:
# TODO block at higher level
raise TypeMismatch("Event indexes may only be value types", expr)
diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py
index 33c400941e..dc0e98786f 100644
--- a/vyper/codegen/expr.py
+++ b/vyper/codegen/expr.py
@@ -3,7 +3,6 @@
import vyper.codegen.arithmetic as arithmetic
from vyper import ast as vy_ast
-from vyper.address_space import DATA, IMMUTABLES, MEMORY, STORAGE
from vyper.codegen import external_call, self_call
from vyper.codegen.core import (
clamp,
@@ -17,10 +16,14 @@
is_numeric_type,
is_tuple_like,
pop_dyn_array,
+ sar,
+ shl,
+ shr,
unwrap_location,
)
from vyper.codegen.ir_node import IRnode
from vyper.codegen.keccak256_helper import keccak256_helper
+from vyper.evm.address_space import DATA, IMMUTABLES, MEMORY, STORAGE, TRANSIENT
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CompilerPanic,
@@ -72,11 +75,11 @@ def __init__(self, node, context):
fn = getattr(self, f"parse_{type(node).__name__}", None)
if fn is None:
- raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}")
+ raise TypeCheckFailure(f"Invalid statement node: {type(node).__name__}", node)
self.ir_node = fn()
if self.ir_node is None:
- raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR. {self.expr}")
+ raise TypeCheckFailure(f"{type(node).__name__} node did not produce IR.", node)
self.ir_node.annotation = self.expr.get("node_source_code")
self.ir_node.source_pos = getpos(self.expr)
@@ -163,7 +166,7 @@ def parse_Name(self):
return IRnode.from_list(["address"], typ=AddressT())
elif self.expr.id in self.context.vars:
var = self.context.vars[self.expr.id]
- return IRnode.from_list(
+ ret = IRnode.from_list(
var.pos,
typ=var.typ,
location=var.location, # either 'memory' or 'calldata' storage is handled above.
@@ -171,6 +174,8 @@ def parse_Name(self):
annotation=self.expr.id,
mutable=var.mutable,
)
+ ret._referenced_variables = {var}
+ return ret
# TODO: use self.expr._expr_info
elif self.expr.id in self.context.globals:
@@ -179,16 +184,18 @@ def parse_Name(self):
ofst = varinfo.position.offset
- if self.context.sig.is_init_func:
+ if self.context.is_ctor_context:
mutable = True
location = IMMUTABLES
else:
mutable = False
location = DATA
- return IRnode.from_list(
+ ret = IRnode.from_list(
ofst, typ=varinfo.typ, location=location, annotation=self.expr.id, mutable=mutable
)
+ ret._referenced_variables = {varinfo}
+ return ret
# x.y or x[5]
def parse_Attribute(self):
@@ -235,10 +242,6 @@ def parse_Attribute(self):
# x.codehash: keccak of address x
elif self.expr.attr == "codehash":
addr = Expr.parse_value_expr(self.expr.value, self.context)
- if not version_check(begin="constantinople"):
- raise EvmVersionException(
- "address.codehash is unavailable prior to constantinople ruleset", self.expr
- )
if addr.typ == AddressT():
return IRnode.from_list(["extcodehash", addr], typ=BYTES32_T)
# x.code: codecopy/extcodecopy of address x
@@ -252,12 +255,18 @@ def parse_Attribute(self):
# self.x: global attribute
elif isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id == "self":
varinfo = self.context.globals[self.expr.attr]
- return IRnode.from_list(
+ location = TRANSIENT if varinfo.is_transient else STORAGE
+
+ ret = IRnode.from_list(
varinfo.position.position,
typ=varinfo.typ,
- location=STORAGE,
+ location=location,
annotation="self." + self.expr.attr,
)
+ ret._referenced_variables = {varinfo}
+
+ return ret
+
# Reserved keywords
elif (
isinstance(self.expr.value, vy_ast.Name) and self.expr.value.id in ENVIRONMENT_VARIABLES
@@ -317,6 +326,9 @@ def parse_Attribute(self):
sub = Expr(self.expr.value, self.context).ir_node
# contract type
if isinstance(sub.typ, InterfaceT):
+ # MyInterface.address
+ assert self.expr.attr == "address"
+ sub.typ = typ
return sub
if isinstance(sub.typ, StructT) and self.expr.attr in sub.typ.member_types:
return get_element_ptr(sub, self.expr.attr)
@@ -332,11 +344,10 @@ def parse_Subscript(self):
if isinstance(sub.typ, HashMapT):
# TODO sanity check we are in a self.my_map[i] situation
- index = Expr.parse_value_expr(self.expr.slice.value, self.context)
- if isinstance(index.typ, BytesT):
+ index = Expr(self.expr.slice.value, self.context).ir_node
+ if isinstance(index.typ, _BytestringT):
# we have to hash the key to get a storage location
- assert len(index.args) == 1
- index = keccak256_helper(self.expr.slice.value, index.args[0], self.context)
+ index = keccak256_helper(index, self.context)
elif is_array_like(sub.typ):
index = Expr.parse_value_expr(self.expr.slice.value, self.context)
@@ -357,10 +368,17 @@ def parse_BinOp(self):
left = Expr.parse_value_expr(self.expr.left, self.context)
right = Expr.parse_value_expr(self.expr.right, self.context)
- # Sanity check - ensure that we aren't dealing with different types
- # This should be unreachable due to the type check pass
- assert left.typ == right.typ, f"unreachable, {left.typ}!={right.typ}"
- assert is_numeric_type(left.typ) or is_enum_type(left.typ)
+ is_shift_op = isinstance(self.expr.op, (vy_ast.LShift, vy_ast.RShift))
+
+ if is_shift_op:
+ assert is_numeric_type(left.typ)
+ assert is_numeric_type(right.typ)
+ else:
+ # Sanity check - ensure that we aren't dealing with different types
+ # This should be unreachable due to the type check pass
+ if left.typ != right.typ:
+ raise TypeCheckFailure(f"unreachable, {left.typ} != {right.typ}", self.expr)
+ assert is_numeric_type(left.typ) or is_enum_type(left.typ)
out_typ = left.typ
@@ -371,6 +389,20 @@ def parse_BinOp(self):
if isinstance(self.expr.op, vy_ast.BitXor):
return IRnode.from_list(["xor", left, right], typ=out_typ)
+ if isinstance(self.expr.op, vy_ast.LShift):
+ new_typ = left.typ
+ if new_typ.bits != 256:
+ # TODO implement me. ["and", 2**bits - 1, shl(right, left)]
+ return
+ return IRnode.from_list(shl(right, left), typ=new_typ)
+ if isinstance(self.expr.op, vy_ast.RShift):
+ new_typ = left.typ
+ if new_typ.bits != 256:
+ # TODO implement me. promote_signed_int(op(right, left), bits)
+ return
+ op = shr if not left.typ.is_signed else sar
+ return IRnode.from_list(op(right, left), typ=new_typ)
+
# enums can only do bit ops, not arithmetic.
assert is_numeric_type(left.typ)
@@ -505,8 +537,8 @@ def parse_Compare(self):
left = Expr(self.expr.left, self.context).ir_node
right = Expr(self.expr.right, self.context).ir_node
- left_keccak = keccak256_helper(self.expr, left, self.context)
- right_keccak = keccak256_helper(self.expr, right, self.context)
+ left_keccak = keccak256_helper(left, self.context)
+ right_keccak = keccak256_helper(right, self.context)
if op not in ("eq", "ne"):
return # raises
@@ -634,7 +666,7 @@ def parse_Call(self):
elif isinstance(self.expr._metadata["type"], StructT):
args = self.expr.args
if len(args) == 1 and isinstance(args[0], vy_ast.Dict):
- return Expr.struct_literals(args[0], function_name, self.context)
+ return Expr.struct_literals(args[0], self.context, self.expr._metadata["type"])
# Interface assignment. Bar().
elif isinstance(self.expr._metadata["type"], InterfaceT):
@@ -679,8 +711,31 @@ def parse_Tuple(self):
multi_ir = IRnode.from_list(["multi"] + tuple_elements, typ=typ)
return multi_ir
+ def parse_IfExp(self):
+ test = Expr.parse_value_expr(self.expr.test, self.context)
+ assert test.typ == BoolT() # sanity check
+
+ body = Expr(self.expr.body, self.context).ir_node
+ orelse = Expr(self.expr.orelse, self.context).ir_node
+
+ # if they are in the same location, we can skip copying
+ # into memory. also for the case where either body or orelse are
+ # literal `multi` values (ex. for tuple or arrays), copy to
+ # memory (to avoid crashing in make_setter, XXX fixme).
+ if body.location != orelse.location or body.value == "multi":
+ body = ensure_in_memory(body, self.context)
+ orelse = ensure_in_memory(orelse, self.context)
+
+ assert body.location == orelse.location
+ # check this once compare_type has no side effects:
+ # assert body.typ.compare_type(orelse.typ)
+
+ typ = self.expr._metadata["type"]
+ location = body.location
+ return IRnode.from_list(["if", test, body, orelse], typ=typ, location=location)
+
@staticmethod
- def struct_literals(expr, name, context):
+ def struct_literals(expr, context, typ):
member_subs = {}
member_typs = {}
for key, value in zip(expr.keys, expr.values):
@@ -691,10 +746,8 @@ def struct_literals(expr, name, context):
member_subs[key.id] = sub
member_typs[key.id] = sub.typ
- # TODO: get struct type from context.global_ctx.parse_type(name)
return IRnode.from_list(
- ["multi"] + [member_subs[key] for key in member_subs.keys()],
- typ=StructT(name, member_typs),
+ ["multi"] + [member_subs[key] for key in member_subs.keys()], typ=typ
)
# Parse an expression that results in a value
diff --git a/vyper/codegen/external_call.py b/vyper/codegen/external_call.py
index f99723f16d..ba89f3cace 100644
--- a/vyper/codegen/external_call.py
+++ b/vyper/codegen/external_call.py
@@ -1,7 +1,6 @@
from dataclasses import dataclass
import vyper.utils as util
-from vyper.address_space import MEMORY
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
_freshname,
@@ -17,6 +16,7 @@
wrap_value_for_external_return,
)
from vyper.codegen.ir_node import Encoding, IRnode
+from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import InterfaceT, TupleT
from vyper.semantics.types.function import StateMutability
@@ -37,7 +37,7 @@ def _pack_arguments(fn_type, args, context):
args_abi_t = args_tuple_t.abi_type
# sanity typecheck - make sure the arguments can be assigned
- dst_tuple_t = TupleT(list(fn_type.arguments.values())[: len(args)])
+ dst_tuple_t = TupleT(fn_type.argument_types[: len(args)])
check_assign(dummy_node_for_type(dst_tuple_t), args_as_tuple)
if fn_type.return_type is not None:
@@ -63,8 +63,7 @@ def _pack_arguments(fn_type, args, context):
# 32 bytes | args
# 0x..00 | args
# the reason for the left padding is just so the alignment is easier.
- # if we were only targeting constantinople, we could align
- # to buf (and also keep code size small) by using
+ # XXX: we could align to buf (and also keep code size small) by using
# (mstore buf (shl signature.method_id 224))
pack_args = ["seq"]
pack_args.append(["mstore", buf, util.method_id_int(abi_signature)])
@@ -171,12 +170,10 @@ def _extcodesize_check(address):
def _external_call_helper(contract_address, args_ir, call_kwargs, call_expr, context):
- # expr.func._metadata["type"].return_type is more accurate
- # than fn_sig.return_type in the case of JSON interfaces.
fn_type = call_expr.func._metadata["type"]
# sanity check
- assert fn_type.min_arg_count <= len(args_ir) <= fn_type.max_arg_count
+ assert fn_type.n_positional_args <= len(args_ir) <= fn_type.n_total_args
ret = ["seq"]
diff --git a/vyper/codegen/function_definitions/__init__.py b/vyper/codegen/function_definitions/__init__.py
index 08bebbb4a5..94617bef35 100644
--- a/vyper/codegen/function_definitions/__init__.py
+++ b/vyper/codegen/function_definitions/__init__.py
@@ -1 +1 @@
-from .common import generate_ir_for_function # noqa
+from .common import FuncIR, generate_ir_for_function # noqa
diff --git a/vyper/codegen/function_definitions/common.py b/vyper/codegen/function_definitions/common.py
index 6dece865fa..3fd5ce0b29 100644
--- a/vyper/codegen/function_definitions/common.py
+++ b/vyper/codegen/function_definitions/common.py
@@ -1,24 +1,94 @@
-# can't use from [module] import [object] because it breaks mocks in testing
-from typing import Dict
+from dataclasses import dataclass
+from functools import cached_property
+from typing import Optional
import vyper.ast as vy_ast
-from vyper.ast.signatures import FrameInfo, FunctionSignature
from vyper.codegen.context import Constancy, Context
-from vyper.codegen.core import check_single_exit, getpos
+from vyper.codegen.core import check_single_exit
from vyper.codegen.function_definitions.external_function import generate_ir_for_external_function
from vyper.codegen.function_definitions.internal_function import generate_ir_for_internal_function
from vyper.codegen.global_context import GlobalContext
from vyper.codegen.ir_node import IRnode
from vyper.codegen.memory_allocator import MemoryAllocator
-from vyper.utils import MemoryPositions, calc_mem_gas
+from vyper.exceptions import CompilerPanic
+from vyper.semantics.types import VyperType
+from vyper.semantics.types.function import ContractFunctionT
+from vyper.utils import MemoryPositions, calc_mem_gas, mkalphanum
+@dataclass
+class FrameInfo:
+ frame_start: int
+ frame_size: int
+ frame_vars: dict[str, tuple[int, VyperType]]
+
+ @property
+ def mem_used(self):
+ return self.frame_size + MemoryPositions.RESERVED_MEMORY
+
+
+@dataclass
+class _FuncIRInfo:
+ func_t: ContractFunctionT
+ gas_estimate: Optional[int] = None
+ frame_info: Optional[FrameInfo] = None
+
+ @property
+ def visibility(self):
+ return "internal" if self.func_t.is_internal else "external"
+
+ @property
+ def exit_sequence_label(self) -> str:
+ return self.ir_identifier + "_cleanup"
+
+ @cached_property
+ def ir_identifier(self) -> str:
+ argz = ",".join([str(argtyp) for argtyp in self.func_t.argument_types])
+ return mkalphanum(f"{self.visibility} {self.func_t.name} ({argz})")
+
+ def set_frame_info(self, frame_info: FrameInfo) -> None:
+ if self.frame_info is not None:
+ raise CompilerPanic(f"frame_info already set for {self.func_t}!")
+ self.frame_info = frame_info
+
+ @property
+ # common entry point for external function with kwargs
+ def external_function_base_entry_label(self) -> str:
+ assert not self.func_t.is_internal, "uh oh, should be external"
+ return self.ir_identifier + "_common"
+
+ def internal_function_label(self, is_ctor_context: bool = False) -> str:
+ assert self.func_t.is_internal, "uh oh, should be internal"
+ suffix = "_deploy" if is_ctor_context else "_runtime"
+ return self.ir_identifier + suffix
+
+
+class FuncIR:
+ pass
+
+
+@dataclass
+class EntryPointInfo:
+ func_t: ContractFunctionT
+ min_calldatasize: int # the min calldata required for this entry point
+ ir_node: IRnode # the ir for this entry point
+
+
+@dataclass
+class ExternalFuncIR(FuncIR):
+ entry_points: dict[str, EntryPointInfo] # map from abi sigs to entry points
+ common_ir: IRnode # the "common" code for the function
+
+
+@dataclass
+class InternalFuncIR(FuncIR):
+ func_ir: IRnode # the code for the function
+
+
+# TODO: should split this into external and internal ir generation?
def generate_ir_for_function(
- code: vy_ast.FunctionDef,
- sigs: Dict[str, Dict[str, FunctionSignature]], # all signatures in all namespaces
- global_ctx: GlobalContext,
- skip_nonpayable_check: bool,
-) -> IRnode:
+ code: vy_ast.FunctionDef, global_ctx: GlobalContext, is_ctor_context: bool = False
+) -> FuncIR:
"""
Parse a function and produce IR code for the function, includes:
- Signature method if statement
@@ -26,18 +96,21 @@ def generate_ir_for_function(
- Clamping and copying of arguments
- Function body
"""
- sig = code._metadata["signature"]
+ func_t = code._metadata["type"]
+
+ # generate _FuncIRInfo
+ func_t._ir_info = _FuncIRInfo(func_t)
# Validate return statements.
+ # XXX: This should really be in semantics pass.
check_single_exit(code)
- callees = code._metadata["type"].called_functions
+ callees = func_t.called_functions
# we start our function frame from the largest callee frame
max_callee_frame_size = 0
- for c in callees:
- frame_info = sigs["self"][c.name].frame_info
- assert frame_info is not None # make mypy happy
+ for c_func_t in callees:
+ frame_info = c_func_t._ir_info.frame_info
max_callee_frame_size = max(max_callee_frame_size, frame_info.frame_size)
allocate_start = max_callee_frame_size + MemoryPositions.RESERVED_MEMORY
@@ -47,33 +120,40 @@ def generate_ir_for_function(
context = Context(
vars_=None,
global_ctx=global_ctx,
- sigs=sigs,
memory_allocator=memory_allocator,
- constancy=Constancy.Constant if sig.mutability in ("view", "pure") else Constancy.Mutable,
- sig=sig,
+ constancy=Constancy.Mutable if func_t.is_mutable else Constancy.Constant,
+ func_t=func_t,
+ is_ctor_context=is_ctor_context,
)
- if sig.internal:
- assert skip_nonpayable_check is False
- o = generate_ir_for_internal_function(code, sig, context)
+ if func_t.is_internal:
+ ret: FuncIR = InternalFuncIR(generate_ir_for_internal_function(code, func_t, context))
+ func_t._ir_info.gas_estimate = ret.func_ir.gas # type: ignore
else:
- if sig.mutability == "payable":
- assert skip_nonpayable_check is False # nonsense
- o = generate_ir_for_external_function(code, sig, context, skip_nonpayable_check)
-
- o.source_pos = getpos(code)
+ kwarg_handlers, common = generate_ir_for_external_function(code, func_t, context)
+ entry_points = {
+ k: EntryPointInfo(func_t, mincalldatasize, ir_node)
+ for k, (mincalldatasize, ir_node) in kwarg_handlers.items()
+ }
+ ret = ExternalFuncIR(entry_points, common)
+ # note: this ignores the cost of traversing selector table
+ func_t._ir_info.gas_estimate = ret.common_ir.gas
frame_size = context.memory_allocator.size_of_mem - MemoryPositions.RESERVED_MEMORY
- sig.set_frame_info(FrameInfo(allocate_start, frame_size, context.vars))
+ frame_info = FrameInfo(allocate_start, frame_size, context.vars)
- if not sig.internal:
+ # XXX: when can this happen?
+ if func_t._ir_info.frame_info is None:
+ func_t._ir_info.set_frame_info(frame_info)
+ else:
+ assert frame_info == func_t._ir_info.frame_info
+
+ if not func_t.is_internal:
# adjust gas estimate to include cost of mem expansion
# frame_size of external function includes all private functions called
# (note: internal functions do not need to adjust gas estimate since
- # it is already accounted for by the caller.)
- o.add_gas_estimate += calc_mem_gas(sig.frame_info.mem_used)
-
- sig.gas_estimate = o.gas
+ mem_expansion_cost = calc_mem_gas(func_t._ir_info.frame_info.mem_used) # type: ignore
+ ret.common_ir.add_gas_estimate += mem_expansion_cost # type: ignore
- return o
+ return ret
diff --git a/vyper/codegen/function_definitions/external_function.py b/vyper/codegen/function_definitions/external_function.py
index feb2973e2a..32236e9aad 100644
--- a/vyper/codegen/function_definitions/external_function.py
+++ b/vyper/codegen/function_definitions/external_function.py
@@ -1,8 +1,3 @@
-from typing import Any, List
-
-import vyper.utils as util
-from vyper.address_space import CALLDATA, DATA, MEMORY
-from vyper.ast.signatures.function_signature import FunctionSignature
from vyper.codegen.abi_encoder import abi_encoding_matches_vyper
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import get_element_ptr, getpos, make_setter, needs_clamp
@@ -10,24 +5,25 @@
from vyper.codegen.function_definitions.utils import get_nonreentrant_lock
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.codegen.stmt import parse_body
+from vyper.evm.address_space import CALLDATA, DATA, MEMORY
from vyper.semantics.types import TupleT
+from vyper.semantics.types.function import ContractFunctionT
# register function args with the local calling context.
# also allocate the ones that live in memory (i.e. kwargs)
-def _register_function_args(context: Context, sig: FunctionSignature) -> List[IRnode]:
+def _register_function_args(func_t: ContractFunctionT, context: Context) -> list[IRnode]:
ret = []
-
# the type of the calldata
- base_args_t = TupleT(tuple(arg.typ for arg in sig.base_args))
+ base_args_t = TupleT(tuple(arg.typ for arg in func_t.positional_args))
# tuple with the abi_encoded args
- if sig.is_init_func:
+ if func_t.is_constructor:
base_args_ofst = IRnode(0, location=DATA, typ=base_args_t, encoding=Encoding.ABI)
else:
base_args_ofst = IRnode(4, location=CALLDATA, typ=base_args_t, encoding=Encoding.ABI)
- for i, arg in enumerate(sig.base_args):
+ for i, arg in enumerate(func_t.positional_args):
arg_ir = get_element_ptr(base_args_ofst, i)
if needs_clamp(arg.typ, Encoding.ABI):
@@ -53,17 +49,13 @@ def _register_function_args(context: Context, sig: FunctionSignature) -> List[IR
return ret
-def _annotated_method_id(abi_sig):
- method_id = util.method_id_int(abi_sig)
- annotation = f"{hex(method_id)}: {abi_sig}"
- return IRnode(method_id, annotation=annotation)
-
-
-def _generate_kwarg_handlers(context: Context, sig: FunctionSignature) -> List[Any]:
+def _generate_kwarg_handlers(
+ func_t: ContractFunctionT, context: Context
+) -> dict[str, tuple[int, IRnode]]:
# generate kwarg handlers.
# since they might come in thru calldata or be default,
# allocate them in memory and then fill it in based on calldata or default,
- # depending on the signature
+ # depending on the ContractFunctionT
# a kwarg handler looks like
# (if (eq _method_id )
# copy calldata args to memory
@@ -71,12 +63,11 @@ def _generate_kwarg_handlers(context: Context, sig: FunctionSignature) -> List[A
# goto external_function_common_ir
def handler_for(calldata_kwargs, default_kwargs):
- calldata_args = sig.base_args + calldata_kwargs
+ calldata_args = func_t.positional_args + calldata_kwargs
# create a fake type so that get_element_ptr works
calldata_args_t = TupleT(list(arg.typ for arg in calldata_args))
- abi_sig = sig.abi_signature_for_kwargs(calldata_kwargs)
- method_id = _annotated_method_id(abi_sig)
+ abi_sig = func_t.abi_signature_for_kwargs(calldata_kwargs)
calldata_kwargs_ofst = IRnode(
4, location=CALLDATA, typ=calldata_args_t, encoding=Encoding.ABI
@@ -88,16 +79,13 @@ def handler_for(calldata_kwargs, default_kwargs):
# ensure calldata is at least of minimum length
args_abi_t = calldata_args_t.abi_type
calldata_min_size = args_abi_t.min_size() + 4
- ret.append(["assert", ["ge", "calldatasize", calldata_min_size]])
# TODO optimize make_setter by using
# TupleT(list(arg.typ for arg in calldata_kwargs + default_kwargs))
# (must ensure memory area is contiguous)
- n_base_args = len(sig.base_args)
-
for i, arg_meta in enumerate(calldata_kwargs):
- k = n_base_args + i
+ k = func_t.n_positional_args + i
dst = context.lookup_var(arg_meta.name).pos
@@ -113,38 +101,21 @@ def handler_for(calldata_kwargs, default_kwargs):
dst = context.lookup_var(x.name).pos
lhs = IRnode(dst, location=MEMORY, typ=x.typ)
lhs.source_pos = getpos(x.ast_source)
- kw_ast_val = sig.default_values[x.name] # e.g. `3` in x: int = 3
+ kw_ast_val = func_t.default_values[x.name] # e.g. `3` in x: int = 3
rhs = Expr(kw_ast_val, context).ir_node
copy_arg = make_setter(lhs, rhs)
copy_arg.source_pos = getpos(x.ast_source)
ret.append(copy_arg)
- ret.append(["goto", sig.external_function_base_entry_label])
-
- method_id_check = ["eq", "_calldata_method_id", method_id]
-
- # if there is a function whose selector is 0, it won't be distinguished
- # from the case where nil calldata is supplied, b/c calldataload loads
- # 0s past the end of physical calldata (cf. yellow paper).
- # since supplying 0 calldata is expected to trigger the fallback fn,
- # we check that calldatasize > 0, which distinguishes the 0 selector
- # from the fallback function "selector"
- # (equiv. to "all selectors not in the selector table").
+ ret.append(["goto", func_t._ir_info.external_function_base_entry_label])
- # note: cases where not enough calldata is supplied (besides
- # calldatasize==0) are not addressed here b/c a calldatasize
- # well-formedness check is already present in the function body
- # as part of abi validation
- if method_id.value == 0:
- method_id_check = ["and", ["gt", "calldatasize", 0], method_id_check]
+ # return something we can turn into ExternalFuncIR
+ return abi_sig, calldata_min_size, ret
- ret = ["if", method_id_check, ret]
- return ret
+ ret = {}
- ret = ["seq"]
-
- keyword_args = sig.default_args
+ keyword_args = func_t.keyword_args
# allocate variable slots in memory
for arg in keyword_args:
@@ -154,9 +125,12 @@ def handler_for(calldata_kwargs, default_kwargs):
calldata_kwargs = keyword_args[:i]
default_kwargs = keyword_args[i:]
- ret.append(handler_for(calldata_kwargs, default_kwargs))
+ sig, calldata_min_size, ir_node = handler_for(calldata_kwargs, default_kwargs)
+ ret[sig] = calldata_min_size, ir_node
+
+ sig, calldata_min_size, ir_node = handler_for(keyword_args, [])
- ret.append(handler_for(keyword_args, []))
+ ret[sig] = calldata_min_size, ir_node
return ret
@@ -164,44 +138,40 @@ def handler_for(calldata_kwargs, default_kwargs):
# TODO it would be nice if this returned a data structure which were
# amenable to generating a jump table instead of the linear search for
# method_id we have now.
-def generate_ir_for_external_function(code, sig, context, skip_nonpayable_check):
+def generate_ir_for_external_function(code, func_t, context):
# TODO type hints:
# def generate_ir_for_external_function(
- # code: vy_ast.FunctionDef, sig: FunctionSignature, context: Context, check_nonpayable: bool,
+ # code: vy_ast.FunctionDef,
+ # func_t: ContractFunctionT,
+ # context: Context,
+ # check_nonpayable: bool,
# ) -> IRnode:
"""Return the IR for an external function. Includes code to inspect the method_id,
enter the function (nonpayable and reentrancy checks), handle kwargs and exit
the function (clean up reentrancy storage variables)
"""
- func_type = code._metadata["type"]
-
- nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_type)
+ nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)
# generate handlers for base args and register the variable records
- handle_base_args = _register_function_args(context, sig)
+ handle_base_args = _register_function_args(func_t, context)
# generate handlers for kwargs and register the variable records
- kwarg_handlers = _generate_kwarg_handlers(context, sig)
+ kwarg_handlers = _generate_kwarg_handlers(func_t, context)
body = ["seq"]
# once optional args have been handled,
# generate the main body of the function
body += handle_base_args
- if sig.mutability != "payable" and not skip_nonpayable_check:
- # if the contract contains payable functions, but this is not one of them
- # add an assertion that the value of the call is zero
- body += [["assert", ["iszero", "callvalue"]]]
-
body += nonreentrant_pre
body += [parse_body(code.body, context, ensure_terminated=True)]
# wrap the body in labeled block
- body = ["label", sig.external_function_base_entry_label, ["var_list"], body]
+ body = ["label", func_t._ir_info.external_function_base_entry_label, ["var_list"], body]
exit_sequence = ["seq"] + nonreentrant_post
- if sig.is_init_func:
+ if func_t.is_constructor:
pass # init func has special exit sequence generated by module.py
elif context.return_type is None:
exit_sequence += [["stop"]]
@@ -212,22 +182,10 @@ def generate_ir_for_external_function(code, sig, context, skip_nonpayable_check)
if context.return_type is not None:
exit_sequence_args += ["ret_ofst", "ret_len"]
# wrap the exit in a labeled block
- exit = ["label", sig.exit_sequence_label, exit_sequence_args, exit_sequence]
+ exit_ = ["label", func_t._ir_info.exit_sequence_label, exit_sequence_args, exit_sequence]
# the ir which comprises the main body of the function,
# besides any kwarg handling
- func_common_ir = ["seq", body, exit]
-
- if sig.is_default_func or sig.is_init_func:
- ret = ["seq"]
- # add a goto to make the function entry look like other functions
- # (for zksync interpreter)
- ret.append(["goto", sig.external_function_base_entry_label])
- ret.append(func_common_ir)
- else:
- ret = kwarg_handlers
- # sneak the base code into the kwarg handler
- # TODO rethink this / make it clearer
- ret[-1][-1].append(func_common_ir)
+ func_common_ir = IRnode.from_list(["seq", body, exit_], source_pos=getpos(code))
- return IRnode.from_list(ret, source_pos=getpos(sig.func_ast_code))
+ return kwarg_handlers, func_common_ir
diff --git a/vyper/codegen/function_definitions/internal_function.py b/vyper/codegen/function_definitions/internal_function.py
index b0ca117b6b..228191e3ca 100644
--- a/vyper/codegen/function_definitions/internal_function.py
+++ b/vyper/codegen/function_definitions/internal_function.py
@@ -1,18 +1,18 @@
from vyper import ast as vy_ast
-from vyper.ast.signatures import FunctionSignature
from vyper.codegen.context import Context
from vyper.codegen.function_definitions.utils import get_nonreentrant_lock
from vyper.codegen.ir_node import IRnode
from vyper.codegen.stmt import parse_body
+from vyper.semantics.types.function import ContractFunctionT
def generate_ir_for_internal_function(
- code: vy_ast.FunctionDef, sig: FunctionSignature, context: Context
+ code: vy_ast.FunctionDef, func_t: ContractFunctionT, context: Context
) -> IRnode:
"""
Parse a internal function (FuncDef), and produce full function body.
- :param sig: the FuntionSignature
+ :param func_t: the ContractFunctionT
:param code: ast of function
:param context: current calling context
:return: function body in IR
@@ -37,22 +37,20 @@ def generate_ir_for_internal_function(
# situation like the following is easy to bork:
# x: T[2] = [self.generate_T(), self.generate_T()]
- func_type = code._metadata["type"]
-
# Get nonreentrant lock
- for arg in sig.args:
+ for arg in func_t.arguments:
# allocate a variable for every arg, setting mutability
- # to False to comply with vyper semantics, function arguments are immutable
- context.new_variable(arg.name, arg.typ, is_mutable=False)
+ # to True to allow internal function arguments to be mutable
+ context.new_variable(arg.name, arg.typ, is_mutable=True)
- nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_type)
+ nonreentrant_pre, nonreentrant_post = get_nonreentrant_lock(func_t)
- function_entry_label = sig.internal_function_label
- cleanup_label = sig.exit_sequence_label
+ function_entry_label = func_t._ir_info.internal_function_label(context.is_ctor_context)
+ cleanup_label = func_t._ir_info.exit_sequence_label
stack_args = ["var_list"]
- if func_type.return_type:
+ if func_t.return_type:
stack_args += ["return_buffer"]
stack_args += ["return_pc"]
diff --git a/vyper/codegen/function_definitions/utils.py b/vyper/codegen/function_definitions/utils.py
index 7129388c58..f524ec6e88 100644
--- a/vyper/codegen/function_definitions/utils.py
+++ b/vyper/codegen/function_definitions/utils.py
@@ -8,6 +8,10 @@ def get_nonreentrant_lock(func_type):
nkey = func_type.reentrancy_key_position.position
+ LOAD, STORE = "sload", "sstore"
+ if version_check(begin="cancun"):
+ LOAD, STORE = "tload", "tstore"
+
if version_check(begin="berlin"):
# any nonzero values would work here (see pricing as of net gas
# metering); these values are chosen so that downgrading to the
@@ -16,12 +20,12 @@ def get_nonreentrant_lock(func_type):
else:
final_value, temp_value = 0, 1
- check_notset = ["assert", ["ne", temp_value, ["sload", nkey]]]
+ check_notset = ["assert", ["ne", temp_value, [LOAD, nkey]]]
if func_type.mutability == StateMutability.VIEW:
return [check_notset], [["seq"]]
else:
- pre = ["seq", check_notset, ["sstore", nkey, temp_value]]
- post = ["sstore", nkey, final_value]
+ pre = ["seq", check_notset, [STORE, nkey, temp_value]]
+ post = [STORE, nkey, final_value]
return [pre], [post]
diff --git a/vyper/codegen/global_context.py b/vyper/codegen/global_context.py
index 5de04d5127..1f6783f6f8 100644
--- a/vyper/codegen/global_context.py
+++ b/vyper/codegen/global_context.py
@@ -2,8 +2,6 @@
from typing import Optional
from vyper import ast as vy_ast
-from vyper.semantics.namespace import get_namespace, override_global_namespace
-from vyper.semantics.types.utils import type_from_annotation
# Datatype to store all global context information.
@@ -25,16 +23,6 @@ def variables(self):
variable_decls = self._module.get_children(vy_ast.VariableDecl)
return {s.target.id: s.target._metadata["varinfo"] for s in variable_decls}
- def parse_type(self, ast_node):
- # kludge implementation for backwards compatibility.
- # TODO: replace with type_from_ast
- try:
- ns = self._module._metadata["namespace"]
- except AttributeError:
- ns = get_namespace()
- with override_global_namespace(ns):
- return type_from_annotation(ast_node)
-
@property
def immutables(self):
return [t for t in self.variables.values() if t.is_immutable]
diff --git a/vyper/codegen/ir_node.py b/vyper/codegen/ir_node.py
index c2c127b9d5..ad4aa76437 100644
--- a/vyper/codegen/ir_node.py
+++ b/vyper/codegen/ir_node.py
@@ -1,10 +1,11 @@
+import contextlib
import re
from enum import Enum, auto
from functools import cached_property
from typing import Any, List, Optional, Tuple, Union
-from vyper.address_space import AddrSpace
from vyper.compiler.settings import VYPER_COLOR_OUTPUT
+from vyper.evm.address_space import AddrSpace
from vyper.evm.opcodes import get_ir_opcodes
from vyper.exceptions import CodegenPanic, CompilerPanic
from vyper.semantics.types import VyperType
@@ -38,11 +39,6 @@ def __repr__(self) -> str:
__mul__ = __add__
-def push_label_to_stack(labelname: str) -> str:
- # items prefixed with `_sym_` are ignored until asm phase
- return "_sym_" + labelname
-
-
class Encoding(Enum):
# vyper encoding, default for memory variables
VYPER = auto()
@@ -51,13 +47,81 @@ class Encoding(Enum):
# future: packed
+# shortcut for chaining multiple cache_when_complex calls
+# CMC 2023-08-10 remove this and scope_together _as soon as_ we have
+# real variables in IR (that we can declare without explicit scoping -
+# needs liveness analysis).
+@contextlib.contextmanager
+def scope_multi(ir_nodes, names):
+ assert len(ir_nodes) == len(names)
+
+ builders = []
+ scoped_ir_nodes = []
+
+ class _MultiBuilder:
+ def resolve(self, body):
+ # sanity check that it's initialized properly
+ assert len(builders) == len(ir_nodes)
+ ret = body
+ for b in reversed(builders):
+ ret = b.resolve(ret)
+ return ret
+
+ mb = _MultiBuilder()
+
+ with contextlib.ExitStack() as stack:
+ for arg, name in zip(ir_nodes, names):
+ b, ir_node = stack.enter_context(arg.cache_when_complex(name))
+
+ builders.append(b)
+ scoped_ir_nodes.append(ir_node)
+
+ yield mb, scoped_ir_nodes
+
+
+# create multiple with scopes if any of the items are complex, to force
+# ordering of side effects.
+@contextlib.contextmanager
+def scope_together(ir_nodes, names):
+ assert len(ir_nodes) == len(names)
+
+ should_scope = any(s._optimized.is_complex_ir for s in ir_nodes)
+
+ class _Builder:
+ def resolve(self, body):
+ if not should_scope:
+ # uses of the variable have already been inlined
+ return body
+
+ ret = body
+ # build with scopes from inside-out (hence reversed)
+ for arg, name in reversed(list(zip(ir_nodes, names))):
+ ret = ["with", name, arg, ret]
+
+ if isinstance(body, IRnode):
+ return IRnode.from_list(
+ ret, typ=body.typ, location=body.location, encoding=body.encoding
+ )
+ else:
+ return ret
+
+ b = _Builder()
+
+ if should_scope:
+ ir_vars = tuple(
+ IRnode.from_list(name, typ=arg.typ, location=arg.location, encoding=arg.encoding)
+ for (arg, name) in zip(ir_nodes, names)
+ )
+ yield b, ir_vars
+ else:
+ # inline them
+ yield b, ir_nodes
+
+
# this creates a magical block which maps to IR `with`
class _WithBuilder:
def __init__(self, ir_node, name, should_inline=False):
- # TODO figure out how to fix this circular import
- from vyper.ir.optimizer import optimize
-
- if should_inline and optimize(ir_node).is_complex_ir:
+ if should_inline and ir_node._optimized.is_complex_ir:
# this can only mean trouble
raise CompilerPanic("trying to inline a complex IR node")
@@ -156,6 +220,13 @@ def _check(condition, err):
self.valency = 1
self._gas = 5
+ elif isinstance(self.value, bytes):
+ # a literal bytes value, probably inside a "data" node.
+ _check(len(self.args) == 0, "bytes can't have arguments")
+
+ self.valency = 0
+ self._gas = 0
+
elif isinstance(self.value, str):
# Opcodes and pseudo-opcodes (e.g. clamp)
if self.value.upper() in get_ir_opcodes():
@@ -272,8 +343,11 @@ def _check(condition, err):
self.valency = 0
self._gas = sum([arg.gas for arg in self.args])
elif self.value == "label":
- if not self.args[1].value == "var_list":
- raise CodegenPanic(f"2nd argument to label must be var_list, {self}")
+ _check(
+ self.args[1].value == "var_list",
+ f"2nd argument to label must be var_list, {self}",
+ )
+ _check(len(args) == 3, f"label should have 3 args but has {len(args)}, {self}")
self.valency = 0
self._gas = 1 + sum(t.gas for t in self.args)
elif self.value == "unique_symbol":
@@ -324,20 +398,29 @@ def _check(condition, err):
def gas(self):
return self._gas + self.add_gas_estimate
- # the IR should be cached.
- # TODO make this private. turns out usages are all for the caching
- # idiom that cache_when_complex addresses
+ # the IR should be cached and/or evaluated exactly once
@property
def is_complex_ir(self):
# list of items not to cache. note can add other env variables
# which do not change, e.g. calldatasize, coinbase, etc.
- do_not_cache = {"~empty", "calldatasize"}
+ # reads (from memory or storage) should not be cached because
+ # they can have or be affected by side effects.
+ do_not_cache = {"~empty", "calldatasize", "callvalue"}
+
return (
isinstance(self.value, str)
and (self.value.lower() in VALID_IR_MACROS or self.value.upper() in get_ir_opcodes())
and self.value.lower() not in do_not_cache
)
+ # set an error message and push down into all children.
+ # useful for overriding an error message generated by a helper
+ # function with a more specific error message.
+ def set_error_msg(self, error_msg: str) -> None:
+ self.error_msg = error_msg
+ for arg in self.args:
+ arg.set_error_msg(error_msg)
+
# get the unique symbols contained in this node, which provides
# sanity check invariants for the optimizer.
# cache because it's a perf hotspot. note that this (and other cached
@@ -371,6 +454,13 @@ def is_pointer(self):
# eventually
return self.location is not None
+ @property # probably could be cached_property but be paranoid
+ def _optimized(self):
+ # TODO figure out how to fix this circular import
+ from vyper.ir.optimizer import optimize
+
+ return optimize(self)
+
# This function is slightly confusing but abstracts a common pattern:
# when an IR value needs to be computed once and then cached as an
# IR value (if it is expensive, or more importantly if its computation
@@ -387,16 +477,24 @@ def is_pointer(self):
# return builder.resolve(ret)
# ```
def cache_when_complex(self, name):
- from vyper.ir.optimizer import optimize
-
# for caching purposes, see if the ir_node will be optimized
# because a non-literal expr could turn into a literal,
# (e.g. `(add 1 2)`)
# TODO this could really be moved into optimizer.py
- should_inline = not optimize(self).is_complex_ir
+ should_inline = not self._optimized.is_complex_ir
return _WithBuilder(self, name, should_inline)
+ @cached_property
+ def referenced_variables(self):
+ ret = set()
+ for arg in self.args:
+ ret |= arg.referenced_variables
+
+ ret |= getattr(self, "_referenced_variables", set())
+
+ return ret
+
@cached_property
def contains_self_call(self):
return getattr(self, "is_self_call", False) or any(x.contains_self_call for x in self.args)
diff --git a/vyper/codegen/jumptable_utils.py b/vyper/codegen/jumptable_utils.py
new file mode 100644
index 0000000000..6404b75532
--- /dev/null
+++ b/vyper/codegen/jumptable_utils.py
@@ -0,0 +1,212 @@
+# helper module which implements jumptable for function selection
+import math
+from dataclasses import dataclass
+
+from vyper.utils import method_id_int
+
+
+@dataclass
+class Signature:
+ method_id: int
+ payable: bool
+
+
+# bucket for dense function
+@dataclass
+class Bucket:
+ bucket_id: int
+ magic: int
+ method_ids: list[int]
+
+ @property
+ def image(self):
+ return _image_of([s for s in self.method_ids], self.magic)
+
+ @property
+ # return method ids, sorted by by their image
+ def method_ids_image_order(self):
+ return [x[1] for x in sorted(zip(self.image, self.method_ids))]
+
+ @property
+ def bucket_size(self):
+ return len(self.method_ids)
+
+
+BITS_MAGIC = 24 # a constant which produced good results, see _bench_dense()
+
+
+def _image_of(xs, magic):
+ bits_shift = BITS_MAGIC
+
+ # take the upper bits from the multiplication for more entropy
+ # can we do better using primes of some sort?
+ return [((x * magic) >> bits_shift) % len(xs) for x in xs]
+
+
+class _FindMagicFailure(Exception):
+ pass
+
+
+class _HasEmptyBuckets(Exception):
+ pass
+
+
+def find_magic_for(xs):
+ for m in range(2**16):
+ test = _image_of(xs, m)
+ if len(test) == len(set(test)):
+ return m
+
+ raise _FindMagicFailure(f"Could not find hash for {xs}")
+
+
+def _mk_buckets(method_ids, n_buckets):
+ buckets = {}
+ for x in method_ids:
+ t = x % n_buckets
+ buckets.setdefault(t, [])
+ buckets[t].append(x)
+ return buckets
+
+
+# two layer method for generating perfect hash
+# first get "reasonably good" distribution by using
+# method_id % len(method_ids)
+# second, get the magic for the bucket.
+def _dense_jumptable_info(method_ids, n_buckets):
+ buckets = _mk_buckets(method_ids, n_buckets)
+
+ # if there are somehow empty buckets, bail out as that can mess up
+ # the bucket header layout
+ if len(buckets) != n_buckets:
+ raise _HasEmptyBuckets()
+
+ ret = {}
+ for bucket_id, method_ids in buckets.items():
+ magic = find_magic_for(method_ids)
+ ret[bucket_id] = Bucket(bucket_id, magic, method_ids)
+
+ return ret
+
+
+START_BUCKET_SIZE = 5
+
+
+# this is expensive! for 80 methods, costs about 350ms and probably
+# linear in # of methods.
+# see _bench_perfect()
+# note the buckets are NOT in order!
+def generate_dense_jumptable_info(signatures):
+ method_ids = [method_id_int(sig) for sig in signatures]
+ n = len(signatures)
+ # start at bucket size of 5 and try to improve (generally
+ # speaking we want as few buckets as possible)
+ n_buckets = (n // START_BUCKET_SIZE) + 1
+ ret = None
+ tried_exhaustive = False
+ while n_buckets > 0:
+ try:
+ # print(f"trying {n_buckets} (bucket size {n // n_buckets})")
+ solution = _dense_jumptable_info(method_ids, n_buckets)
+ assert len(solution) == n_buckets
+ ret = n_buckets, solution
+
+ except _HasEmptyBuckets:
+ # found a solution which has empty buckets; skip it since
+ # it will break the bucket layout.
+ pass
+
+ except _FindMagicFailure:
+ if ret is not None:
+ break
+
+ # we have not tried exhaustive search. try really hard
+ # to find a valid jumptable at the cost of performance
+ if not tried_exhaustive:
+ # print("failed with guess! trying exhaustive search.")
+ n_buckets = n
+ tried_exhaustive = True
+ continue
+ else:
+ raise RuntimeError(f"Could not generate jumptable! {signatures}")
+ n_buckets -= 1
+
+ return ret
+
+
+# note the buckets are NOT in order!
+def generate_sparse_jumptable_buckets(signatures):
+ method_ids = [method_id_int(sig) for sig in signatures]
+ n = len(signatures)
+
+ # search a range of buckets to try to minimize bucket size
+ # (doing the range search improves worst worst bucket size from 9 to 4,
+ # see _bench_sparse)
+ lo = max(1, math.floor(n * 0.85))
+ hi = max(1, math.ceil(n * 1.15))
+ stats = {}
+ for i in range(lo, hi + 1):
+ buckets = _mk_buckets(method_ids, i)
+
+ stats[i] = buckets
+
+ min_max_bucket_size = hi + 1 # smallest max_bucket_size
+ # find the smallest i which gives us the smallest max_bucket_size
+ for i, buckets in stats.items():
+ max_bucket_size = max(len(bucket) for bucket in buckets.values())
+ if max_bucket_size < min_max_bucket_size:
+ min_max_bucket_size = max_bucket_size
+ ret = i, buckets
+
+ assert ret is not None
+ return ret
+
+
+# benchmark for quality of buckets
+def _bench_dense(N=1_000, n_methods=100):
+ import random
+
+ stats = []
+ for i in range(N):
+ seed = random.randint(0, 2**64 - 1)
+ # "large" contracts in prod hit about ~50 methods, test with
+ # double the limit
+ sigs = [f"foo{i + seed}()" for i in range(n_methods)]
+
+ xs = generate_dense_jumptable_info(sigs)
+ print(f"found. n buckets {len(xs)}")
+ stats.append(xs)
+
+ def mean(xs):
+ return sum(xs) / len(xs)
+
+ avg_n_buckets = mean([len(jt) for jt in stats])
+ # usually around ~14 buckets per 100 sigs
+ # N=10, time=3.6s
+ print(f"average N buckets: {avg_n_buckets}")
+
+
+def _bench_sparse(N=10_000, n_methods=80):
+ import random
+
+ stats = []
+ for _ in range(N):
+ seed = random.randint(0, 2**64 - 1)
+ sigs = [f"foo{i + seed}()" for i in range(n_methods)]
+ _, buckets = generate_sparse_jumptable_buckets(sigs)
+
+ bucket_sizes = [len(bucket) for bucket in buckets.values()]
+ worst_bucket_size = max(bucket_sizes)
+ mean_bucket_size = sum(bucket_sizes) / len(bucket_sizes)
+ stats.append((worst_bucket_size, mean_bucket_size))
+
+ # N=10_000, time=9s
+ # range 0.85*n - 1.15*n
+ # worst worst bucket size: 4
+ # avg worst bucket size: 3.0018
+ # worst mean bucket size: 2.0
+ # avg mean bucket size: 1.579112583664968
+ print("worst worst bucket size:", max(x[0] for x in stats))
+ print("avg worst bucket size:", sum(x[0] for x in stats) / len(stats))
+ print("worst mean bucket size:", max(x[1] for x in stats))
+ print("avg mean bucket size:", sum(x[1] for x in stats) / len(stats))
diff --git a/vyper/codegen/keccak256_helper.py b/vyper/codegen/keccak256_helper.py
index b22453761b..9c5f5eb1d0 100644
--- a/vyper/codegen/keccak256_helper.py
+++ b/vyper/codegen/keccak256_helper.py
@@ -8,7 +8,7 @@
from vyper.utils import SHA3_BASE, SHA3_PER_WORD, MemoryPositions, bytes_to_int, keccak256
-def _check_byteslike(typ, _expr):
+def _check_byteslike(typ):
if not isinstance(typ, _BytestringT) and typ != BYTES32_T:
# NOTE this may be checked at a higher level, but just be safe
raise CompilerPanic("keccak256 only accepts bytes-like objects")
@@ -18,8 +18,8 @@ def _gas_bound(num_words):
return SHA3_BASE + num_words * SHA3_PER_WORD
-def keccak256_helper(expr, to_hash, context):
- _check_byteslike(to_hash.typ, expr)
+def keccak256_helper(to_hash, context):
+ _check_byteslike(to_hash.typ)
# Can hash literals
# TODO this is dead code.
diff --git a/vyper/codegen/module.py b/vyper/codegen/module.py
index 71f9ed552d..6445a5e1e0 100644
--- a/vyper/codegen/module.py
+++ b/vyper/codegen/module.py
@@ -1,15 +1,15 @@
-# a contract.vy -- all functions and constructor
+# a compilation unit -- all functions and constructor
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List
-from vyper import ast as vy_ast
-from vyper.ast.signatures.function_signature import FunctionSignature, FunctionSignatures
+from vyper.codegen import core, jumptable_utils
from vyper.codegen.core import shr
from vyper.codegen.function_definitions import generate_ir_for_function
from vyper.codegen.global_context import GlobalContext
from vyper.codegen.ir_node import IRnode
+from vyper.compiler.settings import _is_debug_mode
from vyper.exceptions import CompilerPanic
-from vyper.semantics.types.function import StateMutability
+from vyper.utils import method_id_int
def _topsort_helper(functions, lookup):
@@ -34,12 +34,12 @@ def _topsort(functions):
return list(dict.fromkeys(_topsort_helper(functions, lookup)))
-def _is_init_func(func_ast):
- return func_ast._metadata["signature"].is_init_func
+def _is_constructor(func_ast):
+ return func_ast._metadata["type"].is_constructor
-def _is_default_func(func_ast):
- return func_ast._metadata["signature"].is_default_func
+def _is_fallback(func_ast):
+ return func_ast._metadata["type"].is_fallback
def _is_internal(func_ast):
@@ -47,139 +47,452 @@ def _is_internal(func_ast):
def _is_payable(func_ast):
- return func_ast._metadata["type"].mutability == StateMutability.PAYABLE
+ return func_ast._metadata["type"].is_payable
-# codegen for all runtime functions + callvalue/calldata checks + method selector routines
-def _runtime_ir(runtime_functions, all_sigs, global_ctx):
- # categorize the runtime functions because we will organize the runtime
- # code into the following sections:
- # payable functions, nonpayable functions, fallback function, internal_functions
- internal_functions = [f for f in runtime_functions if _is_internal(f)]
+def _annotated_method_id(abi_sig):
+ method_id = method_id_int(abi_sig)
+ annotation = f"{hex(method_id)}: {abi_sig}"
+ return IRnode(method_id, annotation=annotation)
- external_functions = [f for f in runtime_functions if not _is_internal(f)]
- default_function = next((f for f in external_functions if _is_default_func(f)), None)
- # functions that need to go exposed in the selector section
- regular_functions = [f for f in external_functions if not _is_default_func(f)]
- payables = [f for f in regular_functions if _is_payable(f)]
- nonpayables = [f for f in regular_functions if not _is_payable(f)]
+def label_for_entry_point(abi_sig, entry_point):
+ method_id = method_id_int(abi_sig)
+ return f"{entry_point.func_t._ir_info.ir_identifier}{method_id}"
- # create a map of the IR functions since they might live in both
- # runtime and deploy code (if init function calls them)
- internal_functions_map: Dict[str, IRnode] = {}
- for func_ast in internal_functions:
- func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False)
- internal_functions_map[func_ast.name] = func_ir
+# adapt whatever generate_ir_for_function gives us into an IR node
+def _ir_for_fallback_or_ctor(func_ast, *args, **kwargs):
+ func_t = func_ast._metadata["type"]
+ assert func_t.is_fallback or func_t.is_constructor
+
+ ret = ["seq"]
+ if not func_t.is_payable:
+ callvalue_check = ["assert", ["iszero", "callvalue"]]
+ ret.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check"))
+
+ func_ir = generate_ir_for_function(func_ast, *args, **kwargs)
+ assert len(func_ir.entry_points) == 1
+
+ # add a goto to make the function entry look like other functions
+ # (for zksync interpreter)
+ ret.append(["goto", func_t._ir_info.external_function_base_entry_label])
+ ret.append(func_ir.common_ir)
+
+ return IRnode.from_list(ret)
+
+
+def _ir_for_internal_function(func_ast, *args, **kwargs):
+ return generate_ir_for_function(func_ast, *args, **kwargs).func_ir
+
+
+def _generate_external_entry_points(external_functions, global_ctx):
+ entry_points = {} # map from ABI sigs to ir code
+ sig_of = {} # reverse map from method ids to abi sig
+
+ for code in external_functions:
+ func_ir = generate_ir_for_function(code, global_ctx)
+ for abi_sig, entry_point in func_ir.entry_points.items():
+ assert abi_sig not in entry_points
+ entry_points[abi_sig] = entry_point
+ sig_of[method_id_int(abi_sig)] = abi_sig
+
+ # stick function common body into final entry point to save a jump
+ ir_node = IRnode.from_list(["seq", entry_point.ir_node, func_ir.common_ir])
+ entry_point.ir_node = ir_node
+
+ return entry_points, sig_of
+
+
+# codegen for all runtime functions + callvalue/calldata checks,
+# with O(1) jumptable for selector table.
+# uses two level strategy: uses `method_id % n_buckets` to descend
+# into a bucket (of about 8-10 items), and then uses perfect hash
+# to select the final function.
+# costs about 212 gas for typical function and 8 bytes of code (+ ~87 bytes of global overhead)
+def _selector_section_dense(external_functions, global_ctx):
+ function_irs = []
- # for some reason, somebody may want to deploy a contract with no
- # external functions, or more likely, a "pure data" contract which
- # contains immutables
if len(external_functions) == 0:
- # TODO: prune internal functions in this case?
- runtime = ["seq"] + list(internal_functions_map.values())
- return runtime, internal_functions_map
+ return IRnode.from_list(["seq"])
+
+ entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx)
+
+ # generate the label so the jumptable works
+ for abi_sig, entry_point in entry_points.items():
+ label = label_for_entry_point(abi_sig, entry_point)
+ ir_node = ["label", label, ["var_list"], entry_point.ir_node]
+ function_irs.append(IRnode.from_list(ir_node))
- # note: if the user does not provide one, the default fallback function
- # reverts anyway. so it does not hurt to batch the payable check.
- default_is_nonpayable = default_function is None or not _is_payable(default_function)
+ n_buckets, jumptable_info = jumptable_utils.generate_dense_jumptable_info(entry_points.keys())
+ # note: we are guaranteed by jumptable_utils that there are no buckets
+ # which are empty. sanity check that the bucket ids are well-behaved:
+ assert n_buckets == len(jumptable_info)
+ for i, (bucket_id, _) in enumerate(sorted(jumptable_info.items())):
+ assert i == bucket_id
- # when a contract has a nonpayable default function,
- # we can do a single check for all nonpayable functions
- batch_payable_check = len(nonpayables) > 0 and default_is_nonpayable
- skip_nonpayable_check = batch_payable_check
+ # bucket magic <2 bytes> | bucket location <2 bytes> | bucket size <1 byte>
+ # TODO: can make it smaller if the largest bucket magic <= 255
+ SZ_BUCKET_HEADER = 5
selector_section = ["seq"]
- for func_ast in payables:
- func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, False)
- selector_section.append(func_ir)
+ bucket_id = ["mod", "_calldata_method_id", n_buckets]
+ bucket_hdr_location = [
+ "add",
+ ["symbol", "BUCKET_HEADERS"],
+ ["mul", bucket_id, SZ_BUCKET_HEADER],
+ ]
+ # get bucket header
+ dst = 32 - SZ_BUCKET_HEADER
+ assert dst >= 0
+
+ if _is_debug_mode():
+ selector_section.append(["assert", ["eq", "msize", 0]])
+
+ selector_section.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER])
+
+ # figure out the minimum number of bytes we can use to encode
+ # min_calldatasize in function info
+ largest_mincalldatasize = max(f.min_calldatasize for f in entry_points.values())
+ FN_METADATA_BYTES = (largest_mincalldatasize.bit_length() + 7) // 8
+
+ func_info_size = 4 + 2 + FN_METADATA_BYTES
+ # grab function info.
+ # method id <4 bytes> | label <2 bytes> | func info <1-3 bytes>
+ # func info (1-3 bytes, packed) for: expected calldatasize, is_nonpayable bit
+ # NOTE: might be able to improve codesize if we use variable # of bytes
+ # per bucket
+
+ hdr_info = IRnode.from_list(["mload", 0])
+ with hdr_info.cache_when_complex("hdr_info") as (b1, hdr_info):
+ bucket_location = ["and", 0xFFFF, shr(8, hdr_info)]
+ bucket_magic = shr(24, hdr_info)
+ bucket_size = ["and", 0xFF, hdr_info]
+ # ((method_id * bucket_magic) >> BITS_MAGIC) % bucket_size
+ func_id = [
+ "mod",
+ shr(jumptable_utils.BITS_MAGIC, ["mul", bucket_magic, "_calldata_method_id"]),
+ bucket_size,
+ ]
+ func_info_location = ["add", bucket_location, ["mul", func_id, func_info_size]]
+ dst = 32 - func_info_size
+ assert func_info_size >= SZ_BUCKET_HEADER # otherwise mload will have dirty bytes
+ assert dst >= 0
+ selector_section.append(b1.resolve(["codecopy", dst, func_info_location, func_info_size]))
+
+ func_info = IRnode.from_list(["mload", 0])
+ fn_metadata_mask = 2 ** (FN_METADATA_BYTES * 8) - 1
+ calldatasize_mask = fn_metadata_mask - 1 # ex. 0xFFFE
+ with func_info.cache_when_complex("func_info") as (b1, func_info):
+ x = ["seq"]
+
+ # expected calldatasize always satisfies (x - 4) % 32 == 0
+ # the lower 5 bits are always 0b00100, so we can use those
+ # bits for other purposes.
+ is_nonpayable = ["and", 1, func_info]
+ expected_calldatasize = ["and", calldatasize_mask, func_info]
+
+ label_bits_ofst = FN_METADATA_BYTES * 8
+ function_label = ["and", 0xFFFF, shr(label_bits_ofst, func_info)]
+ method_id_bits_ofst = (FN_METADATA_BYTES + 2) * 8
+ function_method_id = shr(method_id_bits_ofst, func_info)
+
+ # check method id is right, if not then fallback.
+ # need to check calldatasize >= 4 in case there are
+ # trailing 0s in the method id.
+ calldatasize_valid = ["gt", "calldatasize", 3]
+ method_id_correct = ["eq", function_method_id, "_calldata_method_id"]
+ should_fallback = ["iszero", ["and", calldatasize_valid, method_id_correct]]
+ x.append(["if", should_fallback, ["goto", "fallback"]])
+
+ # assert callvalue == 0 if nonpayable
+ bad_callvalue = ["mul", is_nonpayable, "callvalue"]
+ # assert calldatasize at least minimum for the abi type
+ bad_calldatasize = ["lt", "calldatasize", expected_calldatasize]
+ failed_entry_conditions = ["or", bad_callvalue, bad_calldatasize]
+ check_entry_conditions = IRnode.from_list(
+ ["assert", ["iszero", failed_entry_conditions]],
+ error_msg="bad calldatasize or callvalue",
+ )
+ x.append(check_entry_conditions)
+ x.append(["jump", function_label])
+ selector_section.append(b1.resolve(x))
- if batch_payable_check:
- selector_section.append(["assert", ["iszero", "callvalue"]])
+ bucket_headers = ["data", "BUCKET_HEADERS"]
- for func_ast in nonpayables:
- func_ir = generate_ir_for_function(func_ast, all_sigs, global_ctx, skip_nonpayable_check)
- selector_section.append(func_ir)
+ for bucket_id, bucket in sorted(jumptable_info.items()):
+ bucket_headers.append(bucket.magic.to_bytes(2, "big"))
+ bucket_headers.append(["symbol", f"bucket_{bucket_id}"])
+ # note: buckets are usually ~10 items. to_bytes would
+ # fail if the int is too big.
+ bucket_headers.append(bucket.bucket_size.to_bytes(1, "big"))
- if default_function:
- fallback_ir = generate_ir_for_function(
- default_function, all_sigs, global_ctx, skip_nonpayable_check
- )
- else:
- fallback_ir = IRnode.from_list(
- ["revert", 0, 0], annotation="Default function", error_msg="fallback function"
- )
+ selector_section.append(bucket_headers)
- # ensure the external jumptable section gets closed out
- # (for basic block hygiene and also for zksync interpreter)
- # NOTE: this jump gets optimized out in assembly since the
- # fallback label is the immediate next instruction,
- close_selector_section = ["goto", "fallback"]
-
- runtime = [
- "seq",
- ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section],
- close_selector_section,
- ["label", "fallback", ["var_list"], fallback_ir],
- ]
+ for bucket_id, bucket in jumptable_info.items():
+ function_infos = ["data", f"bucket_{bucket_id}"]
+ # sort function infos by their image.
+ for method_id in bucket.method_ids_image_order:
+ abi_sig = sig_of[method_id]
+ entry_point = entry_points[abi_sig]
- # TODO: prune unreachable functions?
- runtime.extend(internal_functions_map.values())
+ method_id_bytes = method_id.to_bytes(4, "big")
+ symbol = ["symbol", label_for_entry_point(abi_sig, entry_point)]
+ func_metadata_int = entry_point.min_calldatasize | int(
+ not entry_point.func_t.is_payable
+ )
+ func_metadata = func_metadata_int.to_bytes(FN_METADATA_BYTES, "big")
- return runtime, internal_functions_map
+ function_infos.extend([method_id_bytes, symbol, func_metadata])
+ selector_section.append(function_infos)
+
+ ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), selector_section]]
+
+ ret.extend(function_irs)
+
+ return ret
+
+
+# codegen for all runtime functions + callvalue/calldata checks,
+# with O(1) jumptable for selector table.
+# uses two level strategy: uses `method_id % n_methods` to calculate
+# a bucket, and then descends into linear search from there.
+# costs about 126 gas for typical (nonpayable, >0 args, avg bucket size 1.5)
+# function and 24 bytes of code (+ ~23 bytes of global overhead)
+def _selector_section_sparse(external_functions, global_ctx):
+ ret = ["seq"]
+
+ if len(external_functions) == 0:
+ return ret
+
+ entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx)
+
+ n_buckets, buckets = jumptable_utils.generate_sparse_jumptable_buckets(entry_points.keys())
+
+ # 2 bytes for bucket location
+ SZ_BUCKET_HEADER = 2
+
+ if n_buckets > 1:
+ bucket_id = ["mod", "_calldata_method_id", n_buckets]
+ bucket_hdr_location = [
+ "add",
+ ["symbol", "selector_buckets"],
+ ["mul", bucket_id, SZ_BUCKET_HEADER],
+ ]
+ # get bucket header
+ dst = 32 - SZ_BUCKET_HEADER
+ assert dst >= 0
+
+ if _is_debug_mode():
+ ret.append(["assert", ["eq", "msize", 0]])
+
+ ret.append(["codecopy", dst, bucket_hdr_location, SZ_BUCKET_HEADER])
+
+ jumpdest = IRnode.from_list(["mload", 0])
+ # don't particularly like using `jump` here since it can cause
+ # issues for other backends, consider changing `goto` to allow
+ # dynamic jumps, or adding some kind of jumptable instruction
+ ret.append(["jump", jumpdest])
+
+ jumptable_data = ["data", "selector_buckets"]
+ for i in range(n_buckets):
+ if i in buckets:
+ bucket_label = f"selector_bucket_{i}"
+ jumptable_data.append(["symbol", bucket_label])
+ else:
+ # empty bucket
+ jumptable_data.append(["symbol", "fallback"])
+
+ ret.append(jumptable_data)
+
+ for bucket_id, bucket in buckets.items():
+ bucket_label = f"selector_bucket_{bucket_id}"
+ ret.append(["label", bucket_label, ["var_list"], ["seq"]])
+
+ handle_bucket = ["seq"]
+
+ for method_id in bucket:
+ sig = sig_of[method_id]
+ entry_point = entry_points[sig]
+ func_t = entry_point.func_t
+ expected_calldatasize = entry_point.min_calldatasize
+
+ dispatch = ["seq"] # code to dispatch into the function
+ skip_callvalue_check = func_t.is_payable
+ skip_calldatasize_check = expected_calldatasize == 4
+ bad_callvalue = [0] if skip_callvalue_check else ["callvalue"]
+ bad_calldatasize = (
+ [0] if skip_calldatasize_check else ["lt", "calldatasize", expected_calldatasize]
+ )
+
+ dispatch.append(
+ IRnode.from_list(
+ ["assert", ["iszero", ["or", bad_callvalue, bad_calldatasize]]],
+ error_msg="bad calldatasize or callvalue",
+ )
+ )
+ # we could skip a jumpdest per method if we out-lined the entry point
+ # so the dispatcher looks just like -
+ # ```(if (eq method_id)
+ # (goto entry_point_label))```
+ # it would another optimization for patterns like
+ # `if ... (goto)` though.
+ dispatch.append(entry_point.ir_node)
+
+ method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)]
+ has_trailing_zeroes = method_id.to_bytes(4, "big").endswith(b"\x00")
+ if has_trailing_zeroes:
+ # if the method id check has trailing 0s, we need to include
+ # a calldatasize check to distinguish from when not enough
+ # bytes are provided for the method id in calldata.
+ method_id_check = ["and", ["ge", "calldatasize", 4], method_id_check]
+ handle_bucket.append(["if", method_id_check, dispatch])
+
+ # close out the bucket with a goto fallback so we don't keep searching
+ handle_bucket.append(["goto", "fallback"])
+
+ ret.append(handle_bucket)
+
+ ret = ["seq", ["with", "_calldata_method_id", shr(224, ["calldataload", 0]), ret]]
-# take a GlobalContext, which is basically
-# and generate the runtime and deploy IR, also return the dict of all signatures
-def generate_ir_for_module(global_ctx: GlobalContext) -> Tuple[IRnode, IRnode, FunctionSignatures]:
+ return ret
+
+
+# codegen for all runtime functions + callvalue/calldata checks,
+# O(n) linear search for the method id
+# mainly keep this in for backends which cannot handle the indirect jump
+# in selector_section_dense and selector_section_sparse
+def _selector_section_linear(external_functions, global_ctx):
+ ret = ["seq"]
+ if len(external_functions) == 0:
+ return ret
+
+ ret.append(["if", ["lt", "calldatasize", 4], ["goto", "fallback"]])
+
+ entry_points, sig_of = _generate_external_entry_points(external_functions, global_ctx)
+
+ dispatcher = ["seq"]
+
+ for sig, entry_point in entry_points.items():
+ func_t = entry_point.func_t
+ expected_calldatasize = entry_point.min_calldatasize
+
+ dispatch = ["seq"] # code to dispatch into the function
+
+ if not func_t.is_payable:
+ callvalue_check = ["assert", ["iszero", "callvalue"]]
+ dispatch.append(IRnode.from_list(callvalue_check, error_msg="nonpayable check"))
+
+ good_calldatasize = ["ge", "calldatasize", expected_calldatasize]
+ calldatasize_check = ["assert", good_calldatasize]
+ dispatch.append(IRnode.from_list(calldatasize_check, error_msg="calldatasize check"))
+
+ dispatch.append(entry_point.ir_node)
+
+ method_id_check = ["eq", "_calldata_method_id", _annotated_method_id(sig)]
+ dispatcher.append(["if", method_id_check, dispatch])
+
+ ret.append(["with", "_calldata_method_id", shr(224, ["calldataload", 0]), dispatcher])
+
+ return ret
+
+
+# take a GlobalContext, and generate the runtime and deploy IR
+def generate_ir_for_module(global_ctx: GlobalContext) -> tuple[IRnode, IRnode]:
# order functions so that each function comes after all of its callees
function_defs = _topsort(global_ctx.functions)
- # FunctionSignatures for all interfaces defined in this module
- all_sigs: Dict[str, FunctionSignatures] = {}
+ runtime_functions = [f for f in function_defs if not _is_constructor(f)]
+ init_function = next((f for f in function_defs if _is_constructor(f)), None)
+
+ internal_functions = [f for f in runtime_functions if _is_internal(f)]
+
+ external_functions = [
+ f for f in runtime_functions if not _is_internal(f) and not _is_fallback(f)
+ ]
+ default_function = next((f for f in runtime_functions if _is_fallback(f)), None)
- init_function: Optional[vy_ast.FunctionDef] = None
- local_sigs: FunctionSignatures = {} # internal/local functions
+ internal_functions_ir: list[IRnode] = []
- # generate all signatures
- # TODO really this should live in GlobalContext
- for f in function_defs:
- sig = FunctionSignature.from_definition(f, global_ctx)
- # add it to the global namespace.
- local_sigs[sig.name] = sig
- # a little hacky, eventually FunctionSignature should be
- # merged with ContractFunction and we can remove this.
- f._metadata["signature"] = sig
+ # compile internal functions first so we have the function info
+ for func_ast in internal_functions:
+ func_ir = _ir_for_internal_function(func_ast, global_ctx, False)
+ internal_functions_ir.append(IRnode.from_list(func_ir))
+
+ if core._opt_none():
+ selector_section = _selector_section_linear(external_functions, global_ctx)
+ # dense vs sparse global overhead is amortized after about 4 methods.
+ # (--debug will force dense selector table anyway if _opt_codesize is selected.)
+ elif core._opt_codesize() and (len(external_functions) > 4 or _is_debug_mode()):
+ selector_section = _selector_section_dense(external_functions, global_ctx)
+ else:
+ selector_section = _selector_section_sparse(external_functions, global_ctx)
- assert "self" not in all_sigs
- all_sigs["self"] = local_sigs
+ if default_function:
+ fallback_ir = _ir_for_fallback_or_ctor(default_function, global_ctx)
+ else:
+ fallback_ir = IRnode.from_list(
+ ["revert", 0, 0], annotation="Default function", error_msg="fallback function"
+ )
- runtime_functions = [f for f in function_defs if not _is_init_func(f)]
- init_function = next((f for f in function_defs if _is_init_func(f)), None)
+ runtime = ["seq", selector_section]
+ runtime.append(["goto", "fallback"])
+ runtime.append(["label", "fallback", ["var_list"], fallback_ir])
- runtime, internal_functions = _runtime_ir(runtime_functions, all_sigs, global_ctx)
+ runtime.extend(internal_functions_ir)
deploy_code: List[Any] = ["seq"]
immutables_len = global_ctx.immutable_section_bytes
if init_function:
- init_func_ir = generate_ir_for_function(init_function, all_sigs, global_ctx, False)
- deploy_code.append(init_func_ir)
+ # cleanly rerun codegen for internal functions with `is_ctor_ctx=True`
+ ctor_internal_func_irs = []
+ internal_functions = [f for f in runtime_functions if _is_internal(f)]
+ for f in internal_functions:
+ init_func_t = init_function._metadata["type"]
+ if f.name not in init_func_t.recursive_calls:
+ # unreachable code, delete it
+ continue
+
+ func_ir = _ir_for_internal_function(f, global_ctx, is_ctor_context=True)
+ ctor_internal_func_irs.append(func_ir)
+
+ # generate init_func_ir after callees to ensure they have analyzed
+ # memory usage.
+ # TODO might be cleaner to separate this into an _init_ir helper func
+ init_func_ir = _ir_for_fallback_or_ctor(init_function, global_ctx, is_ctor_context=True)
# pass the amount of memory allocated for the init function
# so that deployment does not clobber while preparing immutables
# note: (deploy mem_ofst, code, extra_padding)
- init_mem_used = init_function._metadata["signature"].frame_info.mem_used
- deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])
+ init_mem_used = init_function._metadata["type"]._ir_info.frame_info.mem_used
+
+ # force msize to be initialized past the end of immutables section
+ # so that builtins which use `msize` for "dynamic" memory
+ # allocation do not clobber uninitialized immutables.
+ # cf. GH issue 3101.
+ # note mload/iload X touches bytes from X to X+32, and msize rounds up
+ # to the nearest 32, so `iload`ing `immutables_len - 32` guarantees
+ # that `msize` will refer to a memory location of at least
+ # ` + immutables_len` (where ==
+ # `_mem_deploy_end` as defined in the assembler).
+ # note:
+ # mload 32 => msize == 64
+ # mload 33 => msize == 96
+ # assumption in general: (mload X) => msize == ceil32(X + 32)
+ # see py-evm extend_memory: after_size = ceil32(start_position + size)
+ if immutables_len > 0:
+ deploy_code.append(["iload", max(0, immutables_len - 32)])
- # internal functions come after everything else
- for f in init_function._metadata["type"].called_functions:
- deploy_code.append(internal_functions[f.name])
+ deploy_code.append(init_func_ir)
+ deploy_code.append(["deploy", init_mem_used, runtime, immutables_len])
+ # internal functions come at end of initcode
+ deploy_code.extend(ctor_internal_func_irs)
else:
if immutables_len != 0:
raise CompilerPanic("unreachable")
deploy_code.append(["deploy", 0, runtime, 0])
- return IRnode.from_list(deploy_code), IRnode.from_list(runtime), local_sigs
+ return IRnode.from_list(deploy_code), IRnode.from_list(runtime)
diff --git a/vyper/codegen/return_.py b/vyper/codegen/return_.py
index 1cd50075a1..56bea2b8da 100644
--- a/vyper/codegen/return_.py
+++ b/vyper/codegen/return_.py
@@ -1,6 +1,5 @@
from typing import Any, Optional
-from vyper.address_space import MEMORY
from vyper.codegen.abi_encoder import abi_encode, abi_encoding_matches_vyper
from vyper.codegen.context import Context
from vyper.codegen.core import (
@@ -13,15 +12,16 @@
wrap_value_for_external_return,
)
from vyper.codegen.ir_node import IRnode
+from vyper.evm.address_space import MEMORY
Stmt = Any # mypy kludge
# Generate code for return stmt
def make_return_stmt(ir_val: IRnode, stmt: Any, context: Context) -> Optional[IRnode]:
- sig = context.sig
+ func_t = context.func_t
- jump_to_exit = ["exit_to", f"_sym_{sig.exit_sequence_label}"]
+ jump_to_exit = ["exit_to", func_t._ir_info.exit_sequence_label]
if context.return_type is None:
if stmt.value is not None:
@@ -35,7 +35,7 @@ def make_return_stmt(ir_val: IRnode, stmt: Any, context: Context) -> Optional[IR
# do NOT bypass this. jump_to_exit may do important function cleanup.
def finalize(fill_return_buffer):
fill_return_buffer = IRnode.from_list(
- fill_return_buffer, annotation=f"fill return buffer {sig._ir_identifier}"
+ fill_return_buffer, annotation=f"fill return buffer {func_t._ir_info.ir_identifier}"
)
cleanup_loops = "cleanup_repeat" if context.forvars else "seq"
# NOTE: because stack analysis is incomplete, cleanup_repeat must
@@ -43,7 +43,8 @@ def finalize(fill_return_buffer):
return IRnode.from_list(["seq", fill_return_buffer, cleanup_loops, jump_to_exit])
if context.return_type is None:
- jump_to_exit += ["return_pc"]
+ if context.is_internal:
+ jump_to_exit += ["return_pc"]
return finalize(["seq"])
if context.is_internal:
diff --git a/vyper/codegen/self_call.py b/vyper/codegen/self_call.py
index 2669f99192..c320e6889c 100644
--- a/vyper/codegen/self_call.py
+++ b/vyper/codegen/self_call.py
@@ -1,7 +1,7 @@
-from vyper.address_space import MEMORY
from vyper.codegen.core import _freshname, eval_once_check, make_setter
-from vyper.codegen.ir_node import IRnode, push_label_to_stack
-from vyper.exceptions import StateAccessViolation, StructureException
+from vyper.codegen.ir_node import IRnode
+from vyper.evm.address_space import MEMORY
+from vyper.exceptions import StateAccessViolation
from vyper.semantics.types.subscriptable import TupleT
_label_counter = 0
@@ -14,6 +14,21 @@ def _generate_label(name: str) -> str:
return f"label{_label_counter}"
+def _align_kwargs(func_t, args_ir):
+ """
+ Using a list of args, find the kwargs which need to be filled in by
+ the compiler
+ """
+
+ # sanity check
+ assert func_t.n_positional_args <= len(args_ir) <= func_t.n_total_args
+
+ num_provided_kwargs = len(args_ir) - func_t.n_positional_args
+
+ unprovided_kwargs = func_t.keyword_args[num_provided_kwargs:]
+ return [i.default_value for i in unprovided_kwargs]
+
+
def ir_for_self_call(stmt_expr, context):
from vyper.codegen.expr import Expr # TODO rethink this circular import
@@ -24,45 +39,45 @@ def ir_for_self_call(stmt_expr, context):
# - push jumpdest (callback ptr) and return buffer location
# - jump to label
# - (private function will fill return buffer and jump back)
-
method_name = stmt_expr.func.attr
+ func_t = stmt_expr.func._metadata["type"]
pos_args_ir = [Expr(x, context).ir_node for x in stmt_expr.args]
- sig, kw_vals = context.lookup_internal_function(method_name, pos_args_ir, stmt_expr)
+ default_vals = _align_kwargs(func_t, pos_args_ir)
+ default_vals_ir = [Expr(x, context).ir_node for x in default_vals]
- kw_args_ir = [Expr(x, context).ir_node for x in kw_vals]
-
- args_ir = pos_args_ir + kw_args_ir
+ args_ir = pos_args_ir + default_vals_ir
+ assert len(args_ir) == len(func_t.arguments)
args_tuple_t = TupleT([x.typ for x in args_ir])
args_as_tuple = IRnode.from_list(["multi"] + [x for x in args_ir], typ=args_tuple_t)
- if context.is_constant() and sig.mutability not in ("view", "pure"):
+ # CMC 2023-05-17 this seems like it is already caught in typechecker
+ if context.is_constant() and func_t.is_mutable:
raise StateAccessViolation(
f"May not call state modifying function "
f"'{method_name}' within {context.pp_constancy()}.",
stmt_expr,
)
- # TODO move me to type checker phase
- if not sig.internal:
- raise StructureException("Cannot call external functions via 'self'", stmt_expr)
-
- return_label = _generate_label(f"{sig.internal_function_label}_call")
+ # note: internal_function_label asserts `func_t.is_internal`.
+ _label = func_t._ir_info.internal_function_label(context.is_ctor_context)
+ return_label = _generate_label(f"{_label}_call")
# allocate space for the return buffer
# TODO allocate in stmt and/or expr.py
- if sig.return_type is not None:
+ if func_t.return_type is not None:
return_buffer = IRnode.from_list(
- context.new_internal_variable(sig.return_type), annotation=f"{return_label}_return_buf"
+ context.new_internal_variable(func_t.return_type),
+ annotation=f"{return_label}_return_buf",
)
else:
return_buffer = None
# note: dst_tuple_t != args_tuple_t
- dst_tuple_t = TupleT([arg.typ for arg in sig.args])
- args_dst = IRnode(sig.frame_info.frame_start, typ=dst_tuple_t, location=MEMORY)
+ dst_tuple_t = TupleT(tuple(func_t.argument_types))
+ args_dst = IRnode(func_t._ir_info.frame_info.frame_start, typ=dst_tuple_t, location=MEMORY)
# if one of the arguments is a self call, the argument
# buffer could get borked. to prevent against that,
@@ -84,12 +99,12 @@ def ir_for_self_call(stmt_expr, context):
else:
copy_args = make_setter(args_dst, args_as_tuple)
- goto_op = ["goto", sig.internal_function_label]
+ goto_op = ["goto", func_t._ir_info.internal_function_label(context.is_ctor_context)]
# pass return buffer to subroutine
if return_buffer is not None:
goto_op += [return_buffer]
# pass return label to subroutine
- goto_op += [push_label_to_stack(return_label)]
+ goto_op.append(["symbol", return_label])
call_sequence = ["seq"]
call_sequence.append(eval_once_check(_freshname(stmt_expr.node_source_code)))
@@ -100,10 +115,10 @@ def ir_for_self_call(stmt_expr, context):
o = IRnode.from_list(
call_sequence,
- typ=sig.return_type,
+ typ=func_t.return_type,
location=MEMORY,
annotation=stmt_expr.get("node_source_code"),
- add_gas_estimate=sig.gas_estimate,
+ add_gas_estimate=func_t._ir_info.gas_estimate,
)
o.is_self_call = True
return o
diff --git a/vyper/codegen/stmt.py b/vyper/codegen/stmt.py
index de9801a740..3ecb0afdc3 100644
--- a/vyper/codegen/stmt.py
+++ b/vyper/codegen/stmt.py
@@ -1,7 +1,6 @@
import vyper.codegen.events as events
import vyper.utils as util
from vyper import ast as vy_ast
-from vyper.address_space import MEMORY, STORAGE
from vyper.builtins.functions import STMT_DISPATCH_TABLE
from vyper.codegen import external_call, self_call
from vyper.codegen.context import Constancy, Context
@@ -11,6 +10,7 @@
IRnode,
append_dyn_array,
check_assign,
+ clamp,
dummy_node_for_type,
get_dyn_array_count,
get_element_ptr,
@@ -23,6 +23,7 @@
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
+from vyper.evm.address_space import MEMORY, STORAGE
from vyper.exceptions import CompilerPanic, StructureException, TypeCheckFailure
from vyper.semantics.types import DArrayT, MemberFunctionT
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
@@ -59,7 +60,7 @@ def parse_Name(self):
raise StructureException(f"Unsupported statement type: {type(self.stmt)}", self.stmt)
def parse_AnnAssign(self):
- ltyp = self.context.parse_type(self.stmt.annotation)
+ ltyp = self.stmt.target._metadata["type"]
varname = self.stmt.target.id
alloced = self.context.new_variable(varname, ltyp)
@@ -72,11 +73,22 @@ def parse_AnnAssign(self):
def parse_Assign(self):
# Assignment (e.g. x[4] = y)
- sub = Expr(self.stmt.value, self.context).ir_node
- target = self._get_target(self.stmt.target)
+ src = Expr(self.stmt.value, self.context).ir_node
+ dst = self._get_target(self.stmt.target)
- ir_node = make_setter(target, sub)
- return ir_node
+ ret = ["seq"]
+ overlap = len(dst.referenced_variables & src.referenced_variables) > 0
+ if overlap and not dst.typ._is_prim_word:
+ # there is overlap between the lhs and rhs, and the type is
+ # complex - i.e., it spans multiple words. for safety, we
+ # copy to a temporary buffer before copying to the destination.
+ tmp = self.context.new_internal_variable(src.typ)
+ tmp = IRnode.from_list(tmp, typ=src.typ, location=MEMORY)
+ ret.append(make_setter(tmp, src))
+ src = tmp
+
+ ret.append(make_setter(dst, src))
+ return IRnode.from_list(ret)
def parse_If(self):
if self.stmt.orelse:
@@ -246,11 +258,17 @@ def _parse_For_range(self):
arg0 = self.stmt.iter.args[0]
num_of_args = len(self.stmt.iter.args)
+ kwargs = {
+ s.arg: Expr.parse_value_expr(s.value, self.context)
+ for s in self.stmt.iter.keywords or []
+ }
+
# Type 1 for, e.g. for i in range(10): ...
if num_of_args == 1:
- arg0_val = self._get_range_const_value(arg0)
+ n = Expr.parse_value_expr(arg0, self.context)
start = IRnode.from_list(0, typ=iter_typ)
- rounds = arg0_val
+ rounds = n
+ rounds_bound = kwargs.get("bound", rounds)
# Type 2 for, e.g. for i in range(100, 110): ...
elif self._check_valid_range_constant(self.stmt.iter.args[1]).is_literal:
@@ -258,15 +276,19 @@ def _parse_For_range(self):
arg1_val = self._get_range_const_value(self.stmt.iter.args[1])
start = IRnode.from_list(arg0_val, typ=iter_typ)
rounds = IRnode.from_list(arg1_val - arg0_val, typ=iter_typ)
+ rounds_bound = rounds
# Type 3 for, e.g. for i in range(x, x + 10): ...
else:
arg1 = self.stmt.iter.args[1]
rounds = self._get_range_const_value(arg1.right)
start = Expr.parse_value_expr(arg0, self.context)
+ _, hi = start.typ.int_bounds
+ start = clamp("le", start, hi + 1 - rounds)
+ rounds_bound = rounds
- r = rounds if isinstance(rounds, int) else rounds.value
- if r < 1:
+ bound = rounds_bound if isinstance(rounds_bound, int) else rounds_bound.value
+ if bound < 1:
return
varname = self.stmt.target.id
@@ -280,7 +302,12 @@ def _parse_For_range(self):
loop_body.append(["mstore", iptr, i])
loop_body.append(parse_body(self.stmt.body, self.context))
- ir_node = IRnode.from_list(["repeat", i, start, rounds, rounds, loop_body])
+ # NOTE: codegen for `repeat` inserts an assertion that rounds <= rounds_bound.
+ # if we ever want to remove that, we need to manually add the assertion
+ # where it makes sense.
+ ir_node = IRnode.from_list(
+ ["repeat", i, start, rounds, rounds_bound, loop_body], error_msg="range() bounds check"
+ )
del self.context.forvars[varname]
return ir_node
@@ -289,10 +316,8 @@ def _parse_For_list(self):
with self.context.range_scope():
iter_list = Expr(self.stmt.iter, self.context).ir_node
- # override with type inferred at typechecking time
- # TODO investigate why stmt.target.type != stmt.iter.type.value_type
target_type = self.stmt.target._metadata["type"]
- iter_list.typ.value_type = target_type
+ assert target_type == iter_list.typ.value_type
# user-supplied name for loop variable
varname = self.stmt.target.id
@@ -333,8 +358,12 @@ def _parse_For_list(self):
def parse_AugAssign(self):
target = self._get_target(self.stmt.target)
+
sub = Expr.parse_value_expr(self.stmt.value, self.context)
if not target.typ._is_prim_word:
+ # because of this check, we do not need to check for
+ # make_setter references lhs<->rhs as in parse_Assign -
+ # single word load/stores are atomic.
return
with target.cache_when_complex("_loc") as (b, target):
diff --git a/vyper/compiler/__init__.py b/vyper/compiler/__init__.py
index 7be45ce832..0b3c0d8191 100644
--- a/vyper/compiler/__init__.py
+++ b/vyper/compiler/__init__.py
@@ -5,7 +5,8 @@
import vyper.codegen.core as codegen
import vyper.compiler.output as output
from vyper.compiler.phases import CompilerData
-from vyper.evm.opcodes import DEFAULT_EVM_VERSION, evm_wrapper
+from vyper.compiler.settings import Settings
+from vyper.evm.opcodes import DEFAULT_EVM_VERSION, anchor_evm_version
from vyper.typing import (
ContractCodes,
ContractPath,
@@ -46,15 +47,14 @@
}
-@evm_wrapper
def compile_codes(
contract_sources: ContractCodes,
output_formats: Union[OutputDict, OutputFormats, None] = None,
exc_handler: Union[Callable, None] = None,
interface_codes: Union[InterfaceDict, InterfaceImports, None] = None,
initial_id: int = 0,
- no_optimize: bool = False,
- storage_layouts: Dict[ContractPath, StorageLayout] = None,
+ settings: Settings = None,
+ storage_layouts: Optional[dict[ContractPath, Optional[StorageLayout]]] = None,
show_gas_estimates: bool = False,
no_bytecode_metadata: bool = False,
) -> OrderedDict:
@@ -73,11 +73,8 @@ def compile_codes(
two arguments - the name of the contract, and the exception that was raised
initial_id: int, optional
The lowest source ID value to be used when generating the source map.
- evm_version: str, optional
- The target EVM ruleset to compile for. If not given, defaults to the latest
- implemented ruleset.
- no_optimize: bool, optional
- Turn off optimizations. Defaults to False
+ settings: Settings, optional
+ Compiler settings
show_gas_estimates: bool, optional
Show gas estimates for abi and ir output modes
interface_codes: Dict, optional
@@ -98,6 +95,7 @@ def compile_codes(
Dict
Compiler output as `{'contract name': {'output key': "output data"}}`
"""
+ settings = settings or Settings()
if output_formats is None:
output_formats = ("bytecode",)
@@ -121,27 +119,30 @@ def compile_codes(
# make IR output the same between runs
codegen.reset_names()
- compiler_data = CompilerData(
- source_code,
- contract_name,
- interfaces,
- source_id,
- no_optimize,
- storage_layout_override,
- show_gas_estimates,
- no_bytecode_metadata,
- )
- for output_format in output_formats[contract_name]:
- if output_format not in OUTPUT_FORMATS:
- raise ValueError(f"Unsupported format type {repr(output_format)}")
- try:
- out.setdefault(contract_name, {})
- out[contract_name][output_format] = OUTPUT_FORMATS[output_format](compiler_data)
- except Exception as exc:
- if exc_handler is not None:
- exc_handler(contract_name, exc)
- else:
- raise exc
+
+ with anchor_evm_version(settings.evm_version):
+ compiler_data = CompilerData(
+ source_code,
+ contract_name,
+ interfaces,
+ source_id,
+ settings,
+ storage_layout_override,
+ show_gas_estimates,
+ no_bytecode_metadata,
+ )
+ for output_format in output_formats[contract_name]:
+ if output_format not in OUTPUT_FORMATS:
+ raise ValueError(f"Unsupported format type {repr(output_format)}")
+ try:
+ out.setdefault(contract_name, {})
+ formatter = OUTPUT_FORMATS[output_format]
+ out[contract_name][output_format] = formatter(compiler_data)
+ except Exception as exc:
+ if exc_handler is not None:
+ exc_handler(contract_name, exc)
+ else:
+ raise exc
return out
@@ -153,9 +154,8 @@ def compile_code(
contract_source: str,
output_formats: Optional[OutputFormats] = None,
interface_codes: Optional[InterfaceImports] = None,
- evm_version: str = DEFAULT_EVM_VERSION,
- no_optimize: bool = False,
- storage_layout_override: StorageLayout = None,
+ settings: Settings = None,
+ storage_layout_override: Optional[StorageLayout] = None,
show_gas_estimates: bool = False,
) -> dict:
"""
@@ -171,8 +171,8 @@ def compile_code(
evm_version: str, optional
The target EVM ruleset to compile for. If not given, defaults to the latest
implemented ruleset.
- no_optimize: bool, optional
- Turn off optimizations. Defaults to False
+ settings: Settings, optional
+ Compiler settings.
show_gas_estimates: bool, optional
Show gas estimates for abi and ir output modes
interface_codes: Dict, optional
@@ -194,8 +194,7 @@ def compile_code(
contract_sources,
output_formats,
interface_codes=interface_codes,
- evm_version=evm_version,
- no_optimize=no_optimize,
+ settings=settings,
storage_layouts=storage_layouts,
show_gas_estimates=show_gas_estimates,
)[UNKNOWN_CONTRACT_NAME]
diff --git a/vyper/compiler/output.py b/vyper/compiler/output.py
index 82e583c160..334c5ba613 100644
--- a/vyper/compiler/output.py
+++ b/vyper/compiler/output.py
@@ -44,7 +44,7 @@ def build_external_interface_output(compiler_data: CompilerData) -> str:
for func in interface.functions.values():
if func.visibility == FunctionVisibility.INTERNAL or func.name == "__init__":
continue
- args = ", ".join([f"{name}: {typ}" for name, typ in func.arguments.items()])
+ args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments])
return_value = f" -> {func.return_type}" if func.return_type is not None else ""
mutability = func.mutability.value
out = f"{out} def {func.name}({args}){return_value}: {mutability}\n"
@@ -69,7 +69,7 @@ def build_interface_output(compiler_data: CompilerData) -> str:
continue
if func.mutability != StateMutability.NONPAYABLE:
out = f"{out}@{func.mutability.value}\n"
- args = ", ".join([f"{name}: {typ}" for name, typ in func.arguments.items()])
+ args = ", ".join([f"{arg.name}: {arg.typ}" for arg in func.arguments])
return_value = f" -> {func.return_type}" if func.return_type is not None else ""
out = f"{out}@external\ndef {func.name}({args}){return_value}:\n pass\n\n"
@@ -117,21 +117,39 @@ def _var_rec_dict(variable_record):
ret["location"] = ret["location"].name
return ret
- def _to_dict(sig):
- ret = vars(sig)
+ def _to_dict(func_t):
+ ret = vars(func_t)
ret["return_type"] = str(ret["return_type"])
- ret["_ir_identifier"] = sig._ir_identifier
- for attr in ("gas_estimate", "func_ast_code"):
- del ret[attr]
- for attr in ("args", "base_args", "default_args"):
- if attr in ret:
- ret[attr] = {arg.name: str(arg.typ) for arg in ret[attr]}
- for k in ret["default_values"]:
- # e.g. {"x": vy_ast.Int(..)} -> {"x": 1}
- ret["default_values"][k] = ret["default_values"][k].node_source_code
- ret["frame_info"] = vars(ret["frame_info"])
- for k in ret["frame_info"]["frame_vars"].keys():
- ret["frame_info"]["frame_vars"][k] = _var_rec_dict(ret["frame_info"]["frame_vars"][k])
+ ret["_ir_identifier"] = func_t._ir_info.ir_identifier
+
+ for attr in ("mutability", "visibility"):
+ ret[attr] = ret[attr].name.lower()
+
+ # e.g. {"x": vy_ast.Int(..)} -> {"x": 1}
+ ret["default_values"] = {
+ k: val.node_source_code for k, val in func_t.default_values.items()
+ }
+
+ for attr in ("positional_args", "keyword_args"):
+ args = ret[attr]
+ ret[attr] = {arg.name: str(arg.typ) for arg in args}
+
+ ret["frame_info"] = vars(func_t._ir_info.frame_info)
+ del ret["frame_info"]["frame_vars"] # frame_var.pos might be IR, cannot serialize
+
+ keep_keys = {
+ "name",
+ "return_type",
+ "positional_args",
+ "keyword_args",
+ "default_values",
+ "frame_info",
+ "mutability",
+ "visibility",
+ "_ir_identifier",
+ "nonreentrant_key",
+ }
+ ret = {k: v for k, v in ret.items() if k in keep_keys}
return ret
return {"function_info": {name: _to_dict(sig) for (name, sig) in sigs.items()}}
@@ -190,7 +208,7 @@ def _build_asm(asm_list):
else:
output_string += str(node) + " "
- if isinstance(node, str) and node.startswith("PUSH"):
+ if isinstance(node, str) and node.startswith("PUSH") and node != "PUSH0":
assert in_push == 0
in_push = int(node[4:])
output_string += "0x"
@@ -200,9 +218,7 @@ def _build_asm(asm_list):
def build_source_map_output(compiler_data: CompilerData) -> OrderedDict:
_, line_number_map = compile_ir.assembly_to_evm(
- compiler_data.assembly_runtime,
- insert_vyper_signature=True,
- disable_bytecode_metadata=compiler_data.no_bytecode_metadata,
+ compiler_data.assembly_runtime, insert_compiler_metadata=False
)
# Sort line_number_map
out = OrderedDict()
@@ -284,9 +300,13 @@ def _build_opcodes(bytecode: bytes) -> str:
while bytecode_sequence:
op = bytecode_sequence.popleft()
- opcode_output.append(opcode_map[op])
- if "PUSH" in opcode_output[-1]:
+ opcode_output.append(opcode_map.get(op, f"VERBATIM_{hex(op)}"))
+ if "PUSH" in opcode_output[-1] and opcode_output[-1] != "PUSH0":
push_len = int(opcode_map[op][4:])
+ # we can have push_len > len(bytecode_sequence) when there is data
+ # (instead of code) at end of contract
+ # CMC 2023-07-13 maybe just strip known data segments?
+ push_len = min(push_len, len(bytecode_sequence))
push_values = [hex(bytecode_sequence.popleft())[2:] for i in range(push_len)]
opcode_output.append(f"0x{''.join(push_values).upper()}")
diff --git a/vyper/compiler/phases.py b/vyper/compiler/phases.py
index fb5a9da4f5..17c8f2788a 100644
--- a/vyper/compiler/phases.py
+++ b/vyper/compiler/phases.py
@@ -4,12 +4,15 @@
from typing import Optional, Tuple
from vyper import ast as vy_ast
-from vyper.ast.signatures.function_signature import FunctionSignatures
from vyper.codegen import module
+from vyper.codegen.core import anchor_opt_level
from vyper.codegen.global_context import GlobalContext
from vyper.codegen.ir_node import IRnode
+from vyper.compiler.settings import OptimizationLevel, Settings
+from vyper.exceptions import StructureException
from vyper.ir import compile_ir, optimizer
from vyper.semantics import set_data_positions, validate_semantics
+from vyper.semantics.types.function import ContractFunctionT
from vyper.typing import InterfaceImports, StorageLayout
@@ -49,7 +52,7 @@ def __init__(
contract_name: str = "VyperContract",
interface_codes: Optional[InterfaceImports] = None,
source_id: int = 0,
- no_optimize: bool = False,
+ settings: Settings = None,
storage_layout: StorageLayout = None,
show_gas_estimates: bool = False,
no_bytecode_metadata: bool = False,
@@ -69,8 +72,8 @@ def __init__(
* JSON interfaces are given as lists, vyper interfaces as strings
source_id : int, optional
ID number used to identify this contract in the source map.
- no_optimize: bool, optional
- Turn off optimizations. Defaults to False
+ settings: Settings
+ Set optimization mode.
show_gas_estimates: bool, optional
Show gas estimates for abi and ir output modes
no_bytecode_metadata: bool, optional
@@ -80,14 +83,45 @@ def __init__(
self.source_code = source_code
self.interface_codes = interface_codes
self.source_id = source_id
- self.no_optimize = no_optimize
self.storage_layout_override = storage_layout
self.show_gas_estimates = show_gas_estimates
self.no_bytecode_metadata = no_bytecode_metadata
+ self.settings = settings or Settings()
@cached_property
- def vyper_module(self) -> vy_ast.Module:
- return generate_ast(self.source_code, self.source_id, self.contract_name)
+ def _generate_ast(self):
+ settings, ast = generate_ast(self.source_code, self.source_id, self.contract_name)
+ # validate the compiler settings
+ # XXX: this is a bit ugly, clean up later
+ if settings.evm_version is not None:
+ if (
+ self.settings.evm_version is not None
+ and self.settings.evm_version != settings.evm_version
+ ):
+ raise StructureException(
+ f"compiler settings indicate evm version {self.settings.evm_version}, "
+ f"but source pragma indicates {settings.evm_version}."
+ )
+
+ self.settings.evm_version = settings.evm_version
+
+ if settings.optimize is not None:
+ if self.settings.optimize is not None and self.settings.optimize != settings.optimize:
+ raise StructureException(
+ f"compiler options indicate optimization mode {self.settings.optimize}, "
+ f"but source pragma indicates {settings.optimize}."
+ )
+ self.settings.optimize = settings.optimize
+
+ # ensure defaults
+ if self.settings.optimize is None:
+ self.settings.optimize = OptimizationLevel.default()
+
+ return ast
+
+ @cached_property
+ def vyper_module(self):
+ return self._generate_ast
@cached_property
def vyper_module_unfolded(self) -> vy_ast.Module:
@@ -119,42 +153,43 @@ def global_ctx(self) -> GlobalContext:
@cached_property
def _ir_output(self):
# fetch both deployment and runtime IR
- return generate_ir_nodes(self.global_ctx, self.no_optimize)
+ return generate_ir_nodes(self.global_ctx, self.settings.optimize)
@property
def ir_nodes(self) -> IRnode:
- ir, ir_runtime, sigs = self._ir_output
+ ir, ir_runtime = self._ir_output
return ir
@property
def ir_runtime(self) -> IRnode:
- ir, ir_runtime, sigs = self._ir_output
+ ir, ir_runtime = self._ir_output
return ir_runtime
@property
- def function_signatures(self) -> FunctionSignatures:
- ir, ir_runtime, sigs = self._ir_output
- return sigs
+ def function_signatures(self) -> dict[str, ContractFunctionT]:
+ # some metadata gets calculated during codegen, so
+ # ensure codegen is run:
+ _ = self._ir_output
+
+ fs = self.vyper_module_folded.get_children(vy_ast.FunctionDef)
+ return {f.name: f._metadata["type"] for f in fs}
@cached_property
def assembly(self) -> list:
- return generate_assembly(self.ir_nodes, self.no_optimize)
+ return generate_assembly(self.ir_nodes, self.settings.optimize)
@cached_property
def assembly_runtime(self) -> list:
- return generate_assembly(self.ir_runtime, self.no_optimize)
+ return generate_assembly(self.ir_runtime, self.settings.optimize)
@cached_property
def bytecode(self) -> bytes:
- return generate_bytecode(
- self.assembly, is_runtime=False, no_bytecode_metadata=self.no_bytecode_metadata
- )
+ insert_compiler_metadata = not self.no_bytecode_metadata
+ return generate_bytecode(self.assembly, insert_compiler_metadata=insert_compiler_metadata)
@cached_property
def bytecode_runtime(self) -> bytes:
- return generate_bytecode(
- self.assembly_runtime, is_runtime=True, no_bytecode_metadata=self.no_bytecode_metadata
- )
+ return generate_bytecode(self.assembly_runtime, insert_compiler_metadata=False)
@cached_property
def blueprint_bytecode(self) -> bytes:
@@ -168,7 +203,9 @@ def blueprint_bytecode(self) -> bytes:
return deploy_bytecode + blueprint_bytecode
-def generate_ast(source_code: str, source_id: int, contract_name: str) -> vy_ast.Module:
+def generate_ast(
+ source_code: str, source_id: int, contract_name: str
+) -> tuple[Settings, vy_ast.Module]:
"""
Generate a Vyper AST from source code.
@@ -186,7 +223,7 @@ def generate_ast(source_code: str, source_id: int, contract_name: str) -> vy_ast
vy_ast.Module
Top-level Vyper AST node
"""
- return vy_ast.parse_to_ast(source_code, source_id, contract_name)
+ return vy_ast.parse_to_ast_with_settings(source_code, source_id, contract_name)
def generate_unfolded_ast(
@@ -225,15 +262,14 @@ def generate_folded_ast(
vyper_module_folded = copy.deepcopy(vyper_module)
vy_ast.folding.fold(vyper_module_folded)
validate_semantics(vyper_module_folded, interface_codes)
- vy_ast.expansion.expand_annotated_ast(vyper_module_folded)
symbol_tables = set_data_positions(vyper_module_folded, storage_layout_overrides)
return vyper_module_folded, symbol_tables
def generate_ir_nodes(
- global_ctx: GlobalContext, no_optimize: bool
-) -> Tuple[IRnode, IRnode, FunctionSignatures]:
+ global_ctx: GlobalContext, optimize: OptimizationLevel
+) -> tuple[IRnode, IRnode]:
"""
Generate the intermediate representation (IR) from the contextualized AST.
@@ -253,14 +289,15 @@ def generate_ir_nodes(
IR to generate deployment bytecode
IR to generate runtime bytecode
"""
- ir_nodes, ir_runtime, function_sigs = module.generate_ir_for_module(global_ctx)
- if not no_optimize:
+ with anchor_opt_level(optimize):
+ ir_nodes, ir_runtime = module.generate_ir_for_module(global_ctx)
+ if optimize != OptimizationLevel.NONE:
ir_nodes = optimizer.optimize(ir_nodes)
ir_runtime = optimizer.optimize(ir_runtime)
- return ir_nodes, ir_runtime, function_sigs
+ return ir_nodes, ir_runtime
-def generate_assembly(ir_nodes: IRnode, no_optimize: bool = False) -> list:
+def generate_assembly(ir_nodes: IRnode, optimize: Optional[OptimizationLevel] = None) -> list:
"""
Generate assembly instructions from IR.
@@ -274,7 +311,8 @@ def generate_assembly(ir_nodes: IRnode, no_optimize: bool = False) -> list:
list
List of assembly instructions.
"""
- assembly = compile_ir.compile_to_assembly(ir_nodes, no_optimize=no_optimize)
+ optimize = optimize or OptimizationLevel.default()
+ assembly = compile_ir.compile_to_assembly(ir_nodes, optimize=optimize)
if _find_nested_opcode(assembly, "DEBUG"):
warnings.warn(
@@ -292,9 +330,7 @@ def _find_nested_opcode(assembly, key):
return any(_find_nested_opcode(x, key) for x in sublists)
-def generate_bytecode(
- assembly: list, is_runtime: bool = False, no_bytecode_metadata: bool = False
-) -> bytes:
+def generate_bytecode(assembly: list, insert_compiler_metadata: bool) -> bytes:
"""
Generate bytecode from assembly instructions.
@@ -308,6 +344,6 @@ def generate_bytecode(
bytes
Final compiled bytecode.
"""
- return compile_ir.assembly_to_evm(
- assembly, insert_vyper_signature=is_runtime, disable_bytecode_metadata=no_bytecode_metadata
- )[0]
+ return compile_ir.assembly_to_evm(assembly, insert_compiler_metadata=insert_compiler_metadata)[
+ 0
+ ]
diff --git a/vyper/compiler/settings.py b/vyper/compiler/settings.py
index 09ced0dcb8..d2c88a8592 100644
--- a/vyper/compiler/settings.py
+++ b/vyper/compiler/settings.py
@@ -1,4 +1,6 @@
import os
+from dataclasses import dataclass
+from enum import Enum
from typing import Optional
VYPER_COLOR_OUTPUT = os.environ.get("VYPER_COLOR_OUTPUT", "0") == "1"
@@ -12,3 +14,44 @@
VYPER_TRACEBACK_LIMIT = int(_tb_limit_str)
else:
VYPER_TRACEBACK_LIMIT = None
+
+
+class OptimizationLevel(Enum):
+ NONE = 1
+ GAS = 2
+ CODESIZE = 3
+
+ @classmethod
+ def from_string(cls, val):
+ match val:
+ case "none":
+ return cls.NONE
+ case "gas":
+ return cls.GAS
+ case "codesize":
+ return cls.CODESIZE
+ raise ValueError(f"unrecognized optimization level: {val}")
+
+ @classmethod
+ def default(cls):
+ return cls.GAS
+
+
+@dataclass
+class Settings:
+ compiler_version: Optional[str] = None
+ optimize: Optional[OptimizationLevel] = None
+ evm_version: Optional[str] = None
+
+
+_DEBUG = False
+
+
+def _is_debug_mode():
+ global _DEBUG
+ return _DEBUG
+
+
+def _set_debug_mode(dbg: bool = False) -> None:
+ global _DEBUG
+ _DEBUG = dbg
diff --git a/vyper/compiler/utils.py b/vyper/compiler/utils.py
index baa35ac93d..8de2589367 100644
--- a/vyper/compiler/utils.py
+++ b/vyper/compiler/utils.py
@@ -1,12 +1,15 @@
from typing import Dict
-from vyper.ast.signatures import FunctionSignature
+from vyper.semantics.types.function import ContractFunctionT
-def build_gas_estimates(function_sigs: Dict[str, FunctionSignature]) -> dict:
- # note: `.gas_estimate` is added to FunctionSignature
- # in vyper/codegen/function_definitions/common.py
- return {k: v.gas_estimate for (k, v) in function_sigs.items()}
+def build_gas_estimates(func_ts: Dict[str, ContractFunctionT]) -> dict:
+ # note: `.gas_estimate` is added to ContractFunctionT._ir_info
+ # in vyper/semantics/types/function.py
+ ret = {}
+ for k, v in func_ts.items():
+ ret[k] = v._ir_info.gas_estimate
+ return ret
def expand_source_map(compressed_map: str) -> list:
diff --git a/vyper/address_space.py b/vyper/evm/address_space.py
similarity index 97%
rename from vyper/address_space.py
rename to vyper/evm/address_space.py
index 855e98b5c8..85a75c3c23 100644
--- a/vyper/address_space.py
+++ b/vyper/evm/address_space.py
@@ -48,6 +48,7 @@ def byte_addressable(self) -> bool:
MEMORY = AddrSpace("memory", 32, "mload", "mstore")
STORAGE = AddrSpace("storage", 1, "sload", "sstore")
+TRANSIENT = AddrSpace("transient", 1, "tload", "tstore")
CALLDATA = AddrSpace("calldata", 32, "calldataload")
# immutables address space: "immutables" section of memory
# which is read-write in deploy code but then gets turned into
diff --git a/vyper/evm/opcodes.py b/vyper/evm/opcodes.py
index b9f1e77ca8..767d634c89 100644
--- a/vyper/evm/opcodes.py
+++ b/vyper/evm/opcodes.py
@@ -1,41 +1,28 @@
-from typing import Dict, Optional
+import contextlib
+from typing import Dict, Generator, Optional
from vyper.exceptions import CompilerPanic
from vyper.typing import OpcodeGasCost, OpcodeMap, OpcodeRulesetMap, OpcodeRulesetValue, OpcodeValue
-active_evm_version: int = 4
-
# EVM version rules work as follows:
# 1. Fork rules go from oldest (lowest value) to newest (highest value).
# 2. Fork versions aren't actually tied to anything. They are not a part of our
# official API. *DO NOT USE THE VALUES FOR ANYTHING IMPORTANT* besides versioning.
-# 3. When support for an older version is dropped, the numbers should *not* change to
-# reflect it (i.e. dropping support for version 0 removes version 0 entirely).
-# 4. There can be multiple aliases to the same version number (but not the reverse).
-# 5. When supporting multiple chains, if a chain gets a fix first, it increments the
-# number first.
-# 6. Yes, this will probably have to be rethought if there's ever conflicting support
-# between multiple chains for a specific feature. Let's hope not.
-# 7. We support at a maximum 3 hard forks (for any given chain).
-EVM_VERSIONS: Dict[str, int] = {
- # ETH Forks
- "byzantium": 0,
- "constantinople": 1,
- "petersburg": 1,
- "istanbul": 2,
- "berlin": 3,
- "paris": 4,
- # ETC Forks
- "atlantis": 0,
- "agharta": 1,
-}
-DEFAULT_EVM_VERSION: str = "paris"
+# 3. Per VIP-3365, we support mainnet fork choice rules up to 1 year old
+# (and may optionally have forward support for experimental/unreleased
+# fork choice rules)
+_evm_versions = ("istanbul", "berlin", "london", "paris", "shanghai", "cancun")
+EVM_VERSIONS: dict[str, int] = dict((v, i) for i, v in enumerate(_evm_versions))
+
+
+DEFAULT_EVM_VERSION: str = "shanghai"
+active_evm_version: int = EVM_VERSIONS[DEFAULT_EVM_VERSION]
# opcode as hex value
# number of values removed from stack
# number of values added to stack
-# gas cost (byzantium, constantinople, istanbul, berlin)
+# gas cost (istanbul, berlin, paris, shanghai, cancun)
OPCODES: OpcodeMap = {
"STOP": (0x00, 0, 0, 0),
"ADD": (0x01, 2, 1, 3),
@@ -60,12 +47,12 @@
"XOR": (0x18, 2, 1, 3),
"NOT": (0x19, 1, 1, 3),
"BYTE": (0x1A, 2, 1, 3),
- "SHL": (0x1B, 2, 1, (None, 3)),
- "SHR": (0x1C, 2, 1, (None, 3)),
- "SAR": (0x1D, 2, 1, (None, 3)),
+ "SHL": (0x1B, 2, 1, 3),
+ "SHR": (0x1C, 2, 1, 3),
+ "SAR": (0x1D, 2, 1, 3),
"SHA3": (0x20, 2, 1, 30),
"ADDRESS": (0x30, 0, 1, 2),
- "BALANCE": (0x31, 1, 1, (400, 400, 700)),
+ "BALANCE": (0x31, 1, 1, 700),
"ORIGIN": (0x32, 0, 1, 2),
"CALLER": (0x33, 0, 1, 2),
"CALLVALUE": (0x34, 0, 1, 2),
@@ -75,11 +62,11 @@
"CODESIZE": (0x38, 0, 1, 2),
"CODECOPY": (0x39, 3, 0, 3),
"GASPRICE": (0x3A, 0, 1, 2),
- "EXTCODESIZE": (0x3B, 1, 1, (700, 700, 700, 2600)),
- "EXTCODECOPY": (0x3C, 4, 0, (700, 700, 700, 2600)),
+ "EXTCODESIZE": (0x3B, 1, 1, (700, 2600)),
+ "EXTCODECOPY": (0x3C, 4, 0, (700, 2600)),
"RETURNDATASIZE": (0x3D, 0, 1, 2),
"RETURNDATACOPY": (0x3E, 3, 0, 3),
- "EXTCODEHASH": (0x3F, 1, 1, (None, 400, 700, 2600)),
+ "EXTCODEHASH": (0x3F, 1, 1, (700, 2600)),
"BLOCKHASH": (0x40, 1, 1, 20),
"COINBASE": (0x41, 0, 1, 2),
"TIMESTAMP": (0x42, 0, 1, 2),
@@ -87,14 +74,14 @@
"DIFFICULTY": (0x44, 0, 1, 2),
"PREVRANDAO": (0x44, 0, 1, 2),
"GASLIMIT": (0x45, 0, 1, 2),
- "CHAINID": (0x46, 0, 1, (None, None, 2)),
- "SELFBALANCE": (0x47, 0, 1, (None, None, 5)),
- "BASEFEE": (0x48, 0, 1, (None, None, None, 2)),
+ "CHAINID": (0x46, 0, 1, 2),
+ "SELFBALANCE": (0x47, 0, 1, 5),
+ "BASEFEE": (0x48, 0, 1, (None, 2)),
"POP": (0x50, 1, 0, 2),
"MLOAD": (0x51, 1, 1, 3),
"MSTORE": (0x52, 2, 0, 3),
"MSTORE8": (0x53, 2, 0, 3),
- "SLOAD": (0x54, 1, 1, (200, 200, 800, 2100)),
+ "SLOAD": (0x54, 1, 1, (800, 2100)),
"SSTORE": (0x55, 2, 0, 20000),
"JUMP": (0x56, 1, 0, 8),
"JUMPI": (0x57, 2, 0, 10),
@@ -102,6 +89,8 @@
"MSIZE": (0x59, 0, 1, 2),
"GAS": (0x5A, 0, 1, 2),
"JUMPDEST": (0x5B, 0, 0, 1),
+ "MCOPY": (0x5E, 3, 0, (None, None, None, None, None, 3)),
+ "PUSH0": (0x5F, 0, 1, 2),
"PUSH1": (0x60, 0, 1, 3),
"PUSH2": (0x61, 0, 1, 3),
"PUSH3": (0x62, 0, 1, 3),
@@ -172,17 +161,19 @@
"LOG3": (0xA3, 5, 0, 1500),
"LOG4": (0xA4, 6, 0, 1875),
"CREATE": (0xF0, 3, 1, 32000),
- "CALL": (0xF1, 7, 1, (700, 700, 700, 2100)),
- "CALLCODE": (0xF2, 7, 1, (700, 700, 700, 2100)),
+ "CALL": (0xF1, 7, 1, (700, 2100)),
+ "CALLCODE": (0xF2, 7, 1, (700, 2100)),
"RETURN": (0xF3, 2, 0, 0),
- "DELEGATECALL": (0xF4, 6, 1, (700, 700, 700, 2100)),
- "CREATE2": (0xF5, 4, 1, (None, 32000)),
+ "DELEGATECALL": (0xF4, 6, 1, (700, 2100)),
+ "CREATE2": (0xF5, 4, 1, 32000),
"SELFDESTRUCT": (0xFF, 1, 0, 25000),
- "STATICCALL": (0xFA, 6, 1, (700, 700, 700, 2100)),
+ "STATICCALL": (0xFA, 6, 1, (700, 2100)),
"REVERT": (0xFD, 2, 0, 0),
"INVALID": (0xFE, 0, 0, 0),
"DEBUG": (0xA5, 1, 0, 0),
"BREAKPOINT": (0xA6, 0, 0, 0),
+ "TLOAD": (0x5C, 1, 1, (None, None, None, None, None, 100)),
+ "TSTORE": (0x5D, 2, 0, (None, None, None, None, None, 100)),
}
PSEUDO_OPCODES: OpcodeMap = {
@@ -217,17 +208,16 @@
IR_OPCODES: OpcodeMap = {**OPCODES, **PSEUDO_OPCODES}
-def evm_wrapper(fn, *args, **kwargs):
- def _wrapper(*args, **kwargs):
- global active_evm_version
- evm_version = kwargs.pop("evm_version", None) or DEFAULT_EVM_VERSION
- active_evm_version = EVM_VERSIONS[evm_version]
- try:
- return fn(*args, **kwargs)
- finally:
- active_evm_version = EVM_VERSIONS[DEFAULT_EVM_VERSION]
-
- return _wrapper
+@contextlib.contextmanager
+def anchor_evm_version(evm_version: Optional[str] = None) -> Generator:
+ global active_evm_version
+ anchor = active_evm_version
+ evm_version = evm_version or DEFAULT_EVM_VERSION
+ active_evm_version = EVM_VERSIONS[evm_version]
+ try:
+ yield
+ finally:
+ active_evm_version = anchor
def _gas(value: OpcodeValue, idx: int) -> Optional[OpcodeRulesetValue]:
diff --git a/vyper/exceptions.py b/vyper/exceptions.py
index 9c4358a9ad..defca7cc53 100644
--- a/vyper/exceptions.py
+++ b/vyper/exceptions.py
@@ -22,7 +22,7 @@ def raise_if_not_empty(self):
raise VyperException("\n\n".join(err_msg))
-class VyperException(Exception):
+class _BaseVyperException(Exception):
"""
Base Vyper exception class.
@@ -54,7 +54,9 @@ def __init__(self, message="Error Message not found.", *items):
# support older exceptions that don't annotate - remove this in the future!
self.lineno, self.col_offset = items[0][:2]
else:
- self.annotations = items
+ # strip out None sources so that None can be passed as a valid
+ # annotation (in case it is only available optionally)
+ self.annotations = [k for k in items if k is not None]
def with_annotation(self, *annotations):
"""
@@ -125,6 +127,10 @@ def __str__(self):
return f"{self.message}\n{annotation_msg}"
+class VyperException(_BaseVyperException):
+ pass
+
+
class SyntaxException(VyperException):
"""Invalid syntax."""
@@ -151,6 +157,10 @@ class StructureException(VyperException):
"""Invalid structure for parsable syntax."""
+class InstantiationException(StructureException):
+ """Variable or expression cannot be instantiated"""
+
+
class VersionException(VyperException):
"""Version string is malformed or incompatible with this compiler version."""
@@ -281,7 +291,7 @@ class StaticAssertionException(VyperException):
"""An assertion is proven to fail at compile-time."""
-class VyperInternalException(Exception):
+class VyperInternalException(_BaseVyperException):
"""
Base Vyper internal exception class.
@@ -292,9 +302,6 @@ class VyperInternalException(Exception):
compiler has panicked, and that filing a bug report would be appropriate.
"""
- def __init__(self, message=""):
- self.message = message
-
def __str__(self):
return (
f"{self.message}\n\nThis is an unhandled internal compiler error. "
diff --git a/vyper/ir/README.md b/vyper/ir/README.md
index 50e61ce81f..ebcc381590 100644
--- a/vyper/ir/README.md
+++ b/vyper/ir/README.md
@@ -194,6 +194,10 @@ Could compile to:
_sym_foo JUMPDEST
```
+### UNIQUE\_SYMBOL
+
+`(unique_symbol l)` defines a "unique symbol". These are generated to help catch front-end bugs involving multiple execution of side-effects (which should be only executed once). They can be ignored, or to be strict, any backend should enforce that each `unique_symbol` only appears one time in the code.
+
### IF\_STMT
Branching statements. There are two forms, if with a single branch, and if with two branches.
diff --git a/vyper/ir/compile_ir.py b/vyper/ir/compile_ir.py
index c24b3a67a2..7a3e97155b 100644
--- a/vyper/ir/compile_ir.py
+++ b/vyper/ir/compile_ir.py
@@ -1,9 +1,13 @@
import copy
import functools
import math
+from dataclasses import dataclass
+
+import cbor2
from vyper.codegen.ir_node import IRnode
-from vyper.evm.opcodes import get_opcodes
+from vyper.compiler.settings import OptimizationLevel
+from vyper.evm.opcodes import get_opcodes, version_check
from vyper.exceptions import CodegenPanic, CompilerPanic
from vyper.utils import MemoryPositions
from vyper.version import version_tuple
@@ -23,7 +27,8 @@ def num_to_bytearray(x):
def PUSH(x):
bs = num_to_bytearray(x)
- if len(bs) == 0:
+ # starting in shanghai, can do push0 directly with no immediates
+ if len(bs) == 0 and not version_check(begin="shanghai"):
bs = [0]
return [f"PUSH{len(bs)}"] + bs
@@ -117,14 +122,14 @@ def _rewrite_return_sequences(ir_node, label_params=None):
args[0].value = "pass"
else:
# handle jump to cleanup
- assert is_symbol(args[0].value)
ir_node.value = "seq"
_t = ["seq"]
if "return_buffer" in label_params:
_t.append(["pop", "pass"])
- dest = args[0].value[5:] # `_sym_foo` -> `foo`
+ dest = args[0].value
+ # works for both internal and external exit_to
more_args = ["pass" if t.value == "return_pc" else t for t in args[1:]]
_t.append(["goto", dest] + more_args)
ir_node.args = IRnode.from_list(_t, source_pos=ir_node.source_pos).args
@@ -149,18 +154,27 @@ def _add_postambles(asm_ops):
global _revert_label
- _revert_string = [_revert_label, "JUMPDEST", "PUSH1", 0, "DUP1", "REVERT"]
+ _revert_string = [_revert_label, "JUMPDEST", *PUSH(0), "DUP1", "REVERT"]
if _revert_label in asm_ops:
# shared failure block
to_append.extend(_revert_string)
if len(to_append) > 0:
+ # insert the postambles *before* runtime code
+ # so the data section of the runtime code can't bork the postambles.
+ runtime = None
+ if isinstance(asm_ops[-1], list) and isinstance(asm_ops[-1][0], _RuntimeHeader):
+ runtime = asm_ops.pop()
+
# for some reason there might not be a STOP at the end of asm_ops.
# (generally vyper programs will have it but raw IR might not).
asm_ops.append("STOP")
asm_ops.extend(to_append)
+ if runtime:
+ asm_ops.append(runtime)
+
# need to do this recursively since every sublist is basically
# treated as its own program (there are no global labels.)
for t in asm_ops:
@@ -200,7 +214,7 @@ def apply_line_no_wrapper(*args, **kwargs):
@apply_line_numbers
-def compile_to_assembly(code, no_optimize=False):
+def compile_to_assembly(code, optimize=OptimizationLevel.GAS):
global _revert_label
_revert_label = mksymbol("revert")
@@ -211,7 +225,10 @@ def compile_to_assembly(code, no_optimize=False):
res = _compile_to_assembly(code)
_add_postambles(res)
- if not no_optimize:
+
+ _relocate_segments(res)
+
+ if optimize != OptimizationLevel.NONE:
_optimize_assembly(res)
return res
@@ -295,6 +312,7 @@ def _height_of(witharg):
return o
# batch copy from data section of the currently executing code to memory
+ # (probably should have named this dcopy but oh well)
elif code.value == "dloadbytes":
dst = code.args[0]
src = code.args[1]
@@ -398,9 +416,8 @@ def _height_of(witharg):
)
# stack: i, rounds, rounds_bound
# assert rounds <= rounds_bound
- # TODO this runtime assertion should never fail for
+ # TODO this runtime assertion shouldn't fail for
# internally generated repeats.
- # maybe drop it or jump to 0xFE
o.extend(["DUP2", "GT"] + _assert_false())
# stack: i, rounds
@@ -493,31 +510,30 @@ def _height_of(witharg):
elif code.value == "deploy":
memsize = code.args[0].value # used later to calculate _mem_deploy_start
ir = code.args[1]
- padding = code.args[2].value
+ immutables_len = code.args[2].value
assert isinstance(memsize, int), "non-int memsize"
- assert isinstance(padding, int), "non-int padding"
+ assert isinstance(immutables_len, int), "non-int immutables_len"
- begincode = mksymbol("runtime_begin")
+ runtime_begin = mksymbol("runtime_begin")
subcode = _compile_to_assembly(ir)
o = []
# COPY the code to memory for deploy
- o.extend(["_sym_subcode_size", begincode, "_mem_deploy_start", "CODECOPY"])
+ o.extend(["_sym_subcode_size", runtime_begin, "_mem_deploy_start", "CODECOPY"])
# calculate the len of runtime code
- o.extend(["_OFST", "_sym_subcode_size", padding]) # stack: len
+ o.extend(["_OFST", "_sym_subcode_size", immutables_len]) # stack: len
o.extend(["_mem_deploy_start"]) # stack: len mem_ofst
o.extend(["RETURN"])
# since the asm data structures are very primitive, to make sure
# assembly_to_evm is able to calculate data offsets correctly,
# we pass the memsize via magic opcodes to the subcode
- subcode = [f"_DEPLOY_MEM_OFST_{memsize}"] + subcode
+ subcode = [_RuntimeHeader(runtime_begin, memsize, immutables_len)] + subcode
# append the runtime code after the ctor code
- o.extend([begincode, "BLANK"])
# `append(...)` call here is intentional.
# each sublist is essentially its own program with its
# own symbols.
@@ -555,13 +571,10 @@ def _height_of(witharg):
o = _compile_to_assembly(code.args[0], withargs, existing_labels, break_dest, height)
o.extend(
[
- "PUSH1",
- MemoryPositions.FREE_VAR_SPACE,
+ *PUSH(MemoryPositions.FREE_VAR_SPACE),
"MSTORE",
- "PUSH1",
- 32,
- "PUSH1",
- MemoryPositions.FREE_VAR_SPACE,
+ *PUSH(32),
+ *PUSH(MemoryPositions.FREE_VAR_SPACE),
"SHA3",
]
)
@@ -572,16 +585,12 @@ def _height_of(witharg):
o.extend(_compile_to_assembly(code.args[1], withargs, existing_labels, break_dest, height))
o.extend(
[
- "PUSH1",
- MemoryPositions.FREE_VAR_SPACE2,
+ *PUSH(MemoryPositions.FREE_VAR_SPACE2),
"MSTORE",
- "PUSH1",
- MemoryPositions.FREE_VAR_SPACE,
+ *PUSH(MemoryPositions.FREE_VAR_SPACE),
"MSTORE",
- "PUSH1",
- 64,
- "PUSH1",
- MemoryPositions.FREE_VAR_SPACE,
+ *PUSH(64),
+ *PUSH(MemoryPositions.FREE_VAR_SPACE),
"SHA3",
]
)
@@ -665,16 +674,36 @@ def _height_of(witharg):
height,
)
+ elif code.value == "data":
+ data_node = [_DataHeader("_sym_" + code.args[0].value)]
+
+ for c in code.args[1:]:
+ if isinstance(c.value, int):
+ assert 0 <= c < 256, f"invalid data byte {c}"
+ data_node.append(c.value)
+ elif isinstance(c.value, bytes):
+ data_node.append(c.value)
+ elif isinstance(c, IRnode):
+ assert c.value == "symbol"
+ data_node.extend(
+ _compile_to_assembly(c, withargs, existing_labels, break_dest, height)
+ )
+ else:
+ raise ValueError(f"Invalid data: {type(c)} {c}")
+
+ # intentionally return a sublist.
+ return [data_node]
+
# jump to a symbol, and push variable # of arguments onto stack
elif code.value == "goto":
o = []
for i, c in enumerate(reversed(code.args[1:])):
o.extend(_compile_to_assembly(c, withargs, existing_labels, break_dest, height + i))
- o.extend(["_sym_" + str(code.args[0]), "JUMP"])
+ o.extend(["_sym_" + code.args[0].value, "JUMP"])
return o
# push a literal symbol
- elif isinstance(code.value, str) and is_symbol(code.value):
- return [code.value]
+ elif code.value == "symbol":
+ return ["_sym_" + code.args[0].value]
# set a symbol as a location.
elif code.value == "label":
label_name = code.args[0].value
@@ -732,8 +761,8 @@ def _height_of(witharg):
# inject debug opcode.
elif code.value == "pc_debugger":
return mkdebug(pc_debugger=True, source_pos=code.source_pos)
- else:
- raise Exception("Weird code element: " + repr(code))
+ else: # pragma: no cover
+ raise ValueError(f"Weird code element: {type(code)} {code}")
def note_line_num(line_number_map, item, pos):
@@ -764,17 +793,21 @@ def note_breakpoint(line_number_map, item, pos):
line_number_map["breakpoints"].add(item.lineno + 1)
+_TERMINAL_OPS = ("JUMP", "RETURN", "REVERT", "STOP", "INVALID")
+
+
def _prune_unreachable_code(assembly):
- # In converting IR to assembly we sometimes end up with unreachable
- # instructions - POPing to clear the stack or STOPing execution at the
- # end of a function that has already returned or reverted. This should
- # be addressed in the IR, but for now we do a final sanity check here
- # to avoid unnecessary bytecode bloat.
+ # delete code between terminal ops and JUMPDESTS as those are
+ # unreachable
changed = False
i = 0
- while i < len(assembly) - 1:
- if assembly[i] in ("JUMP", "RETURN", "REVERT", "STOP") and not (
- is_symbol(assembly[i + 1]) or assembly[i + 1] == "JUMPDEST"
+ while i < len(assembly) - 2:
+ instr = assembly[i]
+ if isinstance(instr, list):
+ instr = assembly[i][-1]
+
+ if assembly[i] in _TERMINAL_OPS and not (
+ is_symbol(assembly[i + 1]) or isinstance(assembly[i + 1], list)
):
changed = True
del assembly[i + 1]
@@ -886,6 +919,14 @@ def _merge_iszero(assembly):
return changed
+# a symbol _sym_x in assembly can either mean to push _sym_x to the stack,
+# or it can precede a location in code which we want to add to symbol map.
+# this helper function tells us if we want to add the previous instruction
+# to the symbol map.
+def is_symbol_map_indicator(asm_node):
+ return asm_node == "JUMPDEST"
+
+
def _prune_unused_jumpdests(assembly):
changed = False
@@ -893,9 +934,17 @@ def _prune_unused_jumpdests(assembly):
# find all used jumpdests
for i in range(len(assembly) - 1):
- if is_symbol(assembly[i]) and assembly[i + 1] != "JUMPDEST":
+ if is_symbol(assembly[i]) and not is_symbol_map_indicator(assembly[i + 1]):
used_jumpdests.add(assembly[i])
+ for item in assembly:
+ if isinstance(item, list) and isinstance(item[0], _DataHeader):
+ # add symbols used in data sections as they are likely
+ # used for a jumptable.
+ for t in item:
+ if is_symbol(t):
+ used_jumpdests.add(t)
+
# delete jumpdests that aren't used
i = 0
while i < len(assembly) - 2:
@@ -934,7 +983,7 @@ def _stack_peephole_opts(assembly):
# optimize assembly, in place
def _optimize_assembly(assembly):
for x in assembly:
- if isinstance(x, list):
+ if isinstance(x, list) and isinstance(x[0], _RuntimeHeader):
_optimize_assembly(x)
for _ in range(1024):
@@ -967,16 +1016,101 @@ def adjust_pc_maps(pc_maps, ofst):
return ret
-def assembly_to_evm(
- assembly, pc_ofst=0, insert_vyper_signature=False, disable_bytecode_metadata=False
-):
+SYMBOL_SIZE = 2 # size of a PUSH instruction for a code symbol
+
+
+def _data_to_evm(assembly, symbol_map):
+ ret = bytearray()
+ assert isinstance(assembly[0], _DataHeader)
+ for item in assembly[1:]:
+ if is_symbol(item):
+ symbol = symbol_map[item].to_bytes(SYMBOL_SIZE, "big")
+ ret.extend(symbol)
+ elif isinstance(item, int):
+ ret.append(item)
+ elif isinstance(item, bytes):
+ ret.extend(item)
+ else:
+ raise ValueError(f"invalid data {type(item)} {item}")
+
+ return ret
+
+
+# predict what length of an assembly [data] node will be in bytecode
+def _length_of_data(assembly):
+ ret = 0
+ assert isinstance(assembly[0], _DataHeader)
+ for item in assembly[1:]:
+ if is_symbol(item):
+ ret += SYMBOL_SIZE
+ elif isinstance(item, int):
+ assert 0 <= item < 256, f"invalid data byte {item}"
+ ret += 1
+ elif isinstance(item, bytes):
+ ret += len(item)
+ else:
+ raise ValueError(f"invalid data {type(item)} {item}")
+
+ return ret
+
+
+@dataclass
+class _RuntimeHeader:
+ label: str
+ ctor_mem_size: int
+ immutables_len: int
+
+ def __repr__(self):
+ return f""
+
+
+@dataclass
+class _DataHeader:
+ label: str
+
+ def __repr__(self):
+ return f"DATA {self.label}"
+
+
+def _relocate_segments(assembly):
+ # relocate all data segments to the end, otherwise data could be
+ # interpreted as PUSH instructions and mangle otherwise valid jumpdests
+ # relocate all runtime segments to the end as well
+ data_segments = []
+ non_data_segments = []
+ code_segments = []
+ for t in assembly:
+ if isinstance(t, list):
+ if isinstance(t[0], _DataHeader):
+ data_segments.append(t)
+ else:
+ _relocate_segments(t) # recurse
+ assert isinstance(t[0], _RuntimeHeader)
+ code_segments.append(t)
+ else:
+ non_data_segments.append(t)
+ assembly.clear()
+ assembly.extend(non_data_segments)
+ assembly.extend(code_segments)
+ assembly.extend(data_segments)
+
+
+# TODO: change API to split assembly_to_evm and assembly_to_source/symbol_maps
+def assembly_to_evm(assembly, pc_ofst=0, insert_compiler_metadata=False):
+ bytecode, source_maps, _ = assembly_to_evm_with_symbol_map(
+ assembly, pc_ofst=pc_ofst, insert_compiler_metadata=insert_compiler_metadata
+ )
+ return bytecode, source_maps
+
+
+def assembly_to_evm_with_symbol_map(assembly, pc_ofst=0, insert_compiler_metadata=False):
"""
Assembles assembly into EVM
assembly: list of asm instructions
pc_ofst: when constructing the source map, the amount to offset all
pcs by (no effect until we add deploy code source map)
- insert_vyper_signature: whether to append vyper metadata to output
+ insert_compiler_metadata: whether to append vyper metadata to output
(should be true for runtime code)
"""
line_number_map = {
@@ -992,14 +1126,6 @@ def assembly_to_evm(
runtime_code, runtime_code_start, runtime_code_end = None, None, None
- bytecode_suffix = b""
- if (not disable_bytecode_metadata) and insert_vyper_signature:
- # CBOR encoded: {"vyper": [major,minor,patch]}
- bytecode_suffix += b"\xa1\x65vyper\x83" + bytes(list(version_tuple))
- bytecode_suffix += len(bytecode_suffix).to_bytes(2, "big")
-
- CODE_OFST_SIZE = 2 # size of a PUSH instruction for a code symbol
-
# to optimize the size of deploy code - we want to use the smallest
# PUSH instruction possible which can support all memory symbols
# (and also works with linear pass symbol resolution)
@@ -1008,17 +1134,13 @@ def assembly_to_evm(
mem_ofst_size, ctor_mem_size = None, None
max_mem_ofst = 0
for i, item in enumerate(assembly):
- if isinstance(item, list):
+ if isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
assert runtime_code is None, "Multiple subcodes"
- runtime_code, runtime_map = assembly_to_evm(
- item,
- insert_vyper_signature=True,
- disable_bytecode_metadata=disable_bytecode_metadata,
- )
- assert item[0].startswith("_DEPLOY_MEM_OFST_")
assert ctor_mem_size is None
- ctor_mem_size = int(item[0][len("_DEPLOY_MEM_OFST_") :])
+ ctor_mem_size = item[0].ctor_mem_size
+
+ runtime_code, runtime_map = assembly_to_evm(item[1:])
runtime_code_start, runtime_code_end = _runtime_code_offsets(
ctor_mem_size, len(runtime_code)
@@ -1031,6 +1153,9 @@ def assembly_to_evm(
if runtime_code_end is not None:
mem_ofst_size = calc_mem_ofst_size(runtime_code_end + max_mem_ofst)
+ data_section_lengths = []
+ immutables_len = None
+
# go through the code, resolving symbolic locations
# (i.e. JUMPDEST locations) to actual code locations
for i, item in enumerate(assembly):
@@ -1056,14 +1181,14 @@ def assembly_to_evm(
# update pc
if is_symbol(item):
- if assembly[i + 1] == "JUMPDEST" or assembly[i + 1] == "BLANK":
+ if is_symbol_map_indicator(assembly[i + 1]):
# Don't increment pc as the symbol itself doesn't go into code
if item in symbol_map:
raise CompilerPanic(f"duplicate jumpdest {item}")
symbol_map[item] = pc
else:
- pc += CODE_OFST_SIZE + 1 # PUSH2 highbits lowbits
+ pc += SYMBOL_SIZE + 1 # PUSH2 highbits lowbits
elif is_mem_sym(item):
# PUSH item
pc += mem_ofst_size + 1
@@ -1073,22 +1198,42 @@ def assembly_to_evm(
# [_OFST, _sym_foo, bar] -> PUSH2 (foo+bar)
# [_OFST, _mem_foo, bar] -> PUSHN (foo+bar)
pc -= 1
- elif item == "BLANK":
- pc += 0
- elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"):
- # _DEPLOY_MEM_OFST is assembly magic which will
- # get removed during final assembly-to-bytecode
- pc += 0
- elif isinstance(item, list):
+ elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
+ # we are in initcode
+ symbol_map[item[0].label] = pc
# add source map for all items in the runtime map
t = adjust_pc_maps(runtime_map, pc)
for key in line_number_map:
line_number_map[key].update(t[key])
+ immutables_len = item[0].immutables_len
pc += len(runtime_code)
-
+ # grab lengths of data sections from the runtime
+ for t in item:
+ if isinstance(t, list) and isinstance(t[0], _DataHeader):
+ data_section_lengths.append(_length_of_data(t))
+
+ elif isinstance(item, list) and isinstance(item[0], _DataHeader):
+ symbol_map[item[0].label] = pc
+ pc += _length_of_data(item)
else:
pc += 1
+ bytecode_suffix = b""
+ if insert_compiler_metadata:
+ # this will hold true when we are in initcode
+ assert immutables_len is not None
+ metadata = (
+ len(runtime_code),
+ data_section_lengths,
+ immutables_len,
+ {"vyper": version_tuple},
+ )
+ bytecode_suffix += cbor2.dumps(metadata)
+ # append the length of the footer, *including* the length
+ # of the length bytes themselves.
+ suffix_len = len(bytecode_suffix) + 2
+ bytecode_suffix += suffix_len.to_bytes(2, "big")
+
pc += len(bytecode_suffix)
symbol_map["_sym_code_end"] = pc
@@ -1097,13 +1242,9 @@ def assembly_to_evm(
if runtime_code is not None:
symbol_map["_sym_subcode_size"] = len(runtime_code)
- # (NOTE CMC 2022-06-17 this way of generating bytecode did not
- # seem to be a perf hotspot. if it is, may want to use bytearray()
- # instead).
-
- # TODO refactor into two functions, create posmap and assemble
+ # TODO refactor into two functions, create symbol_map and assemble
- o = b""
+ ret = bytearray()
# now that all symbols have been resolved, generate bytecode
# using the symbol map
@@ -1113,47 +1254,47 @@ def assembly_to_evm(
to_skip -= 1
continue
- if item in ("DEBUG", "BLANK"):
+ if item in ("DEBUG",):
continue # skippable opcodes
- elif isinstance(item, str) and item.startswith("_DEPLOY_MEM_OFST_"):
- continue
-
elif is_symbol(item):
- if assembly[i + 1] != "JUMPDEST" and assembly[i + 1] != "BLANK":
- bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=CODE_OFST_SIZE))
- o += bytecode
+ # push a symbol to stack
+ if not is_symbol_map_indicator(assembly[i + 1]):
+ bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=SYMBOL_SIZE))
+ ret.extend(bytecode)
elif is_mem_sym(item):
bytecode, _ = assembly_to_evm(PUSH_N(symbol_map[item], n=mem_ofst_size))
- o += bytecode
+ ret.extend(bytecode)
elif is_ofst(item):
# _OFST _sym_foo 32
ofst = symbol_map[assembly[i + 1]] + assembly[i + 2]
- n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else CODE_OFST_SIZE
+ n = mem_ofst_size if is_mem_sym(assembly[i + 1]) else SYMBOL_SIZE
bytecode, _ = assembly_to_evm(PUSH_N(ofst, n))
- o += bytecode
+ ret.extend(bytecode)
to_skip = 2
elif isinstance(item, int):
- o += bytes([item])
+ ret.append(item)
elif isinstance(item, str) and item.upper() in get_opcodes():
- o += bytes([get_opcodes()[item.upper()][0]])
+ ret.append(get_opcodes()[item.upper()][0])
elif item[:4] == "PUSH":
- o += bytes([PUSH_OFFSET + int(item[4:])])
+ ret.append(PUSH_OFFSET + int(item[4:]))
elif item[:3] == "DUP":
- o += bytes([DUP_OFFSET + int(item[3:])])
+ ret.append(DUP_OFFSET + int(item[3:]))
elif item[:4] == "SWAP":
- o += bytes([SWAP_OFFSET + int(item[4:])])
- elif isinstance(item, list):
- o += runtime_code
- else:
- # Should never reach because, assembly is create in _compile_to_assembly.
- raise Exception("Weird symbol in assembly: " + str(item)) # pragma: no cover
+ ret.append(SWAP_OFFSET + int(item[4:]))
+ elif isinstance(item, list) and isinstance(item[0], _RuntimeHeader):
+ ret.extend(runtime_code)
+ elif isinstance(item, list) and isinstance(item[0], _DataHeader):
+ ret.extend(_data_to_evm(item, symbol_map))
+ else: # pragma: no cover
+ # unreachable
+ raise ValueError(f"Weird symbol in assembly: {type(item)} {item}")
- o += bytecode_suffix
+ ret.extend(bytecode_suffix)
line_number_map["breakpoints"] = list(line_number_map["breakpoints"])
line_number_map["pc_breakpoints"] = list(line_number_map["pc_breakpoints"])
- return o, line_number_map
+ return bytes(ret), line_number_map, symbol_map
diff --git a/vyper/ir/optimizer.py b/vyper/ir/optimizer.py
index f54b0e8115..08c2168381 100644
--- a/vyper/ir/optimizer.py
+++ b/vyper/ir/optimizer.py
@@ -78,6 +78,11 @@ def _deep_contains(node_or_list, node):
COMMUTATIVE_OPS = {"add", "mul", "eq", "ne", "and", "or", "xor"}
COMPARISON_OPS = {"gt", "sgt", "ge", "sge", "lt", "slt", "le", "sle"}
+STRICT_COMPARISON_OPS = {t for t in COMPARISON_OPS if t.endswith("t")}
+UNSTRICT_COMPARISON_OPS = {t for t in COMPARISON_OPS if t.endswith("e")}
+
+assert not (STRICT_COMPARISON_OPS & UNSTRICT_COMPARISON_OPS)
+assert STRICT_COMPARISON_OPS | UNSTRICT_COMPARISON_OPS == COMPARISON_OPS
def _flip_comparison_op(opname):
@@ -256,11 +261,15 @@ def _conservative_eq(x, y):
return finalize("seq", [args[0]])
if binop in {"sub", "xor", "ne"} and _conservative_eq(args[0], args[1]):
- # x - x == x ^ x == x != x == 0
+ # (x - x) == (x ^ x) == (x != x) == 0
+ return finalize(0, [])
+
+ if binop in STRICT_COMPARISON_OPS and _conservative_eq(args[0], args[1]):
+ # (x < x) == (x > x) == 0
return finalize(0, [])
- if binop == "eq" and _conservative_eq(args[0], args[1]):
- # (x == x) == 1
+ if binop in {"eq"} | UNSTRICT_COMPARISON_OPS and _conservative_eq(args[0], args[1]):
+ # (x == x) == (x >= x) == (x <= x) == 1
return finalize(1, [])
# TODO associativity rules
@@ -331,19 +340,17 @@ def _conservative_eq(x, y):
if binop == "mod":
return finalize("and", [args[0], _int(args[1]) - 1])
- if binop == "div" and version_check(begin="constantinople"):
+ if binop == "div":
# x / 2**n == x >> n
# recall shr/shl have unintuitive arg order
return finalize("shr", [int_log2(_int(args[1])), args[0]])
# note: no rule for sdiv since it rounds differently from sar
- if binop == "mul" and version_check(begin="constantinople"):
+ if binop == "mul":
# x * 2**n == x << n
return finalize("shl", [int_log2(_int(args[1])), args[0]])
- # reachable but only before constantinople
- if version_check(begin="constantinople"): # pragma: no cover
- raise CompilerPanic("unreachable")
+ raise CompilerPanic("unreachable") # pragma: no cover
##
# COMPARISONS
@@ -466,6 +473,9 @@ def finalize(val, args):
if value == "seq":
changed |= _merge_memzero(argz)
changed |= _merge_calldataload(argz)
+ changed |= _merge_dload(argz)
+ changed |= _rewrite_mstore_dload(argz)
+ changed |= _merge_mload(argz)
changed |= _remove_empty_seqs(argz)
# (seq x) => (x) for cleanliness and
@@ -630,12 +640,38 @@ def _remove_empty_seqs(argz):
def _merge_calldataload(argz):
- # look for sequential operations copying from calldata to memory
- # and merge them into a single calldatacopy operation
+ return _merge_load(argz, "calldataload", "calldatacopy")
+
+
+def _merge_dload(argz):
+ return _merge_load(argz, "dload", "dloadbytes")
+
+
+def _rewrite_mstore_dload(argz):
+ changed = False
+ for i, arg in enumerate(argz):
+ if arg.value == "mstore" and arg.args[1].value == "dload":
+ dst = arg.args[0]
+ src = arg.args[1].args[0]
+ len_ = 32
+ argz[i] = IRnode.from_list(["dloadbytes", dst, src, len_], source_pos=arg.source_pos)
+ changed = True
+ return changed
+
+
+def _merge_mload(argz):
+ if not version_check(begin="cancun"):
+ return False
+ return _merge_load(argz, "mload", "mcopy")
+
+
+def _merge_load(argz, _LOAD, _COPY):
+ # look for sequential operations copying from X to Y
+ # and merge them into a single copy operation
changed = False
mstore_nodes: List = []
- initial_mem_offset = 0
- initial_calldata_offset = 0
+ initial_dst_offset = 0
+ initial_src_offset = 0
total_length = 0
idx = None
for i, ir_node in enumerate(argz):
@@ -643,19 +679,19 @@ def _merge_calldataload(argz):
if (
ir_node.value == "mstore"
and isinstance(ir_node.args[0].value, int)
- and ir_node.args[1].value == "calldataload"
+ and ir_node.args[1].value == _LOAD
and isinstance(ir_node.args[1].args[0].value, int)
):
# mstore of a zero value
- mem_offset = ir_node.args[0].value
- calldata_offset = ir_node.args[1].args[0].value
+ dst_offset = ir_node.args[0].value
+ src_offset = ir_node.args[1].args[0].value
if not mstore_nodes:
- initial_mem_offset = mem_offset
- initial_calldata_offset = calldata_offset
+ initial_dst_offset = dst_offset
+ initial_src_offset = src_offset
idx = i
if (
- initial_mem_offset + total_length == mem_offset
- and initial_calldata_offset + total_length == calldata_offset
+ initial_dst_offset + total_length == dst_offset
+ and initial_src_offset + total_length == src_offset
):
mstore_nodes.append(ir_node)
total_length += 32
@@ -670,7 +706,7 @@ def _merge_calldataload(argz):
if len(mstore_nodes) > 1:
changed = True
new_ir = IRnode.from_list(
- ["calldatacopy", initial_mem_offset, initial_calldata_offset, total_length],
+ [_COPY, initial_dst_offset, initial_src_offset, total_length],
source_pos=mstore_nodes[0].source_pos,
)
# replace first copy operation with optimized node and remove the rest
@@ -678,8 +714,8 @@ def _merge_calldataload(argz):
# note: del xs[k:l] deletes l - k items
del argz[idx + 1 : idx + len(mstore_nodes)]
- initial_mem_offset = 0
- initial_calldata_offset = 0
+ initial_dst_offset = 0
+ initial_src_offset = 0
total_length = 0
mstore_nodes.clear()
diff --git a/vyper/semantics/README.md b/vyper/semantics/README.md
index 64be46c0e8..1d81a0979b 100644
--- a/vyper/semantics/README.md
+++ b/vyper/semantics/README.md
@@ -10,47 +10,43 @@ Vyper abstract syntax tree (AST).
`vyper.semantics` has the following structure:
* [`types/`](types): Subpackage of classes and methods used to represent types
- * [`types/indexable/`](types/indexable)
- * [`mapping.py`](types/indexable/mapping.py): Mapping type
- * [`sequence.py`](types/indexable/sequence.py): Array and Tuple types
- * [`types/user/`](types/user)
- * [`interface.py`](types/user/interface.py): Contract interface types and getter functions
- * [`struct.py`](types/user/struct.py): Struct types and getter functions
- * [`types/value/`](types/value)
- * [`address.py`](types/value/address.py): Address type
- * [`array_value.py`](types/value/array_value.py): Single-value subscript types (bytes, string)
- * [`boolean.py`](types/value/boolean.py): Boolean type
- * [`bytes_fixed.py`](types/value/bytes_fixed.py): Fixed length byte types
- * [`numeric.py`](types/value/numeric.py): Integer and decimal types
- * [`abstract.py`](types/abstract.py): Abstract data type classes
* [`bases.py`](types/bases.py): Common base classes for all type objects
- * [`event.py`](types/user/event.py): `Event` type class
- * [`function.py`](types/function.py): `ContractFunction` type class
+ * [`bytestrings.py`](types/bytestrings.py): Single-value subscript types (bytes, string)
+ * [`function.py`](types/function.py): Contract function and member function types
+ * [`primitives.py`](types/primitives.py): Address, boolean, fixed length byte, integer and decimal types
+ * [`shortcuts.py`](types/shortcuts.py): Helper constants for commonly used types
+ * [`subscriptable.py`](types/subscriptable.py): Mapping, array and tuple types
+ * [`user.py`](types/user.py): Enum, event, interface and struct types
* [`utils.py`](types/utils.py): Functions for generating and fetching type objects
-* [`validation/`](validation): Subpackage for type checking and syntax verification logic
- * [`base.py`](validation/base.py): Base validation class
- * [`local.py`](validation/local.py): Validates the local namespace of each function within a contract
- * [`module.py`](validation/module.py): Validates the module namespace of a contract.
- * [`utils.py`](validation/utils.py): Functions for comparing and validating types
+* [`analysis/`](analysis): Subpackage for type checking and syntax verification logic
+ * [`annotation.py`](analysis/annotation.py): Annotates statements and expressions with the appropriate type information
+ * [`base.py`](analysis/base.py): Base validation class
+ * [`common.py`](analysis/common.py): Base AST visitor class
+ * [`data_positions`](analysis/data_positions.py): Functions for tracking storage variables and allocating storage slots
+ * [`levenhtein_utils.py`](analysis/levenshtein_utils.py): Helper for better error messages
+ * [`local.py`](analysis/local.py): Validates the local namespace of each function within a contract
+ * [`module.py`](analysis/module.py): Validates the module namespace of a contract.
+ * [`utils.py`](analysis/utils.py): Functions for comparing and validating types
+* [`data_locations.py`](data_locations.py): `DataLocation` object for type location information
* [`environment.py`](environment.py): Environment variables and builtin constants
* [`namespace.py`](namespace.py): `Namespace` object, a `dict` subclass representing the namespace of a contract
## Control Flow
-The [`validation`](validation) subpackage contains the top-level `validate_semantics`
+The [`analysis`](analysis) subpackage contains the top-level `validate_semantics`
function. This function is used to verify and type-check a contract. The process
consists of three steps:
1. Preparing the builtin namespace
2. Validating the module-level scope
-3. Validating local scopes
+3. Annotating and validating local scopes
### 1. Preparing the builtin namespace
The [`Namespace`](namespace.py) object represents the namespace for a contract.
Builtins are added upon initialization of the object. This includes:
-* Adding primitive type classes from the [`types/`](types) subpackage
+* Adding type classes from the [`types/`](types) subpackage
* Adding environment variables and builtin constants from [`environment.py`](environment.py)
* Adding builtin functions from the [`functions`](../builtins/functions.py) package
* Adding / resetting `self` and `log`
@@ -65,11 +61,11 @@ of a contract. This includes:
and functions
* Validating import statements and function signatures
-### 3. Validating the Local Scopes
+### 3. Annotating and validating the Local Scopes
[`validation/local.py`](validation/local.py) validates the local scope within each
function in a contract. `FunctionNodeVisitor` is used to iterate over the statement
-nodes in each function body and apply appropriate checks.
+nodes in each function body, annotate them and apply appropriate checks.
To learn more about the checks on each node type, read the docstrings on the methods
of `FunctionNodeVisitor`.
@@ -106,45 +102,6 @@ The array is given a type of `int128[2]`.
All type classes are found within the [`semantics/types/`](types) subpackage.
-Type classes rely on inheritance to define their structure and functionlity.
-Vyper uses three broad categories to represent types within the compiler.
-
-#### Primitive Types
-
-A **primitive type** (or just primitive) defines the base attributes of a given type.
-There is only one primitive type object created for each Vyper type. All primitive
-classes are subclasses of `BasePrimitive`.
-
-Along with the builtin primitive types, user-defined ones may be created. These
-primitives are defined in the modules within [`semantics/types/user`](types/user).
-See the docstrings there for more information.
-
-#### Type Definitions
-
-A **type definition** (or just definition) is a type that has been assigned to a
-specific variable, literal, or other value. Definition objects are typically derived
-from primitives. They include additional information such as the constancy,
-visibility and scope of the associated value.
-
-A primitive type always has a corresponding type definition. However, not all
-type definitions have a primitive type, e.g. arrays and tuples.
-
-Comparing a definition to it's related primitive type will always evaluate `True`.
-Comparing two definitions of the same class can sometimes evaluate false depending
-on certain attributes. All definition classes are subclasses of `BaseTypeDefinition`.
-
-Additionally, literal values sometimes have multiple _potential type definitions_.
-In this case, a membership check determines if the literal is valid by comparing
-the list of potential types against a specific type.
-
-#### Abstract Types
-
-An **abstract type** is an inherited class shared by two or more definition
-classes. Abstract types do not implement any functionality and may not be directly
-assigned to any values. They are used for broad type checking, in cases where
-e.g. a function expects any numeric value, or any bytes value. All abstract type
-classes are subclasses of `AbstractDataType`.
-
### Namespace
[`namespace.py`](namespace.py) contains the `Namespace` object. `Namespace` is a
@@ -190,12 +147,12 @@ namespace['foo'] # this raises an UndeclaredDefinition
Validation is handled by calling methods within each type object. In general:
-* Primitive type objects include one or both of `from_annotation` and `from_literal`
-methods, which validate an AST node and a produce definition object
-* Definition objects include a variety of `get_` and `validate_` methods,
+* Type objects include one or both of `from_annotation` and `from_literal`
+methods, which validate an AST node and produce a type object
+* Type objects include a variety of `get_` and `validate_` methods,
which are used to validate interactions and obtain new types based on AST nodes
-All possible methods for primitives and definitions are outlined within the base
+All possible methods for type objects are outlined within the base
classes in [`types/bases.py`](types/bases.py). The functionality within the methods
of the base classes is typically to raise and give a meaningful explanation
for _why_ the syntax not valid.
@@ -208,9 +165,7 @@ Here are some examples:
foo: int128
```
-1. We look up `int128` in `namespace`. We retrieve an `Int128Primitive` object.
-2. We call `Int128Primitive.from_annotation` with the AST node of the statement. This
-method validates the statement and returns an `Int128Definition` object.
+1. We look up `int128` in `namespace`. We retrieve an `IntegerT` object.
3. We store the new definition under the key `foo` within `namespace`.
#### 2. Modifying the value of a variable
@@ -219,15 +174,14 @@ method validates the statement and returns an `Int128Definition` object.
foo += 6
```
-1. We look up `foo` in `namespace` and retrieve the `Int128Definition`.
+1. We look up `foo` in `namespace` and retrieve an `IntegerT` with `_is_signed=True` and `_bits=128`.
2. We call `get_potential_types_from_node` with the target node
and are returned a list of types that are valid for the literal `6`. In this
-case, the list includes `Int128Definition`. The type check for the statement
-passes.
+case, the list includes an `IntegerT` with `_is_signed=True` and `_bits=128`. The type check for the statement passes.
3. We call the `validate_modification` method on the definition object
for `foo` to confirm that it is a value that may be modified (not a constant).
4. Because the statement involves a mathematical operator, we also call the
-`validate_numeric_operation` method on `foo` to confirm that the operation is
+`validate_numeric_op` method on `foo` to confirm that the operation is
allowed.
#### 3. Calling a builtin function
@@ -240,7 +194,7 @@ bar: bytes32 = sha256(b"hash me!")
function.
2. We call `fetch_call_return` on the function definition object, with the AST
node representing the call. This method validates the input arguments, and returns
-a `Bytes32Definition`.
+a `BytesM_T` with `m=32`.
3. We validation of the delcaration of `bar` in the same manner as the first
example, and compare the generated type to that returned by `sha256`.
diff --git a/vyper/semantics/analysis/__init__.py b/vyper/semantics/analysis/__init__.py
index 0b580a834c..9e987d1cd0 100644
--- a/vyper/semantics/analysis/__init__.py
+++ b/vyper/semantics/analysis/__init__.py
@@ -1,7 +1,10 @@
+import vyper.ast as vy_ast
+
from .. import types # break a dependency cycle.
from ..namespace import get_namespace
from .local import validate_functions
from .module import add_module_namespace
+from .utils import _ExprAnalyser
def validate_semantics(vyper_ast, interface_codes):
@@ -10,4 +13,5 @@ def validate_semantics(vyper_ast, interface_codes):
with namespace.enter_scope():
add_module_namespace(vyper_ast, interface_codes)
+ vy_ast.expansion.expand_annotated_ast(vyper_ast)
validate_functions(vyper_ast)
diff --git a/vyper/semantics/analysis/annotation.py b/vyper/semantics/analysis/annotation.py
index 3d2397f30d..01ca51d7f4 100644
--- a/vyper/semantics/analysis/annotation.py
+++ b/vyper/semantics/analysis/annotation.py
@@ -1,11 +1,11 @@
from vyper import ast as vy_ast
-from vyper.exceptions import StructureException
+from vyper.exceptions import StructureException, TypeCheckFailure
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
get_possible_types_from_node,
)
-from vyper.semantics.types import TYPE_T, EnumT, EventT, SArrayT, StructT, is_type_t
+from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT
@@ -43,9 +43,9 @@ def __init__(self, fn_node: vy_ast.FunctionDef, namespace: dict) -> None:
self.namespace = namespace
self.expr_visitor = ExpressionAnnotationVisitor(self.func)
- assert len(self.func.kwarg_keys) == len(fn_node.args.defaults)
- for kw, val in zip(self.func.kwarg_keys, fn_node.args.defaults):
- self.expr_visitor.visit(val, self.func.arguments[kw])
+ assert self.func.n_keyword_args == len(fn_node.args.defaults)
+ for kwarg in self.func.keyword_args:
+ self.expr_visitor.visit(kwarg.default_value, kwarg.typ)
def visit(self, node):
super().visit(node)
@@ -85,16 +85,19 @@ def visit_Return(self, node):
def visit_For(self, node):
if isinstance(node.iter, (vy_ast.Name, vy_ast.Attribute)):
self.expr_visitor.visit(node.iter)
- # typecheck list literal as static array
+
+ iter_type = node.target._metadata["type"]
if isinstance(node.iter, vy_ast.List):
- value_type = get_common_types(*node.iter.elements).pop()
+ # typecheck list literal as static array
len_ = len(node.iter.elements)
- self.expr_visitor.visit(node.iter, SArrayT(value_type, len_))
+ self.expr_visitor.visit(node.iter, SArrayT(iter_type, len_))
if isinstance(node.iter, vy_ast.Call) and node.iter.func.id == "range":
- iter_type = node.target._metadata["type"]
for a in node.iter.args:
self.expr_visitor.visit(a, iter_type)
+ for a in node.iter.keywords:
+ if a.arg == "bound":
+ self.expr_visitor.visit(a.value, iter_type)
class ExpressionAnnotationVisitor(_AnnotationVisitorBase):
@@ -136,7 +139,7 @@ def visit_Call(self, node, type_):
# function calls
if call_type.is_internal:
self.func.called_functions.add(call_type)
- for arg, typ in zip(node.args, list(call_type.arguments.values())):
+ for arg, typ in zip(node.args, call_type.argument_types):
self.visit(arg, typ)
for kwarg in node.keywords:
# We should only see special kwargs
@@ -231,14 +234,23 @@ def visit_Subscript(self, node, type_):
elif type_ is not None and len(possible_base_types) > 1:
for possible_type in possible_base_types:
- if isinstance(possible_type.value_type, type(type_)):
+ if type_.compare_type(possible_type.value_type):
base_type = possible_type
break
+ else:
+ # this should have been caught in
+ # `get_possible_types_from_node` but wasn't.
+ raise TypeCheckFailure(f"Expected {type_} but it is not a possible type", node)
else:
base_type = get_exact_type_from_node(node.value)
- self.visit(node.slice, base_type.key_type)
+ # get the correct type for the index, it might
+ # not be base_type.key_type
+ index_types = get_possible_types_from_node(node.slice.value)
+ index_type = index_types.pop()
+
+ self.visit(node.slice, index_type)
self.visit(node.value, base_type)
def visit_Tuple(self, node, type_):
@@ -258,3 +270,14 @@ def visit_UnaryOp(self, node, type_):
type_ = type_.pop()
node._metadata["type"] = type_
self.visit(node.operand, type_)
+
+ def visit_IfExp(self, node, type_):
+ if type_ is None:
+ ts = get_common_types(node.body, node.orelse)
+ if len(ts) == 1:
+ type_ = ts.pop()
+
+ node._metadata["type"] = type_
+ self.visit(node.test, BoolT())
+ self.visit(node.body, type_)
+ self.visit(node.orelse, type_)
diff --git a/vyper/semantics/analysis/base.py b/vyper/semantics/analysis/base.py
index 27f04577af..449e6ca338 100644
--- a/vyper/semantics/analysis/base.py
+++ b/vyper/semantics/analysis/base.py
@@ -9,6 +9,7 @@
StateAccessViolation,
VyperInternalException,
)
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
@@ -92,15 +93,6 @@ def from_abi(cls, abi_dict: Dict) -> "StateMutability":
# specifying a state mutability modifier at all. Do the same here.
-# TODO: move me to locations.py?
-class DataLocation(enum.Enum):
- UNSET = 0
- MEMORY = 1
- STORAGE = 2
- CALLDATA = 3
- CODE = 4
-
-
class DataPosition:
_location: DataLocation
@@ -170,9 +162,13 @@ class VarInfo:
is_constant: bool = False
is_public: bool = False
is_immutable: bool = False
+ is_transient: bool = False
is_local_var: bool = False
decl_node: Optional[vy_ast.VyperNode] = None
+ def __hash__(self):
+ return hash(id(self))
+
def __post_init__(self):
self._modification_count = 0
diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py
index 5c1347d5a3..87ec45c40d 100644
--- a/vyper/semantics/analysis/data_positions.py
+++ b/vyper/semantics/analysis/data_positions.py
@@ -1,11 +1,11 @@
-# TODO this doesn't really belong in "validation"
-import math
+# TODO this module doesn't really belong in "validation"
from typing import Dict, List
from vyper import ast as vy_ast
from vyper.exceptions import StorageLayoutException
from vyper.semantics.analysis.base import CodeOffset, StorageSlot
from vyper.typing import StorageLayout
+from vyper.utils import ceil32
def set_data_positions(
@@ -121,8 +121,7 @@ def set_storage_slots_with_overrides(
# Expect to find this variable within the storage layout overrides
if node.target.id in storage_layout_overrides:
var_slot = storage_layout_overrides[node.target.id]["slot"]
- # Calculate how many storage slots are required
- storage_length = math.ceil(varinfo.typ.size_in_bytes / 32)
+ storage_length = varinfo.typ.storage_size_in_words
# Ensure that all required storage slots are reserved, and prevents other variables
# from using these slots
reserved_slots.reserve_slot_range(var_slot, storage_length, node.target.id)
@@ -139,6 +138,21 @@ def set_storage_slots_with_overrides(
return ret
+class SimpleStorageAllocator:
+ def __init__(self, starting_slot: int = 0):
+ self._slot = starting_slot
+
+ def allocate_slot(self, n, var_name):
+ ret = self._slot
+ if self._slot + n >= 2**256:
+ raise StorageLayoutException(
+ f"Invalid storage slot for var {var_name}, tried to allocate"
+ f" slots {self._slot} through {self._slot + n}"
+ )
+ self._slot += n
+ return ret
+
+
def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout:
"""
Parse module-level Vyper AST to calculate the layout of storage variables.
@@ -146,7 +160,7 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout:
"""
# Allocate storage slots from 0
# note storage is word-addressable, not byte-addressable
- storage_slot = 0
+ allocator = SimpleStorageAllocator()
ret: Dict[str, Dict] = {}
@@ -165,16 +179,16 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout:
type_.set_reentrancy_key_position(StorageSlot(_slot))
continue
- type_.set_reentrancy_key_position(StorageSlot(storage_slot))
-
- # TODO this could have better typing but leave it untyped until
- # we nail down the format better
- ret[variable_name] = {"type": "nonreentrant lock", "slot": storage_slot}
-
# TODO use one byte - or bit - per reentrancy key
# requires either an extra SLOAD or caching the value of the
# location in memory at entrance
- storage_slot += 1
+ slot = allocator.allocate_slot(1, variable_name)
+
+ type_.set_reentrancy_key_position(StorageSlot(slot))
+
+ # TODO this could have better typing but leave it untyped until
+ # we nail down the format better
+ ret[variable_name] = {"type": "nonreentrant lock", "slot": slot}
for node in vyper_module.get_children(vy_ast.VariableDecl):
# skip non-storage variables
@@ -182,19 +196,20 @@ def set_storage_slots(vyper_module: vy_ast.Module) -> StorageLayout:
continue
varinfo = node.target._metadata["varinfo"]
- varinfo.set_position(StorageSlot(storage_slot))
-
type_ = varinfo.typ
- # this could have better typing but leave it untyped until
- # we understand the use case better
- ret[node.target.id] = {"type": str(type_), "slot": storage_slot}
-
# CMC 2021-07-23 note that HashMaps get assigned a slot here.
# I'm not sure if it's safe to avoid allocating that slot
# for HashMaps because downstream code might use the slot
# ID as a salt.
- storage_slot += math.ceil(type_.size_in_bytes / 32)
+ n_slots = type_.storage_size_in_words
+ slot = allocator.allocate_slot(n_slots, node.target.id)
+
+ varinfo.set_position(StorageSlot(slot))
+
+ # this could have better typing but leave it untyped until
+ # we understand the use case better
+ ret[node.target.id] = {"type": str(type_), "slot": slot}
return ret
@@ -216,7 +231,7 @@ def set_code_offsets(vyper_module: vy_ast.Module) -> Dict:
type_ = varinfo.typ
varinfo.set_position(CodeOffset(offset))
- len_ = math.ceil(type_.size_in_bytes / 32) * 32
+ len_ = ceil32(type_.size_in_bytes)
# this could have better typing but leave it untyped until
# we understand the use case better
diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py
index ae90f82528..c10df3b8fd 100644
--- a/vyper/semantics/analysis/local.py
+++ b/vyper/semantics/analysis/local.py
@@ -1,6 +1,7 @@
from typing import Optional
from vyper import ast as vy_ast
+from vyper.ast.metadata import NodeMetadata
from vyper.ast.validation import validate_call_args
from vyper.exceptions import (
ExceptionList,
@@ -18,7 +19,7 @@
VyperException,
)
from vyper.semantics.analysis.annotation import StatementAnnotationVisitor
-from vyper.semantics.analysis.base import DataLocation, VarInfo
+from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.utils import (
get_common_types,
@@ -27,6 +28,7 @@
get_possible_types_from_node,
validate_expected_type,
)
+from vyper.semantics.data_locations import DataLocation
# TODO consolidate some of these imports
from vyper.semantics.environment import CONSTANT_ENVIRONMENT_VARS, MUTABLE_ENVIRONMENT_VARS
@@ -171,8 +173,13 @@ def __init__(
self.func = fn_node._metadata["type"]
self.annotation_visitor = StatementAnnotationVisitor(fn_node, namespace)
self.expr_visitor = _LocalExpressionVisitor()
- for argname, argtype in self.func.arguments.items():
- namespace[argname] = VarInfo(argtype, location=DataLocation.CALLDATA, is_immutable=True)
+
+ # allow internal function params to be mutable
+ location, is_immutable = (
+ (DataLocation.MEMORY, False) if self.func.is_internal else (DataLocation.CALLDATA, True)
+ )
+ for arg in self.func.arguments:
+ namespace[arg.name] = VarInfo(arg.typ, location=location, is_immutable=is_immutable)
for node in fn_node.body:
self.visit(node)
@@ -231,7 +238,7 @@ def visit_AnnAssign(self, node):
"Memory variables must be declared with an initial value", node
)
- type_ = type_from_annotation(node.annotation)
+ type_ = type_from_annotation(node.annotation, DataLocation.MEMORY)
validate_expected_type(node.value, type_)
try:
@@ -339,18 +346,39 @@ def visit_For(self, node):
raise IteratorException(
"Cannot iterate over the result of a function call", node.iter
)
- validate_call_args(node.iter, (1, 2))
+ range_ = node.iter
+ validate_call_args(range_, (1, 2), kwargs=["bound"])
- args = node.iter.args
+ args = range_.args
+ kwargs = {s.arg: s.value for s in range_.keywords or []}
if len(args) == 1:
# range(CONSTANT)
- if not isinstance(args[0], vy_ast.Num):
- raise StateAccessViolation("Value must be a literal", node)
- if args[0].value <= 0:
- raise StructureException("For loop must have at least 1 iteration", args[0])
- validate_expected_type(args[0], IntegerT.any())
- type_list = get_possible_types_from_node(args[0])
+ n = args[0]
+ bound = kwargs.pop("bound", None)
+ validate_expected_type(n, IntegerT.any())
+
+ if bound is None:
+ if not isinstance(n, vy_ast.Num):
+ raise StateAccessViolation("Value must be a literal", n)
+ if n.value <= 0:
+ raise StructureException("For loop must have at least 1 iteration", args[0])
+ type_list = get_possible_types_from_node(n)
+
+ else:
+ if not isinstance(bound, vy_ast.Num):
+ raise StateAccessViolation("bound must be a literal", bound)
+ if bound.value <= 0:
+ raise StructureException("bound must be at least 1", args[0])
+ type_list = get_common_types(n, bound)
+
else:
+ if range_.keywords:
+ raise StructureException(
+ "Keyword arguments are not supported for `range(N, M)` and"
+ "`range(x, x + N)` expressions",
+ range_.keywords[0],
+ )
+
validate_expected_type(args[0], IntegerT.any())
type_list = get_common_types(*args)
if not isinstance(args[0], vy_ast.Constant):
@@ -452,14 +480,22 @@ def visit_For(self, node):
raise exc.with_annotation(node) from None
try:
- for n in node.body:
- self.visit(n)
- # type information is applied directly because the scope is
- # closed prior to the call to `StatementAnnotationVisitor`
- node.target._metadata["type"] = type_
- return
+ with NodeMetadata.enter_typechecker_speculation():
+ for n in node.body:
+ self.visit(n)
except (TypeMismatch, InvalidOperation) as exc:
for_loop_exceptions.append(exc)
+ else:
+ # type information is applied directly here because the
+ # scope is closed prior to the call to
+ # `StatementAnnotationVisitor`
+ node.target._metadata["type"] = type_
+
+ # success -- bail out instead of error handling.
+ return
+
+ # if we have gotten here, there was an error for
+ # every type tried for the iterator
if len(set(str(i) for i in for_loop_exceptions)) == 1:
# if every attempt at type checking raised the same exception
@@ -531,6 +567,10 @@ def visit_Log(self, node):
f = get_exact_type_from_node(node.value.func)
if not is_type_t(f, EventT):
raise StructureException("Value is not an event", node.value)
+ if self.func.mutability <= StateMutability.VIEW:
+ raise StructureException(
+ f"Cannot emit logs from {self.func.mutability.value.lower()} functions", node
+ )
f.fetch_call_return(node.value)
self.expr_visitor.visit(node.value)
@@ -586,3 +626,8 @@ def visit_Tuple(self, node: vy_ast.Tuple) -> None:
def visit_UnaryOp(self, node: vy_ast.UnaryOp) -> None:
self.visit(node.operand) # type: ignore[attr-defined]
+
+ def visit_IfExp(self, node: vy_ast.IfExp) -> None:
+ self.visit(node.test)
+ self.visit(node.body)
+ self.visit(node.orelse)
diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py
index db9b9d7c91..02ae82faac 100644
--- a/vyper/semantics/analysis/module.py
+++ b/vyper/semantics/analysis/module.py
@@ -4,6 +4,7 @@
import vyper.builtins.interfaces
from vyper import ast as vy_ast
+from vyper.evm.opcodes import version_check
from vyper.exceptions import (
CallViolation,
CompilerPanic,
@@ -18,14 +19,11 @@
VariableDeclarationException,
VyperException,
)
-from vyper.semantics.analysis.base import DataLocation, VarInfo
+from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.common import VyperNodeVisitorBase
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
-from vyper.semantics.analysis.utils import (
- check_constant,
- validate_expected_type,
- validate_unique_method_ids,
-)
+from vyper.semantics.analysis.utils import check_constant, validate_expected_type
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import Namespace, get_namespace
from vyper.semantics.types import EnumT, EventT, InterfaceT, StructT
from vyper.semantics.types.function import ContractFunctionT
@@ -87,7 +85,8 @@ def __init__(
if count == len(module_nodes):
err_list.raise_if_not_empty()
- # generate an `InterfacePrimitive` from the top-level node - used for building the ABI
+ # generate an `InterfaceT` from the top-level node - used for building the ABI
+ # note: also validates unique method ids
interface = InterfaceT.from_ast(module_node)
module_node._metadata["type"] = interface
self.interface = interface # this is useful downstream
@@ -96,15 +95,11 @@ def __init__(
_ns = Namespace()
# note that we don't just copy the namespace because
# there are constructor issues.
- _ns.update({k: namespace[k] for k in namespace._scopes[-1]})
+ _ns.update({k: namespace[k] for k in namespace._scopes[-1]}) # type: ignore
module_node._metadata["namespace"] = _ns
# check for collisions between 4byte function selectors
- # internal functions are intentionally included in this check, to prevent breaking
- # changes in in case of a future change to their calling convention
self_members = namespace["self"].typ.members
- functions = [i for i in self_members.values() if isinstance(i, ContractFunctionT)]
- validate_unique_method_ids(functions)
# get list of internal function calls made by each function
function_defs = self.ast.get_children(vy_ast.FunctionDef)
@@ -116,11 +111,6 @@ def __init__(
# anything that is not a function call will get semantically checked later
calls_to_self = calls_to_self.intersection(function_names)
self_members[node.name].internal_calls = calls_to_self
- if node.name in self_members[node.name].internal_calls:
- self_node = node.get_descendants(
- vy_ast.Attribute, {"value.id": "self", "attr": node.name}
- )[0]
- raise CallViolation(f"Function '{node.name}' calls into itself", self_node)
for fn_name in sorted(function_names):
if fn_name not in self_members:
@@ -193,10 +183,17 @@ def visit_VariableDecl(self, node):
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
+ # XXX: needed if we want separate transient allocator
+ # else DataLocation.TRANSIENT
+ # if node.is_transient
else DataLocation.STORAGE
)
- type_ = type_from_annotation(node.annotation)
+ type_ = type_from_annotation(node.annotation, data_loc)
+
+ if node.is_transient and not version_check(begin="cancun"):
+ raise StructureException("`transient` is not available pre-cancun", node.annotation)
+
var_info = VarInfo(
type_,
decl_node=node,
@@ -204,6 +201,7 @@ def visit_VariableDecl(self, node):
is_constant=node.is_constant,
is_public=node.is_public,
is_immutable=node.is_immutable,
+ is_transient=node.is_transient,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
node._metadata["type"] = type_
@@ -225,6 +223,17 @@ def _finalize():
except VyperException as exc:
raise exc.with_annotation(node) from None
+ def _validate_self_namespace():
+ # block globals if storage variable already exists
+ try:
+ if name in self.namespace["self"].typ.members:
+ raise NamespaceCollision(
+ f"Value '{name}' has already been declared", node
+ ) from None
+ self.namespace[name] = var_info
+ except VyperException as exc:
+ raise exc.with_annotation(node) from None
+
if node.is_constant:
if not node.value:
raise VariableDeclarationException("Constant must be declared with a value", node)
@@ -232,10 +241,7 @@ def _finalize():
raise StateAccessViolation("Value must be a literal", node.value)
validate_expected_type(node.value, type_)
- try:
- self.namespace[name] = var_info
- except VyperException as exc:
- raise exc.with_annotation(node) from None
+ _validate_self_namespace()
return _finalize()
@@ -246,16 +252,7 @@ def _finalize():
)
if node.is_immutable:
- try:
- # block immutable if storage variable already exists
- if name in self.namespace["self"].typ.members:
- raise NamespaceCollision(
- f"Value '{name}' has already been declared", node
- ) from None
- self.namespace[name] = var_info
- except VyperException as exc:
- raise exc.with_annotation(node) from None
-
+ _validate_self_namespace()
return _finalize()
try:
diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py
index abe06f0580..4f911764e0 100644
--- a/vyper/semantics/analysis/utils.py
+++ b/vyper/semantics/analysis/utils.py
@@ -3,6 +3,7 @@
from vyper import ast as vy_ast
from vyper.exceptions import (
+ CompilerPanic,
InvalidLiteral,
InvalidOperation,
InvalidReference,
@@ -23,7 +24,7 @@
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
-from vyper.utils import checksum_encode
+from vyper.utils import checksum_encode, int_to_fourbytes
def _validate_op(node, types_list, validation_fn_name):
@@ -123,6 +124,7 @@ def get_exact_type_from_node(self, node, include_type_exprs=False):
if len(types_list) > 1:
raise StructureException("Ambiguous type", node)
+
return types_list[0]
def get_possible_types_from_node(self, node, include_type_exprs=False):
@@ -145,20 +147,26 @@ def get_possible_types_from_node(self, node, include_type_exprs=False):
if "type" in node._metadata:
return [node._metadata["type"]]
- fn = self._find_fn(node)
- ret = fn(node)
+ # this method is a perf hotspot, so we cache the result and
+ # try to return it if found.
+ k = f"possible_types_from_node_{include_type_exprs}"
+ if k not in node._metadata:
+ fn = self._find_fn(node)
+ ret = fn(node)
- if not include_type_exprs:
- invalid = next((i for i in ret if isinstance(i, TYPE_T)), None)
- if invalid is not None:
- raise InvalidReference(f"not a variable or literal: '{invalid.typedef}'", node)
+ if not include_type_exprs:
+ invalid = next((i for i in ret if isinstance(i, TYPE_T)), None)
+ if invalid is not None:
+ raise InvalidReference(f"not a variable or literal: '{invalid.typedef}'", node)
- if all(isinstance(i, IntegerT) for i in ret):
- # for numeric types, sort according by number of bits descending
- # this ensures literals are cast with the largest possible type
- ret.sort(key=lambda k: (k.bits, not k.is_signed), reverse=True)
+ if all(isinstance(i, IntegerT) for i in ret):
+ # for numeric types, sort according by number of bits descending
+ # this ensures literals are cast with the largest possible type
+ ret.sort(key=lambda k: (k.bits, not k.is_signed), reverse=True)
- return ret
+ node._metadata[k] = ret
+
+ return node._metadata[k].copy()
def _find_fn(self, node):
# look for a type-check method for each class in the given class mro
@@ -172,24 +180,30 @@ def _find_fn(self, node):
raise StructureException("Cannot determine type of this object", node)
def types_from_Attribute(self, node):
+ is_self_reference = node.get("value.id") == "self"
# variable attribute, e.g. `foo.bar`
t = self.get_exact_type_from_node(node.value, include_type_exprs=True)
name = node.attr
+
+ def _raise_invalid_reference(name, node):
+ raise InvalidReference(
+ f"'{name}' is not a storage variable, it should not be prepended with self", node
+ )
+
try:
s = t.get_member(name, node)
if isinstance(s, VyperType):
# ex. foo.bar(). bar() is a ContractFunctionT
return [s]
+ if is_self_reference and (s.is_constant or s.is_immutable):
+ _raise_invalid_reference(name, node)
# general case. s is a VarInfo, e.g. self.foo
return [s.typ]
except UnknownAttribute:
- if node.get("value.id") != "self":
+ if not is_self_reference:
raise
if name in self.namespace:
- raise InvalidReference(
- f"'{name}' is not a storage variable, it should not be prepended with self",
- node,
- ) from None
+ _raise_invalid_reference(name, node)
suggestions_str = get_levenshtein_error_suggestions(name, t.members, 0.4)
raise UndeclaredDefinition(
@@ -198,15 +212,20 @@ def types_from_Attribute(self, node):
def types_from_BinOp(self, node):
# binary operation: `x + y`
- types_list = get_common_types(node.left, node.right)
+ if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)):
+ # ad-hoc handling for LShift and RShift, since operands
+ # can be different types
+ types_list = get_possible_types_from_node(node.left)
+ # check rhs is unsigned integer
+ validate_expected_type(node.right, IntegerT.unsigneds())
+ else:
+ types_list = get_common_types(node.left, node.right)
if (
isinstance(node.op, (vy_ast.Div, vy_ast.Mod))
and isinstance(node.right, vy_ast.Num)
and not node.right.value
):
- # CMC 2022-07-20 this seems like unreachable code -
- # should be handled in evaluate()
raise ZeroDivisionException(f"{node.op.description} by zero", node)
return _validate_op(node, types_list, "validate_numeric_op")
@@ -261,7 +280,7 @@ def types_from_Call(self, node):
def types_from_Constant(self, node):
# literal value (integer, string, etc)
types_list = []
- for t in types.get_primitive_types().values():
+ for t in types.PRIMITIVE_TYPES.values():
try:
# clarity and perf note: will be better to construct a
# map from node types to valid vyper types
@@ -294,11 +313,20 @@ def types_from_List(self, node):
# literal array
if _is_empty_list(node):
# empty list literal `[]`
+ ret = []
# subtype can be anything
- types_list = types.get_types()
- # 1 is minimum possible length for dynarray, assignable to anything
- ret = [DArrayT(t, 1) for t in types_list.values()]
+ for t in types.PRIMITIVE_TYPES.values():
+ # 1 is minimum possible length for dynarray,
+ # can be assigned to anything
+ if isinstance(t, VyperType):
+ ret.append(DArrayT(t, 1))
+ elif isinstance(t, type) and issubclass(t, VyperType):
+ # for typeclasses like bytestrings, use a generic type acceptor
+ ret.append(DArrayT(t.any(), 1))
+ else:
+ raise CompilerPanic("busted type {t}", node)
return ret
+
types_list = get_common_types(*node.elements)
if len(types_list) > 0:
@@ -312,7 +340,11 @@ def types_from_List(self, node):
def types_from_Name(self, node):
# variable name, e.g. `foo`
name = node.id
- if name not in self.namespace and name in self.namespace["self"].typ.members:
+ if (
+ name not in self.namespace
+ and "self" in self.namespace
+ and name in self.namespace["self"].typ.members
+ ):
raise InvalidReference(
f"'{name}' is a storage variable, access it as self.{name}", node
)
@@ -351,6 +383,17 @@ def types_from_UnaryOp(self, node):
types_list = self.get_possible_types_from_node(node.operand)
return _validate_op(node, types_list, "validate_numeric_op")
+ def types_from_IfExp(self, node):
+ validate_expected_type(node.test, BoolT())
+ types_list = get_common_types(node.body, node.orelse)
+
+ if not types_list:
+ a = get_possible_types_from_node(node.body)[0]
+ b = get_possible_types_from_node(node.orelse)[0]
+ raise TypeMismatch(f"Dislike types: {a} and {b}", node)
+
+ return types_list
+
def _is_empty_list(node):
# Checks if a node is a `List` node with an empty list for `elements`,
@@ -414,7 +457,6 @@ def get_exact_type_from_node(node):
BaseType
Type object.
"""
-
return _ExprAnalyser().get_exact_type_from_node(node, include_type_exprs=True)
@@ -444,6 +486,7 @@ def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> Li
new_types = _ExprAnalyser().get_possible_types_from_node(item)
common = [i for i in common_types if _is_type_in_list(i, new_types)]
+
rejected = [i for i in common_types if i not in common]
common += [i for i in new_types if _is_type_in_list(i, rejected)]
@@ -496,8 +539,8 @@ def validate_expected_type(node, expected_type):
if not isinstance(expected_type, tuple):
expected_type = (expected_type,)
- if isinstance(node, (vy_ast.List, vy_ast.Tuple)):
- # special case - for literal arrays or tuples we individually validate each item
+ if isinstance(node, vy_ast.List):
+ # special case - for literal arrays we individually validate each item
for expected in expected_type:
if not isinstance(expected, (DArrayT, SArrayT)):
continue
@@ -556,8 +599,13 @@ def validate_unique_method_ids(functions: List) -> None:
seen = set()
for method_id in method_ids:
if method_id in seen:
- collision_str = ", ".join(i.name for i in functions if method_id in i.method_ids)
- raise StructureException(f"Methods have conflicting IDs: {collision_str}")
+ collision_str = ", ".join(
+ x for i in functions for x in i.method_ids.keys() if i.method_ids[x] == method_id
+ )
+ collision_hex = int_to_fourbytes(method_id).hex()
+ raise StructureException(
+ f"Methods produce colliding method ID `0x{collision_hex}`: {collision_str}"
+ )
seen.add(method_id)
diff --git a/vyper/semantics/data_locations.py b/vyper/semantics/data_locations.py
new file mode 100644
index 0000000000..2f259b1766
--- /dev/null
+++ b/vyper/semantics/data_locations.py
@@ -0,0 +1,11 @@
+import enum
+
+
+class DataLocation(enum.Enum):
+ UNSET = 0
+ MEMORY = 1
+ STORAGE = 2
+ CALLDATA = 3
+ CODE = 4
+ # XXX: needed for separate transient storage allocator
+ # TRANSIENT = 5
diff --git a/vyper/semantics/namespace.py b/vyper/semantics/namespace.py
index db9affed24..613ac0c03b 100644
--- a/vyper/semantics/namespace.py
+++ b/vyper/semantics/namespace.py
@@ -1,12 +1,7 @@
import contextlib
-import re
-
-from vyper.exceptions import (
- CompilerPanic,
- NamespaceCollision,
- StructureException,
- UndeclaredDefinition,
-)
+
+from vyper.ast.identifiers import validate_identifier
+from vyper.exceptions import CompilerPanic, NamespaceCollision, UndeclaredDefinition
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
@@ -20,17 +15,21 @@ class Namespace(dict):
List of sets containing the key names for each scope
"""
+ def __new__(cls, *args, **kwargs):
+ self = super().__new__(cls, *args, **kwargs)
+ self._scopes = []
+ return self
+
def __init__(self):
super().__init__()
- self._scopes = []
# NOTE cyclic imports!
# TODO: break this cycle by providing an `init_vyper_namespace` in 3rd module
from vyper.builtins.functions import get_builtin_functions
from vyper.semantics import environment
from vyper.semantics.analysis.base import VarInfo
- from vyper.semantics.types import get_types
+ from vyper.semantics.types import PRIMITIVE_TYPES
- self.update(get_types())
+ self.update(PRIMITIVE_TYPES)
self.update(environment.get_constant_vars())
self.update({k: VarInfo(b) for (k, b) in get_builtin_functions().items()})
@@ -117,64 +116,3 @@ def override_global_namespace(ns):
finally:
# unclobber
_namespace = tmp
-
-
-def validate_identifier(attr):
- if not re.match("^[_a-zA-Z][a-zA-Z0-9_]*$", attr):
- raise StructureException(f"'{attr}' contains invalid character(s)")
- if attr.lower() in RESERVED_KEYWORDS:
- raise StructureException(f"'{attr}' is a reserved keyword")
-
-
-# Cannot be used for variable or member naming
-RESERVED_KEYWORDS = {
- # decorators
- "public",
- "external",
- "nonpayable",
- "constant",
- "immutable",
- "internal",
- "payable",
- "nonreentrant",
- # "class" keywords
- "interface",
- "struct",
- "event",
- "enum",
- # EVM operations
- "unreachable",
- # special functions (no name mangling)
- "init",
- "_init_",
- "___init___",
- "____init____",
- "default",
- "_default_",
- "___default___",
- "____default____",
- # boolean literals
- "true",
- "false",
- # more control flow and special operations
- "this",
- "range",
- # None sentinal value
- "none",
- # more special operations
- "indexed",
- # denominations
- "ether",
- "wei",
- "finney",
- "szabo",
- "shannon",
- "lovelace",
- "ada",
- "babbage",
- "gwei",
- "kwei",
- "mwei",
- "twei",
- "pwei",
-}
diff --git a/vyper/semantics/types/__init__.py b/vyper/semantics/types/__init__.py
index 246dcfdf34..ad470718c8 100644
--- a/vyper/semantics/types/__init__.py
+++ b/vyper/semantics/types/__init__.py
@@ -7,7 +7,7 @@
from .user import EnumT, EventT, InterfaceT, StructT
-def get_primitive_types():
+def _get_primitive_types():
res = [BoolT(), DecimalT()]
res.extend(IntegerT.all())
@@ -46,8 +46,5 @@ def _get_sequence_types():
return ret
-def get_types():
- result = {}
- result.update(get_primitive_types())
-
- return result
+# note: it might be good to make this a frozen dict of some sort
+PRIMITIVE_TYPES = _get_primitive_types()
diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py
index 0ac4f7b06d..c5af5c2a39 100644
--- a/vyper/semantics/types/base.py
+++ b/vyper/semantics/types/base.py
@@ -3,6 +3,7 @@
from vyper import ast as vy_ast
from vyper.abi_types import ABIType
+from vyper.ast.identifiers import validate_identifier
from vyper.exceptions import (
CompilerPanic,
InvalidLiteral,
@@ -12,7 +13,6 @@
UnknownAttribute,
)
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
-from vyper.semantics.namespace import validate_identifier
# Some fake type with an overridden `compare_type` which accepts any RHS
@@ -40,6 +40,8 @@ class VyperType:
If `True`, this type can be used as the base member for an array.
_valid_literal : Tuple
A tuple of Vyper ast classes that may be assigned this type.
+ _invalid_locations : Tuple
+ A tuple of invalid `DataLocation`s for this type
_is_prim_word: bool, optional
This is a word type like uint256, int8, bytesM or address
"""
@@ -47,6 +49,7 @@ class VyperType:
_id: str
_type_members: Optional[Dict] = None
_valid_literal: Tuple = ()
+ _invalid_locations: Tuple = ()
_is_prim_word: bool = False
_equality_attrs: Optional[Tuple] = None
_is_array_type: bool = False
@@ -311,7 +314,9 @@ def __init__(self, typ, default, require_literal=False):
self.require_literal = require_literal
-# A type type. Only used internally for builtins
+# A type type. Used internally for types which can live in expression
+# position, ex. constructors (events, interfaces and structs), and also
+# certain builtins which take types as parameters
class TYPE_T:
def __init__(self, typedef):
self.typedef = typedef
diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py
index 21f5087c66..77b9efb13d 100644
--- a/vyper/semantics/types/function.py
+++ b/vyper/semantics/types/function.py
@@ -1,9 +1,11 @@
import re
import warnings
-from collections import OrderedDict
-from typing import Any, Dict, List, Optional, Set, Tuple
+from dataclasses import dataclass
+from functools import cached_property
+from typing import Any, Dict, List, Optional, Tuple
from vyper import ast as vy_ast
+from vyper.ast.identifiers import validate_identifier
from vyper.ast.validation import validate_call_args
from vyper.exceptions import (
ArgumentException,
@@ -11,24 +13,35 @@
CompilerPanic,
FunctionDeclarationException,
InvalidType,
- NamespaceCollision,
StateAccessViolation,
StructureException,
)
-from vyper.semantics.analysis.base import (
- DataLocation,
- FunctionVisibility,
- StateMutability,
- StorageSlot,
-)
+from vyper.semantics.analysis.base import FunctionVisibility, StateMutability, StorageSlot
from vyper.semantics.analysis.utils import check_kwargable, validate_expected_type
-from vyper.semantics.namespace import get_namespace
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import KwargSettings, VyperType
from vyper.semantics.types.primitives import BoolT
from vyper.semantics.types.shortcuts import UINT256_T
from vyper.semantics.types.subscriptable import TupleT
from vyper.semantics.types.utils import type_from_abi, type_from_annotation
-from vyper.utils import keccak256
+from vyper.utils import OrderedSet, keccak256
+
+
+@dataclass
+class _FunctionArg:
+ name: str
+ typ: VyperType
+
+
+@dataclass
+class PositionalArg(_FunctionArg):
+ ast_source: Optional[vy_ast.VyperNode] = None
+
+
+@dataclass
+class KeywordArg(_FunctionArg):
+ default_value: vy_ast.VyperNode
+ ast_source: Optional[vy_ast.VyperNode] = None
class ContractFunctionT(VyperType):
@@ -37,26 +50,23 @@ class ContractFunctionT(VyperType):
Functions compare false against all types and so cannot be assigned without
being called. Calls are validated by `fetch_call_return`, check the call
- arguments against `arguments`, and return `return_type`.
+ arguments against `positional_args` and `keyword_arg`, and return `return_type`.
Attributes
----------
name : str
The name of the function.
- arguments : OrderedDict
- Function input arguments as {'name': VyperType}
- min_arg_count : int
- The minimum number of required input arguments.
- max_arg_count : int
- The maximum number of required input arguments. When a function has no
- default arguments, this value is the same as `min_arg_count`.
- kwarg_keys : List
- List of optional input argument keys.
+ positional_args: list[PositionalArg]
+ Positional args for this function
+ keyword_args: list[KeywordArg]
+ Keyword args for this function
+ return_type: Optional[VyperType]
+ Type of return value
function_visibility : FunctionVisibility
enum indicating the external visibility of a function.
state_mutability : StateMutability
enum indicating the authority a function has to mutate it's own state.
- nonreentrant : str
+ nonreentrant : Optional[str]
Re-entrancy lock name.
"""
@@ -65,10 +75,8 @@ class ContractFunctionT(VyperType):
def __init__(
self,
name: str,
- arguments: OrderedDict,
- # TODO rename to something like positional_args, keyword_args
- min_arg_count: int,
- max_arg_count: int,
+ positional_args: list[PositionalArg],
+ keyword_args: list[KeywordArg],
return_type: Optional[VyperType],
function_visibility: FunctionVisibility,
state_mutability: StateMutability,
@@ -77,32 +85,38 @@ def __init__(
super().__init__()
self.name = name
- self.arguments = arguments
- self.min_arg_count = min_arg_count
- self.max_arg_count = max_arg_count
+ self.positional_args = positional_args
+ self.keyword_args = keyword_args
self.return_type = return_type
- self.kwarg_keys = []
- if min_arg_count < max_arg_count:
- self.kwarg_keys = list(self.arguments)[min_arg_count:]
self.visibility = function_visibility
self.mutability = state_mutability
self.nonreentrant = nonreentrant
# a list of internal functions this function calls
- self.called_functions: Set["ContractFunctionT"] = set()
+ self.called_functions = OrderedSet()
+ # to be populated during codegen
+ self._ir_info: Any = None
+
+ @cached_property
+ def call_site_kwargs(self):
# special kwargs that are allowed in call site
- self.call_site_kwargs = {
+ return {
"gas": KwargSettings(UINT256_T, "gas"),
"value": KwargSettings(UINT256_T, 0),
"skip_contract_check": KwargSettings(BoolT(), False, require_literal=True),
- "default_return_value": KwargSettings(return_type, None),
+ "default_return_value": KwargSettings(self.return_type, None),
}
def __repr__(self):
- arg_types = ",".join(repr(a) for a in self.arguments.values())
+ arg_types = ",".join(repr(a) for a in self.argument_types)
return f"contract function {self.name}({arg_types})"
+ def __str__(self):
+ ret_sig = "" if not self.return_type else f" -> {self.return_type}"
+ args_sig = ",".join([str(t) for t in self.argument_types])
+ return f"def {self.name} {args_sig}{ret_sig}:"
+
# override parent implementation. function type equality does not
# make too much sense.
def __eq__(self, other):
@@ -125,10 +139,9 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT":
-------
ContractFunctionT object.
"""
-
- arguments = OrderedDict()
+ positional_args = []
for item in abi["inputs"]:
- arguments[item["name"]] = type_from_abi(item)
+ positional_args.append(PositionalArg(item["name"], type_from_abi(item)))
return_type = None
if len(abi["outputs"]) == 1:
return_type = type_from_abi(abi["outputs"][0])
@@ -136,9 +149,8 @@ def from_abi(cls, abi: Dict) -> "ContractFunctionT":
return_type = TupleT(tuple(type_from_abi(i) for i in abi["outputs"]))
return cls(
abi["name"],
- arguments,
- len(arguments),
- len(arguments),
+ positional_args,
+ [],
return_type,
function_visibility=FunctionVisibility.EXTERNAL,
state_mutability=StateMutability.from_abi(abi),
@@ -209,7 +221,10 @@ def from_FunctionDef(
msg = "Nonreentrant decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)
- kwargs["nonreentrant"] = decorator.args[0].value
+ nonreentrant_key = decorator.args[0].value
+ validate_identifier(nonreentrant_key, decorator.args[0])
+
+ kwargs["nonreentrant"] = nonreentrant_key
elif isinstance(decorator, vy_ast.Name):
if FunctionVisibility.is_valid_value(decorator.id):
@@ -278,35 +293,39 @@ def from_FunctionDef(
"Constructor may not use default arguments", node.args.defaults[0]
)
- arguments = OrderedDict()
- max_arg_count = len(node.args.args)
- min_arg_count = max_arg_count - len(node.args.defaults)
- defaults = [None] * min_arg_count + node.args.defaults
+ argnames = set() # for checking uniqueness
+ n_total_args = len(node.args.args)
+ n_positional_args = n_total_args - len(node.args.defaults)
- namespace = get_namespace()
- for arg, value in zip(node.args.args, defaults):
- if arg.arg in ("gas", "value", "skip_contract_check", "default_return_value"):
+ positional_args: list[PositionalArg] = []
+ keyword_args: list[KeywordArg] = []
+
+ for i, arg in enumerate(node.args.args):
+ argname = arg.arg
+ if argname in ("gas", "value", "skip_contract_check", "default_return_value"):
raise ArgumentException(
- f"Cannot use '{arg.arg}' as a variable name in a function input", arg
+ f"Cannot use '{argname}' as a variable name in a function input", arg
)
- if arg.arg in arguments:
- raise ArgumentException(f"Function contains multiple inputs named {arg.arg}", arg)
- if arg.arg in namespace:
- raise NamespaceCollision(arg.arg, arg)
+ if argname in argnames:
+ raise ArgumentException(f"Function contains multiple inputs named {argname}", arg)
if arg.annotation is None:
- raise ArgumentException(f"Function argument '{arg.arg}' is missing a type", arg)
+ raise ArgumentException(f"Function argument '{argname}' is missing a type", arg)
- type_ = type_from_annotation(arg.annotation)
+ type_ = type_from_annotation(arg.annotation, DataLocation.CALLDATA)
- if value is not None:
+ if i < n_positional_args:
+ positional_args.append(PositionalArg(argname, type_, ast_source=arg))
+ else:
+ value = node.args.defaults[i - n_positional_args]
if not check_kwargable(value):
raise StateAccessViolation(
"Value must be literal or environment variable", value
)
validate_expected_type(value, type_)
+ keyword_args.append(KeywordArg(argname, type_, value, ast_source=arg))
- arguments[arg.arg] = type_
+ argnames.add(argname)
# return types
if node.returns is None:
@@ -315,17 +334,13 @@ def from_FunctionDef(
raise FunctionDeclarationException(
"Constructor may not have a return type", node.returns
)
- elif isinstance(node.returns, (vy_ast.Name, vy_ast.Subscript)):
- return_type = type_from_annotation(node.returns)
- elif isinstance(node.returns, vy_ast.Tuple):
- tuple_types: Tuple = ()
- for n in node.returns.elements:
- tuple_types += (type_from_annotation(n),)
- return_type = TupleT(tuple_types)
+ elif isinstance(node.returns, (vy_ast.Name, vy_ast.Subscript, vy_ast.Tuple)):
+ # note: consider, for cleanliness, adding DataLocation.RETURN_VALUE
+ return_type = type_from_annotation(node.returns, DataLocation.MEMORY)
else:
raise InvalidType("Function return value must be a type name or tuple", node.returns)
- return cls(node.name, arguments, min_arg_count, max_arg_count, return_type, **kwargs)
+ return cls(node.name, positional_args, keyword_args, return_type, **kwargs)
def set_reentrancy_key_position(self, position: StorageSlot) -> None:
if hasattr(self, "reentrancy_key_position"):
@@ -355,17 +370,16 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
"""
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
- type_ = type_from_annotation(node.annotation)
+ type_ = type_from_annotation(node.annotation, DataLocation.STORAGE)
arguments, return_type = type_.getter_signature
- args_dict: OrderedDict = OrderedDict()
- for item in arguments:
- args_dict[f"arg{len(args_dict)}"] = item
+ args = []
+ for i, item in enumerate(arguments):
+ args.append(PositionalArg(f"arg{i}", item))
return cls(
node.target.id,
- args_dict,
- len(arguments),
- len(arguments),
+ args,
+ [],
return_type,
function_visibility=FunctionVisibility.EXTERNAL,
state_mutability=StateMutability.VIEW,
@@ -375,11 +389,12 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
# convenience property for compare_signature, as it would
# appear in a public interface
def _iface_sig(self) -> Tuple[Tuple, Optional[VyperType]]:
- return tuple(self.arguments.values()), self.return_type
+ return tuple(self.argument_types), self.return_type
- def compare_signature(self, other: "ContractFunctionT") -> bool:
+ def implements(self, other: "ContractFunctionT") -> bool:
"""
- Compare the signature of this function with another function.
+ Checks if this function implements the signature of another
+ function.
Used when determining if an interface has been implemented. This method
should not be directly implemented by any inherited classes.
@@ -396,11 +411,40 @@ def compare_signature(self, other: "ContractFunctionT") -> bool:
for atyp, btyp in zip(arguments, other_arguments):
if not atyp.compare_type(btyp):
return False
+
if return_type and not return_type.compare_type(other_return_type): # type: ignore
return False
+ if self.mutability > other.mutability:
+ return False
+
return True
+ @cached_property
+ def default_values(self) -> dict[str, vy_ast.VyperNode]:
+ return {arg.name: arg.default_value for arg in self.keyword_args}
+
+ # for backwards compatibility
+ @cached_property
+ def arguments(self) -> list[_FunctionArg]:
+ return self.positional_args + self.keyword_args # type: ignore
+
+ @cached_property
+ def argument_types(self) -> list[VyperType]:
+ return [arg.typ for arg in self.arguments]
+
+ @property
+ def n_positional_args(self) -> int:
+ return len(self.positional_args)
+
+ @property
+ def n_keyword_args(self) -> int:
+ return len(self.keyword_args)
+
+ @cached_property
+ def n_total_args(self) -> int:
+ return self.n_positional_args + self.n_keyword_args
+
@property
def is_external(self) -> bool:
return self.visibility == FunctionVisibility.EXTERNAL
@@ -409,6 +453,22 @@ def is_external(self) -> bool:
def is_internal(self) -> bool:
return self.visibility == FunctionVisibility.INTERNAL
+ @property
+ def is_mutable(self) -> bool:
+ return self.mutability > StateMutability.VIEW
+
+ @property
+ def is_payable(self) -> bool:
+ return self.mutability == StateMutability.PAYABLE
+
+ @property
+ def is_constructor(self) -> bool:
+ return self.name == "__init__"
+
+ @property
+ def is_fallback(self) -> bool:
+ return self.name == "__default__"
+
@property
def method_ids(self) -> Dict[str, int]:
"""
@@ -418,58 +478,32 @@ def method_ids(self) -> Dict[str, int]:
* For functions with default arguments, there is one key for each
function signature.
"""
- arg_types = [i.canonical_abi_type for i in self.arguments.values()]
+ arg_types = [i.canonical_abi_type for i in self.argument_types]
- if not self.has_default_args:
+ if self.n_keyword_args == 0:
return _generate_method_id(self.name, arg_types)
method_ids = {}
- for i in range(self.min_arg_count, self.max_arg_count + 1):
+ for i in range(self.n_positional_args, self.n_total_args + 1):
method_ids.update(_generate_method_id(self.name, arg_types[:i]))
return method_ids
- # for caller-fills-args calling convention
- def get_args_buffer_offset(self) -> int:
- """
- Get the location of the args buffer in the function frame (caller sets)
- """
- return 0
-
- # TODO is this needed?
- def get_args_buffer_len(self) -> int:
- """
- Get the length of the argument buffer in the function frame
- """
- return sum(arg_t.size_in_bytes() for arg_t in self.arguments.values())
-
- @property
- def is_constructor(self) -> bool:
- return self.name == "__init__"
-
- @property
- def is_fallback(self) -> bool:
- return self.name == "__default__"
-
- @property
- def has_default_args(self) -> bool:
- return self.min_arg_count < self.max_arg_count
-
def fetch_call_return(self, node: vy_ast.Call) -> Optional[VyperType]:
if node.get("func.value.id") == "self" and self.visibility == FunctionVisibility.EXTERNAL:
raise CallViolation("Cannot call external functions via 'self'", node)
+ kwarg_keys = []
# for external calls, include gas and value as optional kwargs
- kwarg_keys = self.kwarg_keys.copy()
- if node.get("func.value.id") != "self":
+ if not self.is_internal:
kwarg_keys += list(self.call_site_kwargs.keys())
- validate_call_args(node, (self.min_arg_count, self.max_arg_count), kwarg_keys)
+ validate_call_args(node, (self.n_positional_args, self.n_total_args), kwarg_keys)
if self.mutability < StateMutability.PAYABLE:
kwarg_node = next((k for k in node.keywords if k.arg == "value"), None)
if kwarg_node is not None:
raise CallViolation("Cannot send ether to nonpayable function", kwarg_node)
- for arg, expected in zip(node.args, self.arguments.values()):
+ for arg, expected in zip(node.args, self.argument_types):
validate_expected_type(arg, expected)
# TODO this should be moved to validate_call_args
@@ -524,7 +558,7 @@ def to_toplevel_abi_dict(self):
abi_dict["type"] = "function"
abi_dict["name"] = self.name
- abi_dict["inputs"] = [v.to_abi_arg(name=k) for k, v in self.arguments.items()]
+ abi_dict["inputs"] = [arg.typ.to_abi_arg(name=arg.name) for arg in self.arguments]
typ = self.return_type
if typ is None:
@@ -534,16 +568,21 @@ def to_toplevel_abi_dict(self):
else:
abi_dict["outputs"] = [typ.to_abi_arg()]
- if self.has_default_args:
+ if self.n_keyword_args > 0:
# for functions with default args, return a dict for each possible arg count
result = []
- for i in range(self.min_arg_count, self.max_arg_count + 1):
+ for i in range(self.n_positional_args, self.n_total_args + 1):
result.append(abi_dict.copy())
result[-1]["inputs"] = result[-1]["inputs"][:i]
return result
else:
return [abi_dict]
+ # calculate the abi signature for a given set of kwargs
+ def abi_signature_for_kwargs(self, kwargs: list[KeywordArg]) -> str:
+ args = self.positional_args + kwargs # type: ignore
+ return self.name + "(" + ",".join([arg.typ.abi_type.selector_name() for arg in args]) + ")"
+
class MemberFunctionT(VyperType):
"""
diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py
index 5b1ab5ab8e..07d1a21a94 100644
--- a/vyper/semantics/types/primitives.py
+++ b/vyper/semantics/types/primitives.py
@@ -141,14 +141,23 @@ def validate_numeric_op(
if isinstance(node.op, self._invalid_ops):
self._raise_invalid_op(node)
- if isinstance(node.op, vy_ast.Pow):
+ def _get_lr():
if isinstance(node, vy_ast.BinOp):
- left, right = node.left, node.right
+ return node.left, node.right
elif isinstance(node, vy_ast.AugAssign):
- left, right = node.target, node.value
+ return node.target, node.value
else:
raise CompilerPanic(f"Unexpected node type for numeric op: {type(node).__name__}")
+ if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)):
+ if self._bits != 256:
+ raise InvalidOperation(
+ f"Cannot perform {node.op.description} on non-int256/uint256 type!", node
+ )
+
+ if isinstance(node.op, vy_ast.Pow):
+ left, right = _get_lr()
+
value_bits = self._bits - (1 if self._is_signed else 0)
# TODO double check: this code seems duplicated with constant eval
diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py
index a5ae075b73..6a2d3aae73 100644
--- a/vyper/semantics/types/subscriptable.py
+++ b/vyper/semantics/types/subscriptable.py
@@ -1,8 +1,10 @@
+import warnings
from typing import Any, Dict, Optional, Tuple, Union
from vyper import ast as vy_ast
from vyper.abi_types import ABI_DynamicArray, ABI_StaticArray, ABI_Tuple, ABIType
from vyper.exceptions import ArrayIndexException, InvalidType, StructureException
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.primitives import IntegerT
from vyper.semantics.types.shortcuts import UINT256_T
@@ -43,6 +45,14 @@ class HashMapT(_SubscriptableT):
_equality_attrs = ("key_type", "value_type")
+ # disallow everything but storage
+ _invalid_locations = (
+ DataLocation.UNSET,
+ DataLocation.CALLDATA,
+ DataLocation.CODE,
+ DataLocation.MEMORY,
+ )
+
def __repr__(self):
return f"HashMap[{self.key_type}, {self.value_type}]"
@@ -72,15 +82,13 @@ def from_annotation(cls, node: Union[vy_ast.Name, vy_ast.Call, vy_ast.Subscript]
),
node,
)
- # if location != DataLocation.STORAGE or is_immutable:
- # raise StructureException("HashMap can only be declared as a storage variable", node)
k_ast, v_ast = node.slice.value.elements
- key_type = type_from_annotation(k_ast)
+ key_type = type_from_annotation(k_ast, DataLocation.STORAGE)
if not key_type._as_hashmap_key:
raise InvalidType("can only use primitive types as HashMap key!", k_ast)
- value_type = type_from_annotation(v_ast)
+ value_type = type_from_annotation(v_ast, DataLocation.STORAGE)
return cls(key_type, value_type)
@@ -103,6 +111,9 @@ def __init__(self, value_type: VyperType, length: int):
if not 0 < length < 2**256:
raise InvalidType("Array length is invalid")
+ if length >= 2**64:
+ warnings.warn("Use of large arrays can be unsafe!")
+
super().__init__(UINT256_T, value_type)
self.length = length
@@ -287,12 +298,19 @@ class TupleT(VyperType):
"""
Tuple type definition.
- This class is used to represent multiple return values from
- functions.
+ This class is used to represent multiple return values from functions.
"""
_equality_attrs = ("members",)
+ # note: docs say that tuples are not instantiable but they
+ # are in fact instantiable and the codegen works. if we
+ # wanted to be stricter in the typechecker, we could
+ # add _invalid_locations = everything but UNSET and RETURN_VALUE.
+ # (we would need to add a DataLocation.RETURN_VALUE in order for
+ # tuples to be instantiable as return values but not in memory).
+ # _invalid_locations = ...
+
def __init__(self, member_types: Tuple[VyperType, ...]) -> None:
super().__init__()
self.member_types = member_types
diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py
index 536d482e75..ce82731c34 100644
--- a/vyper/semantics/types/user.py
+++ b/vyper/semantics/types/user.py
@@ -17,6 +17,7 @@
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import VyperType
from vyper.semantics.types.function import ContractFunctionT
@@ -31,6 +32,10 @@ class _UserType(VyperType):
def __eq__(self, other):
return self is other
+ # TODO: revisit this once user types can be imported via modules
+ def compare_type(self, other):
+ return super().compare_type(other) and self._id == other._id
+
def __hash__(self):
return hash(id(self))
@@ -153,10 +158,13 @@ class EventT(_UserType):
Name of the event.
"""
+ _invalid_locations = tuple(iter(DataLocation)) # not instantiable in any location
+
def __init__(self, name: str, arguments: dict, indexed: list) -> None:
super().__init__(members=arguments)
self.name = name
self.indexed = indexed
+ assert len(self.indexed) == len(self.arguments)
self.event_id = int(keccak256(self.signature.encode()).hex(), 16)
# backward compatible
@@ -165,8 +173,13 @@ def arguments(self):
return self.members
def __repr__(self):
- arg_types = ",".join(repr(a) for a in self.arguments.values())
- return f"event {self.name}({arg_types})"
+ args = []
+ for is_indexed, (_, argtype) in zip(self.indexed, self.arguments.items()):
+ argtype_str = repr(argtype)
+ if is_indexed:
+ argtype_str = f"indexed({argtype_str})"
+ args.append(f"{argtype_str}")
+ return f"event {self.name}({','.join(args)})"
# TODO rename to abi_signature
@property
@@ -306,18 +319,18 @@ def validate_implements(self, node: vy_ast.ImplementsDecl) -> None:
def _is_function_implemented(fn_name, fn_type):
vyper_self = namespace["self"].typ
- if name not in vyper_self.members:
+ if fn_name not in vyper_self.members:
return False
- s = vyper_self.members[name]
+ s = vyper_self.members[fn_name]
if isinstance(s, ContractFunctionT):
- to_compare = vyper_self.members[name]
+ to_compare = vyper_self.members[fn_name]
# this is kludgy, rework order of passes in ModuleNodeVisitor
elif isinstance(s, VarInfo) and s.is_public:
to_compare = s.decl_node._metadata["func_type"]
else:
return False
- return to_compare.compare_signature(fn_type)
+ return to_compare.implements(fn_type)
# check for missing functions
for name, type_ in self.members.items():
@@ -330,14 +343,22 @@ def _is_function_implemented(fn_name, fn_type):
# check for missing events
for name, event in self.events.items():
+ if name not in namespace:
+ unimplemented.append(name)
+ continue
+
+ if not isinstance(namespace[name], EventT):
+ unimplemented.append(f"{name} is not an event!")
if (
- name not in namespace
- or not isinstance(namespace[name], EventT)
- or namespace[name].event_id != event.event_id
+ namespace[name].event_id != event.event_id
+ or namespace[name].indexed != event.indexed
):
- unimplemented.append(name)
+ unimplemented.append(f"{name} is not implemented! (should be {event})")
if len(unimplemented) > 0:
+ # TODO: improve the error message for cases where the
+ # mismatch is small (like mutability, or just one argument
+ # is off, etc).
missing_str = ", ".join(sorted(unimplemented))
raise InterfaceViolation(
f"Contract does not implement all interface functions or events: {missing_str}",
@@ -396,7 +417,7 @@ def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT":
@classmethod
def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "InterfaceT":
"""
- Generate an `InterfacePrimitive` object from a Vyper ast node.
+ Generate an `InterfaceT` object from a Vyper ast node.
Arguments
---------
@@ -404,7 +425,7 @@ def from_ast(cls, node: Union[vy_ast.InterfaceDef, vy_ast.Module]) -> "Interface
Vyper ast node defining the interface
Returns
-------
- InterfacePrimitive
+ InterfaceT
primitive interface type
"""
if isinstance(node, vy_ast.Module):
@@ -471,10 +492,6 @@ def __init__(self, _id, members, ast_def=None):
self.ast_def = ast_def
- for n, t in self.members.items():
- if isinstance(t, HashMapT):
- raise StructureException(f"Struct contains a mapping '{n}'", ast_def)
-
@cached_property
def name(self) -> str:
# Alias for API compatibility with codegen
@@ -538,10 +555,6 @@ def from_ast_def(cls, base_node: vy_ast.StructDef) -> "StructT":
def __repr__(self):
return f"{self._id} declaration object"
- # TODO check me
- def compare_type(self, other):
- return super().compare_type(other) and self._id == other._id
-
@property
def size_in_bytes(self):
return sum(i.size_in_bytes for i in self.member_types.values())
diff --git a/vyper/semantics/types/utils.py b/vyper/semantics/types/utils.py
index 6ae677e451..1187080ca9 100644
--- a/vyper/semantics/types/utils.py
+++ b/vyper/semantics/types/utils.py
@@ -1,8 +1,15 @@
from typing import Dict
from vyper import ast as vy_ast
-from vyper.exceptions import ArrayIndexException, InvalidType, StructureException, UnknownType
+from vyper.exceptions import (
+ ArrayIndexException,
+ InstantiationException,
+ InvalidType,
+ StructureException,
+ UnknownType,
+)
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
+from vyper.semantics.data_locations import DataLocation
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import VyperType
@@ -60,9 +67,11 @@ def type_from_abi(abi_type: Dict) -> VyperType:
raise UnknownType(f"ABI contains unknown type: {type_string}") from None
-def type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
+def type_from_annotation(
+ node: vy_ast.VyperNode, location: DataLocation = DataLocation.UNSET
+) -> VyperType:
"""
- Return a type object for the given AST node.
+ Return a type object for the given AST node after validating its location.
Arguments
---------
@@ -74,6 +83,16 @@ def type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
VyperType
Type definition object.
"""
+ typ_ = _type_from_annotation(node)
+
+ if location in typ_._invalid_locations:
+ location_str = "" if location is DataLocation.UNSET else f"in {location.name.lower()}"
+ raise InstantiationException(f"{typ_} is not instantiable {location_str}", node)
+
+ return typ_
+
+
+def _type_from_annotation(node: vy_ast.VyperNode) -> VyperType:
namespace = get_namespace()
def _failwith(type_name):
@@ -84,7 +103,6 @@ def _failwith(type_name):
if isinstance(node, vy_ast.Tuple):
tuple_t = namespace["$TupleT"]
-
return tuple_t.from_annotation(node)
if isinstance(node, vy_ast.Subscript):
@@ -104,7 +122,14 @@ def _failwith(type_name):
if node.id not in namespace:
_failwith(node.node_source_code)
- return namespace[node.id]
+ typ_ = namespace[node.id]
+ if hasattr(typ_, "from_annotation"):
+ # cases where the object in the namespace is an uninstantiated
+ # type object, ex. Bytestring or DynArray (with no length provided).
+ # call from_annotation to produce a better error message.
+ typ_.from_annotation(node)
+
+ return typ_
def get_index_value(node: vy_ast.Index) -> int:
diff --git a/vyper/utils.py b/vyper/utils.py
index 37a3f13b3d..3d9d9cb416 100644
--- a/vyper/utils.py
+++ b/vyper/utils.py
@@ -11,6 +11,19 @@
from vyper.exceptions import DecimalOverrideException, InvalidLiteral
+class OrderedSet(dict):
+ """
+ a minimal "ordered set" class. this is needed in some places
+ because, while dict guarantees you can recover insertion order
+ vanilla sets do not.
+ no attempt is made to fully implement the set API, will add
+ functionality as needed.
+ """
+
+ def add(self, item):
+ self[item] = None
+
+
class DecimalContextOverride(decimal.Context):
def __setattr__(self, name, value):
if name == "prec":
@@ -183,8 +196,7 @@ def calc_mem_gas(memsize):
# Specific gas usage
GAS_IDENTITY = 15
GAS_IDENTITYWORD = 3
-GAS_CODECOPY_WORD = 3
-GAS_CALLDATACOPY_WORD = 3
+GAS_COPY_WORD = 3 # i.e., W_copy from YP
# A decimal value can store multiples of 1/DECIMAL_DIVISOR
MAX_DECIMAL_PLACES = 10