diff --git a/src/cli/gentest/cli.py b/src/cli/gentest/cli.py index 14926299aa..db2e00c8ae 100644 --- a/src/cli/gentest/cli.py +++ b/src/cli/gentest/cli.py @@ -9,15 +9,11 @@ from typing import TextIO import click -import jinja2 from ethereum_test_base_types import Hash -from .request_manager import RPCRequest -from .test_providers import BlockchainTestProvider - -template_loader = jinja2.PackageLoader("cli.gentest") -template_env = jinja2.Environment(loader=template_loader, keep_trailing_newline=True) +from .source_code_generator import get_test_source +from .test_context_providers import BlockchainTestContextProvider @click.command() @@ -31,26 +27,9 @@ def generate(transaction_hash: str, output_file: TextIO): OUTPUT_FILE is the path to the output python script. """ - request = RPCRequest() - - print( - "Perform tx request: eth_get_transaction_by_hash(" + f"{transaction_hash}" + ")", - file=stderr, - ) - transaction = request.eth_get_transaction_by_hash(Hash(transaction_hash)) - - print("Perform debug_trace_call", file=stderr) - state = request.debug_trace_call(transaction) - - print("Perform eth_get_block_by_number", file=stderr) - block = request.eth_get_block_by_number(transaction.block_number) - - print("Generate py test", file=stderr) - context = BlockchainTestProvider( - block=block, transaction=transaction, state=state - ).get_context() + provider = BlockchainTestContextProvider(transaction_hash=Hash(transaction_hash)) - template = template_env.get_template("blockchain_test/transaction.py.j2") - output_file.write(template.render(context)) + source = get_test_source(provider=provider, template_path="blockchain_test/transaction.py.j2") + output_file.write(source) print("Finished", file=stderr) diff --git a/src/cli/gentest/request_manager.py b/src/cli/gentest/request_manager.py index 28774e62bc..94fed73cfa 100644 --- a/src/cli/gentest/request_manager.py +++ b/src/cli/gentest/request_manager.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from config import EnvConfig -from ethereum_test_base_types import Account, Address, Hash, HexNumber +from ethereum_test_base_types import Hash, HexNumber from ethereum_test_rpc import BlockNumberType, DebugRPC, EthRPC from ethereum_test_types import Transaction @@ -100,7 +100,7 @@ def eth_get_block_by_number(self, block_number: BlockNumberType) -> RemoteBlock: timestamp=res["timestamp"], ) - def debug_trace_call(self, transaction: RemoteTransaction) -> Dict[Address, Account]: + def debug_trace_call(self, transaction: RemoteTransaction) -> Dict[str, dict]: """ Get pre-state required for transaction """ diff --git a/src/cli/gentest/source_code_generator.py b/src/cli/gentest/source_code_generator.py new file mode 100644 index 0000000000..6159642e0e --- /dev/null +++ b/src/cli/gentest/source_code_generator.py @@ -0,0 +1,83 @@ +""" +Pytest source code generator. + +This module maps a test provider instance to pytest source code. +""" + +import subprocess +import sys +import tempfile +from pathlib import Path + +import jinja2 + +from .test_context_providers import Provider + +template_loader = jinja2.PackageLoader("cli.gentest") +template_env = jinja2.Environment(loader=template_loader, keep_trailing_newline=True) + +# This filter maps python objects to string +template_env.filters["stringify"] = lambda input: repr(input) + + +# generates a formatted pytest source code by writing provided data on a given template. +def get_test_source(provider: Provider, template_path: str) -> str: + """ + Generates formatted pytest source code by rendering a template with provided data. + + This function uses the given template path to create a pytest-compatible source + code string. It retrieves context data from the specified provider and applies + it to the template. + + Args: + provider: An object that provides the necessary context for rendering the template. + template_path (str): The path to the Jinja2 template file used to generate tests. + + Returns: + str: The formatted pytest source code. + """ + template = template_env.get_template(template_path) + rendered_template = template.render(provider.get_context()) + # return rendered_template + return format_code(rendered_template) + + +def format_code(code: str) -> str: + """ + Formats the provided Python code using the Black code formatter. + + This function writes the given code to a temporary Python file, formats it using + the Black formatter, and returns the formatted code as a string. + + Args: + code (str): The Python code to be formatted. + + Returns: + str: The formatted Python code. + """ + # Create a temporary python file + with tempfile.NamedTemporaryFile(suffix=".py") as temp_file: + # Write the code to the temporary file + temp_file.write(code.encode("utf-8")) + # Ensure the file is written + temp_file.flush() + + # Create a Path object for the input file + input_file_path = Path(temp_file.name) + + # Get the path to the black executable in the virtual environment + if sys.platform.startswith("win"): + black_path = Path(sys.prefix) / "Scripts" / "black.exe" + else: + black_path = Path(sys.prefix) / "bin" / "black" + + # Call black to format the file + config_path = Path(sys.prefix).parent / "pyproject.toml" + + subprocess.run( + [str(black_path), str(input_file_path), "--quiet", "--config", str(config_path)], + check=True, + ) + + # Return the formatted source code + return input_file_path.read_text() diff --git a/src/cli/gentest/templates/blockchain_test/transaction.py.j2 b/src/cli/gentest/templates/blockchain_test/transaction.py.j2 index 93836949a7..f5a034d668 100644 --- a/src/cli/gentest/templates/blockchain_test/transaction.py.j2 +++ b/src/cli/gentest/templates/blockchain_test/transaction.py.j2 @@ -1,5 +1,5 @@ """ -gentest autogenerated test with debug_traceCall of tx.hash +Gentest autogenerated test from `tx.hash`: {{ tx_hash }} https://etherscan.io/tx/{{tx_hash}} """ @@ -8,7 +8,14 @@ from typing import Dict import pytest -from ethereum_test_tools import Account, Block, BlockchainTestFiller, Environment, Transaction +from ethereum_test_tools import ( + Account, + Block, + BlockchainTestFiller, + Environment, + Storage, + Transaction, +) REFERENCE_SPEC_GIT_PATH = "N/A" REFERENCE_SPEC_VERSION = "N/A" @@ -16,10 +23,7 @@ REFERENCE_SPEC_VERSION = "N/A" @pytest.fixture def env(): # noqa: D103 - return Environment( -{{ environment_kwargs }} - ) - + return {{ environment | stringify }} @pytest.mark.valid_from("Paris") def test_transaction_{{ tx_hash }}( # noqa: SC200, E501 @@ -27,18 +31,13 @@ def test_transaction_{{ tx_hash }}( # noqa: SC200, E501 blockchain_test: BlockchainTestFiller, ): """ - gentest autogenerated test for tx.hash + Gentest autogenerated test for tx.hash: {{ tx_hash }} """ - pre = { -{{ pre_state_items }} - } + pre = {{ pre_state | stringify }} - post: Dict = { - } + post: Dict = {} - tx = Transaction( -{{ transaction_items }} - ) + tx = {{ transaction | stringify }} - blockchain_test(genesis_environment=env, pre=pre, post=post, blocks=[Block(txs=[tx])]) \ No newline at end of file + blockchain_test(genesis_environment=env, pre=pre, post=post, blocks=[Block(txs=[tx])]) diff --git a/src/cli/gentest/test_context_providers.py b/src/cli/gentest/test_context_providers.py new file mode 100644 index 0000000000..7e8cdfd595 --- /dev/null +++ b/src/cli/gentest/test_context_providers.py @@ -0,0 +1,104 @@ +""" +This module contains various providers which generates context required to create test scripts. + +Classes: +- Provider: An provider generates required context for creating a test. +- BlockchainTestProvider: The BlockchainTestProvider takes a transaction hash and creates + required context to create a test. + +Example: + provider = BlockchainTestContextProvider(transaction=transaction) + context = provider.get_context() +""" + +from abc import ABC, abstractmethod +from sys import stderr +from typing import Any, Dict, Optional + +from pydantic import BaseModel + +from ethereum_test_base_types import Account, Hash +from ethereum_test_tools import Environment, Transaction + +from .request_manager import RPCRequest + + +class Provider(ABC, BaseModel): + """ + An provider generates required context for creating a test. + """ + + @abstractmethod + def get_context(self) -> Dict: + """ + Get the context for generating a test. + """ + + pass + + +class BlockchainTestContextProvider(Provider): + """ + Provides context required to generate a `blockchain_test` using pytest. + """ + + transaction_hash: Hash + block: Optional[RPCRequest.RemoteBlock] = None + transaction: Optional[RPCRequest.RemoteTransaction] = None + state: Optional[Dict[str, Dict]] = None + + def _make_rpc_calls(self): + request = RPCRequest() + print( + f"Perform tx request: eth_get_transaction_by_hash({self.transaction_hash})", + file=stderr, + ) + self.transaction = request.eth_get_transaction_by_hash(self.transaction_hash) + + print("Perform debug_trace_call", file=stderr) + self.state = request.debug_trace_call(self.transaction) + + print("Perform eth_get_block_by_number", file=stderr) + self.block = request.eth_get_block_by_number(self.transaction.block_number) + + print("Generate py test", file=stderr) + + def _get_environment(self) -> Environment: + assert self.block is not None + return Environment(**self.block.model_dump()) + + def _get_pre_state(self) -> Dict[str, Account]: + assert self.state is not None + assert self.transaction is not None + + pre_state: Dict[str, Account] = {} + for address, account_data in self.state.items(): + + # TODO: Check if this is required. Ideally, + # the pre-state tracer should have the correct + # values without requiring any additional modifications. + if address == self.transaction.sender: + account_data["nonce"] = self.transaction.nonce + + pre_state[address] = Account(**account_data) + return pre_state + + def _get_transaction(self) -> Transaction: + assert self.transaction is not None + return Transaction(**self.transaction.model_dump()) + + def get_context(self) -> Dict[str, Any]: + """ + Get the context for generating a blockchain test. + + Returns: + Dict[str, Any]: A dictionary containing environment, + pre-state, a transaction and its hash. + """ + self._make_rpc_calls() + return { + "environment": self._get_environment(), + "pre_state": self._get_pre_state(), + "transaction": self._get_transaction(), + "tx_hash": self.transaction_hash, + } diff --git a/src/cli/gentest/test_providers.py b/src/cli/gentest/test_providers.py deleted file mode 100644 index 0779632e45..0000000000 --- a/src/cli/gentest/test_providers.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -This module contains various providers which generates context required to create test scripts. - -Classes: -- BlockchainTestProvider: The BlockchainTestProvider class takes information about a block, -a transaction, and the associated state, and provides methods to generate various elements -needed for testing, such as module docstrings, test names, and pre-state items. - -Example: - provider = BlockchainTestProvider(block=block, transaction=transaction, state=state) - context = provider.get_context() -""" - -from typing import Any, Dict - -from pydantic import BaseModel - -from ethereum_test_base_types import Account, Address, ZeroPaddedHexNumber - -from .request_manager import RPCRequest - - -class BlockchainTestProvider(BaseModel): - """ - Provides context required to generate a `blockchain_test` using pytest. - """ - - block: RPCRequest.RemoteBlock - transaction: RPCRequest.RemoteTransaction - state: Dict[Address, Account] - - def _get_environment_kwargs(self) -> str: - env_str = "" - pad = " " - for field, value in self.block.dict().items(): - env_str += ( - f'{pad}{field}="{value}",\n' if field == "coinbase" else f"{pad}{field}={value},\n" - ) - - return env_str - - # TODO: Output should be dict. Formatting should happen in the template. - def _get_pre_state_items(self) -> str: - # Print a nice .py storage pre - pad = " " - state_str = "" - for address, account_obj in self.state.items(): - state_str += f' "{address}": Account(\n' - state_str += f"{pad}balance={str(account_obj.balance)},\n" - if address == self.transaction.sender: - state_str += f"{pad}nonce={self.transaction.nonce},\n" - else: - state_str += f"{pad}nonce={str(account_obj.nonce)},\n" - - if account_obj.code is None: - state_str += f'{pad}code="0x",\n' - else: - state_str += f'{pad}code="{str(account_obj.code)}",\n' - state_str += pad + "storage={\n" - - if account_obj.storage is not None: - for record, value in account_obj.storage.root.items(): - pad_record = ZeroPaddedHexNumber(record) - pad_value = ZeroPaddedHexNumber(value) - state_str += f'{pad} "{pad_record}" : "{pad_value}",\n' - - state_str += pad + "}\n" - state_str += " ),\n" - return state_str - - # TODO: Output should be dict. Formatting should happen in the template. - def _get_transaction_items(self) -> str: - """ - Print legacy transaction in .py - """ - pad = " " - tr_str = "" - quoted_fields_array = ["data", "to"] - hex_fields_array = ["v", "r", "s"] - legacy_fields_array = [ - "ty", - "chain_id", - "nonce", - "gas_price", - "protected", - "gas_limit", - "value", - ] - for field, value in iter(self.transaction): - if value is None: - continue - - if field in legacy_fields_array: - tr_str += f"{pad}{field}={value},\n" - - if field in quoted_fields_array: - tr_str += f'{pad}{field}="{value}",\n' - - if field in hex_fields_array: - tr_str += f"{pad}{field}={hex(value)},\n" - - return tr_str - - def get_context(self) -> Dict[str, Any]: - """ - Get the context for generating a blockchain test. - - Returns: - Dict[str, Any]: A dictionary containing module docstring, test name, - test docstring, environment kwargs, pre-state items, and transaction items. - """ - return { - "environment_kwargs": self._get_environment_kwargs(), - "pre_state_items": self._get_pre_state_items(), - "transaction_items": self._get_transaction_items(), - "tx_hash": self.transaction.tx_hash, - } diff --git a/src/ethereum_test_base_types/__init__.py b/src/ethereum_test_base_types/__init__.py index 7348b84044..2e57cd0020 100644 --- a/src/ethereum_test_base_types/__init__.py +++ b/src/ethereum_test_base_types/__init__.py @@ -31,7 +31,7 @@ ) from .conversions import to_bytes, to_hex from .json import to_json -from .pydantic import CamelModel +from .pydantic import CamelModel, EthereumTestBaseModel, EthereumTestRootModel from .reference_spec import ReferenceSpec __all__ = ( @@ -49,6 +49,8 @@ "EmptyOmmersRoot", "EmptyTrieRoot", "FixedSizeBytes", + "EthereumTestBaseModel", + "EthereumTestRootModel", "Hash", "HashInt", "HeaderNonce", diff --git a/src/ethereum_test_base_types/composite_types.py b/src/ethereum_test_base_types/composite_types.py index 8d689c63e0..762f58ffe3 100644 --- a/src/ethereum_test_base_types/composite_types.py +++ b/src/ethereum_test_base_types/composite_types.py @@ -5,11 +5,11 @@ from dataclasses import dataclass from typing import Any, ClassVar, Dict, List, SupportsBytes, Type, TypeAlias -from pydantic import Field, PrivateAttr, RootModel, TypeAdapter +from pydantic import Field, PrivateAttr, TypeAdapter from .base_types import Address, Bytes, Hash, HashInt, HexNumber, ZeroPaddedHexNumber from .conversions import BytesConvertible, NumberConvertible -from .pydantic import CamelModel +from .pydantic import CamelModel, EthereumTestRootModel StorageKeyValueTypeConvertible = NumberConvertible StorageKeyValueType = HashInt @@ -17,7 +17,7 @@ StorageRootType = Dict[NumberConvertible, NumberConvertible] -class Storage(RootModel[Dict[StorageKeyValueType, StorageKeyValueType]]): +class Storage(EthereumTestRootModel[Dict[StorageKeyValueType, StorageKeyValueType]]): """ Definition of a storage in pre or post state of a test """ @@ -468,7 +468,7 @@ def to_kwargs_dict(account: "Dict | Account | None") -> Dict: return cls(**kwargs) -class Alloc(RootModel[Dict[Address, Account | None]]): +class Alloc(EthereumTestRootModel[Dict[Address, Account | None]]): """ Allocation of accounts in the state, pre and post test execution. """ diff --git a/src/ethereum_test_base_types/json.py b/src/ethereum_test_base_types/json.py index 9b63d89502..a4d95fdb92 100644 --- a/src/ethereum_test_base_types/json.py +++ b/src/ethereum_test_base_types/json.py @@ -4,18 +4,23 @@ from typing import Any, AnyStr, List -from pydantic import BaseModel, RootModel +from .pydantic import EthereumTestBaseModel, EthereumTestRootModel def to_json( - input: BaseModel | RootModel | AnyStr | List[BaseModel | RootModel | AnyStr], + input: ( + EthereumTestBaseModel + | EthereumTestRootModel + | AnyStr + | List[EthereumTestBaseModel | EthereumTestRootModel | AnyStr] + ), ) -> Any: """ Converts a model to its json data representation. """ if isinstance(input, list): return [to_json(item) for item in input] - elif isinstance(input, (BaseModel, RootModel)): - return input.model_dump(mode="json", by_alias=True, exclude_none=True) + elif isinstance(input, (EthereumTestBaseModel, EthereumTestRootModel)): + return input.serialize(mode="json", by_alias=True) else: return str(input) diff --git a/src/ethereum_test_base_types/mixins.py b/src/ethereum_test_base_types/mixins.py new file mode 100644 index 0000000000..9d26066228 --- /dev/null +++ b/src/ethereum_test_base_types/mixins.py @@ -0,0 +1,79 @@ +""" +This module provides various mixins for Pydantic models. +""" + +from typing import Any, Literal + +from pydantic import BaseModel + + +class ModelCustomizationsMixin: + """ + A mixin that customizes the behavior of pydantic models. Any pydantic + configuration override that must apply to all models + should be placed here. + + This mixin is applied to both `EthereumTestBaseModel` and `EthereumTestRootModel`. + """ + + def serialize( + self, + mode: Literal["json", "python"], + by_alias: bool, + exclude_none: bool = True, + ) -> dict[str, Any]: + """ + Serializes the model to the specified format with the given parameters. + + :param mode: The mode of serialization. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + :param by_alias: Whether to use aliases for field names. + :param exclude_none: Whether to exclude fields with None values, default is True. + :return: The serialized representation of the model. + """ + if not hasattr(self, "model_dump"): + raise NotImplementedError( + f"{self.__class__.__name__} does not have 'model_dump' method." + "Are you sure you are using a Pydantic model?" + ) + return self.model_dump(mode=mode, by_alias=by_alias, exclude_none=exclude_none) + + def __repr_args__(self): + """ + Generate a list of attribute-value pairs for the object representation. + + This method serializes the model, retrieves the attribute names, + and constructs a list of tuples containing attribute names and their corresponding values. + Only attributes with non-None values are included in the list. + + This method is used by the __repr__ method to generate the object representation, + and is used by `gentest` module to generate the test cases. + + See: + - https://pydantic-docs.helpmanual.io/usage/models/#custom-repr + - https://github.com/ethereum/execution-spec-tests/pull/901#issuecomment-2443296835 + + Returns: + List[Tuple[str, Any]]: A list of tuples where each tuple contains an attribute name + and its corresponding non-None value. + """ + attrs_names = self.serialize(mode="python", by_alias=False).keys() + attrs = ((s, getattr(self, s)) for s in attrs_names) + + # Convert field values based on their type. + # This ensures consistency between JSON and Python object representations. + # Should a custom `__repr__` be needed for a specific type, it can added in the + # match statement below. + # Otherwise, the default string representation is used. + repr_attrs = [] + for a, v in attrs: + match v: + + # Note: The `None` case handles an edge case with transactions + # see: https://github.com/ethereum/execution-spec-tests/pull/901#discussion_r1828491918 # noqa: E501 + case list() | dict() | BaseModel() | None: + repr_attrs.append((a, v)) + case _: + repr_attrs.append((a, str(v))) + return repr_attrs diff --git a/src/ethereum_test_base_types/pydantic.py b/src/ethereum_test_base_types/pydantic.py index 6ef2c2b0a4..60bad656fa 100644 --- a/src/ethereum_test_base_types/pydantic.py +++ b/src/ethereum_test_base_types/pydantic.py @@ -1,17 +1,38 @@ """ Base pydantic classes used to define the models for Ethereum tests. """ -from typing import TypeVar -from pydantic import BaseModel, ConfigDict +from typing import Any, TypeVar + +from pydantic import BaseModel, ConfigDict, RootModel from pydantic.alias_generators import to_camel +from .mixins import ModelCustomizationsMixin + Model = TypeVar("Model", bound=BaseModel) +RootModelRootType = TypeVar("RootModelRootType") + + +class EthereumTestBaseModel(BaseModel, ModelCustomizationsMixin): + """ + Base model for all models for Ethereum tests. + """ + + pass + + +class EthereumTestRootModel(RootModel[RootModelRootType], ModelCustomizationsMixin): + """ + Base model for all models for Ethereum tests. + """ + + root: Any + -class CopyValidateModel(BaseModel): +class CopyValidateModel(EthereumTestBaseModel): """ - Base model for Ethereum tests. + Model that supports copying with validation. """ def copy(self: Model, **kwargs) -> Model: diff --git a/src/ethereum_test_fixtures/file.py b/src/ethereum_test_fixtures/file.py index 3d6e573a7a..7115334948 100644 --- a/src/ethereum_test_fixtures/file.py +++ b/src/ethereum_test_fixtures/file.py @@ -1,11 +1,12 @@ """ Defines models for interacting with JSON fixture files. """ + import json from pathlib import Path from typing import Any, Dict, Optional -from pydantic import RootModel +from ethereum_test_base_types import EthereumTestRootModel from .base import FixtureFormat from .blockchain import EngineFixture as BlockchainEngineFixture @@ -16,7 +17,7 @@ FixtureModel = BlockchainFixture | BlockchainEngineFixture | StateFixture | EOFFixture -class BaseFixturesRootModel(RootModel): +class BaseFixturesRootModel(EthereumTestRootModel): """ A base class for defining top-level models that encapsulate multiple test fixtures. Each fixture is stored in a dictionary, where each key is a string diff --git a/whitelist.txt b/whitelist.txt index 1e8486f13b..4d1ac828e9 100644 --- a/whitelist.txt +++ b/whitelist.txt @@ -273,6 +273,8 @@ md mem mempool metaclass +mixin +mixins mixhash mkdocs mkdocstrings