-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
♻️ refactor(gentest,base_types): Improved architecture, implement Eth…
…ereumTestBaseModel, EthereumTestRootModel (#901) * ✨ feat(gentest): Improved architecture * 🥢 nit: Rename provider to context provider * 🥢 nit: Simplfied docstring * ✨ feat: Custom repr * 🐞 fix: Quirk: prestate account nonce * 🧹 chore: Unifiy json and python serialization --------- Co-authored-by: rahul <[email protected]>
- Loading branch information
Showing
13 changed files
with
334 additions
and
176 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Pytest source code generator. | ||
This module maps a test provider instance to pytest source code. | ||
""" | ||
|
||
import subprocess | ||
import sys | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import jinja2 | ||
|
||
from .test_context_providers import Provider | ||
|
||
template_loader = jinja2.PackageLoader("cli.gentest") | ||
template_env = jinja2.Environment(loader=template_loader, keep_trailing_newline=True) | ||
|
||
# This filter maps python objects to string | ||
template_env.filters["stringify"] = lambda input: repr(input) | ||
|
||
|
||
# generates a formatted pytest source code by writing provided data on a given template. | ||
def get_test_source(provider: Provider, template_path: str) -> str: | ||
""" | ||
Generates formatted pytest source code by rendering a template with provided data. | ||
This function uses the given template path to create a pytest-compatible source | ||
code string. It retrieves context data from the specified provider and applies | ||
it to the template. | ||
Args: | ||
provider: An object that provides the necessary context for rendering the template. | ||
template_path (str): The path to the Jinja2 template file used to generate tests. | ||
Returns: | ||
str: The formatted pytest source code. | ||
""" | ||
template = template_env.get_template(template_path) | ||
rendered_template = template.render(provider.get_context()) | ||
# return rendered_template | ||
return format_code(rendered_template) | ||
|
||
|
||
def format_code(code: str) -> str: | ||
""" | ||
Formats the provided Python code using the Black code formatter. | ||
This function writes the given code to a temporary Python file, formats it using | ||
the Black formatter, and returns the formatted code as a string. | ||
Args: | ||
code (str): The Python code to be formatted. | ||
Returns: | ||
str: The formatted Python code. | ||
""" | ||
# Create a temporary python file | ||
with tempfile.NamedTemporaryFile(suffix=".py") as temp_file: | ||
# Write the code to the temporary file | ||
temp_file.write(code.encode("utf-8")) | ||
# Ensure the file is written | ||
temp_file.flush() | ||
|
||
# Create a Path object for the input file | ||
input_file_path = Path(temp_file.name) | ||
|
||
# Get the path to the black executable in the virtual environment | ||
if sys.platform.startswith("win"): | ||
black_path = Path(sys.prefix) / "Scripts" / "black.exe" | ||
else: | ||
black_path = Path(sys.prefix) / "bin" / "black" | ||
|
||
# Call black to format the file | ||
config_path = Path(sys.prefix).parent / "pyproject.toml" | ||
|
||
subprocess.run( | ||
[str(black_path), str(input_file_path), "--quiet", "--config", str(config_path)], | ||
check=True, | ||
) | ||
|
||
# Return the formatted source code | ||
return input_file_path.read_text() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
""" | ||
This module contains various providers which generates context required to create test scripts. | ||
Classes: | ||
- Provider: An provider generates required context for creating a test. | ||
- BlockchainTestProvider: The BlockchainTestProvider takes a transaction hash and creates | ||
required context to create a test. | ||
Example: | ||
provider = BlockchainTestContextProvider(transaction=transaction) | ||
context = provider.get_context() | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
from sys import stderr | ||
from typing import Any, Dict, Optional | ||
|
||
from pydantic import BaseModel | ||
|
||
from ethereum_test_base_types import Account, Hash | ||
from ethereum_test_tools import Environment, Transaction | ||
|
||
from .request_manager import RPCRequest | ||
|
||
|
||
class Provider(ABC, BaseModel): | ||
""" | ||
An provider generates required context for creating a test. | ||
""" | ||
|
||
@abstractmethod | ||
def get_context(self) -> Dict: | ||
""" | ||
Get the context for generating a test. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class BlockchainTestContextProvider(Provider): | ||
""" | ||
Provides context required to generate a `blockchain_test` using pytest. | ||
""" | ||
|
||
transaction_hash: Hash | ||
block: Optional[RPCRequest.RemoteBlock] = None | ||
transaction: Optional[RPCRequest.RemoteTransaction] = None | ||
state: Optional[Dict[str, Dict]] = None | ||
|
||
def _make_rpc_calls(self): | ||
request = RPCRequest() | ||
print( | ||
f"Perform tx request: eth_get_transaction_by_hash({self.transaction_hash})", | ||
file=stderr, | ||
) | ||
self.transaction = request.eth_get_transaction_by_hash(self.transaction_hash) | ||
|
||
print("Perform debug_trace_call", file=stderr) | ||
self.state = request.debug_trace_call(self.transaction) | ||
|
||
print("Perform eth_get_block_by_number", file=stderr) | ||
self.block = request.eth_get_block_by_number(self.transaction.block_number) | ||
|
||
print("Generate py test", file=stderr) | ||
|
||
def _get_environment(self) -> Environment: | ||
assert self.block is not None | ||
return Environment(**self.block.model_dump()) | ||
|
||
def _get_pre_state(self) -> Dict[str, Account]: | ||
assert self.state is not None | ||
assert self.transaction is not None | ||
|
||
pre_state: Dict[str, Account] = {} | ||
for address, account_data in self.state.items(): | ||
|
||
# TODO: Check if this is required. Ideally, | ||
# the pre-state tracer should have the correct | ||
# values without requiring any additional modifications. | ||
if address == self.transaction.sender: | ||
account_data["nonce"] = self.transaction.nonce | ||
|
||
pre_state[address] = Account(**account_data) | ||
return pre_state | ||
|
||
def _get_transaction(self) -> Transaction: | ||
assert self.transaction is not None | ||
return Transaction(**self.transaction.model_dump()) | ||
|
||
def get_context(self) -> Dict[str, Any]: | ||
""" | ||
Get the context for generating a blockchain test. | ||
Returns: | ||
Dict[str, Any]: A dictionary containing environment, | ||
pre-state, a transaction and its hash. | ||
""" | ||
self._make_rpc_calls() | ||
return { | ||
"environment": self._get_environment(), | ||
"pre_state": self._get_pre_state(), | ||
"transaction": self._get_transaction(), | ||
"tx_hash": self.transaction_hash, | ||
} |
Oops, something went wrong.