diff --git a/CHANGELOG b/CHANGELOG index aa719420..47dcc824 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,10 @@ # Changelog +## 0.1.35 +- Refactored password handling to support environment variables +- Added COMX_UNIVERSAL_PASSWORD for a default password +- Added COMX_KEY_PASSWORDS for a key-password mapping, e.g., COMX_KEY_PASSWORDS='{"foo": "bar"}' + ## 0.1.34.6 - Add `py.typed` so type-checkers will know to use our type annotations. diff --git a/pyproject.toml b/pyproject.toml index 48a1e6ee..09b84751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "communex" -version = "0.1.34.6" +version = "0.1.35" description = "A library for Commune network focused on simplicity" authors = ["agicommies "] license = "MIT" diff --git a/src/communex/_common.py b/src/communex/_common.py index 4af1dd9b..233afdfc 100644 --- a/src/communex/_common.py +++ b/src/communex/_common.py @@ -1,9 +1,11 @@ import random import re +import warnings from collections import defaultdict from enum import Enum -from typing import Mapping, TypeVar +from typing import Any, Callable, Mapping, TypeVar +from pydantic import SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict from communex.balance import from_nano @@ -12,6 +14,17 @@ IPFS_REGEX = re.compile(r"^Qm[1-9A-HJ-NP-Za-km-z]{44}$") +def deprecated(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + warnings.warn( + f"The function {func.__name__} is deprecated and may be removed in a future version.", + DeprecationWarning, + ) + return func(*args, **kwargs) + + return wrapper + + class ComxSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="COMX_") # TODO: improve node lists @@ -19,6 +32,8 @@ class ComxSettings(BaseSettings): "wss://api.communeai.net", ] TESTNET_NODE_URLS: list[str] = ["wss://testnet.api.communeai.net"] + UNIVERSAL_PASSWORD: SecretStr | None = None + KEY_PASSWORDS: dict[str, SecretStr] | None = None def get_node_url( diff --git a/src/communex/cli/_common.py b/src/communex/cli/_common.py index df154c8e..32195c8f 100644 --- a/src/communex/cli/_common.py +++ b/src/communex/cli/_common.py @@ -1,20 +1,25 @@ from dataclasses import dataclass from getpass import getpass -from typing import Any, Mapping, TypeVar, cast +from typing import Any, Callable, Mapping, TypeVar, cast import rich +import rich.prompt import typer from rich import box from rich.console import Console from rich.table import Table +from substrateinterface import Keypair from typer import Context -from communex._common import get_node_url +from communex._common import ComxSettings, get_node_url from communex.balance import dict_from_nano, from_horus, from_nano from communex.client import CommuneClient +from communex.compat.key import resolve_key_ss58_encrypted, try_classic_load_key +from communex.errors import InvalidPasswordError, PasswordNotProvidedError from communex.types import ( ModuleInfoWithOptionalBalance, NetworkParams, + Ss58Address, SubnetParamsWithEmission, ) @@ -30,17 +35,68 @@ class ExtendedContext(Context): obj: ExtraCtxData -@dataclass +class CliPasswordProvider: + def __init__( + self, settings: ComxSettings, prompt_secret: Callable[[str], str] + ): + self.settings = settings + self.prompt_secret = prompt_secret + + def get_password(self, key_name: str) -> str | None: + key_map = self.settings.KEY_PASSWORDS + if key_map is not None: + password = key_map.get(key_name) + if password is not None: + return password.get_secret_value() + # fallback to universal password + password = self.settings.UNIVERSAL_PASSWORD + if password is not None: + return password.get_secret_value() + else: + return None + + def ask_password(self, key_name: str) -> str: + password = self.prompt_secret( + f"Please provide the password for the key '{key_name}'" + ) + return password + + class CustomCtx: ctx: ExtendedContext + settings: ComxSettings console: rich.console.Console console_err: rich.console.Console + password_manager: CliPasswordProvider _com_client: CommuneClient | None = None + def __init__( + self, + ctx: ExtendedContext, + settings: ComxSettings, + console: rich.console.Console, + console_err: rich.console.Console, + com_client: CommuneClient | None = None, + ): + self.ctx = ctx + self.settings = settings + self.console = console + self.console_err = console_err + self._com_client = com_client + self.password_manager = CliPasswordProvider( + self.settings, self.prompt_secret + ) + + def get_use_testnet(self) -> bool: + return self.ctx.obj.use_testnet + + def get_node_url(self) -> str: + use_testnet = self.get_use_testnet() + return get_node_url(self.settings, use_testnet=use_testnet) + def com_client(self) -> CommuneClient: - use_testnet = self.ctx.obj.use_testnet if self._com_client is None: - node_url = get_node_url(None, use_testnet=use_testnet) + node_url = self.get_node_url() self.info(f"Using node: {node_url}") for _ in range(5): try: @@ -51,7 +107,7 @@ def com_client(self) -> CommuneClient: ) except Exception: self.info(f"Failed to connect to node: {node_url}") - node_url = get_node_url(None, use_testnet=use_testnet) + node_url = self.get_node_url() self.info(f"Will retry with node {node_url}") continue if self._com_client is None: @@ -59,9 +115,6 @@ def com_client(self) -> CommuneClient: return self._com_client - def get_use_testnet(self) -> bool: - return self.ctx.obj.use_testnet - def output( self, message: str, @@ -78,9 +131,14 @@ def info( ) -> None: self.console_err.print(message, *args, **kwargs) # type: ignore - def error(self, message: str) -> None: + def error( + self, + message: str, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ) -> None: message = f"ERROR: {message}" - self.console_err.print(message, style="bold red") + self.console_err.print(message, *args, style="bold red", **kwargs) # type: ignore def progress_status(self, message: str): return self.console_err.status(message) @@ -89,12 +147,46 @@ def confirm(self, message: str) -> bool: if self.ctx.obj.yes_to_all: print(f"{message} (--yes)") return True - return typer.confirm(message) + return typer.confirm(message, err=True) + + def prompt_secret(self, message: str) -> str: + return rich.prompt.Prompt.ask( + message, password=True, console=self.console_err + ) + + def load_key(self, key: str, password: str | None = None) -> Keypair: + try: + keypair = try_classic_load_key( + key, password, password_provider=self.password_manager + ) + return keypair + except PasswordNotProvidedError: + self.error(f"Password not provided for key '{key}'") + raise typer.Exit(code=1) + except InvalidPasswordError: + self.error(f"Incorrect password for key '{key}'") + raise typer.Exit(code=1) + + def resolve_key_ss58( + self, key: Ss58Address | Keypair | str, password: str | None = None + ) -> Ss58Address: + try: + address = resolve_key_ss58_encrypted( + key, password, password_provider=self.password_manager + ) + return address + except PasswordNotProvidedError: + self.error(f"Password not provided for key '{key}'") + raise typer.Exit(code=1) + except InvalidPasswordError: + self.error(f"Incorrect password for key '{key}'") + raise typer.Exit(code=1) def make_custom_context(ctx: typer.Context) -> CustomCtx: return CustomCtx( - ctx=cast(ExtendedContext, ctx), + ctx=cast(ExtendedContext, ctx), # TODO: better check + settings=ComxSettings(), console=Console(), console_err=Console(stderr=True), ) @@ -114,7 +206,7 @@ def eprint(e: Any) -> None: def print_table_from_plain_dict( - result: Mapping[str, str | int | float | dict[Any, Any]], + result: Mapping[str, str | int | float | dict[Any, Any] | Ss58Address], column_names: list[str], console: Console, ) -> None: diff --git a/src/communex/cli/balance.py b/src/communex/cli/balance.py index 20a5d5af..c814a93b 100644 --- a/src/communex/cli/balance.py +++ b/src/communex/cli/balance.py @@ -10,7 +10,6 @@ make_custom_context, print_table_from_plain_dict, ) -from communex.compat.key import resolve_key_ss58_encrypted, try_classic_load_key from communex.errors import ChainTransactionError from communex.faucet.powv2 import solve_for_difficulty_fast @@ -30,7 +29,7 @@ def free_balance( context = make_custom_context(ctx) client = context.com_client() - key_address = resolve_key_ss58_encrypted(key, context, password) + key_address = context.resolve_key_ss58(key, password) with context.progress_status( f"Getting free balance of key {key_address}..." @@ -53,7 +52,7 @@ def staked_balance( context = make_custom_context(ctx) client = context.com_client() - key_address = resolve_key_ss58_encrypted(key, context, password) + key_address = context.resolve_key_ss58(key, password) with context.progress_status( f"Getting staked balance of key {key_address}..." @@ -76,7 +75,7 @@ def show( context = make_custom_context(ctx) client = context.com_client() - key_address = resolve_key_ss58_encrypted(key, context, password) + key_address = context.resolve_key_ss58(key, password) with context.progress_status(f"Getting value of key {key_address}..."): staked_balance = sum(client.get_staketo(key=key_address).values()) @@ -107,7 +106,7 @@ def get_staked( context = make_custom_context(ctx) client = context.com_client() - key_address = resolve_key_ss58_encrypted(key, context, password) + key_address = context.resolve_key_ss58(key, password) with context.progress_status(f"Getting stake of {key_address}..."): result = sum(client.get_staketo(key=key_address).values()) @@ -124,8 +123,9 @@ def transfer(ctx: Context, key: str, amount: float, dest: str): client = context.com_client() nano_amount = to_nano(amount) - resolved_key = try_classic_load_key(key, context) - resolved_dest = resolve_key_ss58_encrypted(dest, context) + + resolved_key = context.load_key(key, None) + resolved_dest = context.resolve_key_ss58(dest, None) if not context.confirm( f"Are you sure you want to transfer {amount} tokens to {dest}?" @@ -153,16 +153,16 @@ def transfer_stake( context = make_custom_context(ctx) client = context.com_client() - resolved_from = resolve_key_ss58_encrypted(from_key, context) - resolved_dest = resolve_key_ss58_encrypted(dest, context) - resolved_key = try_classic_load_key(key, context) nano_amount = to_nano(amount) + keypair = context.load_key(key, None) + resolved_from = context.resolve_key_ss58(from_key) + resolved_dest = context.resolve_key_ss58(dest) with context.progress_status( f"Transferring {amount} tokens from {from_key} to {dest}' ..." ): response = client.transfer_stake( - key=resolved_key, + key=keypair, amount=nano_amount, from_module_key=resolved_from, dest_module_address=resolved_dest, @@ -188,8 +188,9 @@ def stake( client = context.com_client() nano_amount = to_nano(amount) - resolved_key = try_classic_load_key(key, context) - resolved_dest = resolve_key_ss58_encrypted(dest, context) + keypair = context.load_key(key, None) + resolved_dest = context.resolve_key_ss58(dest, None) + delegating_message = ( "By default you delegate DAO " "voting power to the validator you stake to. " @@ -198,9 +199,10 @@ def stake( ) context.info("INFO: ", style="bold green", end="") # type: ignore context.info(delegating_message) # type: ignore + with context.progress_status(f"Staking {amount} tokens to {dest}..."): response = client.stake( - key=resolved_key, amount=nano_amount, dest=resolved_dest + key=keypair, amount=nano_amount, dest=resolved_dest ) if response.is_success: @@ -218,12 +220,12 @@ def unstake(ctx: Context, key: str, amount: float, dest: str): client = context.com_client() nano_amount = to_nano(amount) - resolved_key = try_classic_load_key(key, context) - resolved_dest = resolve_key_ss58_encrypted(dest, context) + keypair = context.load_key(key, None) + resolved_dest = context.resolve_key_ss58(dest, None) with context.progress_status(f"Unstaking {amount} tokens from {dest}'..."): response = client.unstake( - key=resolved_key, amount=nano_amount, dest=resolved_dest + key=keypair, amount=nano_amount, dest=resolved_dest ) # TODO: is it right? if response.is_success: @@ -241,10 +243,13 @@ def run_faucet( ): context = make_custom_context(ctx) use_testnet = ctx.obj.use_testnet + if not use_testnet: context.error("Faucet only enabled on testnet") - return - resolved_key = try_classic_load_key(key, context) + raise typer.Exit(code=1) + + resolved_key = context.load_key(key, None) + client = context.com_client() for _ in range(num_executions): with context.progress_status("Solving PoW..."): @@ -279,17 +284,17 @@ def transfer_dao_funds( dest: str, ): context = make_custom_context(ctx) + if not re.match(IPFS_REGEX, cid_hash): context.error(f"CID provided is invalid: {cid_hash}") - exit(1) + raise typer.Exit(code=1) + ipfs_prefix = "ipfs://" cid = ipfs_prefix + cid_hash - client = context.com_client() - nano_amount = to_nano(amount) - dest = resolve_key_ss58_encrypted(dest, context) - signer_keypair = try_classic_load_key(signer_key, context) - client.add_transfer_dao_treasury_proposal( - signer_keypair, cid, nano_amount, dest - ) + keypair = context.load_key(signer_key, None) + dest = context.resolve_key_ss58(dest, None) + + client = context.com_client() + client.add_transfer_dao_treasury_proposal(keypair, cid, nano_amount, dest) diff --git a/src/communex/cli/key.py b/src/communex/cli/key.py index 14ee16dc..cd4fdc43 100644 --- a/src/communex/cli/key.py +++ b/src/communex/cli/key.py @@ -1,25 +1,22 @@ -import json import re from enum import Enum from typing import Any, Optional, cast import typer from substrateinterface import Keypair +from typeguard import check_type from typer import Context +import communex.compat.key as comx_key from communex._common import BalanceUnit, format_balance from communex.cli._common import ( - get_universal_password, make_custom_context, print_table_from_plain_dict, print_table_standardize, ) from communex.compat.key import ( - classic_key_path, classic_store_key, local_key_addresses, - try_classic_load_key, - try_load_key, ) from communex.key import check_ss58_address, generate_keypair, is_ss58_address from communex.misc import ( @@ -61,8 +58,9 @@ def regen( """ Stores the given key on a disk. Works with private key or mnemonic. """ - # TODO: secret input from env var and stdin context = make_custom_context(ctx) + # TODO: secret input from env var and stdin + # Determine the input type based on the presence of spaces. if re.search(r"\s", key_input): # If mnemonic (contains spaces between words). @@ -96,15 +94,16 @@ def show( """ context = make_custom_context(ctx) - path = classic_key_path(key) - key_dict_json = try_load_key(path, context, password=password) - key_dict = json.loads(key_dict_json) + keypair = context.load_key(key, password) + key_dict = comx_key.to_classic_dict(keypair, path=key) if show_private is not True: key_dict["private_key"] = "[SENSITIVE-MODE]" key_dict["seed_hex"] = "[SENSITIVE-MODE]" key_dict["mnemonic"] = "[SENSITIVE-MODE]" + key_dict = check_type(key_dict, dict[str, Any]) + print_table_from_plain_dict(key_dict, ["Key", "Value"], context.console) @@ -113,29 +112,14 @@ def balances( ctx: Context, unit: BalanceUnit = BalanceUnit.joule, sort_balance: SortBalance = SortBalance.all, - use_universal_password: bool = typer.Option( - False, - help=""" - If you want to use a password to decrypt all keys. - This will only work if all encrypted keys uses the same password. - If this is not the case, leave it blank and you will be prompted to give - every password. - """, - ), ): """ Gets balances of all keys. """ context = make_custom_context(ctx) client = context.com_client() - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - local_keys = local_key_addresses( - context, universal_password=universal_password - ) + local_keys = local_key_addresses(context.password_manager) with context.console.status( "Getting balances of all keys, this might take a while..." ): @@ -199,25 +183,13 @@ def balances( @key_app.command(name="list") def inventory( ctx: Context, - use_universal_password: bool = typer.Option( - False, - help=""" - Password to decrypt all keys. - This will only work if all encrypted keys uses the same password. - If this is not the case, leave it blank and you will be prompted to give - every password. - """, - ), ): """ Lists all keys stored on disk. """ context = make_custom_context(ctx) - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - key_to_address = local_key_addresses(context, universal_password) + + key_to_address = local_key_addresses(context.password_manager) general_key_to_address: dict[str, str] = cast( dict[str, str], key_to_address ) @@ -242,7 +214,7 @@ def stakefrom( if is_ss58_address(key): key_address = key else: - keypair = try_classic_load_key(key, context, password) + keypair = context.load_key(key, password) key_address = keypair.ss58_address key_address = check_ss58_address(key_address) with context.progress_status( @@ -271,7 +243,7 @@ def staketo( if is_ss58_address(key): key_address = key else: - keypair = try_classic_load_key(key, context, password) + keypair = context.load_key(key, password) key_address = keypair.ss58_address key_address = check_ss58_address(key_address) @@ -303,11 +275,7 @@ def total_free_balance( context = make_custom_context(ctx) client = context.com_client() - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - local_keys = local_key_addresses(context, universal_password) + local_keys = local_key_addresses(context.password_manager) with context.progress_status("Getting total free balance of all keys..."): key2balance: dict[str, int] = local_keys_to_freebalance( client, local_keys @@ -338,11 +306,7 @@ def total_staked_balance( context = make_custom_context(ctx) client = context.com_client() - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - local_keys = local_key_addresses(context, universal_password) + local_keys = local_key_addresses(context.password_manager) with context.progress_status("Getting total staked balance of all keys..."): key2stake: dict[str, int] = local_keys_to_stakedbalance( client, @@ -374,11 +338,7 @@ def total_balance( context = make_custom_context(ctx) client = context.com_client() - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - local_keys = local_key_addresses(context, universal_password) + local_keys = local_key_addresses(context.password_manager) with context.progress_status("Getting total tokens of all keys..."): key2balance, key2stake = local_keys_allbalance(client, local_keys) key2tokens = {k: v + key2stake[k] for k, v in key2balance.items()} @@ -392,7 +352,6 @@ def power_delegation( ctx: Context, key: Optional[str] = None, enable: bool = typer.Option(True, "--disable"), - use_universal_password: bool = typer.Option(False), ): """ Gets power delegation of a key. @@ -409,15 +368,11 @@ def power_delegation( context.info("Aborted.") exit(0) - if use_universal_password: - universal_password = get_universal_password(context) - else: - universal_password = None - local_keys = local_key_addresses(context, universal_password) + local_keys = local_key_addresses(context.password_manager) else: local_keys = {key: None} for key_name in local_keys.keys(): - keypair = try_classic_load_key(key_name, context) + keypair = context.load_key(key_name, None) if enable is True: context.info( f"Enabling vote power delegation on key {key_name} ..." diff --git a/src/communex/cli/misc.py b/src/communex/cli/misc.py index 2a40caad..67f9cc1e 100644 --- a/src/communex/cli/misc.py +++ b/src/communex/cli/misc.py @@ -5,11 +5,7 @@ from communex.balance import from_nano from communex.cli._common import make_custom_context, print_module_info from communex.client import CommuneClient -from communex.compat.key import ( - local_key_addresses, - resolve_key_ss58_encrypted, - try_classic_load_key, -) +from communex.compat.key import local_key_addresses from communex.misc import get_map_modules from communex.types import ModuleInfoWithOptionalBalance @@ -89,7 +85,7 @@ def stats(ctx: Context, balances: bool = False, netuid: int = 0): client, netuid=netuid, include_balances=balances ) modules_to_list = [value for _, value in modules.items()] - local_keys = local_key_addresses() + local_keys = local_key_addresses(password_provider=context.password_manager) local_modules = [ *filter( lambda module: module["key"] in local_keys.values(), modules_to_list @@ -132,8 +128,8 @@ def delegate_rootnet_control(ctx: Context, key: str, target: str): """ context = make_custom_context(ctx) client = context.com_client() - resolved_key = try_classic_load_key(key, context) - ss58_target = resolve_key_ss58_encrypted(target, context) + resolved_key = context.load_key(key, None) + ss58_target = context.resolve_key_ss58(target, None) with context.progress_status("Delegating control of the rootnet..."): client.delegate_rootnet_control(resolved_key, ss58_target) diff --git a/src/communex/cli/module.py b/src/communex/cli/module.py index fa8a5508..24547b9a 100644 --- a/src/communex/cli/module.py +++ b/src/communex/cli/module.py @@ -12,7 +12,6 @@ print_module_info, print_table_from_plain_dict, ) -from communex.compat.key import try_classic_load_key from communex.errors import ChainTransactionError from communex.key import check_ss58_address from communex.misc import get_map_modules @@ -71,7 +70,8 @@ def register( context.info("Not registering") raise typer.Abort() - resolved_key = try_classic_load_key(key, context) + resolved_key = context.load_key(key, None) + with context.progress_status(f"Registering Module {name}..."): subnet_name = client.get_subnet_name(netuid) address = f"{ip}:{port}" @@ -98,7 +98,8 @@ def deregister(ctx: Context, key: str, netuid: int): context = make_custom_context(ctx) client = context.com_client() - resolved_key = try_classic_load_key(key, context) + resolved_key = context.load_key(key, None) + with context.progress_status( f"Deregistering your module on subnet {netuid}..." ): @@ -127,9 +128,11 @@ def update( context = make_custom_context(ctx) client = context.com_client() + if metadata and len(metadata) > 59: raise ValueError("Metadata must be less than 60 characters") - resolved_key = try_classic_load_key(key) + + resolved_key = context.load_key(key, None) if ip and not is_ip_valid(ip): raise ValueError("Invalid ip address") @@ -249,7 +252,8 @@ def serve( context.error(f"Class `{class_name}` not found in module `{module}`") raise typer.Exit(code=1) - keypair = try_classic_load_key(key, context) + keypair = context.load_key(key, None) + if test_mode: subnets_whitelist = None token_refill_rate = token_refill_rate_base_multiplier or 1 diff --git a/src/communex/cli/network.py b/src/communex/cli/network.py index d7b0f94f..9568040b 100644 --- a/src/communex/cli/network.py +++ b/src/communex/cli/network.py @@ -13,7 +13,7 @@ tranform_network_params, ) from communex.client import CommuneClient -from communex.compat.key import local_key_addresses, try_classic_load_key +from communex.compat.key import local_key_addresses from communex.misc import ( IPFS_REGEX, get_global_params, @@ -124,9 +124,10 @@ def propose_globally( Adds a global proposal to the network. """ context = make_custom_context(ctx) - resolved_key = try_classic_load_key(key, context) client = context.com_client() + resolved_key = context.load_key(key, None) + provided_params = cast(NetworkParams, provided_params) global_params = get_global_params(client) global_params_config = global_params["governance_config"] @@ -139,7 +140,7 @@ def propose_globally( if not re.match(IPFS_REGEX, cid): context.error(f"CID provided is invalid: {cid}") - exit(1) + typer.Exit(code=1) with context.progress_status("Adding a proposal..."): client.add_global_proposal(resolved_key, global_params, cid) context.info("Proposal added.") @@ -150,7 +151,7 @@ def get_valid_voting_keys( client: CommuneClient, threshold: int = 25000000000, # 25 $COMAI ) -> dict[str, int]: - local_keys = local_key_addresses(ctx=ctx, universal_password=None) + local_keys = local_key_addresses(password_provider=ctx.password_manager) keys_stake = local_keys_to_stakedbalance(client, local_keys) keys_stake = { key: stake for key, stake in keys_stake.items() if stake >= threshold @@ -184,9 +185,9 @@ def vote_proposal( keys_stake = {key: None} for voting_key in track(keys_stake.keys(), description="Voting..."): - resolved_key = try_classic_load_key(voting_key, context) + keypair = context.load_key(voting_key, None) try: - client.vote_on_proposal(resolved_key, proposal_id, agree) + client.vote_on_proposal(keypair, proposal_id, agree) except Exception as e: print(f"Error while voting with key {key}: ", e) print("Skipping...") @@ -201,7 +202,7 @@ def unvote_proposal(ctx: Context, key: str, proposal_id: int): context = make_custom_context(ctx) client = context.com_client() - resolved_key = try_classic_load_key(key, context) + resolved_key = context.load_key(key, None) with context.progress_status(f"Unvoting on a proposal {proposal_id}..."): client.unvote_on_proposal(resolved_key, proposal_id) @@ -223,7 +224,7 @@ def add_custom_proposal(ctx: Context, key: str, cid: str): ipfs_prefix = "ipfs://" cid = ipfs_prefix + cid - resolved_key = try_classic_load_key(key, context) + resolved_key = context.load_key(key, None) with context.progress_status("Adding a proposal..."): client.add_custom_proposal(resolved_key, cid) @@ -266,7 +267,7 @@ def set_root_weights(ctx: Context, key: str): for uid, weight in zip(uids, weights): typer.echo(f"Subnet {uid} ({subnet_names[uid]}): {weight}") - resolved_key = try_classic_load_key(key, context) + resolved_key = context.load_key(key, None) client.vote(netuid=rootnet_id, uids=uids, weights=weights, key=resolved_key) diff --git a/src/communex/cli/subnet.py b/src/communex/cli/subnet.py index 3a834efa..5396d228 100644 --- a/src/communex/cli/subnet.py +++ b/src/communex/cli/subnet.py @@ -9,7 +9,7 @@ print_table_from_plain_dict, print_table_standardize, ) -from communex.compat.key import resolve_key_ss58, try_classic_load_key +from communex.compat.key import resolve_key_ss58 from communex.errors import ChainTransactionError from communex.misc import ( IPFS_REGEX, @@ -117,8 +117,8 @@ def register( Registers a new subnet. """ context = make_custom_context(ctx) - resolved_key = try_classic_load_key(key) client = context.com_client() + resolved_key = context.load_key(key, None) with context.progress_status("Registering subnet ..."): response = client.register_subnet(resolved_key, name, metadata) @@ -201,7 +201,7 @@ def update( subnet_params["maximum_set_weight_calls_per_epoch"] = client.query( "MaximumSetWeightCallsPerEpoch" ) - resolved_key = try_classic_load_key(key) + resolved_key = context.load_key(key, None) with context.progress_status("Updating subnet ..."): response = client.update_subnet( key=resolved_key, params=subnet_params, netuid=netuid @@ -298,7 +298,7 @@ def propose_on_subnet( "MaximumSetWeightCallsPerEpoch" ) - resolved_key = try_classic_load_key(key) + resolved_key = context.load_key(key, None) with context.progress_status("Adding a proposal..."): client.add_subnet_proposal( resolved_key, subnet_params, cid, netuid=netuid @@ -321,7 +321,7 @@ def submit_general_subnet_application( client = context.com_client() - resolved_key = try_classic_load_key(key) + resolved_key = context.load_key(key, None) resolved_application_key = resolve_key_ss58(application_key) # append the ipfs hash @@ -342,17 +342,16 @@ def add_custom_proposal( """ Adds a custom proposal to a specific subnet. """ - context = make_custom_context(ctx) + if not re.match(IPFS_REGEX, cid): context.error(f"CID provided is invalid: {cid}") exit(1) client = context.com_client() - resolved_key = try_classic_load_key(key) + resolved_key = context.load_key(key, None) - # append the ipfs hash ipfs_prefix = "ipfs://" cid = ipfs_prefix + cid diff --git a/src/communex/compat/key.py b/src/communex/compat/key.py index 6c8aa7da..cedd1779 100644 --- a/src/communex/compat/key.py +++ b/src/communex/compat/key.py @@ -6,23 +6,25 @@ import json import os -from getpass import getpass from pathlib import Path -from typing import Any, Protocol, cast +from typing import Any, cast +from nacl.exceptions import CryptoError from substrateinterface import Keypair from communex.compat.storage import COMMUNE_HOME, classic_load, classic_put from communex.compat.types import CommuneKeyDict +from communex.errors import ( + InvalidPasswordError, + KeyNotFoundError, + PasswordNotProvidedError, +) from communex.key import check_ss58_address, is_ss58_address +from communex.password import NoPassword, PasswordProvider from communex.types import Ss58Address from communex.util import bytes_to_hex, check_str -class GenericCtx(Protocol): - def info(self, message: str): ... - - def check_key_dict(key_dict: Any) -> CommuneKeyDict: """ Validates a given dictionary as a commune key dictionary and returns it. @@ -68,7 +70,7 @@ def classic_key_path(name: str) -> str: def from_classic_dict( - data: dict[Any, Any], from_mnemonic: bool = False + data: dict[Any, Any], from_mnemonic: bool = True ) -> Keypair: """ Creates a `Key` from a dict conforming to the classic `commune` format. @@ -126,7 +128,7 @@ def to_classic_dict(keypair: Keypair, path: str) -> CommuneKeyDict: def classic_load_key( name: str, password: str | None = None, - from_mnemonic: bool = False, + from_mnemonic: bool = True, ) -> Keypair: """ Loads the keypair with the given name from a disk. @@ -137,6 +139,47 @@ def classic_load_key( return from_classic_dict(key_dict, from_mnemonic=from_mnemonic) +def try_classic_load_key( + key_name: str, + password: str | None = None, + *, + password_provider: PasswordProvider = NoPassword(), +) -> Keypair: + password = password or password_provider.get_password(key_name) + try: + try: + keypair = classic_load_key(key_name, password=password) + except PasswordNotProvidedError: + password = password_provider.ask_password(key_name) + keypair = classic_load_key(key_name, password=password) + except FileNotFoundError as err: + raise KeyNotFoundError( + f"Key '{key_name}' is not a valid SS58 address nor a valid key name", + err, + ) + except CryptoError as err: + raise InvalidPasswordError( + f"Invalid password for key '{key_name}'", err + ) + + return keypair + + +def try_load_key(name: str, password: str | None = None): + """ + DEPRECATED + """ + raise DeprecationWarning("Use try_classic_load_key instead") + # try: + # key_dict = classic_load(name, password=password) + # except json.JSONDecodeError: + # prompt = f"Please provide the password for the key {name}" + # print(prompt) + # password = getpass() + # key_dict = classic_load(name, password=password) + # return key_dict + + def is_encrypted(name: str) -> bool: """ Checks if the key with the given name is encrypted. @@ -160,97 +203,13 @@ def classic_store_key( classic_put(path, key_dict_json, password=password) -def try_classic_load_key( - name: str, - context: GenericCtx | None = None, - password: str | None = None, - from_mnemonic: bool = False, -) -> Keypair: - try: - keypair = classic_load_key( - name, password=password, from_mnemonic=from_mnemonic - ) - except json.JSONDecodeError: - prompt = f"Please provide the password for the key {name}" - if context is not None: - context.info(prompt) - else: - print(prompt) - password = getpass() - keypair = classic_load_key( - name, password=password, from_mnemonic=from_mnemonic - ) - return keypair - - -def try_load_key( - name: str, context: GenericCtx | None = None, password: str | None = None -): - try: - key_dict = classic_load(name, password=password) - except json.JSONDecodeError: - prompt = f"Please provide the password for the key {name}" - if context is not None: - context.info(prompt) - else: - print(prompt) - password = getpass() - key_dict = classic_load(name, password=password) - return key_dict - - -def local_key_addresses( - ctx: GenericCtx | None = None, universal_password: str | None = None -) -> dict[str, Ss58Address]: - """ - Retrieves a mapping of local key names to their SS58 addresses. - If password is passed, it will be used to decrypt every key. - If password is not passed and ctx is, - the user will be prompted for the password. - """ - home = Path.home() - key_dir = home / ".commune" / "key" - - key_names = [ - f.stem - for f in key_dir.iterdir() - if f.is_file() and not f.name.startswith(".") - ] - - addresses_map: dict[str, Ss58Address] = {} - - for key_name in key_names: - # issue #11 https://github.com/agicommies/communex/issues/12 added check for key2address to stop error from being thrown by wrong key type. - if key_name == "key2address": - print( - "key2address is saved in an invalid format. It will be ignored." - ) - continue - encrypted = is_encrypted(key_name) - if encrypted: - if universal_password: - password = universal_password - elif ctx: - ctx.info( - f"Please provide the password for the key '{key_name}'" - ) - password = getpass() - else: - print(f"Please provide the password for the key '{key_name}'") - password = getpass() - else: - password = None - key_dict = classic_load_key(key_name, password=password) - addresses_map[key_name] = check_ss58_address(key_dict.ss58_address) - - return addresses_map - - def resolve_key_ss58(key: Ss58Address | Keypair | str) -> Ss58Address: """ Resolves a keypair or key name to its corresponding SS58 address. If the input is already an SS58 address, it is returned as is. + + DEPRECATED """ if isinstance(key, Keypair): @@ -273,8 +232,8 @@ def resolve_key_ss58(key: Ss58Address | Keypair | str) -> Ss58Address: def resolve_key_ss58_encrypted( key: Ss58Address | Keypair | str, - context: GenericCtx, password: str | None = None, + password_provider: PasswordProvider = NoPassword(), ) -> Ss58Address: """ Resolves a keypair or key name to its corresponding SS58 address. @@ -288,17 +247,56 @@ def resolve_key_ss58_encrypted( if is_ss58_address(key): return key - try: - keypair = classic_load_key(key, password=password) - except json.JSONDecodeError: - context.info(f"Please provide the password for the key {key}") - password = getpass() - keypair = classic_load_key(key, password=password) - except FileNotFoundError: - raise ValueError( - f"Key is not a valid SS58 address nor a valid key name: {key}" - ) + keypair = try_classic_load_key( + key, password=password, password_provider=password_provider + ) address = keypair.ss58_address return check_ss58_address(address, keypair.ss58_format) + + +def local_key_addresses( + password_provider: PasswordProvider = NoPassword(), +) -> dict[str, Ss58Address]: + """ + Retrieves a mapping of local key names to their SS58 addresses. + If password is passed, it will be used to decrypt every key. + If password is not passed and ctx is, + the user will be prompted for the password. + """ + + # TODO: refactor to return mapping of (key_name -> Keypair) + # Outside of this, Keypair can be mapped to Ss58Address + + home = Path.home() + key_dir = home / ".commune" / "key" + + key_names = [ + f.stem + for f in key_dir.iterdir() + if f.is_file() and not f.name.startswith(".") + ] + + addresses_map: dict[str, Ss58Address] = {} + + for key_name in key_names: + # issue #12 https://github.com/agicommies/communex/issues/12 + # added check for key2address to stop error + # from being thrown by wrong key type. + if key_name == "key2address": + print( + "key2address is saved in an invalid format. It will be ignored." + ) + continue + + password = password_provider.get_password(key_name) + try: + keypair = classic_load_key(key_name, password=password) + except PasswordNotProvidedError: + password = password_provider.ask_password(key_name) + keypair = classic_load_key(key_name, password=password) + + addresses_map[key_name] = check_ss58_address(keypair.ss58_address) + + return addresses_map diff --git a/src/communex/compat/storage.py b/src/communex/compat/storage.py index 8a3b185f..79f8d4c2 100644 --- a/src/communex/compat/storage.py +++ b/src/communex/compat/storage.py @@ -14,6 +14,7 @@ from nacl.secret import SecretBox from nacl.utils import random +from communex.errors import PasswordNotProvidedError from communex.util import ensure_parent_dir_exists # from cryptography.fernet import Fernet @@ -84,14 +85,16 @@ def classic_load( full_path = os.path.expanduser(os.path.join(COMMUNE_HOME, path)) with open(full_path, "r") as file: body = json.load(file) - if body["encrypted"] and password is None: - raise json.JSONDecodeError( - "Data is encrypted but no password provided", "", 0 - ) - if body["encrypted"] and password is not None: - content = _decrypt_data(password, body["data"]) - else: - content = body["data"] + + if body["encrypted"] and password is None: + raise PasswordNotProvidedError( + "Data is encrypted but no password provided" + ) + if body["encrypted"] and password is not None: + content = _decrypt_data(password, body["data"]) + else: + content = body["data"] + assert isinstance(body, dict) assert isinstance(body["timestamp"], int) assert isinstance(content, (dict, list, tuple, set, float, str, int)) @@ -111,7 +114,6 @@ def classic_put( encrypt: Whether to encrypt the data. Todo: - * Encryption support. * Other serialization modes support. Only json mode is supported now. Raises: @@ -127,6 +129,7 @@ def classic_put( raise TypeError( f"Invalid type for commune data storage value: {type(value)}" ) + timestamp = int(time.time()) full_path = os.path.expanduser(os.path.join(COMMUNE_HOME, path)) @@ -140,13 +143,13 @@ def classic_put( if password: value = _encrypt_data(password, value) - encrypt = True + encrypted = True else: - encrypt = False + encrypted = False with open(full_path, "w") as file: json.dump( - {"data": value, "encrypted": encrypt, "timestamp": timestamp}, + {"data": value, "encrypted": encrypted, "timestamp": timestamp}, file, indent=4, ) diff --git a/src/communex/errors.py b/src/communex/errors.py index 35b67f09..71b25f56 100644 --- a/src/communex/errors.py +++ b/src/communex/errors.py @@ -12,3 +12,19 @@ class NetworkQueryError(NetworkError): class NetworkTimeoutError(NetworkError): """Timeout error""" + + +class PasswordError(Exception): + """Password related error.""" + + +class PasswordNotProvidedError(PasswordError): + """Password is not provided.""" + + +class InvalidPasswordError(PasswordError): + """Password is invalid.""" + + +class KeyNotFoundError(Exception): + """Key not found error.""" diff --git a/src/communex/password.py b/src/communex/password.py new file mode 100644 index 00000000..e1211a07 --- /dev/null +++ b/src/communex/password.py @@ -0,0 +1,37 @@ +from typing import Protocol + +from communex.errors import PasswordNotProvidedError + + +class PasswordProvider(Protocol): + def get_password(self, key_name: str) -> str | None: + """ + Provides a password for the given key name, if it is know. If not, + returns None. In that case, `ask_password` can be called to ask for the + password depending on the implementation. + """ + return None + + def ask_password(self, key_name: str) -> str: + """ + Either provides a password for the given key or raises an + PasswordNotProvidedError error. + """ + raise PasswordNotProvidedError( + f"Password not provided for key '{key_name}'" + ) + + +class NoPassword(PasswordProvider): + pass + + +class Password(PasswordProvider): + def __init__(self, password: str): + self._password = password + + def get_password(self, key_name: str) -> str: + return self._password + + def ask_password(self, key_name: str) -> str: + return self._password