From 797861f39dd74d61fb0ef0a5a2d7342669dbf313 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:02:29 +0800 Subject: [PATCH 01/19] feat: add account service --- .../starknet-snap/src/__tests__/helper.ts | 27 +- .../src/state/__tests__/helper.ts | 3 + .../src/state/account-state-manager.ts | 78 +++++ .../starknet-snap/src/state/state-manager.ts | 5 + packages/starknet-snap/src/types/snapState.ts | 6 +- .../starknet-snap/src/utils/contract.test.ts | 70 +++++ packages/starknet-snap/src/utils/contract.ts | 61 ++++ .../starknet-snap/src/utils/exceptions.ts | 35 +++ packages/starknet-snap/src/utils/index.ts | 1 + .../starknet-snap/src/utils/starknetUtils.ts | 3 +- .../src/wallet/account/__test__/helper.ts | 78 +++++ .../src/wallet/account/account.ts | 66 +++++ .../src/wallet/account/cairo0.test.ts | 32 ++ .../src/wallet/account/cairo0.ts | 22 ++ .../src/wallet/account/cairo1.test.ts | 28 ++ .../src/wallet/account/cairo1.ts | 18 ++ .../src/wallet/account/contract.test.ts | 276 ++++++++++++++++++ .../src/wallet/account/contract.ts | 189 ++++++++++++ .../src/wallet/account/discovery.test.ts | 267 +++++++++++++++++ .../src/wallet/account/discovery.ts | 92 ++++++ .../starknet-snap/src/wallet/account/index.ts | 8 + .../src/wallet/account/keypair.ts | 23 ++ .../src/wallet/account/reader.test.ts | 65 +++++ .../src/wallet/account/reader.ts | 36 +++ .../src/wallet/account/service.test.ts | 193 ++++++++++++ .../src/wallet/account/service.ts | 95 ++++++ .../starknet-snap/src/wallet/account/type.ts | 13 + 27 files changed, 1778 insertions(+), 12 deletions(-) create mode 100644 packages/starknet-snap/src/utils/contract.test.ts create mode 100644 packages/starknet-snap/src/utils/contract.ts create mode 100644 packages/starknet-snap/src/wallet/account/__test__/helper.ts create mode 100644 packages/starknet-snap/src/wallet/account/account.ts create mode 100644 packages/starknet-snap/src/wallet/account/cairo0.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/cairo0.ts create mode 100644 packages/starknet-snap/src/wallet/account/cairo1.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/cairo1.ts create mode 100644 packages/starknet-snap/src/wallet/account/contract.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/contract.ts create mode 100644 packages/starknet-snap/src/wallet/account/discovery.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/discovery.ts create mode 100644 packages/starknet-snap/src/wallet/account/index.ts create mode 100644 packages/starknet-snap/src/wallet/account/keypair.ts create mode 100644 packages/starknet-snap/src/wallet/account/reader.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/reader.ts create mode 100644 packages/starknet-snap/src/wallet/account/service.test.ts create mode 100644 packages/starknet-snap/src/wallet/account/service.ts create mode 100644 packages/starknet-snap/src/wallet/account/type.ts diff --git a/packages/starknet-snap/src/__tests__/helper.ts b/packages/starknet-snap/src/__tests__/helper.ts index 65278361..1d53e7f1 100644 --- a/packages/starknet-snap/src/__tests__/helper.ts +++ b/packages/starknet-snap/src/__tests__/helper.ts @@ -108,6 +108,21 @@ export async function generateBip44Entropy( ]); } +/** + * Method to generate Bip44 Node by index. + * + * @param [mnemonic] - Optional, the provided mnemonic string. + * @returns The deriver function for the derivation path. + */ +export async function generateKeyDeriver(mnemonic?: string) { + let mnemonicString = mnemonic; + if (!mnemonicString) { + mnemonicString = generateMnemonic(); + } + const node = await generateBip44Entropy(mnemonicString); + return await getBIP44AddressKeyDeriver(node); +} + /** * Method to generate starknet account. * @@ -120,18 +135,14 @@ export async function generateAccounts( network: constants.StarknetChainId | string, cnt: number = 1, cairoVersion = '1', - mnemonic?: string, + startIndex = 0, + mnemonicString: string = generateMnemonic(), ) { const accounts: StarknetAccount[] = []; - let mnemonicString = mnemonic; - if (!mnemonicString) { - mnemonicString = generateMnemonic(); - } - for (let i = 0; i < cnt; i++) { + for (let i = startIndex; i < startIndex + cnt; i++) { // simulate the bip44 entropy generation - const node = await generateBip44Entropy(mnemonicString); - const keyDeriver = await getBIP44AddressKeyDeriver(node); + const keyDeriver = await generateKeyDeriver(mnemonicString); const { privateKey } = await keyDeriver(i); if (!privateKey) { diff --git a/packages/starknet-snap/src/state/__tests__/helper.ts b/packages/starknet-snap/src/state/__tests__/helper.ts index 0efdc479..60b5fac4 100644 --- a/packages/starknet-snap/src/state/__tests__/helper.ts +++ b/packages/starknet-snap/src/state/__tests__/helper.ts @@ -34,6 +34,7 @@ export const mockState = async ({ transactions, currentNetwork, transactionRequests, + removedAccounts, }: { accounts?: StarknetAccount[]; tokens?: Erc20Token[]; @@ -41,6 +42,7 @@ export const mockState = async ({ transactions?: Transaction[]; currentNetwork?: Network; transactionRequests?: TransactionRequest[]; + removedAccounts?: Record; }) => { const getDataSpy = jest.spyOn(snapHelper, 'getStateData'); const setDataSpy = jest.spyOn(snapHelper, 'setStateData'); @@ -51,6 +53,7 @@ export const mockState = async ({ transactions: transactions ?? [], currentNetwork, transactionRequests: transactionRequests ?? [], + removedAccounts: removedAccounts ?? {}, }; getDataSpy.mockResolvedValue(state); return { diff --git a/packages/starknet-snap/src/state/account-state-manager.ts b/packages/starknet-snap/src/state/account-state-manager.ts index d23c2a12..75741e2b 100644 --- a/packages/starknet-snap/src/state/account-state-manager.ts +++ b/packages/starknet-snap/src/state/account-state-manager.ts @@ -30,6 +30,9 @@ export class AccountStateManager extends StateManager { if (data.deployRequired !== undefined) { dataInState.deployRequired = data.deployRequired; } + if (data.cairoVersion !== undefined) { + dataInState.cairoVersion = data.cairoVersion; + } } /** @@ -116,6 +119,28 @@ export class AccountStateManager extends StateManager { } } + async upsertAccount(data: AccContract): Promise { + try { + await this.update(async (state: SnapState) => { + const accountInState = await this.getAccount( + { + address: data.address, + chainId: data.chainId, + }, + state, + ); + + if (accountInState) { + this.updateEntity(accountInState, data); + } else { + state.accContracts.push(data); + } + }); + } catch (error) { + throw new StateManagerError(error.message); + } + } + async updateAccountAsDeploy({ address, chainId, @@ -147,4 +172,57 @@ export class AccountStateManager extends StateManager { throw new StateManagerError(error.message); } } + + async getNextIndex(chainId: string): Promise { + let idx = 0; + await this.update(async (state: SnapState) => { + // Choose the deleted account index over the last index (accContracts length). + // If the removedAccounts array is empty, then fallback with the last index. + idx = + state.removedAccounts?.[chainId]?.shift() ?? state.accContracts.length; + }); + return idx; + } + + async removeAccount({ + address, + chainId, + }: { + address: string; + chainId: string; + }): Promise { + try { + await this.update(async (state: SnapState) => { + const accountInState = await this.getAccount( + { + address, + chainId, + }, + state, + ); + + if (!accountInState) { + throw new StateManagerError(`Account does not exist`); + } + + state.accContracts = state.accContracts.filter( + (account) => + account.address !== address && account.chainId === chainId, + ); + + // Safeguard to ensure the removedAccounts object is initialized. + if (!state.removedAccounts) { + state.removedAccounts = {}; + } + + if (!Object.hasOwnProperty.call(state.removedAccounts, chainId)) { + state.removedAccounts[chainId] = []; + } + + state.removedAccounts[chainId].push(accountInState.addressIndex); + }); + } catch (error) { + throw new StateManagerError(error.message); + } + } } diff --git a/packages/starknet-snap/src/state/state-manager.ts b/packages/starknet-snap/src/state/state-manager.ts index 81b7a997..9d138d74 100644 --- a/packages/starknet-snap/src/state/state-manager.ts +++ b/packages/starknet-snap/src/state/state-manager.ts @@ -15,6 +15,7 @@ export abstract class StateManager extends SnapStateManager { networks: [], transactions: [], transactionRequests: [], + removedAccounts: {}, }; } @@ -38,6 +39,10 @@ export abstract class StateManager extends SnapStateManager { state.transactions = []; } + if (!state.removedAccounts) { + state.removedAccounts = {}; + } + return state; }); } diff --git a/packages/starknet-snap/src/types/snapState.ts b/packages/starknet-snap/src/types/snapState.ts index ac9316e5..05089430 100644 --- a/packages/starknet-snap/src/types/snapState.ts +++ b/packages/starknet-snap/src/types/snapState.ts @@ -15,6 +15,7 @@ export type SnapState = { transactions: Transaction[]; currentNetwork?: Network; transactionRequests?: TransactionRequest[]; + removedAccounts?: Record; }; export type TokenTransferData = { @@ -62,11 +63,12 @@ export type AccContract = { publicKey: string; // in hex address: string; // in hex addressIndex: number; - derivationPath: string; - deployTxnHash: string; // in hex + derivationPath?: string; + deployTxnHash?: string; // in hex chainId: string; // in hex upgradeRequired?: boolean; deployRequired?: boolean; + cairoVersion?: string; }; export type Erc20Token = { diff --git a/packages/starknet-snap/src/utils/contract.test.ts b/packages/starknet-snap/src/utils/contract.test.ts new file mode 100644 index 00000000..a0455cd3 --- /dev/null +++ b/packages/starknet-snap/src/utils/contract.test.ts @@ -0,0 +1,70 @@ +import { BlockTag, Provider } from 'starknet'; + +import { generateAccounts } from '../__tests__/helper'; +import { ETHER_MAINNET, STARKNET_SEPOLIA_TESTNET_NETWORK } from './constants'; +import { ContractReader } from './contract'; +import { + ContractNotDeployedError, + ContractReadError, + CONTRACT_NOT_DEPLOYED_ERROR, +} from './exceptions'; + +describe('ContractReader', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const mockProvider = () => { + const callContractSpy = jest.spyOn(Provider.prototype, 'callContract'); + return callContractSpy; + }; + + const getContractCallArgs = async () => { + const [account] = await generateAccounts(network.chainId, 1); + return { + contractAddress: ETHER_MAINNET.address, + entrypoint: 'balanceOf', + calldata: [account.address], + }; + }; + + describe('callContract', () => { + it('returns the response of the contract method', async () => { + const callContractSpy = mockProvider(); + const balance = '1000000000000000000'; + callContractSpy.mockResolvedValue([balance]); + + const args = await getContractCallArgs(); + + const reader = new ContractReader(network); + const result = await reader.callContract(args); + + expect(result).toStrictEqual([balance]); + expect(callContractSpy).toHaveBeenCalledWith(args, BlockTag.LATEST); + }); + + it('throws a `ContractNotDeployedError` if the contract is not found', async () => { + const callContractSpy = mockProvider(); + callContractSpy.mockRejectedValue(new Error(CONTRACT_NOT_DEPLOYED_ERROR)); + + const args = await getContractCallArgs(); + + const reader = new ContractReader(network); + + await expect(reader.callContract(args)).rejects.toThrow( + ContractNotDeployedError, + ); + }); + + it('throws a `ContractReadError` if an error is thrown', async () => { + const callContractSpy = mockProvider(); + callContractSpy.mockRejectedValue(new Error('Read Error')); + + const args = await getContractCallArgs(); + + const reader = new ContractReader(network); + + await expect(reader.callContract(args)).rejects.toThrow( + ContractReadError, + ); + }); + }); +}); diff --git a/packages/starknet-snap/src/utils/contract.ts b/packages/starknet-snap/src/utils/contract.ts new file mode 100644 index 00000000..0bc1c8a1 --- /dev/null +++ b/packages/starknet-snap/src/utils/contract.ts @@ -0,0 +1,61 @@ +import type { + BlockIdentifier, + CallContractResponse, + Provider, + RawCalldata, +} from 'starknet'; +import { BlockTag } from 'starknet'; + +import type { Network } from '../types/snapState'; +import { + ContractNotDeployedError, + ContractReadError, + CONTRACT_NOT_DEPLOYED_ERROR, +} from './exceptions'; +import { getProvider } from './starknetUtils'; + +export class ContractReader { + rpcProvider: Provider; + + constructor(network: Network) { + this.rpcProvider = getProvider(network); + } + + /** + * Call a contract method. + * + * @param param - The parameters to pass to the contract. + * @param param.contractAddress - The address of the contract to call. + * @param param.entrypoint - The entrypoint of the contract to call. + * @param param.calldata - The calldata to pass to the contract. + * @param [param.blockIdentifier] - Optional, the block to call the contract at, default `lastest`. + * @returns A promise that resolves to the response of the contract call. + */ + async callContract({ + contractAddress, + entrypoint, + calldata = [], + blockIdentifier = BlockTag.LATEST, + }: { + contractAddress: string; + entrypoint: string; + calldata?: RawCalldata; + blockIdentifier?: BlockIdentifier; + }): Promise { + try { + return await this.rpcProvider.callContract( + { + contractAddress, + entrypoint, + calldata, + }, + blockIdentifier, + ); + } catch (error) { + if (!error.message.includes(CONTRACT_NOT_DEPLOYED_ERROR)) { + throw new ContractReadError(error.message); + } + throw new ContractNotDeployedError(); + } + } +} diff --git a/packages/starknet-snap/src/utils/exceptions.ts b/packages/starknet-snap/src/utils/exceptions.ts index 840a4e3a..a3b70421 100644 --- a/packages/starknet-snap/src/utils/exceptions.ts +++ b/packages/starknet-snap/src/utils/exceptions.ts @@ -31,6 +31,41 @@ export class InvalidNetworkError extends SnapError { } } +export class AccountNotFoundError extends SnapError { + constructor(message?: string) { + super( + message ?? 'Account not found', + createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown), + ); + } +} + +export class AccountDiscoveryError extends SnapError { + constructor(message?: string) { + super( + message ?? 'Account discovery found', + createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown), + ); + } +} + +export class ContractReadError extends SnapError { + constructor(message: string) { + super(message, createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown)); + } +} + +export const CONTRACT_NOT_DEPLOYED_ERROR = 'Contract not found'; + +export class ContractNotDeployedError extends SnapError { + constructor(message?: string) { + super( + message ?? CONTRACT_NOT_DEPLOYED_ERROR, + createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown), + ); + } +} + export class UserRejectedOpError extends UserRejectedRequestError { constructor(message?: string) { super(message, createWalletRpcErrorWrapper(WalletRpcErrorCode.UserDeny)); diff --git a/packages/starknet-snap/src/utils/index.ts b/packages/starknet-snap/src/utils/index.ts index d4a57fa1..c8987927 100644 --- a/packages/starknet-snap/src/utils/index.ts +++ b/packages/starknet-snap/src/utils/index.ts @@ -11,4 +11,5 @@ export * from './string'; export * from './token'; export * from './snap-ui'; export * from './explorer'; +export * from './contract'; // TODO: add other utils diff --git a/packages/starknet-snap/src/utils/starknetUtils.ts b/packages/starknet-snap/src/utils/starknetUtils.ts index 6501eb96..6b1b5ce5 100644 --- a/packages/starknet-snap/src/utils/starknetUtils.ts +++ b/packages/starknet-snap/src/utils/starknetUtils.ts @@ -18,7 +18,6 @@ import type { DeployAccountSignerDetails, CairoVersion, InvocationsSignerDetails, - ProviderInterface, GetTransactionReceiptResponse, BigNumberish, ArraySignatureType, @@ -128,7 +127,7 @@ export const getCallDataArray = (callDataStr: string): string[] => { export const getProvider = ( network: Network, blockIdentifier?: BlockIdentifierEnum, -): ProviderInterface => { +): Provider => { let providerParam: ProviderOptions = {}; providerParam = { nodeUrl: getRPCUrl(network.chainId), diff --git a/packages/starknet-snap/src/wallet/account/__test__/helper.ts b/packages/starknet-snap/src/wallet/account/__test__/helper.ts new file mode 100644 index 00000000..0d566a64 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/__test__/helper.ts @@ -0,0 +1,78 @@ +import { generateMnemonic } from 'bip39'; + +import { generateAccounts } from '../../../__tests__/helper'; +import { MIN_ACC_CONTRACT_VERSION } from '../../../utils/constants'; +import { Account } from '../account'; +import { Cairo1Contract } from '../cairo1'; +import { AccountContractReader } from '../reader'; + +export const upgradedContractVersion = `2.${MIN_ACC_CONTRACT_VERSION[1]}.0`; +export const upgradedContractVersionInHex = `322e332e30`; +export const nonUpgradedContractVersion = `0.0.0`; + +export const mockAccountContractReader = ({ + balance = BigInt(1000000000000000000), + version = upgradedContractVersion, +}) => { + const getVersionSpy = jest.spyOn( + AccountContractReader.prototype, + 'getVersion', + ); + const getEthBalanceSpy = jest.spyOn( + AccountContractReader.prototype, + 'getEthBalance', + ); + + getVersionSpy.mockResolvedValue(version); + getEthBalanceSpy.mockResolvedValue(balance); + + return { getVersionSpy, getEthBalanceSpy }; +}; + +export const createAccountContract = async ( + network, + hdIndex = 0, + ContractCtor = Cairo1Contract, + mnemonicString = generateMnemonic(), +) => { + const [account] = await generateAccounts( + network.chainId, + 1, + '1', + hdIndex, + mnemonicString, + ); + + const accountContractReader = new AccountContractReader(network); + + const contract = new ContractCtor(account.publicKey, accountContractReader); + + return { + accountContractReader, + contract, + account, + }; +}; + +export const createAccountObject = async (network, hdIndex = 0) => { + const { account, accountContractReader, contract } = + await createAccountContract(network, hdIndex); + + const { privateKey, publicKey, chainId, addressIndex } = account; + + const accountObj = new Account({ + privateKey, + publicKey, + chainId, + hdIndex: addressIndex, + addressSalt: publicKey, + accountContract: contract, + }); + + return { + accountContractReader, + contract, + accountObj, + account, + }; +}; diff --git a/packages/starknet-snap/src/wallet/account/account.ts b/packages/starknet-snap/src/wallet/account/account.ts new file mode 100644 index 00000000..b15e0fde --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/account.ts @@ -0,0 +1,66 @@ +import type { AccContract } from '../../types/snapState'; +import type { CairoAccountContract } from './contract'; + +/** + * Account object that holds the private key, public key, address, chain id, + * hd index, address salt and the `CairoAccountContract`. + * + * It can serialize itself to be persisted in the state. + */ +export class Account { + privateKey: string; + + publicKey: string; + + address: string; + + chainId: string; + + hdIndex: number; + + addressSalt: string; + + /** + * The Cairo version of the account contract. + * `1` referred to Cairo 1. + * `0` referred to Cairo 0. + */ + cairoVersion: string; + + accountContract: CairoAccountContract; + + constructor(props: { + privateKey: string; + publicKey: string; + chainId: string; + hdIndex: number; + addressSalt: string; + accountContract: CairoAccountContract; + }) { + this.privateKey = props.privateKey; + this.publicKey = props.publicKey; + this.chainId = props.chainId; + this.hdIndex = props.hdIndex; + this.addressSalt = props.addressSalt; + this.address = props.accountContract.address; + + this.cairoVersion = props.accountContract.cairoVerion.toString(10); + this.accountContract = props.accountContract; + } + + /** + * Serialize the `Account` object. + * + * @returns A promise that resolves to the serialized `Account` object. + */ + async serialize(): Promise { + return { + addressSalt: this.publicKey, + publicKey: this.publicKey, + address: this.address, + addressIndex: this.hdIndex, + chainId: this.chainId, + cairoVersion: this.cairoVersion, + }; + } +} diff --git a/packages/starknet-snap/src/wallet/account/cairo0.test.ts b/packages/starknet-snap/src/wallet/account/cairo0.test.ts new file mode 100644 index 00000000..f07f0f76 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/cairo0.test.ts @@ -0,0 +1,32 @@ +import { CallData, hash } from 'starknet'; + +import { + ACCOUNT_CLASS_HASH_LEGACY, + STARKNET_SEPOLIA_TESTNET_NETWORK, +} from '../../utils/constants'; +import { createAccountContract } from './__test__/helper'; +import { Cairo0Contract } from './cairo0'; + +jest.mock('../../utils/logger'); + +describe('Cairo0Contract', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + describe('getCallData', () => { + it('returns the call data', async () => { + const { + contract, + account: { publicKey }, + } = await createAccountContract(network, 0, Cairo0Contract); + + // contract.address is a getter method that making a call to calculateAddress. + expect(contract.getCallData()).toStrictEqual( + CallData.compile({ + implementation: ACCOUNT_CLASS_HASH_LEGACY, + selector: hash.getSelectorFromName('initialize'), + calldata: CallData.compile({ signer: publicKey, guardian: '0' }), + }), + ); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/cairo0.ts b/packages/starknet-snap/src/wallet/account/cairo0.ts new file mode 100644 index 00000000..ddf7c822 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/cairo0.ts @@ -0,0 +1,22 @@ +import type { Calldata } from 'starknet'; +import { CallData, hash } from 'starknet'; + +import { + ACCOUNT_CLASS_HASH_LEGACY, + PROXY_CONTRACT_HASH, +} from '../../utils/constants'; +import { CairoAccountContract } from './contract'; + +export class Cairo0Contract extends CairoAccountContract { + cairoVerion = 0; + + classhash: string = PROXY_CONTRACT_HASH; + + getCallData(): Calldata { + return CallData.compile({ + implementation: ACCOUNT_CLASS_HASH_LEGACY, + selector: hash.getSelectorFromName('initialize'), + calldata: CallData.compile({ signer: this.publicKey, guardian: '0' }), + }); + } +} diff --git a/packages/starknet-snap/src/wallet/account/cairo1.test.ts b/packages/starknet-snap/src/wallet/account/cairo1.test.ts new file mode 100644 index 00000000..f9c9e2c2 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/cairo1.test.ts @@ -0,0 +1,28 @@ +import { CallData } from 'starknet'; + +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; +import { createAccountContract } from './__test__/helper'; +import { Cairo1Contract } from './cairo1'; + +jest.mock('../../utils/logger'); + +describe('Cairo1Contract', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + describe('getCallData', () => { + it('returns the call data', async () => { + const { + contract, + account: { publicKey }, + } = await createAccountContract(network, 0, Cairo1Contract); + + // contract.address is a getter method that making a call to calculateAddress. + expect(contract.getCallData()).toStrictEqual( + CallData.compile({ + signer: publicKey, + guardian: '0', + }), + ); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/cairo1.ts b/packages/starknet-snap/src/wallet/account/cairo1.ts new file mode 100644 index 00000000..4b2b665a --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/cairo1.ts @@ -0,0 +1,18 @@ +import type { Calldata } from 'starknet'; +import { CallData } from 'starknet'; + +import { ACCOUNT_CLASS_HASH } from '../../utils/constants'; +import { CairoAccountContract } from './contract'; + +export class Cairo1Contract extends CairoAccountContract { + cairoVerion = 1; + + classhash: string = ACCOUNT_CLASS_HASH; + + getCallData(): Calldata { + return CallData.compile({ + signer: this.publicKey, + guardian: '0', + }); + } +} diff --git a/packages/starknet-snap/src/wallet/account/contract.test.ts b/packages/starknet-snap/src/wallet/account/contract.test.ts new file mode 100644 index 00000000..aec4f51b --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/contract.test.ts @@ -0,0 +1,276 @@ +import { generateMnemonic } from 'bip39'; + +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; +import { + ContractNotDeployedError, + ContractReadError, +} from '../../utils/exceptions'; +import { + createAccountContract, + mockAccountContractReader, + nonUpgradedContractVersion, + upgradedContractVersion, +} from './__test__/helper'; +import { Cairo0Contract } from './cairo0'; +import { Cairo1Contract } from './cairo1'; + +jest.mock('../../utils/logger'); + +describe('CairoAccountContract', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + describe('getVersion', () => { + it('returns the contract version', async () => { + const { contract } = await createAccountContract(network); + + mockAccountContractReader({}); + + const result = await contract.getVersion(); + + expect(result).toStrictEqual(upgradedContractVersion); + }); + + it('caches the result if `getVersion` was called', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + + await contract.getVersion(); + await contract.getVersion(); + + expect(getVersionSpy).toHaveBeenCalledTimes(1); + }); + }); + + describe('isDeployed', () => { + it('returns true if the account has deployed', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + + const result = await contract.isDeployed(); + + expect(result).toBe(true); + expect(getVersionSpy).toHaveBeenCalledTimes(1); + }); + + it('returns false if the account has not deployed', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + const result = await contract.isDeployed(); + + expect(result).toBe(false); + }); + + it('throws an error if a `ContractReadError` was throw', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue( + new ContractReadError('Read contract error'), + ); + + await expect(contract.isDeployed()).rejects.toThrow(ContractReadError); + }); + }); + + describe('isUpgraded', () => { + it('returns true if the contract version meet the requirement', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + + const result = await contract.isUpgraded(); + + expect(result).toBe(true); + expect(getVersionSpy).toHaveBeenCalledTimes(1); + }); + + it('returns false if the contract version does not meet the requirement', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockResolvedValue(nonUpgradedContractVersion); + + const result = await contract.isUpgraded(); + + expect(result).toBe(false); + }); + + it('throws an error if the contract is not deployed', async () => { + const { contract } = await createAccountContract(network); + + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + await expect(contract.isUpgraded()).rejects.toThrow( + ContractNotDeployedError, + ); + }); + }); + + describe('getEthBalance', () => { + it('returns the ETH token balance', async () => { + const balance = BigInt(1000000000000000000); + mockAccountContractReader({ + balance, + }); + const { contract } = await createAccountContract(network); + + const result = await contract.getEthBalance(); + + expect(result).toStrictEqual(balance); + }); + + it('caches the result if `getEthBalance` was called', async () => { + const { getEthBalanceSpy } = mockAccountContractReader({}); + const { contract } = await createAccountContract(network); + + await contract.getEthBalance(); + await contract.getEthBalance(); + + expect(getEthBalanceSpy).toHaveBeenCalledTimes(1); + }); + }); + + describe('fromAccountContract', () => { + it('creates a new `CairoAccountContract` object from an `CairoAccountContract` object', async () => { + const { getEthBalanceSpy, getVersionSpy } = mockAccountContractReader({}); + // Make sure the mnemonic is the same for both contracts + const mnemonicString = generateMnemonic(); + const { contract: cairo0Contract } = await createAccountContract( + network, + 0, + Cairo0Contract, + mnemonicString, + ); + const { contract: cairo1Contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + mnemonicString, + ); + + // assgin _version and _balance to the instance before copy to new Cairo1Contract + const versionFromCairo0Contract = await cairo0Contract.getVersion(); + const ethBalanceFromCairo0Contract = await cairo0Contract.getEthBalance(); + const newContract = Cairo1Contract.fromAccountContract(cairo0Contract); + const versionFromCairo1Contract = await newContract.getVersion(); + const ethBalanceFromCairo1Contract = await newContract.getEthBalance(); + + expect(newContract).toBeInstanceOf(Cairo1Contract); + expect(newContract.address).toStrictEqual(cairo0Contract.address); + expect(newContract.address).not.toStrictEqual(cairo1Contract.address); + expect(newContract.callData).toStrictEqual(cairo1Contract.callData); + expect(newContract.callData).not.toStrictEqual(cairo0Contract.callData); + expect(versionFromCairo1Contract).toStrictEqual( + versionFromCairo0Contract, + ); + expect(ethBalanceFromCairo1Contract).toStrictEqual( + ethBalanceFromCairo0Contract, + ); + expect(getEthBalanceSpy).toHaveBeenCalledTimes(1); + expect(getVersionSpy).toHaveBeenCalledTimes(1); + }); + }); + + describe('isRequireUpgrade', () => { + it('returns true if the contract requires upgrade', async () => { + mockAccountContractReader({ + version: nonUpgradedContractVersion, + }); + const { contract } = await createAccountContract( + network, + 0, + Cairo0Contract, + ); + + const result = await contract.isRequireUpgrade(); + + expect(result).toBe(true); + }); + + it('returns false if the contract has already upgraded', async () => { + mockAccountContractReader({ + version: upgradedContractVersion, + }); + const { contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + ); + + const result = await contract.isRequireUpgrade(); + + expect(result).toBe(false); + }); + + it('returns false if the contract is not deployed', async () => { + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + const { contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + ); + + const result = await contract.isRequireUpgrade(); + + expect(result).toBe(false); + }); + + it('throws an error if a `ContractReadError` was thrown', async () => { + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue( + new ContractReadError('Read contract error'), + ); + + const { contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + ); + + await expect(contract.isRequireUpgrade()).rejects.toThrow( + ContractReadError, + ); + }); + }); + + describe('isRequireDeploy', () => { + it('returns true if the contract requires deploy', async () => { + const { getVersionSpy } = mockAccountContractReader({}); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + const { contract } = await createAccountContract( + network, + 0, + Cairo0Contract, + ); + + const result = await contract.isRequireDeploy(); + + expect(result).toBe(true); + }); + + it('returns false if the contract is not deployed and does not has ETH', async () => { + const { getVersionSpy } = mockAccountContractReader({ + balance: BigInt(0), + }); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + const { contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + ); + + const result = await contract.isRequireUpgrade(); + + expect(result).toBe(false); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/contract.ts b/packages/starknet-snap/src/wallet/account/contract.ts new file mode 100644 index 00000000..9b5b159b --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/contract.ts @@ -0,0 +1,189 @@ +import type { Calldata } from 'starknet'; +import { hash, addAddressPadding } from 'starknet'; + +import { ContractNotDeployedError } from '../../utils/exceptions'; +import { isGTEMinVersion } from '../../utils/starknetUtils'; +import type { AccountContractReader } from './reader'; + +export abstract class CairoAccountContract { + protected _version: string; + + protected _balance: bigint; + + protected _address: string; + + publicKey: string; + + contractReader: AccountContractReader; + + /** + * Contract method map of the Cairo contract. + * This map is used to map the method name to the entrypoint of the contract. + * The entrypoint may be different for each Cairo contract. + */ + contractMethodMap: { + getVersion: string; + } = { + getVersion: 'getVersion', + }; + + abstract classhash: string; + + abstract cairoVerion: number; + + constructor(publicKey: string, contractReader: AccountContractReader) { + this.publicKey = publicKey; + this.contractReader = contractReader; + } + + /** + * Gets the call data of the contract. + * The call data is used to calculate the address of the contract. + * The call data may be different for each Cairo contract. + * + * @returns call data of the contract. + */ + protected abstract getCallData(): Calldata; + + get callData(): Calldata { + return this.getCallData(); + } + + get address(): string { + if (this._address === undefined) { + this._address = this.calculateAddress(); + } + return this._address; + } + + /** + * Calculate the address of the contract base on the public key, class hash and callData of the contract. + * + * @returns The address of the contract. + */ + protected calculateAddress(): string { + const address = hash.calculateContractAddressFromHash( + this.publicKey, + this.classhash, + this.callData, + 0, + ); + return addAddressPadding(address); + } + + /** + * Gets the Cario version of the contract. + * + * @param [refresh] - Optional, if true the result will not be cached, otherwise it will be cached. Default false. + * @returns A promise that resolve the version of the contract. + */ + async getVersion(refresh = false): Promise { + // TODO: add cache layer + if (refresh || this._version === undefined) { + this._version = await this.contractReader.getVersion(this); + } + return this._version; + } + + /** + * Gets the ETH balance of the contract. + * + * @param [refresh] - Optional, if true the result will not be cached, otherwise it will be cached. Default false. + * @returns A promise that resolve the ETH balance of the contract. + */ + async getEthBalance(refresh = false): Promise { + // TODO: add cache layer + if (refresh || this._balance === undefined) { + this._balance = await this.contractReader.getEthBalance(this); + } + return this._balance; + } + + /** + * Determines whether the account contract is deployed. + * if an `ContractNotDeployedError` is thrown, it means the contract is not deployed. + * + * @param [refresh] - Optional, if true the result will not be cached, otherwise it will be cached. Default false. + * @returns A promise that resolve true if the contract is deployed, false otherwise. + * @throws {ContractReadError} If an error occurs while reading the contract. + */ + async isDeployed(refresh = false): Promise { + try { + await this.getVersion(refresh); + return true; + } catch (error) { + if (error instanceof ContractNotDeployedError) { + return false; + } + throw error; + } + } + + /** + * Determines whether the contract is upgraded. + * if the contract is not deployed, it will throw an error. + * + * @param [refresh] - Optional, if true the result will not be cached, otherwise it will be cached. Default false. + * @returns A promise that resolve true if the contract is upgraded, false otherwise. + * @throws {ContractNotDeployedError} If the contract is not deployed. + * @throws {ContractReadError} If an error occurs while reading the contract. + */ + async isUpgraded(refresh = false): Promise { + const version = await this.getVersion(refresh); + return isGTEMinVersion(version); + } + + /** + * Determines whether require upgrade is needed. + * Returns true if the contract is upgraded, false otherwise. + * If the contract is not deployed, returns false. + * + * @returns A promise that resolves true if the contract requires an upgrade, false otherwise. + */ + async isRequireUpgrade(): Promise { + try { + return !(await this.isUpgraded()); + } catch (error) { + // If the contract is not deployed, a lastest Cairo Contract will be return, + // hence no upgrade is needed. + if (error instanceof ContractNotDeployedError) { + return false; + } + throw error; + } + } + + /** + * Determines whether require deploy is needed. + * A contract requires a deploy if it is not deployed and has balance. + * + * @returns A promise that resolves true if the contract requires a deploy, false otherwise. + */ + async isRequireDeploy(): Promise { + return ( + !(await this.isDeployed()) && (await this.getEthBalance()) > BigInt(0) + ); + } + + /** + * Creates a new account contract from an existing account contract. + * + * @param contract - The existing account contract to copy with. + * @returns A promise that resolves the new account contract. + */ + static fromAccountContract( + this: new (...args: any[]) => CairoAccountContract, + contract: CairoAccountContract, + ): CairoAccountContract { + const newContact = new this(contract.publicKey, contract.contractReader); + + // inherit the address from the original contract + newContact.calculateAddress = newContact.calculateAddress.bind(contract); + + // Copy the metadata from the original contract to prevent duplicated call. + newContact._balance = contract._balance; + newContact._version = contract._version; + + return newContact; + } +} diff --git a/packages/starknet-snap/src/wallet/account/discovery.test.ts b/packages/starknet-snap/src/wallet/account/discovery.test.ts new file mode 100644 index 00000000..a8c31693 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/discovery.test.ts @@ -0,0 +1,267 @@ +import { Cairo0Contract, Cairo1Contract } from '.'; +import { generateAccounts } from '../../__tests__/helper'; +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; +import { AccountDiscoveryError } from '../../utils/exceptions'; +import { AccountContractDiscovery } from './discovery'; + +jest.mock('../../utils/logger'); + +describe('AccountContractDiscovery', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + describe('getCairoContract', () => { + const mockContractState = (params: { + cairo1: { + isDeployed: boolean; + isUpgraded: boolean; + }; + cairo0: { + isDeployed: boolean; + isUpgraded: boolean; + }; + }) => { + const { cairo1, cairo0 } = params; + const isCairo1DeployedSpy = jest.spyOn( + Cairo1Contract.prototype, + 'isDeployed', + ); + isCairo1DeployedSpy.mockResolvedValue(cairo1.isDeployed); + const isCairo1UpgradedSpy = jest.spyOn( + Cairo1Contract.prototype, + 'isUpgraded', + ); + isCairo1UpgradedSpy.mockResolvedValue(cairo1.isUpgraded); + const isCairo0DeployedSpy = jest.spyOn( + Cairo0Contract.prototype, + 'isDeployed', + ); + isCairo0DeployedSpy.mockResolvedValue(cairo0.isDeployed); + const isCairo0UpgradedSpy = jest.spyOn( + Cairo0Contract.prototype, + 'isUpgraded', + ); + isCairo0UpgradedSpy.mockResolvedValue(cairo0.isUpgraded); + + return { + isCairo1DeployedSpy, + isCairo1UpgradedSpy, + isCairo0DeployedSpy, + isCairo0UpgradedSpy, + }; + }; + + const mockContractEthBalance = ({ + cairo1HasBalance, + cairo0HasBalance, + }: { + cairo1HasBalance: boolean; + cairo0HasBalance: boolean; + }) => { + const cairo1ContractHasEthBalanceSpy = jest.spyOn( + Cairo1Contract.prototype, + 'getEthBalance', + ); + cairo1ContractHasEthBalanceSpy.mockResolvedValue( + cairo1HasBalance ? BigInt(1) : BigInt(0), + ); + + const cairo0ContractHasEthBalanceSpy = jest.spyOn( + Cairo0Contract.prototype, + 'getEthBalance', + ); + cairo0ContractHasEthBalanceSpy.mockResolvedValue( + cairo0HasBalance ? BigInt(1) : BigInt(0), + ); + + return { + cairo1ContractHasEthBalanceSpy, + cairo0ContractHasEthBalanceSpy, + }; + }; + + // Test cases that assume no contact has balance. + // It tests the following cases: + // - Cairo 0 is deployed and upgraded + // - Cairo 1 is deployed and upgraded + // - Cairo 0 is deployed and Cairo 1 is not deployed + // - Cairo 1 is deployed and Cairo 0 is not deployed + // - Cairo 0 is not deployed and Cairo 1 is not deployed + it.each([ + { + cairo0: { + isDeployed: false, + isUpgraded: false, + }, + cairo1: { + isDeployed: true, + isUpgraded: false, + }, + expected: Cairo1Contract, + title: 'Cairo 1 is deployed, Cairo 0 is not deployed', + }, + { + cairo0: { + isDeployed: false, + isUpgraded: false, + }, + cairo1: { + isDeployed: true, + isUpgraded: true, + }, + expected: Cairo1Contract, + title: 'Cairo 1 is deployed and upgraded, Cairo 0 is not deployed', + }, + { + cairo0: { + isDeployed: true, + isUpgraded: false, + }, + cairo1: { + isDeployed: false, + isUpgraded: false, + }, + expected: Cairo0Contract, + title: 'Cairo 0 is deployed, Cairo 1 is not deployed', + }, + { + cairo0: { + isDeployed: true, + isUpgraded: true, + }, + cairo1: { + isDeployed: false, + isUpgraded: false, + }, + expected: Cairo1Contract, + title: 'Cairo 0 is deployed and upgraded, Cairo 1 is not deployed', + }, + { + cairo0: { + isDeployed: false, + isUpgraded: false, + }, + cairo1: { + isDeployed: false, + isUpgraded: false, + }, + expected: Cairo1Contract, + title: 'Cairo 0 is not deployed and Cairo 1 is not deployed', + }, + ])( + 'returns a $expected.name if $title', + async (param: { + cairo0: { + isDeployed: boolean; + isUpgraded: boolean; + }; + cairo1: { + isDeployed: boolean; + isUpgraded: boolean; + }; + expected: typeof Cairo0Contract | typeof Cairo1Contract; + }) => { + const [account] = await generateAccounts(network.chainId, 1); + + mockContractState({ + cairo0: param.cairo0, + cairo1: param.cairo1, + }); + mockContractEthBalance({ + cairo1HasBalance: false, + cairo0HasBalance: false, + }); + + const service = new AccountContractDiscovery(network); + const contract = await service.getContract(account.publicKey); + + expect(contract).toBeInstanceOf(param.expected); + }, + ); + + it.each([ + { + cairo0HasBalance: false, + cairo1HasBalance: true, + expected: Cairo1Contract, + }, + { + cairo0HasBalance: true, + cairo1HasBalance: false, + expected: Cairo0Contract, + }, + ])( + 'returns a $expected.name if no account contract has deployed and the $expected.name has ETH', + async ({ expected, cairo0HasBalance, cairo1HasBalance }) => { + const [account] = await generateAccounts(network.chainId, 1); + + mockContractState({ + cairo0: { + isDeployed: false, + isUpgraded: false, + }, + cairo1: { + isDeployed: false, + isUpgraded: false, + }, + }); + + mockContractEthBalance({ + cairo0HasBalance, + cairo1HasBalance, + }); + + const service = new AccountContractDiscovery(network); + const contract = await service.getContract(account.publicKey); + + expect(contract).toBeInstanceOf(expected); + }, + ); + + it('throws `AccountDiscoveryError` if more than one contracts deployed', async () => { + const [account] = await generateAccounts(network.chainId, 1); + + mockContractState({ + cairo0: { + isDeployed: true, + isUpgraded: false, + }, + cairo1: { + isDeployed: true, + isUpgraded: false, + }, + }); + + const service = new AccountContractDiscovery(network); + + await expect(service.getContract(account.publicKey)).rejects.toThrow( + AccountDiscoveryError, + ); + }); + + it('throws `AccountDiscoveryError` if more than one contracts has ETH', async () => { + const [account] = await generateAccounts(network.chainId, 1); + + mockContractState({ + cairo0: { + isDeployed: false, + isUpgraded: false, + }, + cairo1: { + isDeployed: false, + isUpgraded: false, + }, + }); + + mockContractEthBalance({ + cairo0HasBalance: true, + cairo1HasBalance: true, + }); + + const service = new AccountContractDiscovery(network); + + await expect(service.getContract(account.publicKey)).rejects.toThrow( + AccountDiscoveryError, + ); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts new file mode 100644 index 00000000..5db66606 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -0,0 +1,92 @@ +import type { Network } from '../../types/snapState'; +import { AccountDiscoveryError } from '../../utils/exceptions'; +import { Cairo0Contract } from './cairo0'; +import { Cairo1Contract } from './cairo1'; +import type { CairoAccountContract } from './contract'; +import { AccountContractReader } from './reader'; +import type { CairoAccountContractStatic, ICairoAccountContract } from './type'; + +export class AccountContractDiscovery { + protected defaultContractCtor: CairoAccountContractStatic = Cairo1Contract; + + protected contractCtors: ICairoAccountContract[] = [ + Cairo1Contract, + Cairo0Contract, + ]; + + protected network: Network; + + protected ethBalanceThreshold = BigInt(0); + + constructor(network: Network) { + this.network = network; + } + + /** + * Get the contract for the given public key. + * The contract is determined based on the following rules: + * 1. If a contract is deployed, then use the deployed contract. + * 2. If no contract is deployed, but has balance, then use the contract with balance. + * 3. If neither contract is deployed or has balance, then use the default contract. + * + * @param publicKey - The public key to get the contract for. + * @returns The contract for the given public key. + * @throws {AccountDiscoveryError} If multiple contracts are deployed or have balance. + */ + async getContract(publicKey: string): Promise { + const reader = new AccountContractReader(this.network); + const DefaultContractCtor = this.defaultContractCtor; + + // Use array to store the result to prevent race condition. + const contracts: { + balance: CairoAccountContract[]; + deploy: CairoAccountContract[]; + } = { + balance: [], + deploy: [], + }; + + let cairoContract: CairoAccountContract | undefined; + + // Identify where all available contracts have been deployed, upgraded, + // and whether they have an ETH balance or not. + await Promise.all( + this.contractCtors.map(async (ContractCtor: ICairoAccountContract) => { + const contract = new ContractCtor(publicKey, reader); + + if (await contract.isDeployed()) { + // if contract upgraded, bind the latest contract with current contract interface, + // to inherit the address from current contract. + if (await contract.isUpgraded()) { + contracts.deploy.push( + DefaultContractCtor.fromAccountContract(contract), + ); + } else { + contracts.deploy.push(contract); + } + } else if (await contract.isRequireDeploy()) { + // if contract is not deployed but has balance, then use the contract with balance. + contracts.balance.push(contract); + } + }), + ); + + // In case of multiple contracts are deployed or have balance, + // We will not be able to determine which contract to use. + // Hence, throw an error. + if (contracts.balance.length > 1 || contracts.deploy.length > 1) { + throw new AccountDiscoveryError(); + } + + if (contracts.deploy.length !== 0) { + // if there is a deployed contract, then choose the deployed contract. + cairoContract = contracts.deploy[0]; + } else if (contracts.balance.length !== 0) { + // otherwise, then choose the contract with balance. + cairoContract = contracts.balance[0]; + } + + // Fallback with default contract. + return cairoContract ?? new DefaultContractCtor(publicKey, reader); + } +} diff --git a/packages/starknet-snap/src/wallet/account/index.ts b/packages/starknet-snap/src/wallet/account/index.ts new file mode 100644 index 00000000..35d7b5e1 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/index.ts @@ -0,0 +1,8 @@ +export * from './cairo0'; +export * from './cairo1'; +export * from './keypair'; +export * from './reader'; +export * from './type'; +export * from './contract'; +export * from './discovery'; +export * from './service'; diff --git a/packages/starknet-snap/src/wallet/account/keypair.ts b/packages/starknet-snap/src/wallet/account/keypair.ts new file mode 100644 index 00000000..c4571d6e --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/keypair.ts @@ -0,0 +1,23 @@ +import { ec, num as numUtils } from 'starknet'; + +import { grindKey } from '../../utils/keyPair'; + +export class AccountKeyPair { + #privateKey: string; + + #publicKey: string; + + constructor(key: string) { + const accountKey = grindKey(key); + this.#publicKey = ec.starkCurve.getStarkKey(accountKey); + this.#privateKey = numUtils.toHex(accountKey); + } + + get privateKey(): string { + return this.#privateKey; + } + + get publicKey(): string { + return this.#publicKey; + } +} diff --git a/packages/starknet-snap/src/wallet/account/reader.test.ts b/packages/starknet-snap/src/wallet/account/reader.test.ts new file mode 100644 index 00000000..1a01fb82 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/reader.test.ts @@ -0,0 +1,65 @@ +import { + ETHER_MAINNET, + STARKNET_SEPOLIA_TESTNET_NETWORK, +} from '../../utils/constants'; +import { ContractReader } from '../../utils/contract'; +import { + createAccountContract, + upgradedContractVersion, + upgradedContractVersionInHex, +} from './__test__/helper'; + +jest.mock('../../utils/logger'); + +describe('AccountContractReader', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const mockContractReader = () => { + const callContractSpy = jest.spyOn( + ContractReader.prototype, + 'callContract', + ); + return { callContractSpy }; + }; + + describe('getVersion', () => { + it('returns the contract version', async () => { + const { accountContractReader, contract } = await createAccountContract( + network, + ); + + const { callContractSpy } = mockContractReader(); + callContractSpy.mockResolvedValue([upgradedContractVersionInHex]); + + const result = await accountContractReader.getVersion(contract); + + expect(result).toStrictEqual(upgradedContractVersion); + expect(callContractSpy).toHaveBeenCalledWith({ + contractAddress: contract.address, + entrypoint: contract.contractMethodMap.getVersion, + }); + }); + }); + + describe('getEthBalance', () => { + it('returns the ETH Balance', async () => { + const { accountContractReader, contract } = await createAccountContract( + network, + ); + + const balance = '1000000000000000000'; + const { callContractSpy } = mockContractReader(); + + callContractSpy.mockResolvedValue([balance]); + + const result = await accountContractReader.getEthBalance(contract); + + expect(result).toStrictEqual(BigInt(balance)); + expect(callContractSpy).toHaveBeenCalledWith({ + contractAddress: ETHER_MAINNET.address, + entrypoint: 'balanceOf', + calldata: [contract.address], + }); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/reader.ts b/packages/starknet-snap/src/wallet/account/reader.ts new file mode 100644 index 00000000..7dac7000 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/reader.ts @@ -0,0 +1,36 @@ +import { hexToString } from '../../utils'; +import { ETHER_MAINNET } from '../../utils/constants'; +import { ContractReader } from '../../utils/contract'; +import type { CairoAccountContract } from './contract'; + +export class AccountContractReader extends ContractReader { + /** + * Get the version of the account contract. + * + * @param cairoContract - The `CairoAccountContract` object to get the version of. + * @returns A promise that resolves to the version of the contract. + */ + async getVersion(cairoContract: CairoAccountContract): Promise { + const resp = await this.callContract({ + contractAddress: cairoContract.address, + entrypoint: cairoContract.contractMethodMap.getVersion, + }); + + return hexToString(resp[0]); + } + + /** + * Get the ETH balance of the account contract. + * + * @param cairoContract - The `CairoAccountContract` object to get the balance of. + * @returns A promise that resolves to the balance of the contract. + */ + async getEthBalance(cairoContract: CairoAccountContract): Promise { + const resp = await this.callContract({ + contractAddress: ETHER_MAINNET.address, + entrypoint: 'balanceOf', + calldata: [cairoContract.address], + }); + return BigInt(resp[0]); + } +} diff --git a/packages/starknet-snap/src/wallet/account/service.test.ts b/packages/starknet-snap/src/wallet/account/service.test.ts new file mode 100644 index 00000000..14b44138 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/service.test.ts @@ -0,0 +1,193 @@ +import { generateMnemonic } from 'bip39'; + +import { AccountContractReader, AccountService, Cairo1Contract } from '.'; +import { generateAccounts, generateKeyDeriver } from '../../__tests__/helper'; +import { AccountStateManager } from '../../state/account-state-manager'; +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; +import { AccountNotFoundError } from '../../utils/exceptions'; +import * as snapUtils from '../../utils/snap'; +import { + createAccountObject, + mockAccountContractReader, +} from './__test__/helper'; +import { Account } from './account'; +import { AccountContractDiscovery } from './discovery'; + +jest.mock('../../utils/logger'); + +describe('AccountService', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + describe('deriveAccountByIndex', () => { + const prepareDeriveAccountByIndex = async (hdIndex) => { + const mnemonicString = generateMnemonic(); + + const [account] = await generateAccounts( + network.chainId, + 1, + '1', + hdIndex, + mnemonicString, + ); + const deriver = await generateKeyDeriver(mnemonicString); + + const getNextIndexSpy = jest.spyOn( + AccountStateManager.prototype, + 'getNextIndex', + ); + const upsertAccountSpy = jest.spyOn( + AccountStateManager.prototype, + 'upsertAccount', + ); + const getCairoContractSpy = jest.spyOn( + AccountContractDiscovery.prototype, + 'getContract', + ); + jest.spyOn(snapUtils, 'getBip44Deriver').mockResolvedValue(deriver); + + mockAccountContractReader({}); + + const cairo1Contract = new Cairo1Contract( + account.publicKey, + new AccountContractReader(network), + ); + + getCairoContractSpy.mockResolvedValue(cairo1Contract); + getNextIndexSpy.mockResolvedValue(hdIndex); + + return { + upsertAccountSpy, + getNextIndexSpy, + getCairoContractSpy, + cairo1Contract, + account, + }; + }; + + it('derive an account with the auto increment index', async () => { + const hdIndex = 0; + const { + getNextIndexSpy, + getCairoContractSpy, + upsertAccountSpy, + cairo1Contract, + account, + } = await prepareDeriveAccountByIndex(hdIndex); + + const service = new AccountService(network, new AccountStateManager()); + const accountObject = await service.deriveAccountByIndex(); + + expect(getNextIndexSpy).toHaveBeenCalled(); + expect(upsertAccountSpy).toHaveBeenCalledWith( + await accountObject.serialize(), + ); + expect(getCairoContractSpy).toHaveBeenCalledWith(account.publicKey); + expect(accountObject).toBeInstanceOf(Account); + expect(accountObject).toHaveProperty('accountContract', cairo1Contract); + expect(accountObject).toHaveProperty('address', account.address); + expect(accountObject).toHaveProperty('chainId', account.chainId); + expect(accountObject).toHaveProperty('privateKey', account.privateKey); + expect(accountObject).toHaveProperty('publicKey', account.publicKey); + expect(accountObject).toHaveProperty('hdIndex', hdIndex); + expect(accountObject).toHaveProperty('addressSalt', account.publicKey); + }); + + it('derive an account with the given index', async () => { + const hdIndex = 2; + const { + getNextIndexSpy, + getCairoContractSpy, + cairo1Contract, + account, + upsertAccountSpy, + } = await prepareDeriveAccountByIndex(hdIndex); + + const service = new AccountService(network, new AccountStateManager()); + const accountObject = await service.deriveAccountByIndex(hdIndex); + + expect(getNextIndexSpy).not.toHaveBeenCalled(); + expect(upsertAccountSpy).toHaveBeenCalledWith( + await accountObject.serialize(), + ); + expect(getCairoContractSpy).toHaveBeenCalledWith(account.publicKey); + expect(accountObject).toBeInstanceOf(Account); + expect(accountObject).toHaveProperty('accountContract', cairo1Contract); + expect(accountObject).toHaveProperty('address', account.address); + expect(accountObject).toHaveProperty('chainId', account.chainId); + expect(accountObject).toHaveProperty('privateKey', account.privateKey); + expect(accountObject).toHaveProperty('publicKey', account.publicKey); + expect(accountObject).toHaveProperty('hdIndex', hdIndex); + expect(accountObject).toHaveProperty('addressSalt', account.publicKey); + }); + }); + + describe('deriveAccountFromAddress', () => { + const prepareDeriveAccountByAddress = async () => { + const getAccountSpy = jest.spyOn( + AccountStateManager.prototype, + 'getAccount', + ); + const deriveAccountByIndexSpy = jest.spyOn( + AccountService.prototype, + 'deriveAccountByIndex', + ); + mockAccountContractReader({}); + + const { accountObj } = await createAccountObject(network, 0); + getAccountSpy.mockResolvedValue(await accountObj.serialize()); + deriveAccountByIndexSpy.mockResolvedValue(accountObj); + + return { + deriveAccountByIndexSpy, + getAccountSpy, + accountObj, + }; + }; + + it('derive an account by address', async () => { + const { getAccountSpy, deriveAccountByIndexSpy, accountObj } = + await prepareDeriveAccountByAddress(); + + const service = new AccountService(network, new AccountStateManager()); + const accountObject = await service.deriveAccountByAddress( + accountObj.address, + ); + + expect(getAccountSpy).toHaveBeenCalled(); + expect(deriveAccountByIndexSpy).toHaveBeenCalledWith(accountObj.hdIndex); + expect(accountObject).toStrictEqual(accountObj); + }); + + it('throws `AccountNotFoundError` if the given address is not found', async () => { + const { getAccountSpy, accountObj } = + await prepareDeriveAccountByAddress(); + + getAccountSpy.mockResolvedValue(null); + + const service = new AccountService(network, new AccountStateManager()); + + await expect( + service.deriveAccountByAddress(accountObj.address), + ).rejects.toThrow(AccountNotFoundError); + }); + }); + + describe('removeAccount', () => { + it('remove an account', async () => { + const { accountObj } = await createAccountObject(network, 0); + const removeAccountSpy = jest.spyOn( + AccountStateManager.prototype, + 'removeAccount', + ); + removeAccountSpy.mockResolvedValue(); + + const service = new AccountService(network, new AccountStateManager()); + await service.removeAccount(accountObj); + + expect(removeAccountSpy).toHaveBeenCalledWith({ + address: accountObj.address, + chainId: accountObj.chainId, + }); + }); + }); +}); diff --git a/packages/starknet-snap/src/wallet/account/service.ts b/packages/starknet-snap/src/wallet/account/service.ts new file mode 100644 index 00000000..909cdb37 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/service.ts @@ -0,0 +1,95 @@ +import type { AccountStateManager } from '../../state/account-state-manager'; +import type { Network } from '../../types/snapState'; +import { getBip44Deriver } from '../../utils'; +import { AccountNotFoundError } from '../../utils/exceptions'; +import { Account } from './account'; +import { AccountContractDiscovery } from './discovery'; +import { AccountKeyPair } from './keypair'; + +export class AccountService { + protected network: Network; + + protected accountStateMgr: AccountStateManager; + + protected accountContractDiscoveryService: AccountContractDiscovery; + + constructor(network: Network, accountStateMgr: AccountStateManager) { + this.network = network; + this.accountStateMgr = accountStateMgr; + this.accountContractDiscoveryService = new AccountContractDiscovery( + network, + ); + } + + /** + * Removes an account from the state. + * + * @param account - The `Account` object to remove. + */ + async removeAccount(account: Account): Promise { + await this.accountStateMgr.removeAccount({ + address: account.address, + chainId: account.chainId, + }); + } + + /** + * Derives a BIP44 node from an index and constructs a new `Account` object using the derived private key and public key. + * The `Account` object is assigned a `CairoAccountContract` contract and is then serialized and persisted to the state. + * + * @param [index] - Optional. The hd index to derive the account from. If not provided, the next index will be used. + * @returns A promise that resolves to the newly created `Account` object. + */ + async deriveAccountByIndex(index?: number): Promise { + let hdIndex = index; + + if (!hdIndex) { + hdIndex = await this.accountStateMgr.getNextIndex(this.network.chainId); + } + + // Derive a BIP44 node from an index. e.g m/44'/60'/0'/0/{hdIndex} + const deriver = await getBip44Deriver(); + const node = await deriver(hdIndex); + + // Grind a new private key and public key from the derived node. + // Private key and public key are independent from the account contract. + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const { privateKey, publicKey } = new AccountKeyPair(node.privateKey!); + + const accountContract = + await this.accountContractDiscoveryService.getContract(publicKey); + + const account = new Account({ + privateKey, + publicKey, + chainId: this.network.chainId, + hdIndex, + addressSalt: publicKey, + accountContract, + }); + + await this.accountStateMgr.upsertAccount(await account.serialize()); + + return account; + } + + /** + * Derives an account by address. + * if the account is not found in the state, an `AccountNotFoundError` will be thrown. + * + * @param address - The address of the account to derive. + * @returns A promise that resolves to the derived `Account` object. + */ + async deriveAccountByAddress(address: string): Promise { + const accountFromState = await this.accountStateMgr.getAccount({ + address, + chainId: this.network.chainId, + }); + + if (accountFromState) { + return await this.deriveAccountByIndex(accountFromState.addressIndex); + } + + throw new AccountNotFoundError(); + } +} diff --git a/packages/starknet-snap/src/wallet/account/type.ts b/packages/starknet-snap/src/wallet/account/type.ts new file mode 100644 index 00000000..bc912e69 --- /dev/null +++ b/packages/starknet-snap/src/wallet/account/type.ts @@ -0,0 +1,13 @@ +import type { CairoAccountContract } from './contract'; +import type { AccountContractReader } from './reader'; + +export type ICairoAccountContract = new ( + publicKey: string, + contractReader: AccountContractReader, +) => CairoAccountContract; + +export type CairoAccountContractStatic = { + fromAccountContract( + accountContract: CairoAccountContract, + ): CairoAccountContract; +} & ICairoAccountContract; From 36358e14a94ff008a4e227cb8d659abb48483d51 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:37:56 +0800 Subject: [PATCH 02/19] chore: fix lint --- packages/starknet-snap/src/utils/snapUtils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/starknet-snap/src/utils/snapUtils.ts b/packages/starknet-snap/src/utils/snapUtils.ts index 3666d29d..ab2a5ff5 100644 --- a/packages/starknet-snap/src/utils/snapUtils.ts +++ b/packages/starknet-snap/src/utils/snapUtils.ts @@ -516,7 +516,7 @@ export async function upsertAccount( storedAccount.derivationPath = userAccount.derivationPath; storedAccount.publicKey = userAccount.publicKey; storedAccount.deployTxnHash = - userAccount.deployTxnHash || storedAccount.deployTxnHash; + userAccount.deployTxnHash ?? storedAccount.deployTxnHash; storedAccount.upgradeRequired = userAccount.upgradeRequired; storedAccount.deployRequired = userAccount.deployRequired; } From 2dc68ae22ca826987007fa260f43ec06995426f1 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:43:31 +0800 Subject: [PATCH 03/19] chore: update account contract discovery logic --- .../src/wallet/account/discovery.test.ts | 2 +- .../src/wallet/account/discovery.ts | 30 +++++-------------- 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/packages/starknet-snap/src/wallet/account/discovery.test.ts b/packages/starknet-snap/src/wallet/account/discovery.test.ts index a8c31693..b9da93cf 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.test.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.test.ts @@ -9,7 +9,7 @@ jest.mock('../../utils/logger'); describe('AccountContractDiscovery', () => { const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - describe('getCairoContract', () => { + describe('getContract', () => { const mockContractState = (params: { cairo1: { isDeployed: boolean; diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index 5db66606..19158bc3 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -16,8 +16,6 @@ export class AccountContractDiscovery { protected network: Network; - protected ethBalanceThreshold = BigInt(0); - constructor(network: Network) { this.network = network; } @@ -38,13 +36,7 @@ export class AccountContractDiscovery { const DefaultContractCtor = this.defaultContractCtor; // Use array to store the result to prevent race condition. - const contracts: { - balance: CairoAccountContract[]; - deploy: CairoAccountContract[]; - } = { - balance: [], - deploy: [], - }; + const contracts: CairoAccountContract[] = []; let cairoContract: CairoAccountContract | undefined; @@ -58,15 +50,13 @@ export class AccountContractDiscovery { // if contract upgraded, bind the latest contract with current contract interface, // to inherit the address from current contract. if (await contract.isUpgraded()) { - contracts.deploy.push( - DefaultContractCtor.fromAccountContract(contract), - ); + contracts.push(DefaultContractCtor.fromAccountContract(contract)); } else { - contracts.deploy.push(contract); + contracts.push(contract); } } else if (await contract.isRequireDeploy()) { // if contract is not deployed but has balance, then use the contract with balance. - contracts.balance.push(contract); + contracts.push(contract); } }), ); @@ -74,16 +64,10 @@ export class AccountContractDiscovery { // In case of multiple contracts are deployed or have balance, // We will not be able to determine which contract to use. // Hence, throw an error. - if (contracts.balance.length > 1 || contracts.deploy.length > 1) { + if (contracts.length > 1) { throw new AccountDiscoveryError(); - } - - if (contracts.deploy.length !== 0) { - // if there is a deployed contract, then choose the deployed contract. - cairoContract = contracts.deploy[0]; - } else if (contracts.balance.length !== 0) { - // otherwise, then choose the contract with balance. - cairoContract = contracts.balance[0]; + } else if (contracts.length === 1) { + cairoContract = contracts[0]; } // Fallback with default contract. From 17b25fc90fb5015861869ab1b8998af0324eff53 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Wed, 8 Jan 2025 17:01:06 +0800 Subject: [PATCH 04/19] fix: code comment --- packages/starknet-snap/src/wallet/account/cairo1.test.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/starknet-snap/src/wallet/account/cairo1.test.ts b/packages/starknet-snap/src/wallet/account/cairo1.test.ts index f9c9e2c2..18b7f021 100644 --- a/packages/starknet-snap/src/wallet/account/cairo1.test.ts +++ b/packages/starknet-snap/src/wallet/account/cairo1.test.ts @@ -16,7 +16,6 @@ describe('Cairo1Contract', () => { account: { publicKey }, } = await createAccountContract(network, 0, Cairo1Contract); - // contract.address is a getter method that making a call to calculateAddress. expect(contract.getCallData()).toStrictEqual( CallData.compile({ signer: publicKey, From cc1efa42771fd25a72ce2ac4a01c45e4417162b0 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:02:37 +0800 Subject: [PATCH 05/19] chore: add discovery logic description --- .../src/wallet/account/discovery.test.ts | 26 ------------------- .../src/wallet/account/discovery.ts | 7 +++-- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/packages/starknet-snap/src/wallet/account/discovery.test.ts b/packages/starknet-snap/src/wallet/account/discovery.test.ts index b9da93cf..afcd3cdb 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.test.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.test.ts @@ -237,31 +237,5 @@ describe('AccountContractDiscovery', () => { AccountDiscoveryError, ); }); - - it('throws `AccountDiscoveryError` if more than one contracts has ETH', async () => { - const [account] = await generateAccounts(network.chainId, 1); - - mockContractState({ - cairo0: { - isDeployed: false, - isUpgraded: false, - }, - cairo1: { - isDeployed: false, - isUpgraded: false, - }, - }); - - mockContractEthBalance({ - cairo0HasBalance: true, - cairo1HasBalance: true, - }); - - const service = new AccountContractDiscovery(network); - - await expect(service.getContract(account.publicKey)).rejects.toThrow( - AccountDiscoveryError, - ); - }); }); }); diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index 19158bc3..39b06bb2 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -54,8 +54,11 @@ export class AccountContractDiscovery { } else { contracts.push(contract); } - } else if (await contract.isRequireDeploy()) { - // if contract is not deployed but has balance, then use the contract with balance. + } else if (contract instanceof(Cairo0Contract) && await contract.isRequireDeploy()) { + // It should only valid for Cairo 0 contract. + // A Cairo 0 contract can only paying fee with ETH token. + // Therefore if the contract is not deployed, and it has ETH token, we should use this contract. + // And the UI will force the user to deploy the Cairo 0 contract. contracts.push(contract); } }), From a31b9c04a4dd19032b550e78261b15cd682e5e50 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:07:59 +0800 Subject: [PATCH 06/19] fix: lint --- packages/starknet-snap/src/wallet/account/discovery.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index 39b06bb2..490a8614 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -54,7 +54,10 @@ export class AccountContractDiscovery { } else { contracts.push(contract); } - } else if (contract instanceof(Cairo0Contract) && await contract.isRequireDeploy()) { + } else if ( + contract instanceof Cairo0Contract && + (await contract.isRequireDeploy()) + ) { // It should only valid for Cairo 0 contract. // A Cairo 0 contract can only paying fee with ETH token. // Therefore if the contract is not deployed, and it has ETH token, we should use this contract. From fa288a482a61ef7d1298059539c874e3c607dacf Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:19:51 +0800 Subject: [PATCH 07/19] feat: add account service factory --- .../starknet-snap/src/utils/factory.test.ts | 15 +++++++++++- packages/starknet-snap/src/utils/factory.ts | 19 +++++++++++++++ .../src/wallet/account/contract.ts | 14 +++++++++++ .../starknet-snap/src/wallet/account/index.ts | 1 + .../src/wallet/account/service.test.ts | 23 ++++++++++--------- .../src/wallet/account/service.ts | 10 ++++++-- 6 files changed, 68 insertions(+), 14 deletions(-) diff --git a/packages/starknet-snap/src/utils/factory.test.ts b/packages/starknet-snap/src/utils/factory.test.ts index c57b07bb..475244ef 100644 --- a/packages/starknet-snap/src/utils/factory.test.ts +++ b/packages/starknet-snap/src/utils/factory.test.ts @@ -1,8 +1,13 @@ import { StarkScanClient } from '../chain/data-client/starkscan'; import { TransactionService } from '../chain/transaction-service'; import { Config, DataClient } from '../config'; +import { AccountService } from '../wallet/account'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from './constants'; -import { createStarkScanClient, createTransactionService } from './factory'; +import { + createAccountService, + createStarkScanClient, + createTransactionService, +} from './factory'; const config = Config.dataClient[DataClient.STARKSCAN]; @@ -31,3 +36,11 @@ describe('createTransactionService', () => { config.apiKey = undefined; }); }); + +describe('createAccountService', () => { + it('creates a Account service', () => { + expect( + createAccountService(STARKNET_SEPOLIA_TESTNET_NETWORK), + ).toBeInstanceOf(AccountService); + }); +}); diff --git a/packages/starknet-snap/src/utils/factory.ts b/packages/starknet-snap/src/utils/factory.ts index 0076a98a..d4f12d60 100644 --- a/packages/starknet-snap/src/utils/factory.ts +++ b/packages/starknet-snap/src/utils/factory.ts @@ -2,8 +2,10 @@ import type { IDataClient } from '../chain/data-client'; import { StarkScanClient } from '../chain/data-client/starkscan'; import { TransactionService } from '../chain/transaction-service'; import { Config, DataClient } from '../config'; +import type { AccountStateManager } from '../state/account-state-manager'; import type { TransactionStateManager } from '../state/transaction-state-manager'; import type { Network } from '../types/snapState'; +import { AccountService } from '../wallet/account'; /** * Create a StarkScan client. @@ -44,3 +46,20 @@ export function createTransactionService( txnStateMgr, }); } + +/** + * Create a AccountService object. + * + * @param network - The network. + * @param [accountStateMgr] - The `AccountStateManager`. + * @returns A AccountService object. + */ +export function createAccountService( + network: Network, + accountStateMgr?: AccountStateManager, +): AccountService { + return new AccountService({ + network, + accountStateMgr, + }); +} diff --git a/packages/starknet-snap/src/wallet/account/contract.ts b/packages/starknet-snap/src/wallet/account/contract.ts index 9b5b159b..5713996b 100644 --- a/packages/starknet-snap/src/wallet/account/contract.ts +++ b/packages/starknet-snap/src/wallet/account/contract.ts @@ -49,6 +49,20 @@ export abstract class CairoAccountContract { return this.getCallData(); } + get deployPaylod(): { + classHash: string; + contractAddress: string; + constructorCalldata: Calldata; + addressSalt: string; + } { + return { + classHash: this.classhash, + contractAddress: this.address, + constructorCalldata: this.callData, + addressSalt: this.publicKey, + }; + } + get address(): string { if (this._address === undefined) { this._address = this.calculateAddress(); diff --git a/packages/starknet-snap/src/wallet/account/index.ts b/packages/starknet-snap/src/wallet/account/index.ts index 35d7b5e1..2d3ffb7c 100644 --- a/packages/starknet-snap/src/wallet/account/index.ts +++ b/packages/starknet-snap/src/wallet/account/index.ts @@ -6,3 +6,4 @@ export * from './type'; export * from './contract'; export * from './discovery'; export * from './service'; +export * from './account'; diff --git a/packages/starknet-snap/src/wallet/account/service.test.ts b/packages/starknet-snap/src/wallet/account/service.test.ts index 14b44138..d56b9f97 100644 --- a/packages/starknet-snap/src/wallet/account/service.test.ts +++ b/packages/starknet-snap/src/wallet/account/service.test.ts @@ -5,6 +5,7 @@ import { generateAccounts, generateKeyDeriver } from '../../__tests__/helper'; import { AccountStateManager } from '../../state/account-state-manager'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; import { AccountNotFoundError } from '../../utils/exceptions'; +import { createAccountService } from '../../utils/factory'; import * as snapUtils from '../../utils/snap'; import { createAccountObject, @@ -19,7 +20,7 @@ describe('AccountService', () => { const network = STARKNET_SEPOLIA_TESTNET_NETWORK; describe('deriveAccountByIndex', () => { - const prepareDeriveAccountByIndex = async (hdIndex) => { + const setupDeriveAccountByIndexTest = async (hdIndex) => { const mnemonicString = generateMnemonic(); const [account] = await generateAccounts( @@ -72,9 +73,9 @@ describe('AccountService', () => { upsertAccountSpy, cairo1Contract, account, - } = await prepareDeriveAccountByIndex(hdIndex); + } = await setupDeriveAccountByIndexTest(hdIndex); - const service = new AccountService(network, new AccountStateManager()); + const service = createAccountService(network); const accountObject = await service.deriveAccountByIndex(); expect(getNextIndexSpy).toHaveBeenCalled(); @@ -100,9 +101,9 @@ describe('AccountService', () => { cairo1Contract, account, upsertAccountSpy, - } = await prepareDeriveAccountByIndex(hdIndex); + } = await setupDeriveAccountByIndexTest(hdIndex); - const service = new AccountService(network, new AccountStateManager()); + const service = createAccountService(network); const accountObject = await service.deriveAccountByIndex(hdIndex); expect(getNextIndexSpy).not.toHaveBeenCalled(); @@ -122,7 +123,7 @@ describe('AccountService', () => { }); describe('deriveAccountFromAddress', () => { - const prepareDeriveAccountByAddress = async () => { + const setupDeriveAccountByAddressTest = async () => { const getAccountSpy = jest.spyOn( AccountStateManager.prototype, 'getAccount', @@ -146,9 +147,9 @@ describe('AccountService', () => { it('derive an account by address', async () => { const { getAccountSpy, deriveAccountByIndexSpy, accountObj } = - await prepareDeriveAccountByAddress(); + await setupDeriveAccountByAddressTest(); - const service = new AccountService(network, new AccountStateManager()); + const service = createAccountService(network); const accountObject = await service.deriveAccountByAddress( accountObj.address, ); @@ -160,11 +161,11 @@ describe('AccountService', () => { it('throws `AccountNotFoundError` if the given address is not found', async () => { const { getAccountSpy, accountObj } = - await prepareDeriveAccountByAddress(); + await setupDeriveAccountByAddressTest(); getAccountSpy.mockResolvedValue(null); - const service = new AccountService(network, new AccountStateManager()); + const service = createAccountService(network); await expect( service.deriveAccountByAddress(accountObj.address), @@ -181,7 +182,7 @@ describe('AccountService', () => { ); removeAccountSpy.mockResolvedValue(); - const service = new AccountService(network, new AccountStateManager()); + const service = createAccountService(network); await service.removeAccount(accountObj); expect(removeAccountSpy).toHaveBeenCalledWith({ diff --git a/packages/starknet-snap/src/wallet/account/service.ts b/packages/starknet-snap/src/wallet/account/service.ts index 909cdb37..08bffb6b 100644 --- a/packages/starknet-snap/src/wallet/account/service.ts +++ b/packages/starknet-snap/src/wallet/account/service.ts @@ -1,4 +1,4 @@ -import type { AccountStateManager } from '../../state/account-state-manager'; +import { AccountStateManager } from '../../state/account-state-manager'; import type { Network } from '../../types/snapState'; import { getBip44Deriver } from '../../utils'; import { AccountNotFoundError } from '../../utils/exceptions'; @@ -13,7 +13,13 @@ export class AccountService { protected accountContractDiscoveryService: AccountContractDiscovery; - constructor(network: Network, accountStateMgr: AccountStateManager) { + constructor({ + network, + accountStateMgr = new AccountStateManager(), + }: { + network: Network; + accountStateMgr?: AccountStateManager; + }) { this.network = network; this.accountStateMgr = accountStateMgr; this.accountContractDiscoveryService = new AccountContractDiscovery( From f216dfb1bf90eb7e0adebf7d38f6ded4a6293661 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:24:08 +0800 Subject: [PATCH 08/19] fix: rename deployPayload --- packages/starknet-snap/src/wallet/account/contract.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/starknet-snap/src/wallet/account/contract.ts b/packages/starknet-snap/src/wallet/account/contract.ts index 5713996b..3f4704af 100644 --- a/packages/starknet-snap/src/wallet/account/contract.ts +++ b/packages/starknet-snap/src/wallet/account/contract.ts @@ -49,7 +49,7 @@ export abstract class CairoAccountContract { return this.getCallData(); } - get deployPaylod(): { + get deployPayload(): { classHash: string; contractAddress: string; constructorCalldata: Calldata; From 871156bdb5320de7436bbe539da0a4913647ba05 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:35:21 +0800 Subject: [PATCH 09/19] chore: adopt account discovery in RPCs --- .../src/rpcs/__tests__/helper.ts | 88 ++-- .../abstract/account-rpc-controller.test.ts | 139 +++---- .../rpcs/abstract/account-rpc-controller.ts | 93 +++-- .../src/rpcs/declare-contract.test.ts | 101 +++-- .../src/rpcs/display-private-key.test.ts | 67 ++- .../src/rpcs/estimate-fee.test.ts | 85 ++-- .../src/rpcs/execute-txn.test.ts | 382 +++++++++--------- .../starknet-snap/src/rpcs/execute-txn.ts | 42 +- .../src/rpcs/get-addr-from-starkname.test.ts | 46 +-- .../src/rpcs/get-addr-from-starkname.ts | 25 +- .../src/rpcs/get-deployment-data.test.ts | 80 ++-- .../src/rpcs/get-deployment-data.ts | 39 +- .../src/rpcs/get-transaction-status.test.ts | 4 +- .../src/rpcs/list-transaction.test.ts | 6 +- .../src/rpcs/sign-declare-transaction.test.ts | 68 ++-- .../src/rpcs/sign-message.test.ts | 76 ++-- .../src/rpcs/sign-transaction.test.ts | 71 ++-- .../src/rpcs/switch-network.test.ts | 6 +- .../src/rpcs/verify-signature.test.ts | 63 +-- .../src/rpcs/watch-asset.test.ts | 16 +- .../src/utils/starknetUtils.test.ts | 11 +- 21 files changed, 732 insertions(+), 776 deletions(-) diff --git a/packages/starknet-snap/src/rpcs/__tests__/helper.ts b/packages/starknet-snap/src/rpcs/__tests__/helper.ts index 2cba2739..72019f31 100644 --- a/packages/starknet-snap/src/rpcs/__tests__/helper.ts +++ b/packages/starknet-snap/src/rpcs/__tests__/helper.ts @@ -1,16 +1,18 @@ import { BigNumber } from 'ethers'; import type { constants } from 'starknet'; -import type { StarknetAccount } from '../../__tests__/helper'; import { generateAccounts, generateRandomValue } from '../../__tests__/helper'; import { FeeTokenUnit } from '../../types/snapApi'; -import type { SnapState } from '../../types/snapState'; +import type { Network } from '../../types/snapState'; import * as snapUiUtils from '../../ui/utils'; import { getExplorerUrl, shortenAddress, toJson } from '../../utils'; import { mockEstimateFeeBulkResponse } from '../../utils/__tests__/helper'; +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; import * as snapHelper from '../../utils/snap'; -import * as snapUtils from '../../utils/snapUtils'; import * as starknetUtils from '../../utils/starknetUtils'; +import { AccountService, CairoAccountContract } from '../../wallet/account'; +import { createAccountObject } from '../../wallet/account/__test__/helper'; +import type { Account } from '../../wallet/account/account'; /** * @@ -24,36 +26,62 @@ export async function mockAccount(chainId: constants.StarknetChainId | string) { /** * * @param account - * @param state + * @param account.accountObj + * @param account.network + * @param account.requireUpgrade + * @param account.requireDeploy + * @param account.isDeployed */ -export function prepareMockAccount(account: StarknetAccount, state: SnapState) { - const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData'); - const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn( - snapUtils, - 'verifyIfAccountNeedUpgradeOrDeploy', +export async function setupAccountController({ + accountObj, + network = STARKNET_SEPOLIA_TESTNET_NETWORK, + requireUpgrade = false, + requireDeploy = false, + isDeployed = false, +}: { + accountObj?: Account; + network?: Network; + requireUpgrade?: boolean; + requireDeploy?: boolean; + isDeployed?: boolean; +}) { + const account = accountObj ?? (await createAccountObject(network)).accountObj; + + // Mock the `accountContract` properties in the `Account` object + const isRequireUpgradeSpy = jest.spyOn( + CairoAccountContract.prototype, + 'isRequireUpgrade', ); - const getKeysFromAddressSpy = jest.spyOn(starknetUtils, 'getKeysFromAddress'); + const isRequireDeploySpy = jest.spyOn( + CairoAccountContract.prototype, + 'isRequireDeploy', + ); + const isDeploySpy = jest.spyOn(CairoAccountContract.prototype, 'isDeployed'); + + isRequireUpgradeSpy.mockResolvedValue(requireUpgrade); + isRequireDeploySpy.mockResolvedValue(requireDeploy); + isDeploySpy.mockResolvedValue(isDeployed); - getKeysFromAddressSpy.mockResolvedValue({ - privateKey: account.privateKey, - publicKey: account.publicKey, - addressIndex: account.addressIndex, - derivationPath: account.derivationPath as unknown as any, - }); + const deriveAccountByAddressSpy = jest.spyOn( + AccountService.prototype, + 'deriveAccountByAddress', + ); - verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis(); - getStateDataSpy.mockResolvedValue(state); + deriveAccountByAddressSpy.mockResolvedValue(account); return { - getKeysFromAddressSpy, - verifyIfAccountNeedUpgradeOrDeploySpy, + deriveAccountByAddressSpy, + isRequireDeploySpy, + isRequireUpgradeSpy, + isDeploySpy, + account, }; } /** * */ -export function prepareConfirmDialog() { +export function mockConfirmDialog() { const confirmDialogSpy = jest.spyOn(snapHelper, 'confirmDialog'); confirmDialogSpy.mockResolvedValue(true); return { @@ -64,7 +92,7 @@ export function prepareConfirmDialog() { /** * */ -export function prepareRenderWatchAssetUI() { +export function mockRenderWatchAssetUI() { const confirmDialogSpy = jest.spyOn(snapUiUtils, 'renderWatchAssetUI'); confirmDialogSpy.mockResolvedValue(true); return { @@ -75,7 +103,7 @@ export function prepareRenderWatchAssetUI() { /** * */ -export function prepareRenderSwitchNetworkUI() { +export function mockRenderSwitchNetworkUI() { const confirmDialogSpy = jest.spyOn(snapUiUtils, 'renderSwitchNetworkUI'); confirmDialogSpy.mockResolvedValue(true); return { @@ -86,7 +114,7 @@ export function prepareRenderSwitchNetworkUI() { /** * */ -export function prepareRenderSignMessageUI() { +export function mockRenderSignMessageUI() { const confirmDialogSpy = jest.spyOn(snapUiUtils, 'renderSignMessageUI'); confirmDialogSpy.mockResolvedValue(true); return { @@ -97,7 +125,7 @@ export function prepareRenderSignMessageUI() { /** * */ -export function prepareRenderSignTransactionUI() { +export function mockRenderSignTransactionUI() { const confirmDialogSpy = jest.spyOn(snapUiUtils, 'renderSignTransactionUI'); confirmDialogSpy.mockResolvedValue(true); return { @@ -108,7 +136,7 @@ export function prepareRenderSignTransactionUI() { /** * */ -export function prepareRenderSignDeclareTransactionUI() { +export function mockRenderSignDeclareTransactionUI() { const confirmDialogSpy = jest.spyOn( snapUiUtils, 'renderSignDeclareTransactionUI', @@ -122,7 +150,7 @@ export function prepareRenderSignDeclareTransactionUI() { /** * */ -export function prepareRenderDisplayPrivateKeyConfirmUI() { +export function mockRenderDisplayPrivateKeyConfirmUI() { const confirmDialogSpy = jest.spyOn( snapUiUtils, 'renderDisplayPrivateKeyConfirmUI', @@ -136,7 +164,7 @@ export function prepareRenderDisplayPrivateKeyConfirmUI() { /** * */ -export function prepareRenderDisplayPrivateKeyAlertUI() { +export function mockRenderDisplayPrivateKeyAlertUI() { const alertDialogSpy = jest.spyOn( snapUiUtils, 'renderDisplayPrivateKeyAlertUI', @@ -150,7 +178,7 @@ export function prepareRenderDisplayPrivateKeyAlertUI() { * * @param result */ -export function prepareConfirmDialogInteractiveUI(result = true) { +export function mockConfirmDialogInteractiveUI(result = true) { const confirmDialogSpy = jest.spyOn( snapHelper, 'createInteractiveConfirmDialog', @@ -164,7 +192,7 @@ export function prepareConfirmDialogInteractiveUI(result = true) { /** * */ -export function prepareAlertDialog() { +export function mockAlertDialog() { const alertDialogSpy = jest.spyOn(snapHelper, 'alertDialog'); alertDialogSpy.mockResolvedValue(true); return { diff --git a/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.test.ts b/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.test.ts index 6a64ca91..2b65807f 100644 --- a/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.test.ts +++ b/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.test.ts @@ -1,26 +1,20 @@ -import { constants } from 'starknet'; import { object, string } from 'superstruct'; import type { Infer } from 'superstruct'; -import type { StarknetAccount } from '../../__tests__/helper'; -import { generateAccounts } from '../../__tests__/helper'; -import type { SnapState } from '../../types/snapState'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; -import * as snapHelper from '../../utils/snap'; +import { + DeployRequiredError, + UpgradeRequiredError, +} from '../../utils/exceptions'; import * as snapUtils from '../../utils/snapUtils'; -import * as starknetUtils from '../../utils/starknetUtils'; +import { setupAccountController } from '../__tests__/helper'; import { AccountRpcController } from './account-rpc-controller'; jest.mock('../../utils/snap'); jest.mock('../../utils/logger'); describe('AccountRpcController', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; const RequestStruct = object({ address: string(), @@ -38,89 +32,88 @@ describe('AccountRpcController', () => { async handleRequest(param: Request) { return `done ${param.address} and ${param.chainId}`; } - } - const mockAccount = async (network: constants.StarknetChainId) => { - const accounts = await generateAccounts(network, 1); - return accounts[0]; - }; + async displayAlert(error: Error): Promise { + return super.displayAlert(error); + } + } - const prepareExecute = async (account: StarknetAccount) => { - const verifyIfAccountNeedUpgradeOrDeploySpy = jest.spyOn( + const setupRpcExecuteTest = async ({ + requireDeploy = false, + requireUpgrade = false, + }: { + requireDeploy?: boolean; + requireUpgrade?: boolean; + }) => { + const { account, isRequireUpgradeSpy, isRequireDeploySpy } = + await setupAccountController({}); + + const showDeployRequestModalSpy = jest.spyOn( snapUtils, - 'verifyIfAccountNeedUpgradeOrDeploy', + 'showDeployRequestModal', ); - - const getKeysFromAddressSpy = jest.spyOn( - starknetUtils, - 'getKeysFromAddress', + const showUpgradeRequestModalSpy = jest.spyOn( + snapUtils, + 'showUpgradeRequestModal', ); - const getStateDataSpy = jest.spyOn(snapHelper, 'getStateData'); - - getStateDataSpy.mockResolvedValue(state); - - getKeysFromAddressSpy.mockResolvedValue({ - privateKey: account.privateKey, - publicKey: account.publicKey, - addressIndex: account.addressIndex, - derivationPath: account.derivationPath as unknown as any, - }); - - verifyIfAccountNeedUpgradeOrDeploySpy.mockReturnThis(); + isRequireUpgradeSpy.mockResolvedValue(requireUpgrade); + isRequireDeploySpy.mockResolvedValue(requireDeploy); return { - getKeysFromAddressSpy, - getStateDataSpy, - verifyIfAccountNeedUpgradeOrDeploySpy, + account, + showDeployRequestModalSpy, + showUpgradeRequestModalSpy, + isRequireUpgradeSpy, + isRequireDeploySpy, }; }; it('executes request', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - await prepareExecute(account); + const { account } = await setupRpcExecuteTest({}); const rpc = new MockAccountRpc(); const result = await rpc.execute({ address: account.address, - chainId, + chainId: network.chainId, }); - expect(result).toBe(`done ${account.address} and ${chainId}`); + expect(result).toBe(`done ${account.address} and ${network.chainId}`); }); - it('fetchs account before execute', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - const { getKeysFromAddressSpy } = await prepareExecute(account); - const rpc = new MockAccountRpc(); + it(`displays a request deploy dialog if account is required deploy and \`showInvalidAccountAlert\` is true`, async () => { + const { account, showDeployRequestModalSpy } = await setupRpcExecuteTest({ + requireDeploy: true, + }); + const rpc = new MockAccountRpc({ + showInvalidAccountAlert: true, + }); - await rpc.execute({ address: account.address, chainId }); + await expect( + rpc.execute({ + address: account.address, + chainId: network.chainId, + }), + ).rejects.toThrow(DeployRequiredError); - expect(getKeysFromAddressSpy).toHaveBeenCalled(); + expect(showDeployRequestModalSpy).toHaveBeenCalled(); }); - it.each([true, false])( - `assign verifyIfAccountNeedUpgradeOrDeploy's argument "showAlert" to %s if the constructor option 'showInvalidAccountAlert' is set to %s`, - async (showInvalidAccountAlert: boolean) => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - const { verifyIfAccountNeedUpgradeOrDeploySpy } = await prepareExecute( - account, - ); - const rpc = new MockAccountRpc({ - showInvalidAccountAlert, - }); - - await rpc.execute({ address: account.address, chainId }); - - expect(verifyIfAccountNeedUpgradeOrDeploySpy).toHaveBeenCalledWith( - expect.any(Object), - account.address, - account.publicKey, - showInvalidAccountAlert, - ); - }, - ); + it(`displays a request upgrade dialog if account is required upgrade and \`showInvalidAccountAlert\` is true`, async () => { + const { account, showUpgradeRequestModalSpy } = await setupRpcExecuteTest({ + requireUpgrade: true, + }); + const rpc = new MockAccountRpc({ + showInvalidAccountAlert: true, + }); + + await expect( + rpc.execute({ + address: account.address, + chainId: network.chainId, + }), + ).rejects.toThrow(UpgradeRequiredError); + + expect(showUpgradeRequestModalSpy).toHaveBeenCalled(); + }); }); diff --git a/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.ts b/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.ts index d02ed372..3370c33e 100644 --- a/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.ts +++ b/packages/starknet-snap/src/rpcs/abstract/account-rpc-controller.ts @@ -1,29 +1,22 @@ -import type { getBIP44ChangePathString } from '@metamask/key-tree/dist/types/utils'; import type { Json } from '@metamask/snaps-sdk'; -import type { Network, SnapState } from '../../types/snapState'; -import { getBip44Deriver, getStateData } from '../../utils'; import { - getNetworkFromChainId, - verifyIfAccountNeedUpgradeOrDeploy, + DeployRequiredError, + UpgradeRequiredError, +} from '../../utils/exceptions'; +import { createAccountService } from '../../utils/factory'; +import { + showDeployRequestModal, + showUpgradeRequestModal, } from '../../utils/snapUtils'; -import { getKeysFromAddress } from '../../utils/starknetUtils'; -import { RpcController } from './base-rpc-controller'; +import type { Account } from '../../wallet/account/account'; +import { ChainRpcController } from './chain-rpc-controller'; export type AccountRpcParams = { chainId: string; address: string; }; -// TODO: the Account object should move into a account manager for generate account -export type Account = { - privateKey: string; - publicKey: string; - addressIndex: number; - // This is the derivation path of the address, it is used in `getNextAddressIndex` to find the account in state where matching the same derivation path - derivationPath: ReturnType; -}; - export type AccountRpcControllerOptions = { showInvalidAccountAlert: boolean; }; @@ -38,11 +31,9 @@ export type AccountRpcControllerOptions = { export abstract class AccountRpcController< Request extends AccountRpcParams, Response extends Json, -> extends RpcController { +> extends ChainRpcController { protected account: Account; - protected network: Network; - protected options: AccountRpcControllerOptions; protected defaultOptions: AccountRpcControllerOptions = { @@ -54,33 +45,55 @@ export abstract class AccountRpcController< this.options = Object.assign({}, this.defaultOptions, options); } + /** + * A Pre execute hook of the rpc method execution. + * Derives the account from the address and verifies if it needs to be upgraded or deployed. + * + * @param params - The request parameters. + * @returns The response. + */ protected async preExecute(params: Request): Promise { await super.preExecute(params); + const { address } = params; - const { chainId, address } = params; - const { showInvalidAccountAlert } = this.options; + const accountService = createAccountService(this.network); + this.account = await accountService.deriveAccountByAddress(address); - const deriver = await getBip44Deriver(); - // TODO: Instead of getting the state directly, we should implement state management to consolidate the state fetching - const state = await getStateData(); + try { + await this.verifyAccount(); + } catch (error) { + await this.displayAlert(error); + throw error; + } + } + + /** + * Verify if the account needs to be upgraded or deployed and throw an error if it does. + * + * @throws {DeployRequiredError} If the account needs to be deployed. + * @throws {UpgradeRequiredError} If the account needs to be upgraded. + */ + protected async verifyAccount(): Promise { + const { accountContract } = this.account; - // TODO: getNetworkFromChainId from state is still needed, due to it is supporting in get-starknet at this moment - this.network = getNetworkFromChainId(state, chainId); + if (await accountContract.isRequireUpgrade()) { + throw new UpgradeRequiredError(); + } else if (await accountContract.isRequireDeploy()) { + throw new DeployRequiredError(); + } + } - // TODO: This method should be refactored to get the account from an account manager - this.account = await getKeysFromAddress( - deriver, - this.network, - state, - address, - ); + /** + * Show an alert modal if the account needs to be upgraded or deployed, otherwise do nothing. + * @param error + */ + protected async displayAlert(error: Error): Promise { + const { showInvalidAccountAlert: enableAlert } = this.options; - // TODO: rename this method to verifyAccount - await verifyIfAccountNeedUpgradeOrDeploy( - this.network, - address, - this.account.publicKey, - showInvalidAccountAlert, - ); + if (error instanceof UpgradeRequiredError) { + enableAlert && (await showUpgradeRequestModal()); + } else if (error instanceof DeployRequiredError) { + enableAlert && (await showDeployRequestModal()); + } } } diff --git a/packages/starknet-snap/src/rpcs/declare-contract.test.ts b/packages/starknet-snap/src/rpcs/declare-contract.test.ts index d745397c..6361811e 100644 --- a/packages/starknet-snap/src/rpcs/declare-contract.test.ts +++ b/packages/starknet-snap/src/rpcs/declare-contract.test.ts @@ -1,6 +1,5 @@ import { utils } from 'ethers'; -import type { Abi, UniversalDetails } from 'starknet'; -import { constants } from 'starknet'; +import type { Abi, UniversalDetails, constants } from 'starknet'; import type { Infer } from 'superstruct'; import { type DeclareContractPayloadStruct } from '../utils'; @@ -17,9 +16,8 @@ import { buildRowComponent, buildSignerComponent, generateRandomFee, - mockAccount, - prepareConfirmDialog, - prepareMockAccount, + mockConfirmDialog, + setupAccountController, } from './__tests__/helper'; import { declareContract } from './declare-contract'; import type { @@ -55,50 +53,44 @@ const generateExpectedDeclareTransactionPayload = }, }); -const prepareMockDeclareContract = async ( - transactionHash: string, - payload: DeclareContractPayload, - details: UniversalDetails, -) => { - const state = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; - const { confirmDialogSpy } = prepareConfirmDialog(); +describe('DeclareContractRpc', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - prepareMockAccount(account, state); + const setupDeclareContractTest = async ( + transactionHash: string, + payload: DeclareContractPayload, + details: UniversalDetails, + ) => { + const { confirmDialogSpy } = mockConfirmDialog(); - const request = { - chainId: state.networks[0].chainId as unknown as constants.StarknetChainId, - address: account.address, - payload, - details, - }; + const { account } = await setupAccountController({}); - const declareContractRespMock: DeclareContractResponse = { - // eslint-disable-next-line @typescript-eslint/naming-convention - transaction_hash: transactionHash, - // eslint-disable-next-line @typescript-eslint/naming-convention - class_hash: '0x123456789abcdef', - }; + const request = { + chainId: network.chainId as unknown as constants.StarknetChainId, + address: account.address, + payload, + details, + }; + + const declareContractMockResp: DeclareContractResponse = { + // eslint-disable-next-line @typescript-eslint/naming-convention + transaction_hash: transactionHash, + // eslint-disable-next-line @typescript-eslint/naming-convention + class_hash: '0x123456789abcdef', + }; - const declareContractUtilSpy = jest.spyOn(starknetUtils, 'declareContract'); - declareContractUtilSpy.mockResolvedValue(declareContractRespMock); + const declareContractUtilSpy = jest.spyOn(starknetUtils, 'declareContract'); + declareContractUtilSpy.mockResolvedValue(declareContractMockResp); - return { - network: state.networks[0], - account, - request, - confirmDialogSpy, - declareContractRespMock, - declareContractUtilSpy, + return { + account, + request, + confirmDialogSpy, + declareContractMockResp, + declareContractUtilSpy, + }; }; -}; -describe('DeclareContractRpc', () => { it('declares a contract correctly if user confirms the dialog', async () => { const payload = generateExpectedDeclareTransactionPayload(); const details = { @@ -109,17 +101,16 @@ describe('DeclareContractRpc', () => { const { account, request, - network, - declareContractRespMock, + declareContractMockResp, confirmDialogSpy, declareContractUtilSpy, - } = await prepareMockDeclareContract(transactionHash, payload, details); + } = await setupDeclareContractTest(transactionHash, payload, details); confirmDialogSpy.mockResolvedValue(true); const result = await declareContract.execute(request); - expect(result).toStrictEqual(declareContractRespMock); + expect(result).toStrictEqual(declareContractMockResp); expect(declareContractUtilSpy).toHaveBeenCalledWith( network, account.address, @@ -137,7 +128,7 @@ describe('DeclareContractRpc', () => { const transactionHash = '0x07f901c023bac6c874691244c4c2332c6825b916fb68d240c807c6156db84fd3'; - const { request, confirmDialogSpy } = await prepareMockDeclareContract( + const { request, confirmDialogSpy } = await setupDeclareContractTest( transactionHash, payload, details, @@ -158,25 +149,25 @@ describe('DeclareContractRpc', () => { it.each([ { testCase: 'class_hash is missing', - declareContractRespMock: { + declareContractMockResp: { // eslint-disable-next-line @typescript-eslint/naming-convention transaction_hash: '0x123', }, }, { testCase: 'transaction_hash is missing', - declareContractRespMock: { + declareContractMockResp: { // eslint-disable-next-line @typescript-eslint/naming-convention class_hash: '0x123456789abcdef', }, }, { testCase: 'empty object is returned', - declareContractRespMock: {}, + declareContractMockResp: {}, }, ])( 'throws `Unknown Error` when $testCase', - async ({ declareContractRespMock }) => { + async ({ declareContractMockResp }) => { const payload = generateExpectedDeclareTransactionPayload(); const details = { maxFee: generateRandomFee('1000000000000000', '2000000000000000'), @@ -184,10 +175,10 @@ describe('DeclareContractRpc', () => { const transactionHash = '0x123'; const { request, declareContractUtilSpy } = - await prepareMockDeclareContract(transactionHash, payload, details); + await setupDeclareContractTest(transactionHash, payload, details); declareContractUtilSpy.mockResolvedValue( - declareContractRespMock as unknown as DeclareContractResponse, + declareContractMockResp as unknown as DeclareContractResponse, ); await expect(declareContract.execute(request)).rejects.toThrow( @@ -205,8 +196,8 @@ describe('DeclareContractRpc', () => { const maxFeeInEth = utils.formatUnits(details.maxFee, 'ether'); const transactionHash = '0x123'; - const { request, network, confirmDialogSpy, account } = - await prepareMockDeclareContract(transactionHash, payload, details); + const { request, confirmDialogSpy, account } = + await setupDeclareContractTest(transactionHash, payload, details); await declareContract.execute(request); diff --git a/packages/starknet-snap/src/rpcs/display-private-key.test.ts b/packages/starknet-snap/src/rpcs/display-private-key.test.ts index aa0aac00..49e55bb0 100644 --- a/packages/starknet-snap/src/rpcs/display-private-key.test.ts +++ b/packages/starknet-snap/src/rpcs/display-private-key.test.ts @@ -1,16 +1,15 @@ import { constants } from 'starknet'; -import type { SnapState } from '../types/snapState'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { UserRejectedOpError, InvalidRequestParamsError, } from '../utils/exceptions'; +import { createAccountObject } from '../wallet/account/__test__/helper'; import { - mockAccount, - prepareMockAccount, - prepareRenderDisplayPrivateKeyAlertUI, - prepareRenderDisplayPrivateKeyConfirmUI, + setupAccountController, + mockRenderDisplayPrivateKeyAlertUI, + mockRenderDisplayPrivateKeyConfirmUI, } from './__tests__/helper'; import { displayPrivateKey } from './display-private-key'; import type { DisplayPrivateKeyParams } from './display-private-key'; @@ -18,32 +17,36 @@ import type { DisplayPrivateKeyParams } from './display-private-key'; jest.mock('../utils/logger'); describe('displayPrivateKey', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const setupDisplayPrivateKeyTest = async () => { + const { chainId } = network; + const { accountObj: account } = await createAccountObject(network); + await setupAccountController({ + accountObj: account, + }); + + const { confirmDialogSpy } = mockRenderDisplayPrivateKeyConfirmUI(); + const { alertDialogSpy } = mockRenderDisplayPrivateKeyAlertUI(); + + mockRenderDisplayPrivateKeyAlertUI(); - const createRequestParam = ( - chainId: constants.StarknetChainId, - address: string, - ): DisplayPrivateKeyParams => { const request: DisplayPrivateKeyParams = { + chainId: chainId as constants.StarknetChainId, + address: account.address, + }; + + return { + request, + account, chainId, - address, + confirmDialogSpy, + alertDialogSpy, }; - return request; }; it('displays private key correctly', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - prepareRenderDisplayPrivateKeyConfirmUI(); - const { alertDialogSpy } = prepareRenderDisplayPrivateKeyAlertUI(); - - const request = createRequestParam(chainId, account.address); + const { account, alertDialogSpy, request } = + await setupDisplayPrivateKeyTest(); await displayPrivateKey.execute(request); @@ -51,13 +54,7 @@ describe('displayPrivateKey', () => { }); it('renders confirmation dialog', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderDisplayPrivateKeyConfirmUI(); - prepareRenderDisplayPrivateKeyAlertUI(); - - const request = createRequestParam(chainId, account.address); + const { confirmDialogSpy, request } = await setupDisplayPrivateKeyTest(); await displayPrivateKey.execute(request); @@ -65,16 +62,10 @@ describe('displayPrivateKey', () => { }); it('throws `UserRejectedOpError` if user denies the operation', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderDisplayPrivateKeyConfirmUI(); - prepareRenderDisplayPrivateKeyAlertUI(); + const { confirmDialogSpy, request } = await setupDisplayPrivateKeyTest(); confirmDialogSpy.mockResolvedValue(false); - const request = createRequestParam(chainId, account.address); - await expect(displayPrivateKey.execute(request)).rejects.toThrow( UserRejectedOpError, ); diff --git a/packages/starknet-snap/src/rpcs/estimate-fee.test.ts b/packages/starknet-snap/src/rpcs/estimate-fee.test.ts index 2b23ed8b..fa422e05 100644 --- a/packages/starknet-snap/src/rpcs/estimate-fee.test.ts +++ b/packages/starknet-snap/src/rpcs/estimate-fee.test.ts @@ -6,10 +6,10 @@ import callsExamples from '../__tests__/fixture/callsExamples.json'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { InvalidRequestParamsError } from '../utils/exceptions'; import type { TxVersionStruct } from '../utils/superstruct'; +import { createAccountObject } from '../wallet/account/__test__/helper'; import { - mockAccount, mockGetEstimatedFeesResponse, - prepareMockAccount, + setupAccountController, } from './__tests__/helper'; import { estimateFee } from './estimate-fee'; import type { EstimateFeeParams } from './estimate-fee'; @@ -17,52 +17,50 @@ import type { EstimateFeeParams } from './estimate-fee'; jest.mock('../utils/snap'); jest.mock('../utils/logger'); -const prepareMockEstimateFee = ({ - chainId, - address, - version, - includeDeploy = false, -}: { - chainId: constants.StarknetChainId; - address: string; - version: Infer; - includeDeploy?: boolean; -}) => { - const invocations: Invocations = [ - { - type: TransactionType.INVOKE, - payload: callsExamples.singleCall.calls, - }, - ]; +describe('estimateFee', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const request = { + const setupMockEstimateFeeTest = ({ chainId, address, - invocations, - details: { version }, - } as unknown as EstimateFeeParams; + version, + includeDeploy = false, + }: { + chainId: constants.StarknetChainId; + address: string; + version: Infer; + includeDeploy?: boolean; + }) => { + const invocations: Invocations = [ + { + type: TransactionType.INVOKE, + payload: callsExamples.singleCall.calls, + }, + ]; - return { - invocations, - request, - ...mockGetEstimatedFeesResponse({ - includeDeploy, - }), - }; -}; + const request = { + chainId, + address, + invocations, + details: { version }, + } as unknown as EstimateFeeParams; -describe('estimateFee', () => { - const state = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], + return { + invocations, + request, + ...mockGetEstimatedFeesResponse({ + includeDeploy, + }), + }; }; it('estimates fee correctly', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); + const { chainId } = network; + const { accountObj: account } = await createAccountObject(network); + await setupAccountController({ + accountObj: account, + }); + const { request, getEstimatedFeesSpy, @@ -72,9 +70,8 @@ describe('estimateFee', () => { suggestedMaxFee, unit, }, - } = prepareMockEstimateFee({ - includeDeploy: false, - chainId, + } = setupMockEstimateFeeTest({ + chainId: chainId as constants.StarknetChainId, address: account.address, version: constants.TRANSACTION_VERSION.V1, }); @@ -82,7 +79,7 @@ describe('estimateFee', () => { const result = await estimateFee.execute(request); expect(getEstimatedFeesSpy).toHaveBeenCalledWith( - STARKNET_SEPOLIA_TESTNET_NETWORK, + network, account.address, account.privateKey, account.publicKey, diff --git a/packages/starknet-snap/src/rpcs/execute-txn.test.ts b/packages/starknet-snap/src/rpcs/execute-txn.test.ts index 801f2f31..19b2f485 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.test.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.test.ts @@ -7,11 +7,7 @@ import { mockTransactionRequestStateManager } from '../state/__tests__/helper'; import { AccountStateManager } from '../state/account-state-manager'; import { TransactionStateManager } from '../state/transaction-state-manager'; import { FeeToken } from '../types/snapApi'; -import type { - FormattedCallData, - Network, - TransactionRequest, -} from '../types/snapState'; +import type { FormattedCallData, TransactionRequest } from '../types/snapState'; import * as uiUtils from '../ui/utils'; import { CAIRO_VERSION, @@ -31,10 +27,9 @@ import { transactionVersionToNumber, } from '../utils/transaction'; import { - mockAccount, mockGetEstimatedFeesResponse, - prepareConfirmDialogInteractiveUI, - prepareMockAccount, + mockConfirmDialogInteractiveUI, + setupAccountController, } from './__tests__/helper'; import type { ConfirmTransactionParams, @@ -73,104 +68,119 @@ class MockExecuteTxnRpc extends ExecuteTxnRpc { } } -const generateAccount = async (network) => { - const state = { - accContracts: [], - erc20Tokens: [], - networks: [network], - transactions: [], - }; +describe('ExecuteTxn', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const account = await mockAccount(network); - prepareMockAccount(account, state); + const createMockRpc = () => { + const rpc = new MockExecuteTxnRpc({ + showInvalidAccountAlert: true, + }); + return rpc; + }; - return account; -}; + const setupRpcTest = async (calls: Call[]) => { + const controller = await setupAccountController({ network }); -const createMockRpc = () => { - const rpc = new MockExecuteTxnRpc({ - showInvalidAccountAlert: true, - }); - return rpc; -}; + const rpc = createMockRpc(); -const setupMockRpc = async (network: Network, calls: Call[]) => { - const account = await generateAccount(network); + // Setup the rpc, to discover the account and network + await rpc.preExecute({ + chainId: network.chainId, + address: controller.account.address, + calls, + } as unknown as ExecuteTxnParams); - const rpc = createMockRpc(); - - // Setup the rpc, to discover the account and network - await rpc.preExecute({ - chainId: network.chainId, - address: account.address, - calls, - } as unknown as ExecuteTxnParams); + return { + rpc, + ...controller, + }; + }; - return { - rpc, - account, + const mockCallToTransactionReqCall = (calls: Call[]) => { + const callToTransactionReqCallSpy = jest.spyOn( + formatUtils, + 'callToTransactionReqCall', + ); + const formattedCalls: FormattedCallData[] = []; + for (const call of calls) { + formattedCalls.push({ + contractAddress: call.contractAddress, + calldata: call.calldata as unknown as string[], + entrypoint: call.entrypoint, + }); + callToTransactionReqCallSpy.mockResolvedValueOnce( + formattedCalls[formattedCalls.length - 1], + ); + } + return { + callToTransactionReqCallSpy, + formattedCalls, + }; }; -}; - -const mockCallToTransactionReqCall = (calls: Call[]) => { - const callToTransactionReqCallSpy = jest.spyOn( - formatUtils, - 'callToTransactionReqCall', - ); - const formattedCalls: FormattedCallData[] = []; - for (const call of calls) { - formattedCalls.push({ - contractAddress: call.contractAddress, - calldata: call.calldata as unknown as string[], - entrypoint: call.entrypoint, - }); - callToTransactionReqCallSpy.mockResolvedValueOnce( - formattedCalls[formattedCalls.length - 1], + + const mockGenerateExecuteTxnFlow = () => { + const generateExecuteTxnFlowSpy = jest.spyOn( + uiUtils, + 'generateExecuteTxnFlow', ); - } - return { - callToTransactionReqCallSpy, - formattedCalls, + const interfaceId = uuidv4(); + generateExecuteTxnFlowSpy.mockResolvedValue(interfaceId); + return { + interfaceId, + generateExecuteTxnFlowSpy, + }; }; -}; - -const mockGenerateExecuteTxnFlow = () => { - const generateExecuteTxnFlowSpy = jest.spyOn( - uiUtils, - 'generateExecuteTxnFlow', - ); - const interfaceId = uuidv4(); - generateExecuteTxnFlowSpy.mockResolvedValue(interfaceId); - return { - interfaceId, - generateExecuteTxnFlowSpy, + + const getTransactionCalls = () => { + const { calls, details, hash } = callsExamples.multipleCalls; + const { formattedCalls } = mockCallToTransactionReqCall(calls); + return { + calls, + details, + hash, + formattedCalls, + }; }; -}; -describe('ExecuteTxn', () => { describe('confirmTransaction', () => { - const prepareConfirmTransaction = async (confirm = true) => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const includeDeploy = true; + const setupConfirmTransactionTest = async (confirm = true) => { const txnVersion = constants.TRANSACTION_VERSION.V3; - const { calls } = callsExamples.multipleCalls; + const { calls, formattedCalls } = getTransactionCalls(); + + const { account, rpc } = await setupRpcTest(calls); + const includeDeploy = !(await account.accountContract.isDeployed()); - const { account, rpc } = await setupMockRpc(network, calls); const { getEstimatedFeesResponse: { suggestedMaxFee: maxFee, resourceBounds }, } = mockGetEstimatedFeesResponse({ - includeDeploy: false, + includeDeploy, }); const request = { calls, - address: account.address, maxFee, resourceBounds, txnVersion, includeDeploy, }; + const { interfaceId } = mockGenerateExecuteTxnFlow(); + + const transactionRequest = { + chainId: network.chainId, + networkName: network.name, + id: uuidv4(), + interfaceId, + type: TransactionType.INVOKE, + signer: account.address, + addressIndex: 0, + maxFee, + calls: formattedCalls, + resourceBounds, + selectedFeeToken: transactionVersionToFeeToken(txnVersion), + includeDeploy, + }; + return { request, rpc, @@ -180,9 +190,8 @@ describe('ExecuteTxn', () => { resourceBounds, txnVersion, includeDeploy, - ...prepareConfirmDialogInteractiveUI(confirm), - ...mockCallToTransactionReqCall(calls), - ...mockGenerateExecuteTxnFlow(), + transactionRequest, + ...mockConfirmDialogInteractiveUI(confirm), ...mockTransactionRequestStateManager(), }; }; @@ -191,41 +200,25 @@ describe('ExecuteTxn', () => { const { request, rpc, - interfaceId, - account: { address }, - formattedCalls, - maxFee, - resourceBounds, - txnVersion, - includeDeploy, - network: { chainId, name: networkName }, upsertTransactionRequestSpy, - confirmDialogSpy, getTransactionRequestSpy, removeTransactionRequestSpy, - } = await prepareConfirmTransaction(); + transactionRequest, + } = await setupConfirmTransactionTest(); + + getTransactionRequestSpy.mockResolvedValue(transactionRequest); const result = await rpc.confirmTransaction(request); const expectedTransactionRequest = { - chainId, - networkName, + ...transactionRequest, id: expect.any(String), - interfaceId, - type: TransactionType.INVOKE, - signer: address, - addressIndex: 0, - maxFee, - calls: formattedCalls, - resourceBounds, - selectedFeeToken: transactionVersionToFeeToken(txnVersion), - includeDeploy, }; + expect(result).toStrictEqual(expectedTransactionRequest); expect(upsertTransactionRequestSpy).toHaveBeenCalledWith( expectedTransactionRequest, ); - expect(confirmDialogSpy).toHaveBeenCalledWith(interfaceId); expect(getTransactionRequestSpy).toHaveBeenCalledWith({ requestId: expect.any(String), }); @@ -235,8 +228,15 @@ describe('ExecuteTxn', () => { }); it('does not throw an error if remove request failed', async () => { - const { request, rpc, removeTransactionRequestSpy } = - await prepareConfirmTransaction(); + const { + request, + rpc, + removeTransactionRequestSpy, + transactionRequest, + getTransactionRequestSpy, + } = await setupConfirmTransactionTest(); + + getTransactionRequestSpy.mockResolvedValue(transactionRequest); removeTransactionRequestSpy.mockRejectedValue( new Error('Failed to remove request'), @@ -254,7 +254,7 @@ describe('ExecuteTxn', () => { rpc, getTransactionRequestSpy, removeTransactionRequestSpy, - } = await prepareConfirmTransaction(); + } = await setupConfirmTransactionTest(); getTransactionRequestSpy.mockResolvedValue(null); @@ -268,7 +268,7 @@ describe('ExecuteTxn', () => { }); it('throws UserRejectedOpError if user denied the operation', async () => { - const { request, rpc } = await prepareConfirmTransaction(false); + const { request, rpc } = await setupConfirmTransactionTest(false); await expect(rpc.confirmTransaction(request)).rejects.toThrow( UserRejectedOpError, @@ -277,11 +277,10 @@ describe('ExecuteTxn', () => { }); describe('deployAccount', () => { - const prepareDeployAccount = async () => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const setupDeployAccountTest = async () => { const txnVersion = constants.TRANSACTION_VERSION.V3; const { calls } = callsExamples.multipleCalls; - const { account, rpc } = await setupMockRpc(network, calls); + const { account, rpc } = await setupRpcTest(calls); const deployAccountSpy = jest.spyOn(starknetUtils, 'deployAccount'); const deployAccountResponse = { @@ -291,7 +290,6 @@ describe('ExecuteTxn', () => { deployAccountSpy.mockResolvedValue(deployAccountResponse); const request = { - address: account.address, txnVersion, }; @@ -304,7 +302,6 @@ describe('ExecuteTxn', () => { accountDeploymentData, request, rpc, - network, account, deployAccountSpy, deployAccountResponse, @@ -315,12 +312,11 @@ describe('ExecuteTxn', () => { const { rpc, request, - network, account: { address, privateKey, publicKey }, deployAccountResponse, deployAccountSpy, accountDeploymentData, - } = await prepareDeployAccount(); + } = await setupDeployAccountTest(); const result = await rpc.deployAccount(request); @@ -338,7 +334,7 @@ describe('ExecuteTxn', () => { it('throws `Failed to deploy account` error if the execution transaction hash is empty', async () => { const { rpc, request, deployAccountSpy, deployAccountResponse } = - await prepareDeployAccount(); + await setupDeployAccountTest(); deployAccountSpy.mockResolvedValue({ ...deployAccountResponse, transaction_hash: '', @@ -351,12 +347,11 @@ describe('ExecuteTxn', () => { }); describe('sendTransaction', () => { - const prepareConfirmTransaction = async () => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const setupConfirmTransactionTest = async () => { const txnVersion = constants.TRANSACTION_VERSION.V3; const { calls } = callsExamples.multipleCalls; - const { account, rpc } = await setupMockRpc(network, calls); + const { account, rpc } = await setupRpcTest(calls); const executeTxnSpy = jest.spyOn(starknetUtils, 'executeTxn'); const executeTxnResponse = { @@ -366,7 +361,6 @@ describe('ExecuteTxn', () => { const request: SendTransactionParams = { calls, - address: account.address, abis: undefined, details: { version: txnVersion, @@ -387,18 +381,17 @@ describe('ExecuteTxn', () => { const { rpc, request, - network, - account: { privateKey }, + account: { privateKey, address }, executeTxnResponse, executeTxnSpy, - } = await prepareConfirmTransaction(); + } = await setupConfirmTransactionTest(); const result = await rpc.sendTransaction(request); expect(result).toStrictEqual(executeTxnResponse.transaction_hash); expect(executeTxnSpy).toHaveBeenCalledWith( network, - request.address, + address, privateKey, request.calls, request.abis, @@ -407,7 +400,8 @@ describe('ExecuteTxn', () => { }); it('throws `Failed to execute transaction` error if the execution transaction hash is empty', async () => { - const { rpc, request, executeTxnSpy } = await prepareConfirmTransaction(); + const { rpc, request, executeTxnSpy } = + await setupConfirmTransactionTest(); executeTxnSpy.mockResolvedValue({ transaction_hash: '' }); await expect(rpc.sendTransaction(request)).rejects.toThrow( @@ -417,53 +411,71 @@ describe('ExecuteTxn', () => { }); describe('execute', () => { - const prepareExecute = async (accountDeployed = true) => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const account = await generateAccount(network); - - const { getEstimatedFeesResponse, getEstimatedFeesSpy } = - mockGetEstimatedFeesResponse({ - includeDeploy: !accountDeployed, - }); - const { suggestedMaxFee, resourceBounds } = getEstimatedFeesResponse; - + const mockExecute = () => { const confirmTransactionSpy = jest.spyOn( MockExecuteTxnRpc.prototype, 'confirmTransaction', ); - const transactionRequest = { - selectedFeeToken: FeeToken.STRK, - maxFee: suggestedMaxFee, - resourceBounds, - } as unknown as TransactionRequest; - confirmTransactionSpy.mockResolvedValue(transactionRequest); - - const sendTansactionResponse = callsExamples.multipleCalls.hash; const sendTransactionSpy = jest.spyOn( MockExecuteTxnRpc.prototype, 'sendTransaction', ); - sendTransactionSpy.mockResolvedValue(sendTansactionResponse); - - const deployAccountResponse = callsExamples.singleCall.hash; + const saveDataToStateSpy = jest.spyOn( + MockExecuteTxnRpc.prototype, + 'saveDataToState', + ); const deployAccountSpy = jest.spyOn( MockExecuteTxnRpc.prototype, 'deployAccount', ); - deployAccountSpy.mockResolvedValue(deployAccountResponse); - const saveDataToStateSpy = jest.spyOn( - MockExecuteTxnRpc.prototype, - 'saveDataToState', - ); + return { + deployAccountSpy, + confirmTransactionSpy, + sendTransactionSpy, + saveDataToStateSpy, + }; + }; + + const setupExecuteTest = async (accountDeployed = true) => { + const { + confirmTransactionSpy, + deployAccountSpy, + sendTransactionSpy, + saveDataToStateSpy, + } = mockExecute(); + const { + calls, + details, + hash: sendTansactionResponse, + } = getTransactionCalls(); + const { account, rpc, isDeploySpy } = await setupRpcTest(calls); + + const { getEstimatedFeesResponse, getEstimatedFeesSpy } = + mockGetEstimatedFeesResponse({ + includeDeploy: !accountDeployed, + }); + const { suggestedMaxFee: maxFee, resourceBounds } = + getEstimatedFeesResponse; + + const transactionRequest = { + selectedFeeToken: FeeToken.STRK, + maxFee, + resourceBounds, + } as unknown as TransactionRequest; + const deployAccountResponse = callsExamples.singleCall.hash; + + confirmTransactionSpy.mockResolvedValue(transactionRequest); + sendTransactionSpy.mockResolvedValue(sendTansactionResponse); + deployAccountSpy.mockResolvedValue(deployAccountResponse); saveDataToStateSpy.mockReturnThis(); + isDeploySpy.mockResolvedValue(accountDeployed); - const rpc = createMockRpc(); const request: ExecuteTxnParams = { chainId: network.chainId, address: account.address, - calls: callsExamples.multipleCalls.calls, - details: callsExamples.multipleCalls.details, + calls, + details, } as unknown as ExecuteTxnParams; return { @@ -489,23 +501,19 @@ describe('ExecuteTxn', () => { request, sendTansactionResponse, sendTransactionSpy, - account: { address }, getEstimatedFeesResponse, confirmTransactionSpy, deployAccountSpy, saveDataToStateSpy, transactionRequest, - } = await prepareExecute(); + } = await setupExecuteTest(); const updatedTxnVersion = feeTokenToTransactionVersion( transactionRequest.selectedFeeToken, ); const { maxFee: updatedMaxFee, resourceBounds: updatedResourceBounds } = transactionRequest; - const { - suggestedMaxFee: maxFee, - resourceBounds, - includeDeploy, - } = getEstimatedFeesResponse; + const { suggestedMaxFee: maxFee, resourceBounds } = + getEstimatedFeesResponse; const { calls, abis, details } = request; const result = await rpc.execute(request); @@ -515,15 +523,12 @@ describe('ExecuteTxn', () => { }); expect(confirmTransactionSpy).toHaveBeenCalledWith({ txnVersion: details?.version, - address, calls, maxFee, resourceBounds, - includeDeploy, }); expect(deployAccountSpy).not.toHaveBeenCalled(); expect(sendTransactionSpy).toHaveBeenCalledWith({ - address, calls, abis, details: { @@ -538,7 +543,6 @@ describe('ExecuteTxn', () => { txnHashForExecute: sendTansactionResponse, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, - address, calls, }); }); @@ -548,13 +552,12 @@ describe('ExecuteTxn', () => { rpc, request, sendTansactionResponse, - account: { address }, deployAccountResponse, deployAccountSpy, saveDataToStateSpy, transactionRequest, sendTransactionSpy, - } = await prepareExecute(false); + } = await setupExecuteTest(false); const updatedTxnVersion = feeTokenToTransactionVersion( transactionRequest.selectedFeeToken, ); @@ -568,11 +571,9 @@ describe('ExecuteTxn', () => { transaction_hash: sendTansactionResponse, }); expect(deployAccountSpy).toHaveBeenCalledWith({ - address, txnVersion: updatedTxnVersion, }); expect(sendTransactionSpy).toHaveBeenCalledWith({ - address, calls, abis, details: { @@ -588,44 +589,49 @@ describe('ExecuteTxn', () => { txnHashForExecute: sendTansactionResponse, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, - address, calls, }); }); }); describe('saveDataToState', () => { - const prepareSaveDataToState = async () => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const txnVersion = constants.TRANSACTION_VERSION.V3; - const { hash: txnHashForExecute, calls } = callsExamples.multipleCalls; - const { hash: txnHashForDeploy } = callsExamples.singleCall; - - const { rpc, account } = await setupMockRpc(network, calls); - const { - getEstimatedFeesResponse: { suggestedMaxFee: maxFee }, - } = mockGetEstimatedFeesResponse({ - includeDeploy: false, - }); - + const mockSaveDataToState = () => { const addTransactionSpy = jest.spyOn( TransactionStateManager.prototype, 'addTransaction', ); - addTransactionSpy.mockReturnThis(); - const updateAccountAsDeploySpy = jest.spyOn( AccountStateManager.prototype, 'updateAccountAsDeploy', ); + addTransactionSpy.mockReturnThis(); updateAccountAsDeploySpy.mockReturnThis(); + return { + addTransactionSpy, + updateAccountAsDeploySpy, + }; + }; + + const setupSaveDataToStateTest = async () => { + const txnVersion = constants.TRANSACTION_VERSION.V3; + const { updateAccountAsDeploySpy, addTransactionSpy } = + mockSaveDataToState(); + const { hash: txnHashForExecute, calls } = getTransactionCalls(); + const { hash: txnHashForDeploy } = callsExamples.singleCall; + + const { rpc, account } = await setupRpcTest(calls); + const { + getEstimatedFeesResponse: { suggestedMaxFee: maxFee }, + } = mockGetEstimatedFeesResponse({ + includeDeploy: false, + }); + const request: SaveDataToStateParamas = { txnHashForDeploy, txnHashForExecute, txnVersion, maxFee, - address: account.address, calls, } as unknown as SaveDataToStateParamas; @@ -664,7 +670,7 @@ describe('ExecuteTxn', () => { addTransactionSpy, updateAccountAsDeploySpy, newInvokeTransaction, - } = await prepareSaveDataToState(); + } = await setupSaveDataToStateTest(); await rpc.saveDataToState({ ...request, @@ -685,7 +691,7 @@ describe('ExecuteTxn', () => { updateAccountAsDeploySpy, newInvokeTransaction, newDeployTransaction, - } = await prepareSaveDataToState(); + } = await setupSaveDataToStateTest(); await rpc.saveDataToState(request); diff --git a/packages/starknet-snap/src/rpcs/execute-txn.ts b/packages/starknet-snap/src/rpcs/execute-txn.ts index 3c7714d9..004519fd 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.ts @@ -26,7 +26,6 @@ import { UserRejectedOpError } from '../utils/exceptions'; import { deployAccount, executeTxn as executeTxnUtil, - getDeployAccountCallData, getEstimatedFees, } from '../utils/starknetUtils'; import { @@ -60,20 +59,16 @@ export type ExecuteTxnResponse = Infer; export type ConfirmTransactionParams = { calls: Call[]; - address: string; maxFee: string; resourceBounds: ResourceBounds; txnVersion: constants.TRANSACTION_VERSION; - includeDeploy: boolean; }; export type DeployAccountParams = { - address: string; txnVersion: constants.TRANSACTION_VERSION; }; export type SendTransactionParams = { - address: string; calls: Call[]; abis?: any[]; details?: Infer; @@ -84,7 +79,6 @@ export type SaveDataToStateParamas = { txnHashForExecute: string; txnVersion: constants.TRANSACTION_VERSION; maxFee: string; - address: string; calls: Call[]; }; @@ -151,11 +145,8 @@ export class ExecuteTxnRpc extends AccountRpcController< const { privateKey, publicKey } = this.account; const callsArray = Array.isArray(calls) ? calls : [calls]; - const { - includeDeploy, - suggestedMaxFee: maxFee, - resourceBounds, - } = await getEstimatedFees( + // FIXME: getEstimatedFees shpuld be refactored to accept account object + const { suggestedMaxFee: maxFee, resourceBounds } = await getEstimatedFees( this.network, address, privateKey, @@ -169,7 +160,7 @@ export class ExecuteTxnRpc extends AccountRpcController< details, ); - const accountDeployed = !includeDeploy; + const accountDeployed = await this.account.accountContract.isDeployed(); const { selectedFeeToken, @@ -177,11 +168,9 @@ export class ExecuteTxnRpc extends AccountRpcController< resourceBounds: updatedResouceBounds, } = await this.confirmTransaction({ txnVersion: details?.version as unknown as constants.TRANSACTION_VERSION, - address, calls: callsArray, maxFee, resourceBounds, - includeDeploy, }); const updatedTxnVersion = feeTokenToTransactionVersion(selectedFeeToken); @@ -190,13 +179,11 @@ export class ExecuteTxnRpc extends AccountRpcController< if (!accountDeployed) { txnHashForDeploy = await this.deployAccount({ - address, txnVersion: updatedTxnVersion, }); } const txnHashForExecute = await this.sendTransaction({ - address, calls: callsArray, abis, details: { @@ -215,7 +202,6 @@ export class ExecuteTxnRpc extends AccountRpcController< txnHashForExecute, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, - address, calls: callsArray, }); @@ -227,15 +213,13 @@ export class ExecuteTxnRpc extends AccountRpcController< protected async confirmTransaction({ calls, - address, maxFee, resourceBounds, txnVersion, - includeDeploy, }: ConfirmTransactionParams): Promise { const requestId = uuidv4(); const { chainId, name: networkName } = this.network; - const { addressIndex } = this.account; + const { hdIndex: addressIndex, address } = this.account; const formattedCalls = await Promise.all( calls.map(async (call) => @@ -260,7 +244,7 @@ export class ExecuteTxnRpc extends AccountRpcController< calls: formattedCalls, resourceBounds, selectedFeeToken: transactionVersionToFeeToken(txnVersion), - includeDeploy, + includeDeploy: !(await this.account.accountContract.isDeployed()), }; const interfaceId = await generateExecuteTxnFlow(request); @@ -300,13 +284,16 @@ export class ExecuteTxnRpc extends AccountRpcController< } protected async deployAccount({ - address, txnVersion, }: DeployAccountParams): Promise { - const { privateKey, publicKey } = this.account; - - const callData = getDeployAccountCallData(publicKey, CAIRO_VERSION); + const { + privateKey, + address, + publicKey, + accountContract: { callData }, + } = this.account; + // FIXME: deployAccount shpuld be refactored to accept account object const { contract_address: contractAddress, transaction_hash: transactionHash, @@ -334,12 +321,11 @@ export class ExecuteTxnRpc extends AccountRpcController< } protected async sendTransaction({ - address, calls, abis, details, }: SendTransactionParams): Promise { - const { privateKey } = this.account; + const { privateKey, address } = this.account; const executeTxnResp = await executeTxnUtil( this.network, @@ -362,11 +348,11 @@ export class ExecuteTxnRpc extends AccountRpcController< txnHashForExecute, txnVersion, maxFee, - address, calls, }: SaveDataToStateParamas) { const txnVersionInNumber = transactionVersionToNumber(txnVersion); const { chainId } = this.network; + const { address } = this.account; if (txnHashForDeploy) { await this.txnStateManager.addTransaction( diff --git a/packages/starknet-snap/src/rpcs/get-addr-from-starkname.test.ts b/packages/starknet-snap/src/rpcs/get-addr-from-starkname.test.ts index 01973640..2de5a403 100644 --- a/packages/starknet-snap/src/rpcs/get-addr-from-starkname.test.ts +++ b/packages/starknet-snap/src/rpcs/get-addr-from-starkname.test.ts @@ -10,35 +10,35 @@ import { jest.mock('../utils/snap'); jest.mock('../utils/logger'); -const prepareMockGetAddrFromStarkName = ({ - chainId, - starkName, -}: { - chainId: constants.StarknetChainId; - starkName: string; -}) => { - const request = { +describe('getAddrFromStarkName', () => { + const setupGetAddrFromStarkNameTest = ({ chainId, starkName, - } as unknown as GetAddrFromStarkNameParams; - - const getAddrFromStarkNameSpy = jest.spyOn( - starknetUtils, - 'getAddrFromStarkNameUtil', - ); - getAddrFromStarkNameSpy.mockResolvedValue( - '0x01c744953f1d671673f46a9179a58a7e58d9299499b1e076cdb908e7abffe69f', - ); - - return { - request, + }: { + chainId: constants.StarknetChainId; + starkName: string; + }) => { + const request = { + chainId, + starkName, + } as unknown as GetAddrFromStarkNameParams; + + const getAddrFromStarkNameSpy = jest.spyOn( + starknetUtils, + 'getAddrFromStarkNameUtil', + ); + getAddrFromStarkNameSpy.mockResolvedValue( + '0x01c744953f1d671673f46a9179a58a7e58d9299499b1e076cdb908e7abffe69f', + ); + + return { + request, + }; }; -}; -describe('getAddrFromStarkName', () => { it('get address from stark name correctly', async () => { const chainId = constants.StarknetChainId.SN_SEPOLIA; - const { request } = prepareMockGetAddrFromStarkName({ + const { request } = setupGetAddrFromStarkNameTest({ chainId, starkName: 'testname.stark', }); diff --git a/packages/starknet-snap/src/rpcs/get-addr-from-starkname.ts b/packages/starknet-snap/src/rpcs/get-addr-from-starkname.ts index cab1d5a1..ff52af76 100644 --- a/packages/starknet-snap/src/rpcs/get-addr-from-starkname.ts +++ b/packages/starknet-snap/src/rpcs/get-addr-from-starkname.ts @@ -2,11 +2,9 @@ import type { Infer } from 'superstruct'; import { assign, object } from 'superstruct'; import { NetworkStateManager } from '../state/network-state-manager'; -import type { Network } from '../types/snapState'; import { AddressStruct, BaseRequestStruct, StarkNameStruct } from '../utils'; -import { InvalidNetworkError } from '../utils/exceptions'; import { getAddrFromStarkNameUtil } from '../utils/starknetUtils'; -import { RpcController } from './abstract/base-rpc-controller'; +import { ChainRpcController } from './abstract/chain-rpc-controller'; export const GetAddrFromStarkNameRequestStruct = assign( object({ @@ -28,7 +26,7 @@ export type GetAddrFromStarkNameResponse = Infer< /** * The RPC handler to get a StarkName by a Starknet address. */ -export class GetAddrFromStarkNameRpc extends RpcController< +export class GetAddrFromStarkNameRpc extends ChainRpcController< GetAddrFromStarkNameParams, GetAddrFromStarkNameResponse > { @@ -58,27 +56,12 @@ export class GetAddrFromStarkNameRpc extends RpcController< return super.execute(params); } - protected async getNetworkFromChainId(chainId: string): Promise { - const network = await this.networkStateMgr.getNetwork({ - chainId, - }); - - // It should be never happen, as the chainId should be validated by the superstruct - if (!network) { - throw new InvalidNetworkError() as unknown as Error; - } - - return network; - } - protected async handleRequest( params: GetAddrFromStarkNameParams, ): Promise { - const { chainId, starkName } = params; - - const network = await this.getNetworkFromChainId(chainId); + const { starkName } = params; - const address = await getAddrFromStarkNameUtil(network, starkName); + const address = await getAddrFromStarkNameUtil(this.network, starkName); return address; } diff --git a/packages/starknet-snap/src/rpcs/get-deployment-data.test.ts b/packages/starknet-snap/src/rpcs/get-deployment-data.test.ts index 355acb48..1fe1ef25 100644 --- a/packages/starknet-snap/src/rpcs/get-deployment-data.test.ts +++ b/packages/starknet-snap/src/rpcs/get-deployment-data.test.ts @@ -1,17 +1,13 @@ -import { constants } from 'starknet'; +import type { constants } from 'starknet'; -import type { SnapState } from '../types/snapState'; -import { - ACCOUNT_CLASS_HASH, - CAIRO_VERSION, - STARKNET_SEPOLIA_TESTNET_NETWORK, -} from '../utils/constants'; +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { InvalidRequestParamsError, AccountAlreadyDeployedError, + ContractNotDeployedError, } from '../utils/exceptions'; -import * as starknetUtils from '../utils/starknetUtils'; -import { mockAccount, prepareMockAccount } from './__tests__/helper'; +import { mockAccountContractReader } from '../wallet/account/__test__/helper'; +import { setupAccountController } from './__tests__/helper'; import type { GetDeploymentDataParams } from './get-deployment-data'; import { getDeploymentData } from './get-deployment-data'; @@ -19,33 +15,23 @@ jest.mock('../utils/snap'); jest.mock('../utils/logger'); describe('GetDeploymentDataRpc', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const createRequest = ( - chainId: constants.StarknetChainId, - address: string, - ) => ({ - address, - chainId, - }); + const setupGetDeploymentDataTest = async (deployed: boolean) => { + const { getVersionSpy } = mockAccountContractReader({}); - const mockIsAccountDeployed = (deployed: boolean) => { - const spy = jest.spyOn(starknetUtils, 'isAccountDeployed'); - spy.mockResolvedValue(deployed); - return spy; - }; + if (!deployed) { + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + } + const { account } = await setupAccountController({ + network, + isDeployed: deployed, + }); - const prepareGetDeploymentDataTest = async (deployed: boolean) => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - mockIsAccountDeployed(deployed); - const request = createRequest(chainId, account.address); + const request = { + address: account.address, + chainId: network.chainId as constants.StarknetChainId, + }; return { account, @@ -54,18 +40,26 @@ describe('GetDeploymentDataRpc', () => { }; it('returns the deployment data', async () => { - const { account, request } = await prepareGetDeploymentDataTest(false); - const { address, publicKey } = account; + const { account, request } = await setupGetDeploymentDataTest(false); + const { + address, + accountContract: { + deployPayload: { + classHash, + addressSalt: salt, + constructorCalldata: calldata, + }, + cairoVerion, + }, + } = account; + const expectedResult = { address, // eslint-disable-next-line @typescript-eslint/naming-convention - class_hash: ACCOUNT_CLASS_HASH, - salt: publicKey, - calldata: starknetUtils.getDeployAccountCallData( - publicKey, - CAIRO_VERSION, - ), - version: CAIRO_VERSION, + class_hash: classHash, + salt, + calldata, + version: cairoVerion.toString(10), }; const result = await getDeploymentData.execute(request); @@ -74,7 +68,7 @@ describe('GetDeploymentDataRpc', () => { }); it('throws `AccountAlreadyDeployedError` if the account has deployed', async () => { - const { request } = await prepareGetDeploymentDataTest(true); + const { request } = await setupGetDeploymentDataTest(true); await expect(getDeploymentData.execute(request)).rejects.toThrow( AccountAlreadyDeployedError, diff --git a/packages/starknet-snap/src/rpcs/get-deployment-data.ts b/packages/starknet-snap/src/rpcs/get-deployment-data.ts index ac8fcc29..82123c9f 100644 --- a/packages/starknet-snap/src/rpcs/get-deployment-data.ts +++ b/packages/starknet-snap/src/rpcs/get-deployment-data.ts @@ -2,12 +2,7 @@ import type { Infer } from 'superstruct'; import { object, string, assign, array } from 'superstruct'; import { AddressStruct, BaseRequestStruct, CairoVersionStruct } from '../utils'; -import { ACCOUNT_CLASS_HASH, CAIRO_VERSION } from '../utils/constants'; import { AccountAlreadyDeployedError } from '../utils/exceptions'; -import { - getDeployAccountCallData, - isAccountDeployed, -} from '../utils/starknetUtils'; import { AccountRpcController } from './abstract/account-rpc-controller'; export const GetDeploymentDataRequestStruct = assign( @@ -61,26 +56,40 @@ export class GetDeploymentDataRpc extends AccountRpcController< } protected async handleRequest( + // eslint-disable-next-line @typescript-eslint/no-unused-vars params: GetDeploymentDataParams, ): Promise { - const { address } = params; + const { accountContract } = this.account; + const { + deployPayload: { + contractAddress: address, + classHash, + addressSalt: salt, + constructorCalldata: calldata, + }, + cairoVerion, + } = accountContract; + // Due to AccountRpcController built-in validation, - // if the account required to force deploy (Cairo 0 with balance), it will alert with a warning dialog. - // if the account required to force upgrade (Cairo 0 without balance), it will alert with a warning dialog. + // if the account required to: + // - deploy (Cairo 0 with balance) + // - upgrade (Cairo 0 without balance) + // it will throw an error // hence we can safely assume that the account is Cairo 1 account. - if (await isAccountDeployed(this.network, address)) { + // therefore if the account is already deployed, we should throw an error. + if (await accountContract.isDeployed()) { throw new AccountAlreadyDeployedError(); } // We only need to take care the deployment data for Cairo 1 account. return { - address: params.address, + address, // eslint-disable-next-line @typescript-eslint/naming-convention - class_hash: ACCOUNT_CLASS_HASH, - salt: this.account.publicKey, - calldata: getDeployAccountCallData(this.account.publicKey, CAIRO_VERSION), - version: CAIRO_VERSION, - }; + class_hash: classHash, + salt, + calldata, + version: cairoVerion.toString(10), + } as GetDeploymentDataResponse; } } diff --git a/packages/starknet-snap/src/rpcs/get-transaction-status.test.ts b/packages/starknet-snap/src/rpcs/get-transaction-status.test.ts index 07e76d46..f0536d38 100644 --- a/packages/starknet-snap/src/rpcs/get-transaction-status.test.ts +++ b/packages/starknet-snap/src/rpcs/get-transaction-status.test.ts @@ -16,7 +16,7 @@ jest.mock('../utils/snap'); jest.mock('../utils/logger'); describe('GetTransactionStatusRpc', () => { - const prepareGetTransactionStatusTest = ({ + const setupGetTransactionStatusTest = ({ network, status, }: { @@ -49,7 +49,7 @@ describe('GetTransactionStatusRpc', () => { finalityStatus: TransactionFinalityStatus.ACCEPTED_ON_L1, executionStatus: TransactionExecutionStatus.SUCCEEDED, }; - const { getTransactionStatusSpy } = prepareGetTransactionStatusTest({ + const { getTransactionStatusSpy } = setupGetTransactionStatusTest({ network, status: expectedResult, }); diff --git a/packages/starknet-snap/src/rpcs/list-transaction.test.ts b/packages/starknet-snap/src/rpcs/list-transaction.test.ts index 26e55ef6..eff34a9a 100644 --- a/packages/starknet-snap/src/rpcs/list-transaction.test.ts +++ b/packages/starknet-snap/src/rpcs/list-transaction.test.ts @@ -17,7 +17,7 @@ import type { ListTransactionsParams } from './list-transactions'; jest.mock('../utils/logger'); describe('listTransactions', () => { - const prepareListTransactions = async () => { + const setupListTransactionsTest = async () => { const network = STARKNET_SEPOLIA_TESTNET_NETWORK; const chainId = network.chainId as unknown as constants.StarknetChainId; const account = await mockAccount(chainId); @@ -46,7 +46,7 @@ describe('listTransactions', () => { it('returns transactions', async () => { const { transactions, getTransactionsSpy, chainId, account } = - await prepareListTransactions(); + await setupListTransactionsTest(); const result = await listTransactions.execute({ chainId, @@ -65,7 +65,7 @@ describe('listTransactions', () => { it('fetchs transactions with config value if input `txnsInLastNumOfDays` has not given', async () => { const { getTransactionsSpy, chainId, account } = - await prepareListTransactions(); + await setupListTransactionsTest(); await listTransactions.execute({ chainId, diff --git a/packages/starknet-snap/src/rpcs/sign-declare-transaction.test.ts b/packages/starknet-snap/src/rpcs/sign-declare-transaction.test.ts index 37db7d59..5990ec1e 100644 --- a/packages/starknet-snap/src/rpcs/sign-declare-transaction.test.ts +++ b/packages/starknet-snap/src/rpcs/sign-declare-transaction.test.ts @@ -1,7 +1,6 @@ import type { DeclareSignerDetails } from 'starknet'; import { constants } from 'starknet'; -import type { SnapState } from '../types/snapState'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { UserRejectedOpError, @@ -9,9 +8,8 @@ import { } from '../utils/exceptions'; import * as starknetUtils from '../utils/starknetUtils'; import { - mockAccount, - prepareMockAccount, - prepareRenderSignDeclareTransactionUI, + setupAccountController, + mockRenderSignDeclareTransactionUI, } from './__tests__/helper'; import { signDeclareTransaction } from './sign-declare-transaction'; import type { SignDeclareTransactionParams } from './sign-declare-transaction'; @@ -20,38 +18,39 @@ jest.mock('../utils/snap'); jest.mock('../utils/logger'); describe('signDeclareTransaction', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const createRequest = ( - chainId: constants.StarknetChainId, - address: string, - ) => ({ + const createRequest = (chainId: string, address: string) => ({ details: { classHash: '0x025ec026985a3bf9d0cc1fe17326b245dfdc3ff89b8fde106542a3ea56c5a918', senderAddress: address, - chainId, + chainId: chainId as constants.StarknetChainId, version: constants.TRANSACTION_VERSION.V2, maxFee: 0, nonce: 0, }, address, - chainId, + chainId: chainId as constants.StarknetChainId, }); - it('signs message correctly', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); + const setupSignDeclareTransactionTest = async () => { + const { account } = await setupAccountController({ + network, + }); - prepareMockAccount(account, state); - prepareRenderSignDeclareTransactionUI(); + const { confirmDialogSpy } = mockRenderSignDeclareTransactionUI(); + const request = createRequest(network.chainId, account.address); - const request = createRequest(chainId, account.address); + return { + account, + request, + confirmDialogSpy, + }; + }; + + it('signs message correctly', async () => { + const { account, request } = await setupSignDeclareTransactionTest(); const expectedResult = await starknetUtils.signDeclareTransaction( account.privateKey, @@ -64,36 +63,25 @@ describe('signDeclareTransaction', () => { }); it('renders confirmation dialog', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - const { address } = account; - - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignDeclareTransactionUI(); - - const request = createRequest(chainId, address); + const { account, request, confirmDialogSpy } = + await setupSignDeclareTransactionTest(); await signDeclareTransaction.execute(request); expect(confirmDialogSpy).toHaveBeenCalledWith({ - senderAddress: address, - chainId, - networkName: STARKNET_SEPOLIA_TESTNET_NETWORK.name, + senderAddress: account.address, + chainId: network.chainId, + networkName: network.name, declareTransactions: request.details, }); }); it('throws `UserRejectedOpError` if user denied the operation', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignDeclareTransactionUI(); + const { request, confirmDialogSpy } = + await setupSignDeclareTransactionTest(); confirmDialogSpy.mockResolvedValue(false); - const request = createRequest(chainId, account.address); - await expect(signDeclareTransaction.execute(request)).rejects.toThrow( UserRejectedOpError, ); diff --git a/packages/starknet-snap/src/rpcs/sign-message.test.ts b/packages/starknet-snap/src/rpcs/sign-message.test.ts index 24a3c4ae..91b8605a 100644 --- a/packages/starknet-snap/src/rpcs/sign-message.test.ts +++ b/packages/starknet-snap/src/rpcs/sign-message.test.ts @@ -1,7 +1,6 @@ -import { constants } from 'starknet'; +import type { constants } from 'starknet'; import typedDataExample from '../__tests__/fixture/typedDataExample.json'; -import type { SnapState } from '../types/snapState'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { UserRejectedOpError, @@ -9,9 +8,8 @@ import { } from '../utils/exceptions'; import * as starknetUtils from '../utils/starknetUtils'; import { - mockAccount, - prepareMockAccount, - prepareRenderSignMessageUI, + setupAccountController, + mockRenderSignMessageUI, } from './__tests__/helper'; import { signMessage } from './sign-message'; import type { SignMessageParams } from './sign-message'; @@ -20,18 +18,31 @@ jest.mock('../utils/snap'); jest.mock('../utils/logger'); describe('signMessage', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const setupSignMessageTest = async (enableAuthorize = false) => { + const { account } = await setupAccountController({ + network, + }); + + const { confirmDialogSpy } = mockRenderSignMessageUI(); + + const request = { + chainId: network.chainId as constants.StarknetChainId, + address: account.address, + typedDataMessage: typedDataExample, + enableAuthorize, + }; + + return { + request, + account, + confirmDialogSpy, + }; }; it('signs message correctly', async () => { - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - - prepareMockAccount(account, state); - prepareRenderSignMessageUI(); + const { account, request } = await setupSignMessageTest(); const expectedResult = await starknetUtils.signMessage( account.privateKey, @@ -39,53 +50,30 @@ describe('signMessage', () => { account.address, ); - const request = { - chainId: constants.StarknetChainId.SN_SEPOLIA, - address: account.address, - typedDataMessage: typedDataExample, - }; const result = await signMessage.execute(request); expect(result).toStrictEqual(expectedResult); }); it('renders confirmation dialog', async () => { - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - const { address, chainId } = account; - - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignMessageUI(); - - const request = { - chainId: chainId as constants.StarknetChainId, - address, - typedDataMessage: typedDataExample, - enableAuthorize: true, - }; + const { account, request, confirmDialogSpy } = await setupSignMessageTest( + true, + ); await signMessage.execute(request); + expect(confirmDialogSpy).toHaveBeenCalledWith({ - address, - chainId, + address: account.address, + chainId: network.chainId as constants.StarknetChainId, typedDataMessage: typedDataExample, }); }); it('throws `UserRejectedOpError` if user denied the operation', async () => { - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignMessageUI(); + const { request, confirmDialogSpy } = await setupSignMessageTest(true); confirmDialogSpy.mockResolvedValue(false); - const request = { - chainId: constants.StarknetChainId.SN_SEPOLIA, - address: account.address, - typedDataMessage: typedDataExample, - enableAuthorize: true, - }; - await expect(signMessage.execute(request)).rejects.toThrow( UserRejectedOpError, ); diff --git a/packages/starknet-snap/src/rpcs/sign-transaction.test.ts b/packages/starknet-snap/src/rpcs/sign-transaction.test.ts index 66dbdc7c..c276317e 100644 --- a/packages/starknet-snap/src/rpcs/sign-transaction.test.ts +++ b/packages/starknet-snap/src/rpcs/sign-transaction.test.ts @@ -1,8 +1,6 @@ -import type { InvocationsSignerDetails } from 'starknet'; -import { constants } from 'starknet'; +import type { InvocationsSignerDetails, constants } from 'starknet'; import transactionExample from '../__tests__/fixture/transactionExample.json'; // Assuming you have a similar fixture -import type { SnapState } from '../types/snapState'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { UserRejectedOpError, @@ -10,9 +8,8 @@ import { } from '../utils/exceptions'; import * as starknetUtils from '../utils/starknetUtils'; import { - mockAccount, - prepareMockAccount, - prepareRenderSignTransactionUI, + setupAccountController, + mockRenderSignTransactionUI, } from './__tests__/helper'; import { signTransaction } from './sign-transaction'; import type { SignTransactionParams } from './sign-transaction'; @@ -20,37 +17,36 @@ import type { SignTransactionParams } from './sign-transaction'; jest.mock('../utils/logger'); describe('signTransaction', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const setupSignTransactionTest = async (enableAuthorize = false) => { + const { account } = await setupAccountController({ + network, + }); + + const { confirmDialogSpy } = mockRenderSignTransactionUI(); - const createRequestParam = ( - chainId: constants.StarknetChainId, - address: string, - enableAuthorize?: boolean, - ): SignTransactionParams => { const request: SignTransactionParams = { - chainId, - address, + chainId: network.chainId as constants.StarknetChainId, + address: account.address, transactions: transactionExample.transactions, transactionsDetail: transactionExample.transactionsDetail as unknown as InvocationsSignerDetails, + enableAuthorize, }; + if (enableAuthorize) { request.enableAuthorize = enableAuthorize; } - return request; + return { + request, + account, + confirmDialogSpy, + }; }; it('signs a transaction correctly', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - prepareRenderSignTransactionUI(); - const request = createRequestParam(chainId, account.address); + const { request, account } = await setupSignTransactionTest(); const expectedResult = await starknetUtils.signTransactions( account.privateKey, @@ -64,28 +60,21 @@ describe('signTransaction', () => { }); it('renders confirmation dialog', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - const { address } = account; - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignTransactionUI(); - const request = createRequestParam(chainId, account.address, true); + const { request, account, confirmDialogSpy } = + await setupSignTransactionTest(true); await signTransaction.execute(request); + expect(confirmDialogSpy).toHaveBeenCalledWith({ - senderAddress: address, - chainId, + senderAddress: account.address, + chainId: network.chainId as constants.StarknetChainId, networkName: STARKNET_SEPOLIA_TESTNET_NETWORK.name, transactions: request.transactions, }); }); it('does not render the confirmation dialog if enableAuthorize is false', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignTransactionUI(); - const request = createRequestParam(chainId, account.address, false); + const { request, confirmDialogSpy } = await setupSignTransactionTest(false); await signTransaction.execute(request); @@ -93,12 +82,8 @@ describe('signTransaction', () => { }); it('throws `UserRejectedOpError` if user denied the operation', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const account = await mockAccount(chainId); - prepareMockAccount(account, state); - const { confirmDialogSpy } = prepareRenderSignTransactionUI(); + const { request, confirmDialogSpy } = await setupSignTransactionTest(true); confirmDialogSpy.mockResolvedValue(false); - const request = createRequestParam(chainId, account.address, true); await expect(signTransaction.execute(request)).rejects.toThrow( UserRejectedOpError, diff --git a/packages/starknet-snap/src/rpcs/switch-network.test.ts b/packages/starknet-snap/src/rpcs/switch-network.test.ts index cca82729..d1a87b29 100644 --- a/packages/starknet-snap/src/rpcs/switch-network.test.ts +++ b/packages/starknet-snap/src/rpcs/switch-network.test.ts @@ -12,7 +12,7 @@ import { InvalidRequestParamsError, UserRejectedOpError, } from '../utils/exceptions'; -import { prepareRenderSwitchNetworkUI } from './__tests__/helper'; +import { mockRenderSwitchNetworkUI } from './__tests__/helper'; import { switchNetwork } from './switch-network'; import type { SwitchNetworkParams } from './switch-network'; @@ -116,7 +116,7 @@ describe('switchNetwork', () => { currentNetwork, network: requestNetwork, }); - const { confirmDialogSpy } = prepareRenderSwitchNetworkUI(); + const { confirmDialogSpy } = mockRenderSwitchNetworkUI(); const request = createRequestParam(requestNetwork.chainId, true); await switchNetwork.execute(request); @@ -134,7 +134,7 @@ describe('switchNetwork', () => { currentNetwork, network: requestNetwork, }); - const { confirmDialogSpy } = prepareRenderSwitchNetworkUI(); + const { confirmDialogSpy } = mockRenderSwitchNetworkUI(); confirmDialogSpy.mockResolvedValue(false); const request = createRequestParam(requestNetwork.chainId, true); diff --git a/packages/starknet-snap/src/rpcs/verify-signature.test.ts b/packages/starknet-snap/src/rpcs/verify-signature.test.ts index 037da54b..d2e7b83b 100644 --- a/packages/starknet-snap/src/rpcs/verify-signature.test.ts +++ b/packages/starknet-snap/src/rpcs/verify-signature.test.ts @@ -1,11 +1,11 @@ import { constants } from 'starknet'; import typedDataExample from '../__tests__/fixture/typedDataExample.json'; -import type { SnapState } from '../types/snapState'; +import { generateAccounts } from '../__tests__/helper'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; import { InvalidRequestParamsError } from '../utils/exceptions'; import * as starknetUtils from '../utils/starknetUtils'; -import { mockAccount, prepareMockAccount } from './__tests__/helper'; +import { setupAccountController } from './__tests__/helper'; import { verifySignature } from './verify-signature'; import type { VerifySignatureParams } from './verify-signature'; @@ -13,16 +13,27 @@ jest.mock('../utils/snap'); jest.mock('../utils/logger'); describe('verifySignature', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const setupVerifySignatureTest = async () => { + const { account } = await setupAccountController({ + network, + }); + + const request = { + chainId: network.chainId as constants.StarknetChainId, + address: account.address, + typedDataMessage: typedDataExample, + }; + + return { + request, + account, + }; }; it('returns true if the signature is correct', async () => { - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - prepareMockAccount(account, state); + const { account, request } = await setupVerifySignatureTest(); const signature = await starknetUtils.signMessage( account.privateKey, @@ -30,39 +41,41 @@ describe('verifySignature', () => { account.address, ); - const request = { - chainId: constants.StarknetChainId.SN_SEPOLIA, - address: account.address, - typedDataMessage: typedDataExample, + const result = await verifySignature.execute({ + ...request, signature, - }; - - const result = await verifySignature.execute(request); + }); expect(result).toBe(true); }); it('returns false if the signature is incorrect', async () => { - const account = await mockAccount(constants.StarknetChainId.SN_SEPOLIA); - const invalidSignatueAccount = await mockAccount( - constants.StarknetChainId.SN_MAIN, + const { account } = await setupAccountController({ + network, + }); + + const [invalidSignatureAccount] = await generateAccounts( + network.chainId, + 1, ); - prepareMockAccount(account, state); - const invalidSignatue = await starknetUtils.signMessage( - invalidSignatueAccount.privateKey, + const invalidSignature = await starknetUtils.signMessage( + invalidSignatureAccount.privateKey, typedDataExample, - invalidSignatueAccount.address, + invalidSignatureAccount.address, ); const request = { chainId: constants.StarknetChainId.SN_SEPOLIA, address: account.address, typedDataMessage: typedDataExample, - signature: invalidSignatue, + signature: invalidSignature, }; - const result = await verifySignature.execute(request); + const result = await verifySignature.execute({ + ...request, + signature: invalidSignature, + }); expect(result).toBe(false); }); diff --git a/packages/starknet-snap/src/rpcs/watch-asset.test.ts b/packages/starknet-snap/src/rpcs/watch-asset.test.ts index 8572b407..a0f1bb51 100644 --- a/packages/starknet-snap/src/rpcs/watch-asset.test.ts +++ b/packages/starknet-snap/src/rpcs/watch-asset.test.ts @@ -11,7 +11,7 @@ import { InvalidNetworkError, UserRejectedOpError, } from '../utils/exceptions'; -import { prepareRenderWatchAssetUI } from './__tests__/helper'; +import { mockRenderWatchAssetUI } from './__tests__/helper'; import type { WatchAssetParams } from './watch-asset'; import { watchAsset } from './watch-asset'; @@ -63,7 +63,7 @@ describe('WatchAssetRpc', () => { return { upsertTokenSpy }; }; - const prepareWatchAssetTest = async ({ + const setupWatchAssetTest = async ({ network = STARKNET_SEPOLIA_TESTNET_NETWORK, }: { network?: Network; @@ -71,7 +71,7 @@ describe('WatchAssetRpc', () => { const request = createRequest({ chainId: network.chainId as unknown as constants.StarknetChainId, }); - const { confirmDialogSpy } = prepareRenderWatchAssetUI(); + const { confirmDialogSpy } = mockRenderWatchAssetUI(); const { getNetworkSpy } = mockNetworkStateManager({ network, }); @@ -86,7 +86,7 @@ describe('WatchAssetRpc', () => { }; it('returns true if the token is added', async () => { - const { request } = await prepareWatchAssetTest({}); + const { request } = await setupWatchAssetTest({}); const expectedResult = true; @@ -97,7 +97,7 @@ describe('WatchAssetRpc', () => { it('renders confirmation dialog', async () => { const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - const { request, confirmDialogSpy } = await prepareWatchAssetTest({ + const { request, confirmDialogSpy } = await setupWatchAssetTest({ network, }); @@ -116,7 +116,7 @@ describe('WatchAssetRpc', () => { }); it('throws `InvalidNetworkError` if the network can not be found', async () => { - const { request, getNetworkSpy } = await prepareWatchAssetTest({}); + const { request, getNetworkSpy } = await setupWatchAssetTest({}); getNetworkSpy.mockResolvedValue(null); await expect(watchAsset.execute(request)).rejects.toThrow( @@ -131,7 +131,7 @@ describe('WatchAssetRpc', () => { const network = Config.availableNetworks.find( (net) => net.chainId === chainId, ); - await prepareWatchAssetTest({ + await setupWatchAssetTest({ network, }); const request = createRequest({ @@ -148,7 +148,7 @@ describe('WatchAssetRpc', () => { }); it('throws `UserRejectedOpError` if user denied the operation', async () => { - const { request, confirmDialogSpy } = await prepareWatchAssetTest({}); + const { request, confirmDialogSpy } = await setupWatchAssetTest({}); confirmDialogSpy.mockResolvedValue(false); await expect(watchAsset.execute(request)).rejects.toThrow( diff --git a/packages/starknet-snap/src/utils/starknetUtils.test.ts b/packages/starknet-snap/src/utils/starknetUtils.test.ts index 5b74df56..b43e9b4e 100644 --- a/packages/starknet-snap/src/utils/starknetUtils.test.ts +++ b/packages/starknet-snap/src/utils/starknetUtils.test.ts @@ -2,22 +2,14 @@ import type { Invocations } from 'starknet'; import { constants, TransactionType } from 'starknet'; import callsExamples from '../__tests__/fixture/callsExamples.json'; -import { mockAccount, prepareMockAccount } from '../rpcs/__tests__/helper'; +import { mockAccount } from '../rpcs/__tests__/helper'; import { FeeTokenUnit } from '../types/snapApi'; -import type { SnapState } from '../types/snapState'; import type { TransactionVersion } from '../types/starknet'; import { mockEstimateFeeBulkResponse } from './__tests__/helper'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from './constants'; import * as starknetUtils from './starknetUtils'; describe('getEstimatedFees', () => { - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - }; - const prepareGetEstimatedFees = async (deployed: boolean) => { const chainId = constants.StarknetChainId.SN_SEPOLIA; const account = await mockAccount(chainId); @@ -29,7 +21,6 @@ describe('getEstimatedFees', () => { }, ]; - prepareMockAccount(account, state); const accountDeployedSpy = jest.spyOn(starknetUtils, 'isAccountDeployed'); accountDeployedSpy.mockResolvedValue(deployed); From 21952f33ec5c202ae750bda4c2d03c864c0ae659 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:00:21 +0800 Subject: [PATCH 10/19] chore: update execute txn test --- .../src/rpcs/execute-txn.test.ts | 308 ++++++++---------- .../starknet-snap/src/rpcs/execute-txn.ts | 42 ++- 2 files changed, 171 insertions(+), 179 deletions(-) diff --git a/packages/starknet-snap/src/rpcs/execute-txn.test.ts b/packages/starknet-snap/src/rpcs/execute-txn.test.ts index 19b2f485..2a9be480 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.test.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.test.ts @@ -7,7 +7,11 @@ import { mockTransactionRequestStateManager } from '../state/__tests__/helper'; import { AccountStateManager } from '../state/account-state-manager'; import { TransactionStateManager } from '../state/transaction-state-manager'; import { FeeToken } from '../types/snapApi'; -import type { FormattedCallData, TransactionRequest } from '../types/snapState'; +import type { + FormattedCallData, + Network, + TransactionRequest, +} from '../types/snapState'; import * as uiUtils from '../ui/utils'; import { CAIRO_VERSION, @@ -68,102 +72,91 @@ class MockExecuteTxnRpc extends ExecuteTxnRpc { } } -describe('ExecuteTxn', () => { - const network = STARKNET_SEPOLIA_TESTNET_NETWORK; - - const createMockRpc = () => { - const rpc = new MockExecuteTxnRpc({ - showInvalidAccountAlert: true, - }); - return rpc; - }; - - const setupRpcTest = async (calls: Call[]) => { - const controller = await setupAccountController({ network }); +const createMockRpc = () => { + const rpc = new MockExecuteTxnRpc({ + showInvalidAccountAlert: true, + }); + return rpc; +}; - const rpc = createMockRpc(); +const setupMockRpc = async (network: Network, calls: Call[]) => { + const { account } = await setupAccountController({ network }); - // Setup the rpc, to discover the account and network - await rpc.preExecute({ - chainId: network.chainId, - address: controller.account.address, - calls, - } as unknown as ExecuteTxnParams); + const rpc = createMockRpc(); - return { - rpc, - ...controller, - }; - }; + // Setup the rpc, to discover the account and network + await rpc.preExecute({ + chainId: network.chainId, + address: account.address, + calls, + } as unknown as ExecuteTxnParams); - const mockCallToTransactionReqCall = (calls: Call[]) => { - const callToTransactionReqCallSpy = jest.spyOn( - formatUtils, - 'callToTransactionReqCall', - ); - const formattedCalls: FormattedCallData[] = []; - for (const call of calls) { - formattedCalls.push({ - contractAddress: call.contractAddress, - calldata: call.calldata as unknown as string[], - entrypoint: call.entrypoint, - }); - callToTransactionReqCallSpy.mockResolvedValueOnce( - formattedCalls[formattedCalls.length - 1], - ); - } - return { - callToTransactionReqCallSpy, - formattedCalls, - }; + return { + rpc, + account, }; - - const mockGenerateExecuteTxnFlow = () => { - const generateExecuteTxnFlowSpy = jest.spyOn( - uiUtils, - 'generateExecuteTxnFlow', +}; + +const mockCallToTransactionReqCall = (calls: Call[]) => { + const callToTransactionReqCallSpy = jest.spyOn( + formatUtils, + 'callToTransactionReqCall', + ); + const formattedCalls: FormattedCallData[] = []; + for (const call of calls) { + formattedCalls.push({ + contractAddress: call.contractAddress, + calldata: call.calldata as unknown as string[], + entrypoint: call.entrypoint, + }); + callToTransactionReqCallSpy.mockResolvedValueOnce( + formattedCalls[formattedCalls.length - 1], ); - const interfaceId = uuidv4(); - generateExecuteTxnFlowSpy.mockResolvedValue(interfaceId); - return { - interfaceId, - generateExecuteTxnFlowSpy, - }; + } + return { + callToTransactionReqCallSpy, + formattedCalls, }; - - const getTransactionCalls = () => { - const { calls, details, hash } = callsExamples.multipleCalls; - const { formattedCalls } = mockCallToTransactionReqCall(calls); - return { - calls, - details, - hash, - formattedCalls, - }; +}; + +const mockGenerateExecuteTxnFlow = () => { + const generateExecuteTxnFlowSpy = jest.spyOn( + uiUtils, + 'generateExecuteTxnFlow', + ); + const interfaceId = uuidv4(); + generateExecuteTxnFlowSpy.mockResolvedValue(interfaceId); + return { + interfaceId, + generateExecuteTxnFlowSpy, }; +}; +describe('ExecuteTxn', () => { describe('confirmTransaction', () => { const setupConfirmTransactionTest = async (confirm = true) => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const includeDeploy = true; const txnVersion = constants.TRANSACTION_VERSION.V3; - const { calls, formattedCalls } = getTransactionCalls(); - - const { account, rpc } = await setupRpcTest(calls); - const includeDeploy = !(await account.accountContract.isDeployed()); + const { calls } = callsExamples.multipleCalls; + const { account, rpc } = await setupMockRpc(network, calls); const { getEstimatedFeesResponse: { suggestedMaxFee: maxFee, resourceBounds }, } = mockGetEstimatedFeesResponse({ - includeDeploy, + includeDeploy: false, }); const request = { calls, + address: account.address, maxFee, resourceBounds, txnVersion, includeDeploy, }; + const { formattedCalls } = mockCallToTransactionReqCall(calls); const { interfaceId } = mockGenerateExecuteTxnFlow(); const transactionRequest = { @@ -180,17 +173,13 @@ describe('ExecuteTxn', () => { selectedFeeToken: transactionVersionToFeeToken(txnVersion), includeDeploy, }; + getTransactionRequestSpy.mockResolvedValue(transactionRequest); return { request, rpc, - network, - account, - maxFee, - resourceBounds, - txnVersion, - includeDeploy, transactionRequest, + interfaceId, ...mockConfirmDialogInteractiveUI(confirm), ...mockTransactionRequestStateManager(), }; @@ -200,25 +189,26 @@ describe('ExecuteTxn', () => { const { request, rpc, + interfaceId, + transactionRequest, upsertTransactionRequestSpy, + confirmDialogSpy, getTransactionRequestSpy, removeTransactionRequestSpy, - transactionRequest, } = await setupConfirmTransactionTest(); - getTransactionRequestSpy.mockResolvedValue(transactionRequest); - - const result = await rpc.confirmTransaction(request); - const expectedTransactionRequest = { ...transactionRequest, id: expect.any(String), }; + const result = await rpc.confirmTransaction(request); + expect(result).toStrictEqual(expectedTransactionRequest); expect(upsertTransactionRequestSpy).toHaveBeenCalledWith( expectedTransactionRequest, ); + expect(confirmDialogSpy).toHaveBeenCalledWith(interfaceId); expect(getTransactionRequestSpy).toHaveBeenCalledWith({ requestId: expect.any(String), }); @@ -228,15 +218,8 @@ describe('ExecuteTxn', () => { }); it('does not throw an error if remove request failed', async () => { - const { - request, - rpc, - removeTransactionRequestSpy, - transactionRequest, - getTransactionRequestSpy, - } = await setupConfirmTransactionTest(); - - getTransactionRequestSpy.mockResolvedValue(transactionRequest); + const { request, rpc, removeTransactionRequestSpy } = + await setupConfirmTransactionTest(); removeTransactionRequestSpy.mockRejectedValue( new Error('Failed to remove request'), @@ -278,9 +261,10 @@ describe('ExecuteTxn', () => { describe('deployAccount', () => { const setupDeployAccountTest = async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; const txnVersion = constants.TRANSACTION_VERSION.V3; const { calls } = callsExamples.multipleCalls; - const { account, rpc } = await setupRpcTest(calls); + const { account, rpc } = await setupMockRpc(network, calls); const deployAccountSpy = jest.spyOn(starknetUtils, 'deployAccount'); const deployAccountResponse = { @@ -290,6 +274,7 @@ describe('ExecuteTxn', () => { deployAccountSpy.mockResolvedValue(deployAccountResponse); const request = { + address: account.address, txnVersion, }; @@ -302,6 +287,7 @@ describe('ExecuteTxn', () => { accountDeploymentData, request, rpc, + network, account, deployAccountSpy, deployAccountResponse, @@ -312,6 +298,7 @@ describe('ExecuteTxn', () => { const { rpc, request, + network, account: { address, privateKey, publicKey }, deployAccountResponse, deployAccountSpy, @@ -348,10 +335,11 @@ describe('ExecuteTxn', () => { describe('sendTransaction', () => { const setupConfirmTransactionTest = async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; const txnVersion = constants.TRANSACTION_VERSION.V3; const { calls } = callsExamples.multipleCalls; - const { account, rpc } = await setupRpcTest(calls); + const { account, rpc } = await setupMockRpc(network, calls); const executeTxnSpy = jest.spyOn(starknetUtils, 'executeTxn'); const executeTxnResponse = { @@ -361,6 +349,7 @@ describe('ExecuteTxn', () => { const request: SendTransactionParams = { calls, + address: account.address, abis: undefined, details: { version: txnVersion, @@ -381,7 +370,8 @@ describe('ExecuteTxn', () => { const { rpc, request, - account: { privateKey, address }, + network, + account: { privateKey }, executeTxnResponse, executeTxnSpy, } = await setupConfirmTransactionTest(); @@ -391,7 +381,7 @@ describe('ExecuteTxn', () => { expect(result).toStrictEqual(executeTxnResponse.transaction_hash); expect(executeTxnSpy).toHaveBeenCalledWith( network, - address, + request.address, privateKey, request.calls, request.abis, @@ -411,71 +401,53 @@ describe('ExecuteTxn', () => { }); describe('execute', () => { - const mockExecute = () => { - const confirmTransactionSpy = jest.spyOn( - MockExecuteTxnRpc.prototype, - 'confirmTransaction', - ); - const sendTransactionSpy = jest.spyOn( - MockExecuteTxnRpc.prototype, - 'sendTransaction', - ); - const saveDataToStateSpy = jest.spyOn( - MockExecuteTxnRpc.prototype, - 'saveDataToState', - ); - const deployAccountSpy = jest.spyOn( - MockExecuteTxnRpc.prototype, - 'deployAccount', - ); - - return { - deployAccountSpy, - confirmTransactionSpy, - sendTransactionSpy, - saveDataToStateSpy, - }; - }; - const setupExecuteTest = async (accountDeployed = true) => { - const { - confirmTransactionSpy, - deployAccountSpy, - sendTransactionSpy, - saveDataToStateSpy, - } = mockExecute(); - const { - calls, - details, - hash: sendTansactionResponse, - } = getTransactionCalls(); - const { account, rpc, isDeploySpy } = await setupRpcTest(calls); + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const { account } = await setupAccountController({ network }); const { getEstimatedFeesResponse, getEstimatedFeesSpy } = mockGetEstimatedFeesResponse({ includeDeploy: !accountDeployed, }); - const { suggestedMaxFee: maxFee, resourceBounds } = - getEstimatedFeesResponse; + const { suggestedMaxFee, resourceBounds } = getEstimatedFeesResponse; + const confirmTransactionSpy = jest.spyOn( + MockExecuteTxnRpc.prototype, + 'confirmTransaction', + ); const transactionRequest = { selectedFeeToken: FeeToken.STRK, - maxFee, + maxFee: suggestedMaxFee, resourceBounds, } as unknown as TransactionRequest; - const deployAccountResponse = callsExamples.singleCall.hash; - confirmTransactionSpy.mockResolvedValue(transactionRequest); + + const sendTansactionResponse = callsExamples.multipleCalls.hash; + const sendTransactionSpy = jest.spyOn( + MockExecuteTxnRpc.prototype, + 'sendTransaction', + ); sendTransactionSpy.mockResolvedValue(sendTansactionResponse); + + const deployAccountResponse = callsExamples.singleCall.hash; + const deployAccountSpy = jest.spyOn( + MockExecuteTxnRpc.prototype, + 'deployAccount', + ); deployAccountSpy.mockResolvedValue(deployAccountResponse); + + const saveDataToStateSpy = jest.spyOn( + MockExecuteTxnRpc.prototype, + 'saveDataToState', + ); saveDataToStateSpy.mockReturnThis(); - isDeploySpy.mockResolvedValue(accountDeployed); + const rpc = createMockRpc(); const request: ExecuteTxnParams = { chainId: network.chainId, address: account.address, - calls, - details, + calls: callsExamples.multipleCalls.calls, + details: callsExamples.multipleCalls.details, } as unknown as ExecuteTxnParams; return { @@ -501,6 +473,7 @@ describe('ExecuteTxn', () => { request, sendTansactionResponse, sendTransactionSpy, + account: { address }, getEstimatedFeesResponse, confirmTransactionSpy, deployAccountSpy, @@ -512,8 +485,11 @@ describe('ExecuteTxn', () => { ); const { maxFee: updatedMaxFee, resourceBounds: updatedResourceBounds } = transactionRequest; - const { suggestedMaxFee: maxFee, resourceBounds } = - getEstimatedFeesResponse; + const { + suggestedMaxFee: maxFee, + resourceBounds, + includeDeploy, + } = getEstimatedFeesResponse; const { calls, abis, details } = request; const result = await rpc.execute(request); @@ -523,12 +499,15 @@ describe('ExecuteTxn', () => { }); expect(confirmTransactionSpy).toHaveBeenCalledWith({ txnVersion: details?.version, + address, calls, maxFee, resourceBounds, + includeDeploy, }); expect(deployAccountSpy).not.toHaveBeenCalled(); expect(sendTransactionSpy).toHaveBeenCalledWith({ + address, calls, abis, details: { @@ -543,6 +522,7 @@ describe('ExecuteTxn', () => { txnHashForExecute: sendTansactionResponse, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, + address, calls, }); }); @@ -552,6 +532,7 @@ describe('ExecuteTxn', () => { rpc, request, sendTansactionResponse, + account: { address }, deployAccountResponse, deployAccountSpy, saveDataToStateSpy, @@ -571,9 +552,11 @@ describe('ExecuteTxn', () => { transaction_hash: sendTansactionResponse, }); expect(deployAccountSpy).toHaveBeenCalledWith({ + address, txnVersion: updatedTxnVersion, }); expect(sendTransactionSpy).toHaveBeenCalledWith({ + address, calls, abis, details: { @@ -589,49 +572,44 @@ describe('ExecuteTxn', () => { txnHashForExecute: sendTansactionResponse, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, + address, calls, }); }); }); describe('saveDataToState', () => { - const mockSaveDataToState = () => { - const addTransactionSpy = jest.spyOn( - TransactionStateManager.prototype, - 'addTransaction', - ); - const updateAccountAsDeploySpy = jest.spyOn( - AccountStateManager.prototype, - 'updateAccountAsDeploy', - ); - addTransactionSpy.mockReturnThis(); - updateAccountAsDeploySpy.mockReturnThis(); - - return { - addTransactionSpy, - updateAccountAsDeploySpy, - }; - }; - const setupSaveDataToStateTest = async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; const txnVersion = constants.TRANSACTION_VERSION.V3; - const { updateAccountAsDeploySpy, addTransactionSpy } = - mockSaveDataToState(); - const { hash: txnHashForExecute, calls } = getTransactionCalls(); + const { hash: txnHashForExecute, calls } = callsExamples.multipleCalls; const { hash: txnHashForDeploy } = callsExamples.singleCall; - const { rpc, account } = await setupRpcTest(calls); + const { rpc, account } = await setupMockRpc(network, calls); const { getEstimatedFeesResponse: { suggestedMaxFee: maxFee }, } = mockGetEstimatedFeesResponse({ includeDeploy: false, }); + const addTransactionSpy = jest.spyOn( + TransactionStateManager.prototype, + 'addTransaction', + ); + addTransactionSpy.mockReturnThis(); + + const updateAccountAsDeploySpy = jest.spyOn( + AccountStateManager.prototype, + 'updateAccountAsDeploy', + ); + updateAccountAsDeploySpy.mockReturnThis(); + const request: SaveDataToStateParamas = { txnHashForDeploy, txnHashForExecute, txnVersion, maxFee, + address: account.address, calls, } as unknown as SaveDataToStateParamas; diff --git a/packages/starknet-snap/src/rpcs/execute-txn.ts b/packages/starknet-snap/src/rpcs/execute-txn.ts index 004519fd..db897784 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.ts @@ -26,6 +26,7 @@ import { UserRejectedOpError } from '../utils/exceptions'; import { deployAccount, executeTxn as executeTxnUtil, + getDeployAccountCallData, getEstimatedFees, } from '../utils/starknetUtils'; import { @@ -59,16 +60,20 @@ export type ExecuteTxnResponse = Infer; export type ConfirmTransactionParams = { calls: Call[]; + address: string; maxFee: string; resourceBounds: ResourceBounds; txnVersion: constants.TRANSACTION_VERSION; + includeDeploy: boolean; }; export type DeployAccountParams = { + address: string; txnVersion: constants.TRANSACTION_VERSION; }; export type SendTransactionParams = { + address: string; calls: Call[]; abis?: any[]; details?: Infer; @@ -79,6 +84,7 @@ export type SaveDataToStateParamas = { txnHashForExecute: string; txnVersion: constants.TRANSACTION_VERSION; maxFee: string; + address: string; calls: Call[]; }; @@ -145,8 +151,11 @@ export class ExecuteTxnRpc extends AccountRpcController< const { privateKey, publicKey } = this.account; const callsArray = Array.isArray(calls) ? calls : [calls]; - // FIXME: getEstimatedFees shpuld be refactored to accept account object - const { suggestedMaxFee: maxFee, resourceBounds } = await getEstimatedFees( + const { + includeDeploy, + suggestedMaxFee: maxFee, + resourceBounds, + } = await getEstimatedFees( this.network, address, privateKey, @@ -160,7 +169,7 @@ export class ExecuteTxnRpc extends AccountRpcController< details, ); - const accountDeployed = await this.account.accountContract.isDeployed(); + const accountDeployed = !includeDeploy; const { selectedFeeToken, @@ -168,9 +177,11 @@ export class ExecuteTxnRpc extends AccountRpcController< resourceBounds: updatedResouceBounds, } = await this.confirmTransaction({ txnVersion: details?.version as unknown as constants.TRANSACTION_VERSION, + address, calls: callsArray, maxFee, resourceBounds, + includeDeploy, }); const updatedTxnVersion = feeTokenToTransactionVersion(selectedFeeToken); @@ -179,11 +190,13 @@ export class ExecuteTxnRpc extends AccountRpcController< if (!accountDeployed) { txnHashForDeploy = await this.deployAccount({ + address, txnVersion: updatedTxnVersion, }); } const txnHashForExecute = await this.sendTransaction({ + address, calls: callsArray, abis, details: { @@ -202,6 +215,7 @@ export class ExecuteTxnRpc extends AccountRpcController< txnHashForExecute, txnVersion: updatedTxnVersion, maxFee: updatedMaxFee, + address, calls: callsArray, }); @@ -213,13 +227,15 @@ export class ExecuteTxnRpc extends AccountRpcController< protected async confirmTransaction({ calls, + address, maxFee, resourceBounds, txnVersion, + includeDeploy, }: ConfirmTransactionParams): Promise { const requestId = uuidv4(); const { chainId, name: networkName } = this.network; - const { hdIndex: addressIndex, address } = this.account; + const { hdIndex: addressIndex } = this.account; const formattedCalls = await Promise.all( calls.map(async (call) => @@ -244,7 +260,7 @@ export class ExecuteTxnRpc extends AccountRpcController< calls: formattedCalls, resourceBounds, selectedFeeToken: transactionVersionToFeeToken(txnVersion), - includeDeploy: !(await this.account.accountContract.isDeployed()), + includeDeploy, }; const interfaceId = await generateExecuteTxnFlow(request); @@ -284,16 +300,13 @@ export class ExecuteTxnRpc extends AccountRpcController< } protected async deployAccount({ + address, txnVersion, }: DeployAccountParams): Promise { - const { - privateKey, - address, - publicKey, - accountContract: { callData }, - } = this.account; + const { privateKey, publicKey } = this.account; + + const callData = getDeployAccountCallData(publicKey, CAIRO_VERSION); - // FIXME: deployAccount shpuld be refactored to accept account object const { contract_address: contractAddress, transaction_hash: transactionHash, @@ -321,11 +334,12 @@ export class ExecuteTxnRpc extends AccountRpcController< } protected async sendTransaction({ + address, calls, abis, details, }: SendTransactionParams): Promise { - const { privateKey, address } = this.account; + const { privateKey } = this.account; const executeTxnResp = await executeTxnUtil( this.network, @@ -348,11 +362,11 @@ export class ExecuteTxnRpc extends AccountRpcController< txnHashForExecute, txnVersion, maxFee, + address, calls, }: SaveDataToStateParamas) { const txnVersionInNumber = transactionVersionToNumber(txnVersion); const { chainId } = this.network; - const { address } = this.account; if (txnHashForDeploy) { await this.txnStateManager.addTransaction( From 4ec0c4f02fe09abaf9a1a34ddfa8cd1595aeddb4 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 11:01:49 +0800 Subject: [PATCH 11/19] fix: execute test --- packages/starknet-snap/src/rpcs/execute-txn.test.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/starknet-snap/src/rpcs/execute-txn.test.ts b/packages/starknet-snap/src/rpcs/execute-txn.test.ts index 2a9be480..dfe79d4e 100644 --- a/packages/starknet-snap/src/rpcs/execute-txn.test.ts +++ b/packages/starknet-snap/src/rpcs/execute-txn.test.ts @@ -158,6 +158,7 @@ describe('ExecuteTxn', () => { const { formattedCalls } = mockCallToTransactionReqCall(calls); const { interfaceId } = mockGenerateExecuteTxnFlow(); + const txnMgrMocks = mockTransactionRequestStateManager(); const transactionRequest = { chainId: network.chainId, @@ -173,7 +174,9 @@ describe('ExecuteTxn', () => { selectedFeeToken: transactionVersionToFeeToken(txnVersion), includeDeploy, }; - getTransactionRequestSpy.mockResolvedValue(transactionRequest); + txnMgrMocks.getTransactionRequestSpy.mockResolvedValue( + transactionRequest, + ); return { request, @@ -181,7 +184,7 @@ describe('ExecuteTxn', () => { transactionRequest, interfaceId, ...mockConfirmDialogInteractiveUI(confirm), - ...mockTransactionRequestStateManager(), + ...txnMgrMocks, }; }; From 883a3e68caf62382fd327e412e0536437c58cbf0 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:27:42 +0800 Subject: [PATCH 12/19] fix: account discovery bug --- packages/starknet-snap/src/state/account-state-manager.ts | 4 +++- packages/starknet-snap/src/wallet/account/account.ts | 8 ++++++-- packages/starknet-snap/src/wallet/account/service.ts | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/starknet-snap/src/state/account-state-manager.ts b/packages/starknet-snap/src/state/account-state-manager.ts index 75741e2b..7b884e02 100644 --- a/packages/starknet-snap/src/state/account-state-manager.ts +++ b/packages/starknet-snap/src/state/account-state-manager.ts @@ -179,7 +179,9 @@ export class AccountStateManager extends StateManager { // Choose the deleted account index over the last index (accContracts length). // If the removedAccounts array is empty, then fallback with the last index. idx = - state.removedAccounts?.[chainId]?.shift() ?? state.accContracts.length; + state.removedAccounts?.[chainId]?.shift() ?? + state.accContracts.filter((account) => account.chainId === chainId) + .length; }); return idx; } diff --git a/packages/starknet-snap/src/wallet/account/account.ts b/packages/starknet-snap/src/wallet/account/account.ts index b15e0fde..d8c2576c 100644 --- a/packages/starknet-snap/src/wallet/account/account.ts +++ b/packages/starknet-snap/src/wallet/account/account.ts @@ -1,3 +1,5 @@ +import type { CairoVersion } from 'starknet'; + import type { AccContract } from '../../types/snapState'; import type { CairoAccountContract } from './contract'; @@ -25,7 +27,7 @@ export class Account { * `1` referred to Cairo 1. * `0` referred to Cairo 0. */ - cairoVersion: string; + cairoVersion: CairoVersion; accountContract: CairoAccountContract; @@ -44,7 +46,9 @@ export class Account { this.addressSalt = props.addressSalt; this.address = props.accountContract.address; - this.cairoVersion = props.accountContract.cairoVerion.toString(10); + this.cairoVersion = props.accountContract.cairoVerion.toString( + 10, + ) as CairoVersion; this.accountContract = props.accountContract; } diff --git a/packages/starknet-snap/src/wallet/account/service.ts b/packages/starknet-snap/src/wallet/account/service.ts index 08bffb6b..23db08d5 100644 --- a/packages/starknet-snap/src/wallet/account/service.ts +++ b/packages/starknet-snap/src/wallet/account/service.ts @@ -49,7 +49,7 @@ export class AccountService { async deriveAccountByIndex(index?: number): Promise { let hdIndex = index; - if (!hdIndex) { + if (hdIndex === undefined) { hdIndex = await this.accountStateMgr.getNextIndex(this.network.chainId); } From c92750acf6ac9a57355c639bd882f2672f3d68e3 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 9 Jan 2025 15:04:26 +0800 Subject: [PATCH 13/19] fix: discovery logic --- .../src/wallet/account/discovery.test.ts | 31 ++++++++++++----- .../src/wallet/account/discovery.ts | 33 ++++++++----------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/packages/starknet-snap/src/wallet/account/discovery.test.ts b/packages/starknet-snap/src/wallet/account/discovery.test.ts index afcd3cdb..f81137af 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.test.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.test.ts @@ -1,7 +1,6 @@ import { Cairo0Contract, Cairo1Contract } from '.'; import { generateAccounts } from '../../__tests__/helper'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; -import { AccountDiscoveryError } from '../../utils/exceptions'; import { AccountContractDiscovery } from './discovery'; jest.mock('../../utils/logger'); @@ -145,7 +144,19 @@ describe('AccountContractDiscovery', () => { isUpgraded: false, }, expected: Cairo1Contract, - title: 'Cairo 0 is not deployed and Cairo 1 is not deployed', + title: 'no contract is deployed', + }, + { + cairo0: { + isDeployed: true, + isUpgraded: false, + }, + cairo1: { + isDeployed: true, + isUpgraded: false, + }, + expected: Cairo1Contract, + title: 'all contracts are deployed', }, ])( 'returns a $expected.name if $title', @@ -190,7 +201,7 @@ describe('AccountContractDiscovery', () => { expected: Cairo0Contract, }, ])( - 'returns a $expected.name if no account contract has deployed and the $expected.name has ETH', + 'returns a $expected.name if no account contract is deployed and the $expected.name has ETH', async ({ expected, cairo0HasBalance, cairo1HasBalance }) => { const [account] = await generateAccounts(network.chainId, 1); @@ -217,12 +228,12 @@ describe('AccountContractDiscovery', () => { }, ); - it('throws `AccountDiscoveryError` if more than one contracts deployed', async () => { + it('returns a Cairo1Contract if the Cairo1Contract is deployed and Cairo0Contract has ETH', async () => { const [account] = await generateAccounts(network.chainId, 1); mockContractState({ cairo0: { - isDeployed: true, + isDeployed: false, isUpgraded: false, }, cairo1: { @@ -231,11 +242,15 @@ describe('AccountContractDiscovery', () => { }, }); + mockContractEthBalance({ + cairo0HasBalance: true, + cairo1HasBalance: false, + }); + const service = new AccountContractDiscovery(network); + const contract = await service.getContract(account.publicKey); - await expect(service.getContract(account.publicKey)).rejects.toThrow( - AccountDiscoveryError, - ); + expect(contract).toBeInstanceOf(Cairo1Contract); }); }); }); diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index 490a8614..e71b001b 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -1,5 +1,4 @@ import type { Network } from '../../types/snapState'; -import { AccountDiscoveryError } from '../../utils/exceptions'; import { Cairo0Contract } from './cairo0'; import { Cairo1Contract } from './cairo1'; import type { CairoAccountContract } from './contract'; @@ -9,6 +8,7 @@ import type { CairoAccountContractStatic, ICairoAccountContract } from './type'; export class AccountContractDiscovery { protected defaultContractCtor: CairoAccountContractStatic = Cairo1Contract; + // The order of the `contractCtors` array determines the priority of the contract to be selected. protected contractCtors: ICairoAccountContract[] = [ Cairo1Contract, Cairo0Contract, @@ -26,6 +26,7 @@ export class AccountContractDiscovery { * 1. If a contract is deployed, then use the deployed contract. * 2. If no contract is deployed, but has balance, then use the contract with balance. * 3. If neither contract is deployed or has balance, then use the default contract. + * 4. If multiple contracts are deployed, then use the default contract. * * @param publicKey - The public key to get the contract for. * @returns The contract for the given public key. @@ -35,14 +36,9 @@ export class AccountContractDiscovery { const reader = new AccountContractReader(this.network); const DefaultContractCtor = this.defaultContractCtor; - // Use array to store the result to prevent race condition. - const contracts: CairoAccountContract[] = []; - - let cairoContract: CairoAccountContract | undefined; - // Identify where all available contracts have been deployed, upgraded, // and whether they have an ETH balance or not. - await Promise.all( + const contracts = await Promise.all( this.contractCtors.map(async (ContractCtor: ICairoAccountContract) => { const contract = new ContractCtor(publicKey, reader); @@ -50,10 +46,9 @@ export class AccountContractDiscovery { // if contract upgraded, bind the latest contract with current contract interface, // to inherit the address from current contract. if (await contract.isUpgraded()) { - contracts.push(DefaultContractCtor.fromAccountContract(contract)); - } else { - contracts.push(contract); + return DefaultContractCtor.fromAccountContract(contract); } + return contract; } else if ( contract instanceof Cairo0Contract && (await contract.isRequireDeploy()) @@ -62,21 +57,21 @@ export class AccountContractDiscovery { // A Cairo 0 contract can only paying fee with ETH token. // Therefore if the contract is not deployed, and it has ETH token, we should use this contract. // And the UI will force the user to deploy the Cairo 0 contract. - contracts.push(contract); + return contract; } + + return null; }), ); - // In case of multiple contracts are deployed or have balance, - // We will not be able to determine which contract to use. - // Hence, throw an error. - if (contracts.length > 1) { - throw new AccountDiscoveryError(); - } else if (contracts.length === 1) { - cairoContract = contracts[0]; + // If multiple contracts are deployed, the first contract in the `contractCtors` array will be selected. + for (const contract of contracts) { + if (contract !== null) { + return contract; + } } // Fallback with default contract. - return cairoContract ?? new DefaultContractCtor(publicKey, reader); + return new DefaultContractCtor(publicKey, reader); } } From 5bc7498a1222e9c11085a10718af45985b6fe122 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:46:05 +0800 Subject: [PATCH 14/19] feat: add `AddAccount` RPC --- packages/starknet-snap/src/index.tsx | 7 + .../src/rpcs/__tests__/helper.ts | 7 + .../src/rpcs/add-account.test.ts | 48 +++++ .../starknet-snap/src/rpcs/add-account.ts | 46 +++++ packages/starknet-snap/src/rpcs/index.ts | 1 + .../src/state/__tests__/helper.ts | 12 +- .../src/state/account-state-manager.test.ts | 195 +++++++++++------- .../src/state/account-state-manager.ts | 85 +++----- .../src/utils/permission.test.ts | 1 + .../starknet-snap/src/utils/permission.ts | 2 + .../src/utils/superstruct.test.ts | 25 +++ .../starknet-snap/src/utils/superstruct.ts | 11 + .../src/wallet/account/account.ts | 8 + 13 files changed, 313 insertions(+), 135 deletions(-) create mode 100644 packages/starknet-snap/src/rpcs/add-account.test.ts create mode 100644 packages/starknet-snap/src/rpcs/add-account.ts diff --git a/packages/starknet-snap/src/index.tsx b/packages/starknet-snap/src/index.tsx index b799bcb5..3e86bcf7 100644 --- a/packages/starknet-snap/src/index.tsx +++ b/packages/starknet-snap/src/index.tsx @@ -37,6 +37,7 @@ import type { GetAddrFromStarkNameParams, GetTransactionStatusParams, ListTransactionsParams, + AddAccountParams, } from './rpcs'; import { displayPrivateKey, @@ -53,6 +54,7 @@ import { getAddrFromStarkName, getTransactionStatus, listTransactions, + addAccount, } from './rpcs'; import { signDeployAccountTransaction } from './signDeployAccountTransaction'; import type { @@ -279,6 +281,11 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ apiParams.requestParams as unknown as GetAddrFromStarkNameParams, ); + case RpcMethod.AddAccount: + return await addAccount.execute( + requestParams as unknown as AddAccountParams, + ); + default: throw new MethodNotFoundError() as unknown as Error; } diff --git a/packages/starknet-snap/src/rpcs/__tests__/helper.ts b/packages/starknet-snap/src/rpcs/__tests__/helper.ts index 72019f31..b330dcbd 100644 --- a/packages/starknet-snap/src/rpcs/__tests__/helper.ts +++ b/packages/starknet-snap/src/rpcs/__tests__/helper.ts @@ -67,10 +67,17 @@ export async function setupAccountController({ 'deriveAccountByAddress', ); + const deriveAccountByIndexSpy = jest.spyOn( + AccountService.prototype, + 'deriveAccountByIndex', + ); + deriveAccountByAddressSpy.mockResolvedValue(account); + deriveAccountByIndexSpy.mockResolvedValue(account); return { deriveAccountByAddressSpy, + deriveAccountByIndexSpy, isRequireDeploySpy, isRequireUpgradeSpy, isDeploySpy, diff --git a/packages/starknet-snap/src/rpcs/add-account.test.ts b/packages/starknet-snap/src/rpcs/add-account.test.ts new file mode 100644 index 00000000..14962a68 --- /dev/null +++ b/packages/starknet-snap/src/rpcs/add-account.test.ts @@ -0,0 +1,48 @@ +import type { constants } from 'starknet'; + +import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../utils/constants'; +import { InvalidRequestParamsError } from '../utils/exceptions'; +import { setupAccountController } from './__tests__/helper'; +import { addAccount } from './add-account'; +import type { AddAccountParams } from './add-account'; + +jest.mock('../utils/snap'); +jest.mock('../utils/logger'); + +describe('AddAccountRpc', () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + + const setupAddAccountTest = async () => { + // Although `AddAccountRpc` does not inherit `AccountRpcController`, + // but we can still use `setupAccountController` to mock the `AccountService`. + const { account, deriveAccountByIndexSpy } = await setupAccountController( + {}, + ); + + const request = { + chainId: network.chainId as unknown as constants.StarknetChainId, + }; + + return { + deriveAccountByIndexSpy, + account, + request, + }; + }; + + it('add a `Account`', async () => { + const { account, request, deriveAccountByIndexSpy } = + await setupAddAccountTest(); + + const result = await addAccount.execute(request); + + expect(result).toStrictEqual(await account.serialize()); + expect(deriveAccountByIndexSpy).toHaveBeenCalled(); + }); + + it('throws `InvalidRequestParamsError` when request parameter is not correct', async () => { + await expect( + addAccount.execute({} as unknown as AddAccountParams), + ).rejects.toThrow(InvalidRequestParamsError); + }); +}); diff --git a/packages/starknet-snap/src/rpcs/add-account.ts b/packages/starknet-snap/src/rpcs/add-account.ts new file mode 100644 index 00000000..a5a1ed9e --- /dev/null +++ b/packages/starknet-snap/src/rpcs/add-account.ts @@ -0,0 +1,46 @@ +import { type Infer } from 'superstruct'; + +import { BaseRequestStruct, AccountStruct } from '../utils'; +import { createAccountService } from '../utils/factory'; +import { ChainRpcController } from './abstract/chain-rpc-controller'; + +export const AddAccountRequestStruct = BaseRequestStruct; + +export const AddAccountResponseStruct = AccountStruct; + +export type AddAccountParams = Infer; + +export type AddAccountResponse = Infer; + +/** + * The RPC handler to get a active account by network. + */ +export class AddAccountRpc extends ChainRpcController< + AddAccountParams, + AddAccountResponse +> { + protected requestStruct = AddAccountRequestStruct; + + protected responseStruct = AddAccountResponseStruct; + + /** + * Execute the get active account request handler. + * + * @param params - The parameters of the request. + * @param params.chainId - The chain id of the network. + * @returns A promise that resolves to an active account. + */ + protected async handleRequest( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + params: AddAccountParams, + ): Promise { + const accountService = createAccountService(this.network); + + const account = await accountService.deriveAccountByIndex(); + + // TODO: after derive an account, we should store it as a active account. + return account.serialize() as unknown as AddAccountResponse; + } +} + +export const addAccount = new AddAccountRpc(); diff --git a/packages/starknet-snap/src/rpcs/index.ts b/packages/starknet-snap/src/rpcs/index.ts index 82e2e623..a51d6100 100644 --- a/packages/starknet-snap/src/rpcs/index.ts +++ b/packages/starknet-snap/src/rpcs/index.ts @@ -12,3 +12,4 @@ export * from './watch-asset'; export * from './get-addr-from-starkname'; export * from './get-transaction-status'; export * from './list-transactions'; +export * from './add-account'; diff --git a/packages/starknet-snap/src/state/__tests__/helper.ts b/packages/starknet-snap/src/state/__tests__/helper.ts index 60b5fac4..4edc1fec 100644 --- a/packages/starknet-snap/src/state/__tests__/helper.ts +++ b/packages/starknet-snap/src/state/__tests__/helper.ts @@ -1,4 +1,4 @@ -import type { constants } from 'starknet'; +import { constants } from 'starknet'; import { generateAccounts, type StarknetAccount } from '../../__tests__/helper'; import type { @@ -27,6 +27,14 @@ export const mockAcccounts = async ( return generateAccounts(chainId, cnt); }; +export const generateTestnetAccounts = async (count?: number) => { + return await mockAcccounts(constants.StarknetChainId.SN_SEPOLIA, count); +}; + +export const generateMainnetAccounts = async (count?: number) => { + return await mockAcccounts(constants.StarknetChainId.SN_MAIN, count); +}; + export const mockState = async ({ accounts, tokens, @@ -47,7 +55,7 @@ export const mockState = async ({ const getDataSpy = jest.spyOn(snapHelper, 'getStateData'); const setDataSpy = jest.spyOn(snapHelper, 'setStateData'); const state = { - accContracts: accounts, + accContracts: accounts ?? [], erc20Tokens: tokens ?? [], networks: networks ?? [], transactions: transactions ?? [], diff --git a/packages/starknet-snap/src/state/account-state-manager.test.ts b/packages/starknet-snap/src/state/account-state-manager.test.ts index 5f88cc1e..4454d7c6 100644 --- a/packages/starknet-snap/src/state/account-state-manager.test.ts +++ b/packages/starknet-snap/src/state/account-state-manager.test.ts @@ -1,6 +1,11 @@ import { constants } from 'starknet'; -import { mockAcccounts, mockState } from './__tests__/helper'; +import type { StarknetAccount } from '../__tests__/helper'; +import { + generateMainnetAccounts, + generateTestnetAccounts, + mockState, +} from './__tests__/helper'; import { AddressFilter, ChainIdFilter, @@ -8,13 +13,13 @@ import { } from './account-state-manager'; describe('AccountStateManager', () => { + const testnetChainId = constants.StarknetChainId.SN_SEPOLIA; + const mainnetChainId = constants.StarknetChainId.SN_MAIN; + describe('getAccount', () => { it('returns the account', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accountsInTestnet = await mockAcccounts(chainId); - const accountsInMainnet = await mockAcccounts( - constants.StarknetChainId.SN_MAIN, - ); + const accountsInTestnet = await generateTestnetAccounts(); + const accountsInMainnet = await generateMainnetAccounts(); await mockState({ accounts: [...accountsInTestnet, ...accountsInMainnet], }); @@ -22,15 +27,14 @@ describe('AccountStateManager', () => { const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ address: accountsInTestnet[0].address, - chainId, + chainId: testnetChainId, }); expect(result).toStrictEqual(accountsInTestnet[0]); }); it('returns null if the account address can not be found', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const [accountNotExist, ...accounts] = await mockAcccounts(chainId); + const [accountNotExist, ...accounts] = await generateTestnetAccounts(); await mockState({ accounts, }); @@ -38,15 +42,14 @@ describe('AccountStateManager', () => { const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ address: accountNotExist.address, - chainId, + chainId: testnetChainId, }); expect(result).toBeNull(); }); it('returns null if the account chainId is not match', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accounts = await mockAcccounts(chainId); + const accounts = await generateTestnetAccounts(); await mockState({ accounts, }); @@ -54,7 +57,7 @@ describe('AccountStateManager', () => { const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ address: accounts[0].address, - chainId: constants.StarknetChainId.SN_MAIN, + chainId: mainnetChainId, }); expect(result).toBeNull(); @@ -63,11 +66,9 @@ describe('AccountStateManager', () => { describe('list', () => { it('returns the list of account', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accountsInTestnet = await mockAcccounts(chainId); - const accountsInMainnet = await mockAcccounts( - constants.StarknetChainId.SN_MAIN, - ); + const accountsInTestnet = await generateTestnetAccounts(); + const accountsInMainnet = await generateMainnetAccounts(); + await mockState({ accounts: [...accountsInTestnet, ...accountsInMainnet], }); @@ -86,8 +87,7 @@ describe('AccountStateManager', () => { }); it('returns empty array if the account address can not be found', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const [accountNotExist, ...accounts] = await mockAcccounts(chainId); + const [accountNotExist, ...accounts] = await generateTestnetAccounts(); await mockState({ accounts, }); @@ -95,93 +95,144 @@ describe('AccountStateManager', () => { const stateManager = new AccountStateManager(); const result = await stateManager.list([ new AddressFilter([accountNotExist.address]), - new ChainIdFilter([chainId]), + new ChainIdFilter([testnetChainId]), ]); expect(result).toStrictEqual([]); }); }); - describe('updateAccount', () => { - it('updates the account', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accounts = await mockAcccounts(chainId); + describe('upsertAccount', () => { + const setupUpserAccountTest = async (accounts: StarknetAccount[] = []) => { + const mainnetAccounts = await generateMainnetAccounts(); + const { state } = await mockState({ - accounts, + accounts: mainnetAccounts.concat(accounts), }); + return state; + }; + + it('adds an account if the account not exist', async () => { + const [account] = await generateTestnetAccounts(1); + const state = await setupUpserAccountTest(); + const originalAccountsFromState = [...state.accContracts]; const stateManager = new AccountStateManager(); - const updatedAccount = { ...accounts[0], deployTxnHash: '0x1234' }; - await stateManager.updateAccount(updatedAccount); + await stateManager.upsertAccount(account); - expect(state.accContracts?.[0]).toStrictEqual(updatedAccount); - expect(state.accContracts?.[0].upgradeRequired).toBeUndefined(); - expect(state.accContracts?.[0].deployRequired).toBeUndefined(); + expect(state.accContracts).toStrictEqual( + originalAccountsFromState.concat([account]), + ); }); - it('updates upgradeRequired and deployRequired of the account', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accounts = await mockAcccounts(chainId); - const { state } = await mockState({ - accounts, - }); - - const stateManager = new AccountStateManager(); + it('updates the account if the account is found', async () => { + const accounts = await generateTestnetAccounts(); const updatedAccount = { ...accounts[0], upgradeRequired: true, - deployRequired: false, }; - await stateManager.updateAccount(updatedAccount); + const state = await setupUpserAccountTest(accounts); + const originalAccountsLength = state.accContracts.length; + + const stateManager = new AccountStateManager(); + await stateManager.upsertAccount(updatedAccount); - expect(state.accContracts?.[0]).toStrictEqual(updatedAccount); - expect(state.accContracts?.[0].upgradeRequired).toBe(true); - expect(state.accContracts?.[0].deployRequired).toBe(false); + expect(state.accContracts).toHaveLength(originalAccountsLength); + expect( + state.accContracts.find( + (acc) => + acc.address === updatedAccount.address && + acc.chainId === updatedAccount.chainId, + ), + ).toStrictEqual(updatedAccount); }); + }); - it('throws `Account does not exist` error if the update account can not be found', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const [accountNotExist, ...accounts] = await mockAcccounts(chainId); - await mockState({ - accounts, + describe('getNextIndex', () => { + const setupGetNextIndexTest = async () => { + const mainnetAccounts = await generateMainnetAccounts(); + + const { state } = await mockState({ + removedAccounts: { + [mainnetChainId]: [0, 1, 2], + }, + accounts: mainnetAccounts, }); + return state; + }; + + it('returns index 0 if `removedAccounts` and `accContracts` are empty for the given chainId', async () => { + await setupGetNextIndexTest(); const stateManager = new AccountStateManager(); - const account = { ...accountNotExist, deployTxnHash: '0x1234' }; - await expect(stateManager.updateAccount(account)).rejects.toThrow( - 'Account does not exist', - ); + const result = await stateManager.getNextIndex(testnetChainId); + + expect(result).toBe(0); + }); + + it('returns the first index from `removedAccounts` if it is not empty for the given chainId', async () => { + const removedAccounts = [1, 3]; + const state = await setupGetNextIndexTest(); + state.removedAccounts[testnetChainId] = removedAccounts; + + const stateManager = new AccountStateManager(); + const result = await stateManager.getNextIndex(testnetChainId); + + expect(result).toBe(1); + // Ensure that the removed account is removed from the state + expect(state.removedAccounts[testnetChainId]).toStrictEqual([3]); + }); + + it('returns the length of index `accContracts` if `removedAccounts` is empty for the given chainId', async () => { + const accounts = await generateTestnetAccounts(); + const state = await setupGetNextIndexTest(); + state.accContracts = state.accContracts.concat(accounts); + + const stateManager = new AccountStateManager(); + const result = await stateManager.getNextIndex(testnetChainId); + + expect(result).toStrictEqual(accounts.length); }); }); - describe('addAccount', () => { - it('add an account', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const [accountNotExist, ...accounts] = await mockAcccounts(chainId, 5); + describe('removeAccount', () => { + const setupRemoveAccountTest = async (accounts: StarknetAccount[] = []) => { + const mainnetAccounts = await generateMainnetAccounts(); + const { state } = await mockState({ - accounts, + accounts: mainnetAccounts.concat(accounts), }); + return state; + }; + + it('removes an account', async () => { + const accounts = await generateTestnetAccounts(); + const removeAccount = accounts[1]; + const state = await setupRemoveAccountTest(accounts); + const originalAccountsFromState = [...state.accContracts]; + const expectedAccountsAfterRemoved = originalAccountsFromState.filter( + (account) => + account.address !== removeAccount.address && + account.chainId === removeAccount.chainId, + ); const stateManager = new AccountStateManager(); - await stateManager.addAccount(accountNotExist); + await stateManager.removeAccount(removeAccount); - expect(state.accContracts?.length).toBe(5); - expect( - state.accContracts?.[state.accContracts?.length - 1], - ).toStrictEqual(accountNotExist); + expect(state.accContracts).toStrictEqual(expectedAccountsAfterRemoved); + expect(state.removedAccounts).toHaveProperty(testnetChainId); + expect(state.removedAccounts[testnetChainId]).toStrictEqual([ + removeAccount.addressIndex, + ]); }); - it('throws `Account already exist` error if the account is exist', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; - const accounts = await mockAcccounts(chainId); - await mockState({ - accounts, - }); + it('throws an `Account does not exist` error if the removed account is not exist', async () => { + const [removeAccount, ...accounts] = await generateTestnetAccounts(); + await setupRemoveAccountTest(accounts); const stateManager = new AccountStateManager(); - - await expect(stateManager.addAccount(accounts[0])).rejects.toThrow( - 'Account already exist', + await expect(stateManager.removeAccount(removeAccount)).rejects.toThrow( + 'Account does not exist', ); }); }); diff --git a/packages/starknet-snap/src/state/account-state-manager.ts b/packages/starknet-snap/src/state/account-state-manager.ts index 7b884e02..c84ccded 100644 --- a/packages/starknet-snap/src/state/account-state-manager.ts +++ b/packages/starknet-snap/src/state/account-state-manager.ts @@ -61,64 +61,11 @@ export class AccountStateManager extends StateManager { } /** - * Updates an account in the state with the given data. + * Upserts an account in the state. * - * @param data - The AccContract object to update. - * @returns A Promise that resolves when the update is complete. - * @throws {StateManagerError} If there is an error updating the account, such as: - * If the account to be updated does not exist in the state. + * @param data - The AccContract object to upsert. + * @throws {StateManagerError} If an error occurs while updating the state. */ - async updateAccount(data: AccContract): Promise { - try { - await this.update(async (state: SnapState) => { - const accountInState = await this.getAccount( - { - address: data.address, - chainId: data.chainId, - }, - state, - ); - - if (!accountInState) { - throw new StateManagerError(`Account does not exist`); - } - - this.updateEntity(accountInState, data); - }); - } catch (error) { - throw new StateManagerError(error.message); - } - } - - /** - * Adds a new account to the state with the given data. - * - * @param data - The AccContract object to add. - * @returns A Promise that resolves when the add is complete. - * @throws {StateManagerError} If there is an error adding the account, such as: - * If the account to be added already exists in the state. - */ - async addAccount(data: AccContract): Promise { - try { - await this.update(async (state: SnapState) => { - const accountInState = await this.getAccount( - { - address: data.address, - chainId: data.chainId, - }, - state, - ); - - if (accountInState) { - throw new Error(`Account already exist`); - } - state.accContracts.push(data); - }); - } catch (error) { - throw new StateManagerError(error.message); - } - } - async upsertAccount(data: AccContract): Promise { try { await this.update(async (state: SnapState) => { @@ -173,19 +120,34 @@ export class AccountStateManager extends StateManager { } } + /** + * Gets the next index based on the chain ID. + * If `removedAccounts` is not empty for the chain ID, the first index is picked. + * Otherwise, the length of `accContracts` for the chain ID is used. + * + * @param chainId - The chain ID. + * @returns A Promise that resolves to the next index. + */ async getNextIndex(chainId: string): Promise { let idx = 0; await this.update(async (state: SnapState) => { - // Choose the deleted account index over the last index (accContracts length). - // If the removedAccounts array is empty, then fallback with the last index. idx = state.removedAccounts?.[chainId]?.shift() ?? - state.accContracts.filter((account) => account.chainId === chainId) - .length; + state.accContracts.filter((account) => + new ChainIdFilter([chainId]).apply(account), + ).length; }); return idx; } + /** + * Removes account by address and chain ID. + * + * @param params - The parameters for removing the account. + * @param params.address - The address of the account to remove. + * @param params.chainId - The chain ID of the account to remove. + * @throws {StateManagerError} If the account to be removed does not exist. + */ async removeAccount({ address, chainId, @@ -209,7 +171,8 @@ export class AccountStateManager extends StateManager { state.accContracts = state.accContracts.filter( (account) => - account.address !== address && account.chainId === chainId, + new ChainIdFilter([chainId]).apply(account) && + account.address !== address, ); // Safeguard to ensure the removedAccounts object is initialized. diff --git a/packages/starknet-snap/src/utils/permission.test.ts b/packages/starknet-snap/src/utils/permission.test.ts index 69931b28..de3ab716 100644 --- a/packages/starknet-snap/src/utils/permission.test.ts +++ b/packages/starknet-snap/src/utils/permission.test.ts @@ -16,6 +16,7 @@ describe('validateOrigin', () => { RpcMethod.GetAddressByStarkName, RpcMethod.ReadContract, RpcMethod.GetStoredErc20Tokens, + RpcMethod.AddAccount, ]; it.each(walletUIDappPermissions)( diff --git a/packages/starknet-snap/src/utils/permission.ts b/packages/starknet-snap/src/utils/permission.ts index 79277a60..8d6ec6aa 100644 --- a/packages/starknet-snap/src/utils/permission.ts +++ b/packages/starknet-snap/src/utils/permission.ts @@ -15,6 +15,7 @@ export enum RpcMethod { SignDeclareTransaction = 'starkNet_signDeclareTransaction', SignDeployAccountTransaction = 'starkNet_signDeployAccountTransaction', + AddAccount = 'starkNet_addAccount', CreateAccount = 'starkNet_createAccount', DisplayPrivateKey = 'starkNet_displayPrivateKey', GetErc20TokenBalance = 'starkNet_getErc20TokenBalance', @@ -64,6 +65,7 @@ const walletUIDappPermissions = publicPermissions.concat([ RpcMethod.GetAddressByStarkName, RpcMethod.ReadContract, RpcMethod.GetStoredErc20Tokens, + RpcMethod.AddAccount, ]); const publicPermissionsSet = new Set(publicPermissions); diff --git a/packages/starknet-snap/src/utils/superstruct.test.ts b/packages/starknet-snap/src/utils/superstruct.test.ts index 50a4b656..b83f89a7 100644 --- a/packages/starknet-snap/src/utils/superstruct.test.ts +++ b/packages/starknet-snap/src/utils/superstruct.test.ts @@ -6,11 +6,13 @@ import transactionExample from '../__tests__/fixture/transactionExample.json'; import typedDataExample from '../__tests__/fixture/typedDataExample.json'; import { generateTransactions } from '../__tests__/helper'; import { ContractFuncName } from '../types/snapState'; +import { createAccountObject } from '../wallet/account/__test__/helper'; import { ACCOUNT_CLASS_HASH, CAIRO_VERSION, CAIRO_VERSION_LEGACY, ETHER_SEPOLIA_TESTNET, + STARKNET_SEPOLIA_TESTNET_NETWORK, } from './constants'; import { AddressStruct, @@ -30,6 +32,7 @@ import { TokenSymbolStruct, TokenNameStruct, TransactionStruct, + AccountStruct, } from './superstruct'; describe('TokenNameStruct', () => { @@ -562,3 +565,25 @@ describe('TransactionStruct', () => { ).toThrow(StructError); }); }); + +describe('AccountStruct', () => { + it('does not throw error if the account is valid', async () => { + const network = STARKNET_SEPOLIA_TESTNET_NETWORK; + const { accountObj } = await createAccountObject(network); + + jest + .spyOn(accountObj.accountContract, 'isRequireUpgrade') + .mockResolvedValue(false); + jest + .spyOn(accountObj.accountContract, 'isRequireDeploy') + .mockResolvedValue(false); + + const account = await accountObj.serialize(); + + expect(() => assert(account, AccountStruct)).not.toThrow(); + }); + + it('throws error if the account is invalid', () => { + expect(() => assert({}, AccountStruct)).toThrow(StructError); + }); +}); diff --git a/packages/starknet-snap/src/utils/superstruct.ts b/packages/starknet-snap/src/utils/superstruct.ts index 3df68715..773ea90a 100644 --- a/packages/starknet-snap/src/utils/superstruct.ts +++ b/packages/starknet-snap/src/utils/superstruct.ts @@ -426,3 +426,14 @@ export const TransactionStruct = object({ // Snap data Version to support backward compatibility , migration. dataVersion: enums(Object.values(TransactionDataVersion)), }); + +export const AccountStruct = object({ + address: AddressStruct, + chainId: ChainIdStruct, + publicKey: HexStruct, + addressSalt: HexStruct, + addressIndex: number(), + cairoVersion: CairoVersionStruct, + upgradeRequired: boolean(), + deployRequired: boolean(), +}); diff --git a/packages/starknet-snap/src/wallet/account/account.ts b/packages/starknet-snap/src/wallet/account/account.ts index d8c2576c..bfb5197f 100644 --- a/packages/starknet-snap/src/wallet/account/account.ts +++ b/packages/starknet-snap/src/wallet/account/account.ts @@ -58,6 +58,12 @@ export class Account { * @returns A promise that resolves to the serialized `Account` object. */ async serialize(): Promise { + // When a Account object discovery by the account service, + // it should already cached the status of requireDeploy and requireUpgrade. + const [upgradeRequired, deployRequired] = await Promise.all([ + this.accountContract.isRequireDeploy(), + this.accountContract.isRequireUpgrade(), + ]); return { addressSalt: this.publicKey, publicKey: this.publicKey, @@ -65,6 +71,8 @@ export class Account { addressIndex: this.hdIndex, chainId: this.chainId, cairoVersion: this.cairoVersion, + upgradeRequired, + deployRequired, }; } } From 0ebb0fea3dbf2f00355348ac2e463590ae864179 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:35:50 +0800 Subject: [PATCH 15/19] feat: add max account create limit --- packages/starknet-snap/src/config.ts | 7 + .../src/state/__tests__/helper.ts | 24 +++ .../src/state/account-state-manager.ts | 23 +++ .../starknet-snap/src/utils/exceptions.ts | 11 ++ .../src/wallet/account/__test__/helper.ts | 13 +- .../src/wallet/account/service.test.ts | 154 ++++++++++++------ .../src/wallet/account/service.ts | 79 +++++---- 7 files changed, 226 insertions(+), 85 deletions(-) diff --git a/packages/starknet-snap/src/config.ts b/packages/starknet-snap/src/config.ts index fc673a06..ba34fe15 100644 --- a/packages/starknet-snap/src/config.ts +++ b/packages/starknet-snap/src/config.ts @@ -36,6 +36,9 @@ export type SnapConfig = { txnsInLastNumOfDays: number; }; }; + account: { + maxAccountToCreate: number; + }; }; export enum DataClient { @@ -61,6 +64,10 @@ export const Config: SnapConfig = { }, }, + account: { + maxAccountToCreate: 2, + }, + // eslint-disable-next-line no-restricted-globals rpcApiKey: process.env.DIN_API_KEY ?? '', diff --git a/packages/starknet-snap/src/state/__tests__/helper.ts b/packages/starknet-snap/src/state/__tests__/helper.ts index 4edc1fec..76605ca7 100644 --- a/packages/starknet-snap/src/state/__tests__/helper.ts +++ b/packages/starknet-snap/src/state/__tests__/helper.ts @@ -12,6 +12,7 @@ import { STRK_SEPOLIA_TESTNET, } from '../../utils/constants'; import * as snapHelper from '../../utils/snap'; +import { AccountStateManager } from '../account-state-manager'; import { NetworkStateManager } from '../network-state-manager'; import { TransactionRequestStateManager } from '../request-state-manager'; import { TokenStateManager } from '../token-state-manager'; @@ -86,6 +87,29 @@ export const mockTokenStateManager = () => { }; }; +export const mockAccountStateManager = () => { + const getAccountSpy = jest.spyOn(AccountStateManager.prototype, 'getAccount'); + const getNextIndexSpy = jest.spyOn( + AccountStateManager.prototype, + 'getNextIndex', + ); + const upsertAccountSpy = jest.spyOn( + AccountStateManager.prototype, + 'upsertAccount', + ); + const isMaxAccountLimitExceededSpy = jest.spyOn( + AccountStateManager.prototype, + 'isMaxAccountLimitExceeded', + ); + + return { + getAccountSpy, + getNextIndexSpy, + upsertAccountSpy, + isMaxAccountLimitExceededSpy, + }; +}; + export const mockTransactionStateManager = () => { const removeTransactionsSpy = jest.spyOn( TransactionStateManager.prototype, diff --git a/packages/starknet-snap/src/state/account-state-manager.ts b/packages/starknet-snap/src/state/account-state-manager.ts index c84ccded..2c5ab0cb 100644 --- a/packages/starknet-snap/src/state/account-state-manager.ts +++ b/packages/starknet-snap/src/state/account-state-manager.ts @@ -1,3 +1,4 @@ +import { Config } from '../config'; import type { AccContract, SnapState } from '../types/snapState'; import type { IFilter } from './filter'; import { @@ -190,4 +191,26 @@ export class AccountStateManager extends StateManager { throw new StateManagerError(error.message); } } + + /** + * Determines whether max account limit exceeded. + * + * @param params - The parameters for checking the max account limit. + * @param params.chainId - The chain ID. + * @param [state] - The optional SnapState object. + * @returns A Promise that resolves to a boolean indicating whether the max account limit is exceeded. + */ + async isMaxAccountLimitExceeded( + { + chainId, + }: { + chainId: string; + }, + state?: SnapState, + ): Promise { + return ( + (await this.list([new ChainIdFilter([chainId])], undefined, state)) + .length >= Config.account.maxAccountToCreate + ); + } } diff --git a/packages/starknet-snap/src/utils/exceptions.ts b/packages/starknet-snap/src/utils/exceptions.ts index a3b70421..8a6c9ada 100644 --- a/packages/starknet-snap/src/utils/exceptions.ts +++ b/packages/starknet-snap/src/utils/exceptions.ts @@ -4,6 +4,7 @@ import { UserRejectedRequestError, } from '@metamask/snaps-sdk'; +import { Config } from '../config'; import { createWalletRpcErrorWrapper, WalletRpcErrorCode } from './error'; // Extend SnapError to allow error message visible to client @@ -49,6 +50,16 @@ export class AccountDiscoveryError extends SnapError { } } +export class MaxAccountLimitExceededError extends SnapError { + constructor(message?: string) { + super( + message ?? + `Maximum number of accounts reached: ${Config.account.maxAccountToCreate}`, + createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown), + ); + } +} + export class ContractReadError extends SnapError { constructor(message: string) { super(message, createWalletRpcErrorWrapper(WalletRpcErrorCode.Unknown)); diff --git a/packages/starknet-snap/src/wallet/account/__test__/helper.ts b/packages/starknet-snap/src/wallet/account/__test__/helper.ts index 0d566a64..76857a96 100644 --- a/packages/starknet-snap/src/wallet/account/__test__/helper.ts +++ b/packages/starknet-snap/src/wallet/account/__test__/helper.ts @@ -54,9 +54,18 @@ export const createAccountContract = async ( }; }; -export const createAccountObject = async (network, hdIndex = 0) => { +export const createAccountObject = async ( + network, + hdIndex = 0, + mnemonicString?: string, +) => { const { account, accountContractReader, contract } = - await createAccountContract(network, hdIndex); + await createAccountContract( + network, + hdIndex, + Cairo1Contract, + mnemonicString, + ); const { privateKey, publicKey, chainId, addressIndex } = account; diff --git a/packages/starknet-snap/src/wallet/account/service.test.ts b/packages/starknet-snap/src/wallet/account/service.test.ts index d56b9f97..6e3a8b6e 100644 --- a/packages/starknet-snap/src/wallet/account/service.test.ts +++ b/packages/starknet-snap/src/wallet/account/service.test.ts @@ -1,10 +1,17 @@ import { generateMnemonic } from 'bip39'; -import { AccountContractReader, AccountService, Cairo1Contract } from '.'; -import { generateAccounts, generateKeyDeriver } from '../../__tests__/helper'; +import { AccountService } from '.'; +import { generateKeyDeriver } from '../../__tests__/helper'; +import { + mockAccountStateManager, + mockState, +} from '../../state/__tests__/helper'; import { AccountStateManager } from '../../state/account-state-manager'; import { STARKNET_SEPOLIA_TESTNET_NETWORK } from '../../utils/constants'; -import { AccountNotFoundError } from '../../utils/exceptions'; +import { + AccountNotFoundError, + MaxAccountLimitExceededError, +} from '../../utils/exceptions'; import { createAccountService } from '../../utils/factory'; import * as snapUtils from '../../utils/snap'; import { @@ -15,53 +22,64 @@ import { Account } from './account'; import { AccountContractDiscovery } from './discovery'; jest.mock('../../utils/logger'); +jest.mock('../../utils/snap'); describe('AccountService', () => { const network = STARKNET_SEPOLIA_TESTNET_NETWORK; describe('deriveAccountByIndex', () => { - const setupDeriveAccountByIndexTest = async (hdIndex) => { - const mnemonicString = generateMnemonic(); - - const [account] = await generateAccounts( - network.chainId, - 1, - '1', + const mockDeriveAccount = async ( + hdIndex, + mnemonicString = generateMnemonic(), + ) => { + const { accountObj } = await createAccountObject( + network, hdIndex, mnemonicString, ); - const deriver = await generateKeyDeriver(mnemonicString); - - const getNextIndexSpy = jest.spyOn( - AccountStateManager.prototype, - 'getNextIndex', - ); - const upsertAccountSpy = jest.spyOn( - AccountStateManager.prototype, - 'upsertAccount', - ); const getCairoContractSpy = jest.spyOn( AccountContractDiscovery.prototype, 'getContract', ); + getCairoContractSpy.mockResolvedValue(accountObj.accountContract); + + return { + accountObj, + getCairoContractSpy, + }; + }; + + const mockSnapDeriver = async (mnemonicString) => { + const deriver = await generateKeyDeriver(mnemonicString); jest.spyOn(snapUtils, 'getBip44Deriver').mockResolvedValue(deriver); + }; - mockAccountContractReader({}); + const setupDeriveAccountByIndexTest = async (hdIndex) => { + const mnemonicString = generateMnemonic(); - const cairo1Contract = new Cairo1Contract( - account.publicKey, - new AccountContractReader(network), + const { accountObj, getCairoContractSpy } = await mockDeriveAccount( + hdIndex, + mnemonicString, ); + await mockSnapDeriver(mnemonicString); + + const { + getNextIndexSpy, + isMaxAccountLimitExceededSpy, + upsertAccountSpy, + } = mockAccountStateManager(); + + mockAccountContractReader({}); - getCairoContractSpy.mockResolvedValue(cairo1Contract); getNextIndexSpy.mockResolvedValue(hdIndex); + isMaxAccountLimitExceededSpy.mockResolvedValue(false); return { upsertAccountSpy, getNextIndexSpy, getCairoContractSpy, - cairo1Contract, - account, + account: accountObj, + isMaxAccountLimitExceededSpy, }; }; @@ -71,26 +89,23 @@ describe('AccountService', () => { getNextIndexSpy, getCairoContractSpy, upsertAccountSpy, - cairo1Contract, account, } = await setupDeriveAccountByIndexTest(hdIndex); const service = createAccountService(network); - const accountObject = await service.deriveAccountByIndex(); + const result = await service.deriveAccountByIndex(); expect(getNextIndexSpy).toHaveBeenCalled(); - expect(upsertAccountSpy).toHaveBeenCalledWith( - await accountObject.serialize(), - ); + expect(upsertAccountSpy).toHaveBeenCalledWith(await result.serialize()); expect(getCairoContractSpy).toHaveBeenCalledWith(account.publicKey); - expect(accountObject).toBeInstanceOf(Account); - expect(accountObject).toHaveProperty('accountContract', cairo1Contract); - expect(accountObject).toHaveProperty('address', account.address); - expect(accountObject).toHaveProperty('chainId', account.chainId); - expect(accountObject).toHaveProperty('privateKey', account.privateKey); - expect(accountObject).toHaveProperty('publicKey', account.publicKey); - expect(accountObject).toHaveProperty('hdIndex', hdIndex); - expect(accountObject).toHaveProperty('addressSalt', account.publicKey); + expect(result).toBeInstanceOf(Account); + expect(result).toHaveProperty('accountContract', account.accountContract); + expect(result).toHaveProperty('address', account.address); + expect(result).toHaveProperty('chainId', account.chainId); + expect(result).toHaveProperty('privateKey', account.privateKey); + expect(result).toHaveProperty('publicKey', account.publicKey); + expect(result).toHaveProperty('hdIndex', hdIndex); + expect(result).toHaveProperty('addressSalt', account.publicKey); }); it('derive an account with the given index', async () => { @@ -98,27 +113,60 @@ describe('AccountService', () => { const { getNextIndexSpy, getCairoContractSpy, - cairo1Contract, account, upsertAccountSpy, } = await setupDeriveAccountByIndexTest(hdIndex); const service = createAccountService(network); - const accountObject = await service.deriveAccountByIndex(hdIndex); + const result = await service.deriveAccountByIndex(hdIndex); expect(getNextIndexSpy).not.toHaveBeenCalled(); - expect(upsertAccountSpy).toHaveBeenCalledWith( - await accountObject.serialize(), - ); + expect(upsertAccountSpy).toHaveBeenCalledWith(await result.serialize()); expect(getCairoContractSpy).toHaveBeenCalledWith(account.publicKey); - expect(accountObject).toBeInstanceOf(Account); - expect(accountObject).toHaveProperty('accountContract', cairo1Contract); - expect(accountObject).toHaveProperty('address', account.address); - expect(accountObject).toHaveProperty('chainId', account.chainId); - expect(accountObject).toHaveProperty('privateKey', account.privateKey); - expect(accountObject).toHaveProperty('publicKey', account.publicKey); - expect(accountObject).toHaveProperty('hdIndex', hdIndex); - expect(accountObject).toHaveProperty('addressSalt', account.publicKey); + expect(result).toBeInstanceOf(Account); + expect(result).toHaveProperty('accountContract', account.accountContract); + expect(result).toHaveProperty('address', account.address); + expect(result).toHaveProperty('chainId', account.chainId); + expect(result).toHaveProperty('privateKey', account.privateKey); + expect(result).toHaveProperty('publicKey', account.publicKey); + expect(result).toHaveProperty('hdIndex', hdIndex); + expect(result).toHaveProperty('addressSalt', account.publicKey); + }); + + it('throws `MaxAccountLimitExceededError` error if the account to derive reach the maximum', async () => { + const { isMaxAccountLimitExceededSpy } = + await setupDeriveAccountByIndexTest(0); + isMaxAccountLimitExceededSpy.mockResolvedValue(true); + + const service = createAccountService(network); + + await expect(service.deriveAccountByIndex()).rejects.toThrow( + MaxAccountLimitExceededError, + ); + }); + + it('does not modify the state if an error has thrown', async () => { + const { setDataSpy } = await mockState({}); + // mockAccountStateManager is only returning the spies, + // it will not mock the function to return a value. + const { isMaxAccountLimitExceededSpy } = mockAccountStateManager(); + + const mnemonicString = generateMnemonic(); + await mockDeriveAccount(0, mnemonicString); + await mockSnapDeriver(mnemonicString); + mockAccountContractReader({}); + + // A `MaxAccountLimitExceededError` will be thrown when `isMaxAccountLimitExceeded` is true. + // Since this checking is placed at the end of the function, + // it is the best way to test if the state is not modified if an error occurs. + isMaxAccountLimitExceededSpy.mockResolvedValue(true); + + const service = createAccountService(network); + + await expect(service.deriveAccountByIndex()).rejects.toThrow( + MaxAccountLimitExceededError, + ); + expect(setDataSpy).not.toHaveBeenCalled(); }); }); diff --git a/packages/starknet-snap/src/wallet/account/service.ts b/packages/starknet-snap/src/wallet/account/service.ts index 23db08d5..d19a4173 100644 --- a/packages/starknet-snap/src/wallet/account/service.ts +++ b/packages/starknet-snap/src/wallet/account/service.ts @@ -1,7 +1,10 @@ import { AccountStateManager } from '../../state/account-state-manager'; import type { Network } from '../../types/snapState'; import { getBip44Deriver } from '../../utils'; -import { AccountNotFoundError } from '../../utils/exceptions'; +import { + AccountNotFoundError, + MaxAccountLimitExceededError, +} from '../../utils/exceptions'; import { Account } from './account'; import { AccountContractDiscovery } from './discovery'; import { AccountKeyPair } from './keypair'; @@ -47,36 +50,52 @@ export class AccountService { * @returns A promise that resolves to the newly created `Account` object. */ async deriveAccountByIndex(index?: number): Promise { - let hdIndex = index; - - if (hdIndex === undefined) { - hdIndex = await this.accountStateMgr.getNextIndex(this.network.chainId); - } - - // Derive a BIP44 node from an index. e.g m/44'/60'/0'/0/{hdIndex} - const deriver = await getBip44Deriver(); - const node = await deriver(hdIndex); - - // Grind a new private key and public key from the derived node. - // Private key and public key are independent from the account contract. - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const { privateKey, publicKey } = new AccountKeyPair(node.privateKey!); - - const accountContract = - await this.accountContractDiscoveryService.getContract(publicKey); - - const account = new Account({ - privateKey, - publicKey, - chainId: this.network.chainId, - hdIndex, - addressSalt: publicKey, - accountContract, + const { chainId } = this.network; + + // use `withTransaction` to ensure that the state is not modified if an error occurs. + return this.accountStateMgr.withTransaction(async (state) => { + let hdIndex = index; + if (hdIndex === undefined) { + hdIndex = await this.accountStateMgr.getNextIndex(chainId); + } + + // Derive a BIP44 node from an index. e.g m/44'/60'/0'/0/{hdIndex} + const deriver = await getBip44Deriver(); + const node = await deriver(hdIndex); + + // Grind a new private key and public key from the derived node. + // Private key and public key are independent from the account contract. + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const { privateKey, publicKey } = new AccountKeyPair(node.privateKey!); + + const accountContract = + await this.accountContractDiscoveryService.getContract(publicKey); + + const account = new Account({ + privateKey, + publicKey, + chainId: this.network.chainId, + hdIndex, + addressSalt: publicKey, + accountContract, + }); + + await this.accountStateMgr.upsertAccount(await account.serialize()); + + // FIXME: this is a convenience way to check if the account limit has been exceeded at the last line of the code. However, it is possible to improve if we can check it before the account is derived. + if ( + await this.accountStateMgr.isMaxAccountLimitExceeded( + { + chainId, + }, + state, + ) + ) { + throw new MaxAccountLimitExceededError(); + } + + return account; }); - - await this.accountStateMgr.upsertAccount(await account.serialize()); - - return account; } /** From d2274afc49a1740329e519e9a92e67565fe8cd08 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:46:28 +0800 Subject: [PATCH 16/19] fix: add `isMaxAccountLimitExceeded` unit test --- .../src/state/account-state-manager.test.ts | 93 +++++++++++-------- .../src/state/account-state-manager.ts | 2 +- 2 files changed, 54 insertions(+), 41 deletions(-) diff --git a/packages/starknet-snap/src/state/account-state-manager.test.ts b/packages/starknet-snap/src/state/account-state-manager.test.ts index 4454d7c6..186e2059 100644 --- a/packages/starknet-snap/src/state/account-state-manager.test.ts +++ b/packages/starknet-snap/src/state/account-state-manager.test.ts @@ -1,6 +1,7 @@ import { constants } from 'starknet'; import type { StarknetAccount } from '../__tests__/helper'; +import { Config } from '../config'; import { generateMainnetAccounts, generateTestnetAccounts, @@ -16,13 +17,21 @@ describe('AccountStateManager', () => { const testnetChainId = constants.StarknetChainId.SN_SEPOLIA; const mainnetChainId = constants.StarknetChainId.SN_MAIN; + const mockStateWithMainnetAccounts = async ( + accounts: StarknetAccount[] = [], + ) => { + const mainnetAccounts = await generateMainnetAccounts(); + + const { state } = await mockState({ + accounts: mainnetAccounts.concat(accounts), + }); + return state; + }; + describe('getAccount', () => { it('returns the account', async () => { const accountsInTestnet = await generateTestnetAccounts(); - const accountsInMainnet = await generateMainnetAccounts(); - await mockState({ - accounts: [...accountsInTestnet, ...accountsInMainnet], - }); + await mockStateWithMainnetAccounts(accountsInTestnet); const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ @@ -35,9 +44,7 @@ describe('AccountStateManager', () => { it('returns null if the account address can not be found', async () => { const [accountNotExist, ...accounts] = await generateTestnetAccounts(); - await mockState({ - accounts, - }); + await mockStateWithMainnetAccounts(accounts); const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ @@ -50,9 +57,7 @@ describe('AccountStateManager', () => { it('returns null if the account chainId is not match', async () => { const accounts = await generateTestnetAccounts(); - await mockState({ - accounts, - }); + await mockStateWithMainnetAccounts(accounts); const stateManager = new AccountStateManager(); const result = await stateManager.getAccount({ @@ -103,18 +108,9 @@ describe('AccountStateManager', () => { }); describe('upsertAccount', () => { - const setupUpserAccountTest = async (accounts: StarknetAccount[] = []) => { - const mainnetAccounts = await generateMainnetAccounts(); - - const { state } = await mockState({ - accounts: mainnetAccounts.concat(accounts), - }); - return state; - }; - it('adds an account if the account not exist', async () => { const [account] = await generateTestnetAccounts(1); - const state = await setupUpserAccountTest(); + const state = await mockStateWithMainnetAccounts(); const originalAccountsFromState = [...state.accContracts]; const stateManager = new AccountStateManager(); @@ -131,7 +127,7 @@ describe('AccountStateManager', () => { ...accounts[0], upgradeRequired: true, }; - const state = await setupUpserAccountTest(accounts); + const state = await mockStateWithMainnetAccounts(accounts); const originalAccountsLength = state.accContracts.length; const stateManager = new AccountStateManager(); @@ -150,14 +146,10 @@ describe('AccountStateManager', () => { describe('getNextIndex', () => { const setupGetNextIndexTest = async () => { - const mainnetAccounts = await generateMainnetAccounts(); - - const { state } = await mockState({ - removedAccounts: { - [mainnetChainId]: [0, 1, 2], - }, - accounts: mainnetAccounts, - }); + const state = await mockStateWithMainnetAccounts(); + state.removedAccounts = { + [mainnetChainId]: [0, 1, 2], + }; return state; }; @@ -196,19 +188,10 @@ describe('AccountStateManager', () => { }); describe('removeAccount', () => { - const setupRemoveAccountTest = async (accounts: StarknetAccount[] = []) => { - const mainnetAccounts = await generateMainnetAccounts(); - - const { state } = await mockState({ - accounts: mainnetAccounts.concat(accounts), - }); - return state; - }; - it('removes an account', async () => { const accounts = await generateTestnetAccounts(); const removeAccount = accounts[1]; - const state = await setupRemoveAccountTest(accounts); + const state = await mockStateWithMainnetAccounts(accounts); const originalAccountsFromState = [...state.accContracts]; const expectedAccountsAfterRemoved = originalAccountsFromState.filter( (account) => @@ -228,7 +211,7 @@ describe('AccountStateManager', () => { it('throws an `Account does not exist` error if the removed account is not exist', async () => { const [removeAccount, ...accounts] = await generateTestnetAccounts(); - await setupRemoveAccountTest(accounts); + await mockStateWithMainnetAccounts(accounts); const stateManager = new AccountStateManager(); await expect(stateManager.removeAccount(removeAccount)).rejects.toThrow( @@ -236,4 +219,34 @@ describe('AccountStateManager', () => { ); }); }); + + describe('isMaxAccountLimitExceeded', () => { + it('returns true if the account limit is reached', async () => { + const accounts = await generateTestnetAccounts( + Config.account.maxAccountToCreate + 1, + ); + await mockStateWithMainnetAccounts(accounts); + + const stateManager = new AccountStateManager(); + const result = await stateManager.isMaxAccountLimitExceeded({ + chainId: testnetChainId, + }); + + expect(result).toBe(true); + }); + + it('returns false if the account limit is not reached', async () => { + const accounts = await generateTestnetAccounts( + Config.account.maxAccountToCreate, + ); + await mockStateWithMainnetAccounts(accounts); + + const stateManager = new AccountStateManager(); + const result = await stateManager.isMaxAccountLimitExceeded({ + chainId: testnetChainId, + }); + + expect(result).toBe(false); + }); + }); }); diff --git a/packages/starknet-snap/src/state/account-state-manager.ts b/packages/starknet-snap/src/state/account-state-manager.ts index 2c5ab0cb..8658729a 100644 --- a/packages/starknet-snap/src/state/account-state-manager.ts +++ b/packages/starknet-snap/src/state/account-state-manager.ts @@ -210,7 +210,7 @@ export class AccountStateManager extends StateManager { ): Promise { return ( (await this.list([new ChainIdFilter([chainId])], undefined, state)) - .length >= Config.account.maxAccountToCreate + .length > Config.account.maxAccountToCreate ); } } From 2cb063311abe9a03cd3830ed5fd14564fba93712 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:43:10 +0800 Subject: [PATCH 17/19] fix: account deploy require result --- .../src/wallet/account/contract.test.ts | 78 ++++++++++++------- .../src/wallet/account/contract.ts | 5 +- .../src/wallet/account/discovery.ts | 5 +- 3 files changed, 53 insertions(+), 35 deletions(-) diff --git a/packages/starknet-snap/src/wallet/account/contract.test.ts b/packages/starknet-snap/src/wallet/account/contract.test.ts index aec4f51b..8c9c4347 100644 --- a/packages/starknet-snap/src/wallet/account/contract.test.ts +++ b/packages/starknet-snap/src/wallet/account/contract.test.ts @@ -242,35 +242,53 @@ describe('CairoAccountContract', () => { }); describe('isRequireDeploy', () => { - it('returns true if the contract requires deploy', async () => { - const { getVersionSpy } = mockAccountContractReader({}); - getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); - - const { contract } = await createAccountContract( - network, - 0, - Cairo0Contract, - ); - - const result = await contract.isRequireDeploy(); - - expect(result).toBe(true); - }); - - it('returns false if the contract is not deployed and does not has ETH', async () => { - const { getVersionSpy } = mockAccountContractReader({ - balance: BigInt(0), - }); - getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); - const { contract } = await createAccountContract( - network, - 0, - Cairo1Contract, - ); - - const result = await contract.isRequireUpgrade(); - - expect(result).toBe(false); - }); + it.each([ + { + expected: true, + balance: 1000000, + }, + { + expected: false, + balance: 0, + }, + ])( + 'returns $expected if a Cairo 0 contract is not deployed - Balance: $balance', + async ({ balance, expected }: { balance: number; expected: boolean }) => { + const { getVersionSpy } = mockAccountContractReader({ + balance: BigInt(balance), + }); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + const { contract } = await createAccountContract( + network, + 0, + Cairo0Contract, + ); + + const result = await contract.isRequireDeploy(); + + expect(result).toBe(expected); + }, + ); + + it.each([0, 1000000])( + 'returns false if the Cairo 1 contract is not deployed regardless if it has ETH or not - Balance: %s', + async (balance: number) => { + const { getVersionSpy } = mockAccountContractReader({ + balance: BigInt(balance), + }); + getVersionSpy.mockRejectedValue(new ContractNotDeployedError()); + + const { contract } = await createAccountContract( + network, + 0, + Cairo1Contract, + ); + + const result = await contract.isRequireDeploy(); + + expect(result).toBe(false); + }, + ); }); }); diff --git a/packages/starknet-snap/src/wallet/account/contract.ts b/packages/starknet-snap/src/wallet/account/contract.ts index 3f4704af..4a2c1a33 100644 --- a/packages/starknet-snap/src/wallet/account/contract.ts +++ b/packages/starknet-snap/src/wallet/account/contract.ts @@ -1,6 +1,7 @@ import type { Calldata } from 'starknet'; import { hash, addAddressPadding } from 'starknet'; +import { CAIRO_VERSION_LEGACY } from '../../utils/constants'; import { ContractNotDeployedError } from '../../utils/exceptions'; import { isGTEMinVersion } from '../../utils/starknetUtils'; import type { AccountContractReader } from './reader'; @@ -175,7 +176,9 @@ export abstract class CairoAccountContract { */ async isRequireDeploy(): Promise { return ( - !(await this.isDeployed()) && (await this.getEthBalance()) > BigInt(0) + this.cairoVerion.toString() === CAIRO_VERSION_LEGACY && + !(await this.isDeployed()) && + (await this.getEthBalance()) > BigInt(0) ); } diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index e71b001b..ee61b9e1 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -49,10 +49,7 @@ export class AccountContractDiscovery { return DefaultContractCtor.fromAccountContract(contract); } return contract; - } else if ( - contract instanceof Cairo0Contract && - (await contract.isRequireDeploy()) - ) { + } else if (await contract.isRequireDeploy()) { // It should only valid for Cairo 0 contract. // A Cairo 0 contract can only paying fee with ETH token. // Therefore if the contract is not deployed, and it has ETH token, we should use this contract. From a6e56b27e487a348d1f3bc6ae0c7813a20459f0d Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Tue, 14 Jan 2025 13:05:05 +0800 Subject: [PATCH 18/19] fix: add some detail comment on contract discovery --- .../starknet-snap/src/wallet/account/discovery.ts | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/starknet-snap/src/wallet/account/discovery.ts b/packages/starknet-snap/src/wallet/account/discovery.ts index ee61b9e1..c5da0725 100644 --- a/packages/starknet-snap/src/wallet/account/discovery.ts +++ b/packages/starknet-snap/src/wallet/account/discovery.ts @@ -23,10 +23,14 @@ export class AccountContractDiscovery { /** * Get the contract for the given public key. * The contract is determined based on the following rules: - * 1. If a contract is deployed, then use the deployed contract. - * 2. If no contract is deployed, but has balance, then use the contract with balance. - * 3. If neither contract is deployed or has balance, then use the default contract. - * 4. If multiple contracts are deployed, then use the default contract. + * + * 1. If a Cairo 1 contract has been deployed, it will always be used regardless of whether the other contract has a balance in ETH or has been deployed. + * 2. If a Cairo 0 contract has been deployed and the other contract has not, the Cairo 0 contract will always be used regardless of whether the other contract has a balance or not, and the contract will be forced to upgrade. + * 3. If neither contract has been deployed, but a Cairo 0 contract has a balance in ETH, it will always be used regardless of whether the other contract has a balance or not, and the contract will be forced to deploy. + * 3. If neither contract has been deployed and neither has a balance in ETH, the default contract (Cairo 1) will be used." + * + * Note: The rules accommodate for most use cases, except 1 edge case: + * - Due to rule #1, if a user wont able to operated a Cairo 0 contract if a Cairo 1 contract has been deployed. * * @param publicKey - The public key to get the contract for. * @returns The contract for the given public key. From cf01c6ce601c111e1c757e83522c959d885e83a6 Mon Sep 17 00:00:00 2001 From: stanleyyuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Tue, 14 Jan 2025 13:08:00 +0800 Subject: [PATCH 19/19] fix: comments on test title --- packages/starknet-snap/src/rpcs/add-account.test.ts | 2 +- .../starknet-snap/src/state/account-state-manager.test.ts | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/starknet-snap/src/rpcs/add-account.test.ts b/packages/starknet-snap/src/rpcs/add-account.test.ts index 14962a68..855dbc6b 100644 --- a/packages/starknet-snap/src/rpcs/add-account.test.ts +++ b/packages/starknet-snap/src/rpcs/add-account.test.ts @@ -30,7 +30,7 @@ describe('AddAccountRpc', () => { }; }; - it('add a `Account`', async () => { + it('add an `Account`', async () => { const { account, request, deriveAccountByIndexSpy } = await setupAddAccountTest(); diff --git a/packages/starknet-snap/src/state/account-state-manager.test.ts b/packages/starknet-snap/src/state/account-state-manager.test.ts index 186e2059..1085443d 100644 --- a/packages/starknet-snap/src/state/account-state-manager.test.ts +++ b/packages/starknet-snap/src/state/account-state-manager.test.ts @@ -55,7 +55,7 @@ describe('AccountStateManager', () => { expect(result).toBeNull(); }); - it('returns null if the account chainId is not match', async () => { + it('returns null if the account chainId does not match', async () => { const accounts = await generateTestnetAccounts(); await mockStateWithMainnetAccounts(accounts); @@ -70,7 +70,7 @@ describe('AccountStateManager', () => { }); describe('list', () => { - it('returns the list of account', async () => { + it('returns the list of accounts', async () => { const accountsInTestnet = await generateTestnetAccounts(); const accountsInMainnet = await generateMainnetAccounts(); @@ -108,7 +108,7 @@ describe('AccountStateManager', () => { }); describe('upsertAccount', () => { - it('adds an account if the account not exist', async () => { + it('adds an account if the account does not exist', async () => { const [account] = await generateTestnetAccounts(1); const state = await mockStateWithMainnetAccounts(); const originalAccountsFromState = [...state.accContracts]; @@ -209,7 +209,7 @@ describe('AccountStateManager', () => { ]); }); - it('throws an `Account does not exist` error if the removed account is not exist', async () => { + it('throws an `Account does not exist` error if the removed account does not exist', async () => { const [removeAccount, ...accounts] = await generateTestnetAccounts(); await mockStateWithMainnetAccounts(accounts);