Skip to content

Commit

Permalink
Merge pull request nucypher#147 from derekpierre/common-scripts
Browse files Browse the repository at this point in the history
Establish common scripts and add `click` support
  • Loading branch information
derekpierre authored Oct 11, 2023
2 parents 261c8ec + 1a53949 commit 31cdefc
Show file tree
Hide file tree
Showing 20 changed files with 197 additions and 143 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,22 @@ from ape import project
from deployment.constants import (
CONSTRUCTOR_PARAMS_DIR,
CURRENT_NETWORK,
LOCAL_BLOCKCHAIN_ENVIRONMENTS,
)
from deployment.networks import is_local_network
from deployment.params import Deployer
VERIFY = CURRENT_NETWORK not in LOCAL_BLOCKCHAIN_ENVIRONMENTS
VERIFY = not is_local_network()
CONSTRUCTOR_PARAMS_FILEPATH = CONSTRUCTOR_PARAMS_DIR / "my-domain" / "example.yml"
def main():
deployer = Deployer.from_yaml(filepath=CONSTRUCTOR_PARAMS_FILEPATH, verify=VERIFY)
deployer = Deployer.from_yaml(filepath=CONSTRUCTOR_PARAMS_FILEPATH,
verify=VERIFY)
token = deployer.deploy(project.MyToken)
my_contract_with_no_parameters = deployer.deploy(project.MyContractWithNoParameters)
my_contract_with_parameters = deployer.deploy(project.MyContractWithParameters)
my_contract_with_no_parameters = deployer.deploy(
project.MyContractWithNoParameters)
my_contract_with_parameters = deployer.deploy(
project.MyContractWithParameters)
deployments = [
token,
my_contract_with_no_parameters,
Expand Down
4 changes: 1 addition & 3 deletions deployment/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from pathlib import Path

from ape import networks, project
from ape import project

import deployment

LOCAL_BLOCKCHAIN_ENVIRONMENTS = ["local"]
CURRENT_NETWORK = networks.network.name
DEPLOYMENT_DIR = Path(deployment.__file__).parent
CONSTRUCTOR_PARAMS_DIR = DEPLOYMENT_DIR / "constructor_params"
ARTIFACTS_DIR = DEPLOYMENT_DIR / "artifacts"
Expand Down
5 changes: 5 additions & 0 deletions deployment/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ape import networks


def is_local_network():
return networks.network.name in ["local"]
36 changes: 18 additions & 18 deletions deployment/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
from ape.cli import get_user_selected_account
from ape.contracts.base import ContractContainer, ContractInstance, ContractTransactionHandler
from ape.utils import ZERO_ADDRESS
from eth_typing import ChecksumAddress
from ethpm_types import MethodABI
from web3.auto.gethdev import w3

from deployment.confirm import _confirm_resolution, _continue
from deployment.constants import (
BYTES_PREFIX,
Expand All @@ -24,12 +20,14 @@
from deployment.registry import registry_from_ape_deployments
from deployment.utils import (
_load_yaml,
verify_contracts,
get_contract_container,
check_plugins,
get_artifact_filepath,
validate_config
get_contract_container,
validate_config,
verify_contracts,
)
from eth_typing import ChecksumAddress
from ethpm_types import MethodABI
from web3.auto.gethdev import w3


def _is_variable(param: Any) -> bool:
Expand Down Expand Up @@ -263,23 +261,23 @@ def validate_constructor_parameters(config: typing.OrderedDict[str, Any]) -> Non
contract_name=contract,
abi_inputs=contract_container.constructor.abi.inputs,
parameters=parameters,
)
)


def _get_contracts_config(config: typing.Dict) -> OrderedDict:
"""Returns the contracts config from a constructor parameters file."""
try:
contracts = config['contracts']
contracts = config["contracts"]
except KeyError:
raise ValueError(f"Constructor parameters file missing 'contracts' field.")
raise ValueError("Constructor parameters file missing 'contracts' field.")
result = OrderedDict()
for contract in contracts:
if isinstance(contract, str):
contract = {contract: OrderedDict()}
elif isinstance(contract, dict):
contract = OrderedDict(contract)
else:
raise ValueError(f"Malformed constructor parameters YAML.")
raise ValueError("Malformed constructor parameters YAML.")
result.update(contract)
return result

Expand Down Expand Up @@ -314,6 +312,8 @@ class Transactor:
def __init__(self, account: typing.Optional[AccountAPI] = None):
if account is None:
self._account = get_user_selected_account()
else:
self._account = account

def get_account(self) -> AccountAPI:
"""Returns the transactor account."""
Expand Down Expand Up @@ -347,11 +347,11 @@ class Deployer(Transactor):
__DEPLOYER_ACCOUNT: AccountAPI = None

def __init__(
self,
config: typing.Dict,
path: Path,
verify: bool,
account: typing.Optional[AccountAPI] = None
self,
config: typing.Dict,
path: Path,
verify: bool,
account: typing.Optional[AccountAPI] = None,
):
check_plugins()
self.path = path
Expand Down Expand Up @@ -426,6 +426,6 @@ def _confirm_start(self) -> None:
f"Network: {networks.provider.network.name}",
f"Chain ID: {networks.provider.network.chain_id}",
f"Gas Price: {networks.provider.gas_price}",
sep="\n"
sep="\n",
)
_continue()
63 changes: 31 additions & 32 deletions deployment/registry.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import json
import os
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional

from ape.contracts import ContractInstance
from deployment.utils import _load_json, get_contract_container
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from web3.types import ABI

from deployment.utils import validate_config, get_contract_container, _load_json

ChainId = int
ContractName = str


class RegistryEntry(NamedTuple):
"""Represents a single entry in a nucypher-style contract registry."""

chain_id: ChainId
name: ContractName
address: ChecksumAddress
Expand Down Expand Up @@ -63,22 +62,18 @@ def _get_entry(
chain_id=receipt.chain_id,
tx_hash=receipt.txn_hash,
block_number=receipt.block_number,
deployer=receipt.transaction.sender
deployer=receipt.transaction.sender,
)
return entry


def _get_entries(
contract_instances: List[ContractInstance],
registry_names: Dict[ContractName, ContractName]
contract_instances: List[ContractInstance], registry_names: Dict[ContractName, ContractName]
) -> List[RegistryEntry]:
"""Returns a list of contract entries from a list of contract instances."""
entries = list()
for contract_instance in contract_instances:
entry = _get_entry(
contract_instance=contract_instance,
registry_names=registry_names
)
entry = _get_entry(contract_instance=contract_instance, registry_names=registry_names)
entries.append(entry)
return entries

Expand All @@ -92,11 +87,11 @@ def read_registry(filepath: Path) -> List[RegistryEntry]:
registry_entry = RegistryEntry(
chain_id=int(chain_id),
name=contract_name,
address=artifacts['address'],
abi=artifacts['abi'],
tx_hash=artifacts['tx_hash'],
block_number=artifacts['block_number'],
deployer=artifacts['deployer']
address=artifacts["address"],
abi=artifacts["abi"],
tx_hash=artifacts["tx_hash"],
block_number=artifacts["block_number"],
deployer=artifacts["deployer"],
)
registry_entries.append(registry_entry)
return registry_entries
Expand Down Expand Up @@ -128,9 +123,11 @@ def write_registry(entries: List[RegistryEntry], filepath: Path) -> Path:
existing_data = _load_json(filepath)

if any(chain_id in existing_data for chain_id in data):
filepath = filepath.with_suffix('.unmerged.json')
print("Cannot merge registries with overlapping chain IDs.\n"
f"Writing to {filepath} to avoid overwriting existing data.")
filepath = filepath.with_suffix(".unmerged.json")
print(
"Cannot merge registries with overlapping chain IDs.\n"
f"Writing to {filepath} to avoid overwriting existing data."
)
else:
existing_data.update(data)
data = existing_data
Expand All @@ -153,12 +150,10 @@ def _select_conflict_resolution(
) -> ConflictResolution:
print(f"\n! Conflict detected for {registry_1_entry.name}:")
print(
f"[1]: {registry_1_entry.name} at {registry_1_entry.address} "
f"for {registry_1_filepath}"
f"[1]: {registry_1_entry.name} at {registry_1_entry.address} " f"for {registry_1_filepath}"
)
print(
f"[2]: {registry_2_entry.name} at {registry_2_entry.address} "
f"for {registry_2_filepath}"
f"[2]: {registry_2_entry.name} at {registry_2_entry.address} " f"for {registry_2_filepath}"
)
print("[A]: Abort merge")

Expand Down Expand Up @@ -191,20 +186,22 @@ def registry_from_ape_deployments(


def merge_registries(
registry_1_filepath: Path,
registry_2_filepath: Path,
output_filepath: Path,
deprecated_contracts: Optional[List[ContractName]] = None,
registry_1_filepath: Path,
registry_2_filepath: Path,
output_filepath: Path,
deprecated_contracts: Optional[List[ContractName]] = None,
) -> Path:
"""Merges two nucypher-style contract registries created from ape deployments API."""
validate_config(registry_filepath=output_filepath)

# If no deprecated contracts are specified, use an empty list
deprecated_contracts = deprecated_contracts or []

# Read the registries, excluding deprecated contracts
reg1 = {e.name: e for e in read_registry(registry_1_filepath) if e.name not in deprecated_contracts}
reg2 = {e.name: e for e in read_registry(registry_2_filepath) if e.name not in deprecated_contracts}
reg1 = {
e.name: e for e in read_registry(registry_1_filepath) if e.name not in deprecated_contracts
}
reg2 = {
e.name: e for e in read_registry(registry_2_filepath) if e.name not in deprecated_contracts
}

merged: List[RegistryEntry] = list()

Expand All @@ -218,7 +215,7 @@ def merge_registries(
registry_1_entry=entry_1,
registry_2_entry=entry_2,
registry_1_filepath=registry_1_filepath,
registry_2_filepath=registry_2_filepath
registry_2_filepath=registry_2_filepath,
)
selected_entry = entry_1 if resolution == ConflictResolution.USE_1 else entry_2
else:
Expand All @@ -233,11 +230,13 @@ def merge_registries(
return output_filepath


def contracts_from_registry(filepath: Path) -> Dict[str, ContractInstance]:
def contracts_from_registry(filepath: Path, chain_id: ChainId) -> Dict[str, ContractInstance]:
"""Returns a dictionary of contract instances from a nucypher-style contract registry."""
registry_entries = read_registry(filepath=filepath)
deployments = dict()
for registry_entry in registry_entries:
if registry_entry.chain_id != chain_id:
continue
contract_type = registry_entry.name
contract_container = get_contract_container(contract_type)
contract_instance = contract_container.at(registry_entry.address)
Expand Down
10 changes: 5 additions & 5 deletions deployment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from ape_etherscan.utils import API_KEY_ENV_KEY_MAP

from deployment.constants import (
CURRENT_NETWORK,
LOCAL_BLOCKCHAIN_ENVIRONMENTS, ARTIFACTS_DIR
ARTIFACTS_DIR
)
from deployment.networks import is_local_network


def _load_yaml(filepath: Path) -> dict:
Expand Down Expand Up @@ -53,7 +53,7 @@ def validate_config(config: Dict) -> Path:

config_chain_id = int(config_chain_id) # Convert chain_id to int here after ensuring it is not None
chain_mismatch = config_chain_id != networks.provider.network.chain_id
live_deployment = CURRENT_NETWORK not in LOCAL_BLOCKCHAIN_ENVIRONMENTS
live_deployment = not is_local_network()
if chain_mismatch and live_deployment:
raise ValueError(
f"chain_id in params file ({config_chain_id}) does not match "
Expand All @@ -76,7 +76,7 @@ def check_etherscan_plugin() -> None:
Checks that the ape-etherscan plugin is installed and that
the appropriate API key environment variable is set.
"""
if CURRENT_NETWORK in LOCAL_BLOCKCHAIN_ENVIRONMENTS:
if is_local_network():
# unnecessary for local deployment
return
try:
Expand All @@ -92,7 +92,7 @@ def check_etherscan_plugin() -> None:

def check_infura_plugin() -> None:
"""Checks that the ape-infura plugin is installed."""
if CURRENT_NETWORK in LOCAL_BLOCKCHAIN_ENVIRONMENTS:
if is_local_network():
return # unnecessary for local deployment
try:
import ape_infura # noqa: F401
Expand Down
34 changes: 34 additions & 0 deletions scripts/grant_initiator_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/python3
from pathlib import Path

import click
from ape import networks, project
from ape.cli import NetworkBoundCommand, account_option, network_option
from deployment.params import Transactor
from deployment.registry import contracts_from_registry
from deployment.utils import check_plugins


@click.command(cls=NetworkBoundCommand)
@network_option(required=True)
@account_option()
@click.option(
"--registry-filepath",
"-r",
help="Filepath to registry file",
type=click.Path(dir_okay=False, exists=True, path_type=Path),
required=True,
)
def cli(network, account, registry_filepath):
check_plugins()
transactor = Transactor(account)
deployments = contracts_from_registry(
filepath=registry_filepath, chain_id=networks.active_provider.chain_id
)
coordinator = deployments[project.Coordinator.contract_type.name]
initiator_role_hash = coordinator.INITIATOR_ROLE()
transactor.transact(
coordinator.grantRole,
initiator_role_hash,
transactor.get_account().address, # <- new initiator
)
Loading

0 comments on commit 31cdefc

Please sign in to comment.