Skip to content

Commit

Permalink
feat: sf 549 update all method to support upgraded address (#188)
Browse files Browse the repository at this point in the history
* fix: recoverAccount to avoid duplicate called on isUpgradeRequired

* chore: hardcode cairo version 1 to constant CAIRO_VERSION 

* fix: findAddressIndex should return address if the address match either cairo0 or cairo{N}

* fix: getCorrectContractAddress, use getVersion instead of getOwner to determine account has deployed or not

* chore: getCorrectContractAddress to return upgrade required attribute
  • Loading branch information
stanleyyconsensys authored Jan 5, 2024
1 parent bd45f53 commit c166dde
Show file tree
Hide file tree
Showing 14 changed files with 608 additions and 644 deletions.
2 changes: 1 addition & 1 deletion packages/starknet-snap/src/estimateFees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export async function estimateFees(params: ApiParams) {
senderAddress,
senderPrivateKey,
requestParamsObj.invocations,
requestParamsObj.invocationDetails ? requestParamsObj.invocationDetails : undefined,
requestParamsObj.invocationDetails,
);

return fees.map((fee) => ({
Expand Down
12 changes: 5 additions & 7 deletions packages/starknet-snap/src/recoverAccounts.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { toJson } from './utils/serializer';
import { num } from 'starknet';
import { getKeysFromAddressIndex, getCorrectContractAddress, isUpgradeRequired } from './utils/starknetUtils';
import { getKeysFromAddressIndex, getCorrectContractAddress } from './utils/starknetUtils';
import { getNetworkFromChainId, getValidNumber, upsertAccount } from './utils/snapUtils';
import { AccContract } from './types/snapState';
import { ApiParams, RecoverAccountsRequestParams } from './types/snapApi';
Expand Down Expand Up @@ -29,17 +29,15 @@ export async function recoverAccounts(params: ApiParams) {
state,
i,
);
const { address: contractAddress, signerPubKey: signerPublicKey } = await getCorrectContractAddress(
const { address: contractAddress, signerPubKey: signerPublicKey, upgradeRequired } = await getCorrectContractAddress(
network,
publicKey,
);
logger.log(`recoverAccounts: index ${i}:\ncontractAddress = ${contractAddress}\npublicKey = ${publicKey}`);
logger.log(`recoverAccounts: index ${i}:\ncontractAddress = ${contractAddress}\npublicKey = ${publicKey}\nisUpgradeRequired = ${upgradeRequired}`);

let _isUpgradeRequired = false;
if (signerPublicKey) {
_isUpgradeRequired = await isUpgradeRequired(network, contractAddress);
logger.log(
`recoverAccounts: index ${i}:\ncontractAddress = ${contractAddress}\nisUpgradeRequired = ${_isUpgradeRequired}`,
`recoverAccounts: index ${i}:\ncontractAddress = ${contractAddress}\n`,
);
if (num.toBigInt(signerPublicKey) === num.toBigInt(publicKey)) {
logger.log(`recoverAccounts: index ${i} matched\npublicKey: ${publicKey}`);
Expand All @@ -57,7 +55,7 @@ export async function recoverAccounts(params: ApiParams) {
derivationPath,
deployTxnHash: '',
chainId: network.chainId,
upgradeRequired: _isUpgradeRequired,
upgradeRequired: upgradeRequired,
};

logger.log(`recoverAccounts: index ${i}\nuserAccount: ${toJson(userAccount)}`);
Expand Down
4 changes: 4 additions & 0 deletions packages/starknet-snap/src/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,7 @@ export const PRELOADED_NETWORKS = [STARKNET_MAINNET_NETWORK, STARKNET_TESTNET_NE
export const PROXY_CONTRACT_HASH = '0x25ec026985a3bf9d0cc1fe17326b245dfdc3ff89b8fde106542a3ea56c5a918'; // for cairo 0 proxy contract

export const MIN_ACC_CONTRACT_VERSION = [0, 3, 0];

export const CAIRO_VERSION = '1';

export const CAIRO_VERSION_LEGACY = '0';
135 changes: 76 additions & 59 deletions packages/starknet-snap/src/utils/starknetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
Abi,
DeclareSignerDetails,
DeployAccountSignerDetails,
CairoVersion,
} from 'starknet';
import type { Hex } from '@noble/curves/abstract/utils';
import { Network, SnapState, Transaction, TransactionType } from '../types/snapState';
Expand All @@ -41,6 +42,8 @@ import {
MIN_ACC_CONTRACT_VERSION,
ACCOUNT_CLASS_HASH_V0,
ACCOUNT_CLASS_HASH_V1,
CAIRO_VERSION,
CAIRO_VERSION_LEGACY,
} from './constants';
import { getAddressKey } from './keyPair';
import { getAccount, getAccounts, getTransactionFromVoyagerUrl, getTransactionsFromVoyagerUrl } from './snapUtils';
Expand Down Expand Up @@ -110,7 +113,7 @@ export const estimateFee = async (
txnInvocation: Call | Call[],
): Promise<EstimateFee> => {
const provider = getProvider(network);
const account = new Account(provider, senderAddress, privateKey, '1');
const account = new Account(provider, senderAddress, privateKey, CAIRO_VERSION);
return account.estimateInvokeFee(txnInvocation, { blockIdentifier: 'latest' });
};

Expand All @@ -124,7 +127,7 @@ export const estimateFeeBulk = async (
// ensure always calling the sequencer endpoint since the rpc endpoint and
// starknet.js are not supported yet.
const provider = getProvider(network);
const account = new Account(provider, senderAddress, privateKey, '1');
const account = new Account(provider, senderAddress, privateKey, CAIRO_VERSION);
return account.estimateFeeBulk(txnInvocation, invocationsDetails);
};

Expand All @@ -137,7 +140,7 @@ export const executeTxn = async (
invocationsDetails?: InvocationsDetails,
): Promise<InvokeFunctionResponse> => {
const provider = getProvider(network);
const account = new Account(provider, senderAddress, privateKey, '1');
const account = new Account(provider, senderAddress, privateKey, CAIRO_VERSION);
return account.execute(txnInvocation, abis, invocationsDetails);
};

Expand All @@ -150,7 +153,7 @@ export const deployAccount = async (
maxFee: num.BigNumberish,
): Promise<DeployContractResponse> => {
const provider = getProvider(network);
const account = new Account(provider, contractAddress, privateKey, '1');
const account = new Account(provider, contractAddress, privateKey, CAIRO_VERSION);
const deployAccountPayload = {
classHash: ACCOUNT_CLASS_HASH_V1,
contractAddress: contractAddress,
Expand All @@ -168,7 +171,7 @@ export const estimateAccountDeployFee = async (
privateKey: string | Uint8Array,
): Promise<EstimateFee> => {
const provider = getProvider(network);
const account = new Account(provider, contractAddress, privateKey, '1');
const account = new Account(provider, contractAddress, privateKey, CAIRO_VERSION);
const deployAccountPayload = {
classHash: ACCOUNT_CLASS_HASH_V1,
contractAddress: contractAddress,
Expand All @@ -193,6 +196,10 @@ export const getOwner = async (userAccAddress: string, network: Network): Promis
return resp.result[0];
};

export const getContractOwner = async (userAccAddress: string, network: Network, version: CairoVersion): Promise<string> => {
return version === '0' ? getSigner(userAccAddress, network) : getOwner(userAccAddress, network);
};

export const getBalance = async (address: string, tokenAddress: string, network: Network) => {
const resp = await callContract(network, tokenAddress, 'balanceOf', [num.toBigInt(address).toString(10)]);
return resp.result[0];
Expand Down Expand Up @@ -423,7 +430,7 @@ export const getNextAddressIndex = (chainId: string, state: SnapState, derivatio
};

/**
* calculate contract address by publicKey, supported for carioVersions [1]
* calculate contract address by publicKey
*
* @param publicKey - address's publicKey.
* @returns - address and calldata.
Expand All @@ -446,12 +453,12 @@ export const getAccContractAddressAndCallData = (publicKey) => {
};

/**
* calculate contract address by publicKey, supported for carioVersions [0]
* calculate contract address by publicKey
*
* @param publicKey - address's publicKey.
* @returns - address and calldata.
*/
export const getAccContractAddressAndCallDataCairo0 = (publicKey) => {
export const getAccContractAddressAndCallDataLegacy = (publicKey) => {
const callData = CallData.compile({
implementation: ACCOUNT_CLASS_HASH_V0,
selector: hash.getSelectorFromName('initialize'),
Expand Down Expand Up @@ -510,7 +517,7 @@ export const getKeysFromAddressIndex = async (
};

/**
* Check address is deployed by using getVersion, supported for carioVersions [0,1]
* Check address is deployed by using getVersion
*
* @param network - Network.
* @param address - Input address.
Expand Down Expand Up @@ -552,7 +559,7 @@ export const validateAndParseAddress = (address: num.BigNumberish, length = 63)
};

/**
* Find address index from the keyDeriver, supported for carioVersions [0,1]
* Find address index from the keyDeriver
*
* @param chainId - Network ChainId.
* @param address - Input address.
Expand All @@ -571,9 +578,9 @@ export const findAddressIndex = async (
const bigIntAddress = num.toBigInt(address);
for (let i = 0; i < maxScan; i++) {
const { publicKey } = await getKeysFromAddressIndex(keyDeriver, chainId, state, i);
const { address: calculatedAddress } = getAccContractAddressAndCallData(publicKey);
const { address: calculatedAddressCairo0 } = getAccContractAddressAndCallDataCairo0(publicKey);
if (num.toBigInt(calculatedAddress) === bigIntAddress || num.toBigInt(calculatedAddressCairo0) === bigIntAddress) {
const { address: calculatedAddress, addressLegacy: calculatedAddressLegacy } = getPermutationAddresses(publicKey);

if (num.toBigInt(calculatedAddress) === bigIntAddress || num.toBigInt(calculatedAddressLegacy) === bigIntAddress) {
logger.log(`findAddressIndex:\nFound address in scan: ${i} ${address}`);
return {
index: i,
Expand All @@ -585,7 +592,23 @@ export const findAddressIndex = async (
};

/**
* Check address needed upgrade by using getVersion and compare with MIN_ACC_CONTRACT_VERSION, supported for carioVersions [0,1]
* Get address permutation by public key
*
* @param pk - Public key.
* @returns - address and addressLegacy.
*/
export const getPermutationAddresses = (pk:string) => {
const { address } = getAccContractAddressAndCallData(pk);
const { address: addressLegacy } = getAccContractAddressAndCallDataLegacy(pk);

return {
address,
addressLegacy
}
}

/**
* Check address needed upgrade by using getVersion and compare with MIN_ACC_CONTRACT_VERSION
*
* @param network - Network.
* @param address - Input address.
Expand All @@ -595,80 +618,74 @@ export const isUpgradeRequired = async (network: Network, address: string) => {
try {
logger.log(`isUpgradeRequired: address = ${address}`);
const hexResp = await getVersion(address, network);
const version = hexToString(hexResp);
logger.log(`isUpgradeRequired: hexResp = ${hexResp}, version = ${version}`);
const versionArr = version.split('.');
return Number(versionArr[1]) < MIN_ACC_CONTRACT_VERSION[1];
return isGTEMinVersion(hexToString(hexResp)) ? false : true;
} catch (err) {
if (!err.message.includes('Contract not found')) {
throw err;
}
logger.error(`isUpgradeRequired: error:`, err);
//[TODO] if address is cario0 but not deployed we should throw error
//[TODO] if address is cairo0 but not deployed we throw error
return false;
}
};

/**
* Get user address by public key, return address if the address has deployed, prioritize cario 1 over cario 0, supported for carioVersions [0,1]
* Compare version number with MIN_ACC_CONTRACT_VERSION
*
* @param version - version, e.g (2.3.0).
* @returns - boolean.
*/
export const isGTEMinVersion = (version: string) => {
logger.log(`isGTEMinVersion: version = ${version}`);
const versionArr = version.split('.');
return Number(versionArr[1]) >= MIN_ACC_CONTRACT_VERSION[1];
};

/**
* Get user address by public key, return address if the address has deployed
*
* @param network - Network.
* @param publicKey - address's public key.
* @returns - address and address's public key.
*/
export const getCorrectContractAddress = async (network: Network, publicKey: string) => {
const { address: contractAddress } = getAccContractAddressAndCallData(publicKey);
const { address: contractAddressCairo0 } = getAccContractAddressAndCallDataCairo0(publicKey);
let pk = '';
const {address: contractAddress, addressLegacy: contractAddressLegacy} = getPermutationAddresses(publicKey)

logger.log(
`getContractAddressByKey: contractAddressCario1 = ${contractAddress}\ncontractAddressCairo0 = ${contractAddressCairo0}\npublicKey = ${publicKey}`,
`getContractAddressByKey: contractAddress = ${contractAddress}\ncontractAddressLegacy = ${contractAddressLegacy}\npublicKey = ${publicKey}`,
);

//test if it is a cairo 1 account
let address = contractAddress;
let upgradeRequired = false;
let pk = '';

try {
pk = await getOwner(contractAddress, network);
logger.log(`getContractAddressByKey: cairo 1 contract found`);
} catch (err) {
if (!err.message.includes('Contract not found')) {
throw err;
await getVersion(contractAddress, network);
pk = await getContractOwner(address, network, CAIRO_VERSION);
} catch (e) {
if (!e.message.includes('Contract not found')) {
throw e;
}
logger.log(`getContractAddressByKey: cairo 1 contract not found`);

logger.log(`getContractAddressByKey: cairo ${CAIRO_VERSION} contract cant found, try cairo ${CAIRO_VERSION_LEGACY}`);

//test if it is a upgraded cairo 0 account
try {
pk = await getOwner(contractAddressCairo0, network);
logger.log(`getContractAddressByKey: upgraded cairo 0 contract found`);
return {
address: contractAddressCairo0,
signerPubKey: pk,
};
} catch (err) {
if (!err.message.includes('Contract not found')) {
throw err;
const version = await getVersion(contractAddressLegacy, network);
upgradeRequired = isGTEMinVersion(hexToString(version)) ? false : true;
pk = await getContractOwner(contractAddressLegacy, network, upgradeRequired ? CAIRO_VERSION_LEGACY : CAIRO_VERSION);
address = contractAddressLegacy
} catch (e) {
if (!e.message.includes('Contract not found')) {
throw e;
}
logger.log(`getContractAddressByKey: upgraded cairo 0 contract not found`);
}

//test if it is a deployed cairo 0 account
try {
pk = await getSigner(contractAddressCairo0, network);
logger.log(`getContractAddressByKey: cairo 0 contract found`);
return {
address: contractAddressCairo0,
signerPubKey: pk,
};
} catch (err) {
if (!err.message.includes('Contract not found')) {
throw err;
}
logger.log(`getContractAddressByKey: cairo 0 contract not found`);
logger.log(`getContractAddressByKey: no deployed contract found, fallback to cairo ${CAIRO_VERSION}`);
}
}

//return new/deployed cairo 1 account
return {
address: contractAddress,
address,
signerPubKey: pk,
upgradeRequired: upgradeRequired,
};
};

Expand Down
4 changes: 2 additions & 2 deletions packages/starknet-snap/test/constants.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export const account4: AccContract = {
chainId: constants.StarknetChainId.SN_GOERLI,
};

export const Cario1Account1: AccContract = {
export const Cairo1Account1: AccContract = {
address: '0x043e3d703b005b8367a9783fb680713349c519202aa01e9beb170bdf710ae20b',
addressSalt: '0x019e59f349e1aa813ab4556c5836d0472e5e1ae82d1e5c3b3e8aabfeb290befd',
addressIndex: 1,
Expand Down Expand Up @@ -99,7 +99,7 @@ export const signature1 =
export const signature2 =
'30440220052956ac852275b6004c4e8042450f6dce83059f068029b037cc47338c80d062022002bc0e712f03e341bb3532fc356b779d84fcb4dbfe8ed34de2db66e121971d92';

export const signature4Cario1SignMessage = [
export const signature4Cairo1SignMessage = [
'2941323345698930086258187297320132789256148405011604592758945785805412997864',
'1024747634926675542679366527128384456926978174336360356924884281219915547518',
];
Expand Down
Loading

0 comments on commit c166dde

Please sign in to comment.