diff --git a/packages/starknet-snap/src/index.test.tsx b/packages/starknet-snap/src/index.test.tsx index 2426ddf5..da3181cb 100644 --- a/packages/starknet-snap/src/index.test.tsx +++ b/packages/starknet-snap/src/index.test.tsx @@ -4,6 +4,7 @@ import { onHomePage, onRpcRequest } from '.'; import * as createAccountApi from './createAccount'; import { HomePageController } from './on-home-page'; import * as keyPairUtils from './utils/keyPair'; +import * as permissionUtil from './utils/permission'; jest.mock('./utils/logger'); @@ -41,7 +42,11 @@ describe('onRpcRequest', () => { expect(createAccountSpy).toHaveBeenCalledTimes(1); }); + // It is a never case, as the permission of each method is checked in the `validateOrigin` function. + // But to increase the coverage, we keep this test case. it('throws `MethodNotFoundError` if the request method not found', async () => { + jest.spyOn(permissionUtil, 'validateOrigin').mockReturnThis(); + await expect( onRpcRequest({ ...createMockRequest(), diff --git a/packages/starknet-snap/src/index.tsx b/packages/starknet-snap/src/index.tsx index d43e696a..0b46b166 100644 --- a/packages/starknet-snap/src/index.tsx +++ b/packages/starknet-snap/src/index.tsx @@ -10,17 +10,14 @@ import type { import { MethodNotFoundError } from '@metamask/snaps-sdk'; import { Box, Link, Text } from '@metamask/snaps-sdk/jsx'; -import { addNetwork } from './addNetwork'; import { Config } from './config'; import { createAccount } from './createAccount'; -import { estimateAccDeployFee } from './estimateAccountDeployFee'; import { extractPublicKey } from './extractPublicKey'; import { getCurrentNetwork } from './getCurrentNetwork'; import { getErc20TokenBalance } from './getErc20TokenBalance'; import { getStarkName } from './getStarkName'; import { getStoredErc20Tokens } from './getStoredErc20Tokens'; import { getStoredNetworks } from './getStoredNetworks'; -import { getStoredTransactions } from './getStoredTransactions'; import { getStoredUserAccounts } from './getStoredUserAccounts'; import { getTransactions } from './getTransactions'; import { getValue } from './getValue'; @@ -81,6 +78,7 @@ import { UnknownError } from './utils/exceptions'; import { getAddressKeyDeriver } from './utils/keyPair'; import { acquireLock } from './utils/lock'; import { logger } from './utils/logger'; +import { RpcMethod, validateOrigin } from './utils/permission'; import { toJson } from './utils/serializer'; import { upsertErc20Token, @@ -91,12 +89,17 @@ import { declare const snap; logger.logLevel = parseInt(Config.logLevel, 10); -export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => { +export const onRpcRequest: OnRpcRequestHandler = async ({ + origin, + request, +}) => { const requestParams = request?.params as unknown as ApiRequestParams; logger.log(`${request.method}:\nrequestParams: ${toJson(requestParams)}`); try { + validateOrigin(origin, request.method); + if (request.method === 'ping') { logger.log('pong'); return 'pong'; @@ -142,13 +145,13 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => { }; switch (request.method) { - case 'starkNet_createAccount': + case RpcMethod.CreateAccount: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await createAccount( apiParams as unknown as ApiParamsWithKeyDeriver, ); - case 'starkNet_createAccountLegacy': + case RpcMethod.DeployCario0Account: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await createAccount( apiParams as unknown as ApiParamsWithKeyDeriver, @@ -157,123 +160,111 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => { CAIRO_VERSION_LEGACY, ); - case 'starkNet_getStoredUserAccounts': + case RpcMethod.ListAccounts: return await getStoredUserAccounts(apiParams); - case 'starkNet_displayPrivateKey': + case RpcMethod.DisplayPrivateKey: return await displayPrivateKey.execute( apiParams.requestParams as unknown as DisplayPrivateKeyParams, ); - case 'starkNet_extractPublicKey': + case RpcMethod.ExtractPublicKey: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await extractPublicKey( apiParams as unknown as ApiParamsWithKeyDeriver, ); - case 'starkNet_signMessage': + case RpcMethod.SignMessage: return await signMessage.execute( apiParams.requestParams as unknown as SignMessageParams, ); - case 'starkNet_signTransaction': + case RpcMethod.SignTransaction: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await signTransaction.execute( apiParams.requestParams as unknown as SignTransactionParams, ); - case 'starkNet_signDeclareTransaction': + case RpcMethod.SignDeclareTransaction: return await signDeclareTransaction.execute( apiParams.requestParams as unknown as SignDeclareTransactionParams, ); - case 'starkNet_signDeployAccountTransaction': + case RpcMethod.SignDeployAccountTransaction: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await signDeployAccountTransaction( apiParams as unknown as ApiParamsWithKeyDeriver, ); - case 'starkNet_verifySignedMessage': + case RpcMethod.VerifySignedMessage: return await verifySignature.execute( apiParams.requestParams as unknown as VerifySignatureParams, ); - case 'starkNet_getErc20TokenBalance': + case RpcMethod.GetErc20TokenBalance: return await getErc20TokenBalance(apiParams); - case 'starkNet_getTransactionStatus': + case RpcMethod.GetTransactionStatus: return await getTransactionStatus.execute( apiParams.requestParams as unknown as GetTransactionStatusParams, ); - case 'starkNet_getValue': + case RpcMethod.ReadContract: return await getValue(apiParams); - case 'starkNet_estimateFee': + case RpcMethod.EstimateFee: return await estimateFee.execute( apiParams.requestParams as unknown as EstimateFeeParams, ); - case 'starkNet_estimateAccountDeployFee': - apiParams.keyDeriver = await getAddressKeyDeriver(snap); - return await estimateAccDeployFee( - apiParams as unknown as ApiParamsWithKeyDeriver, - ); - - case 'starkNet_addErc20Token': + case RpcMethod.AddErc20Token: return await watchAsset.execute( apiParams.requestParams as unknown as WatchAssetParams, ); - case 'starkNet_getStoredErc20Tokens': + case RpcMethod.GetStoredErc20Tokens: return await getStoredErc20Tokens(apiParams); - case 'starkNet_addNetwork': - return await addNetwork(apiParams); - - case 'starkNet_switchNetwork': + case RpcMethod.SwitchNetwork: return await switchNetwork.execute( apiParams.requestParams as unknown as SwitchNetworkParams, ); - case 'starkNet_getCurrentNetwork': + case RpcMethod.GetCurrentNetwork: return await getCurrentNetwork(apiParams); - case 'starkNet_getStoredNetworks': + case RpcMethod.GetStoredNetworks: return await getStoredNetworks(apiParams); - case 'starkNet_getStoredTransactions': - return await getStoredTransactions(apiParams); - - case 'starkNet_getTransactions': + case RpcMethod.GetTransactions: return await getTransactions(apiParams); - case 'starkNet_recoverAccounts': + case RpcMethod.RecoverAccounts: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return await recoverAccounts( apiParams as unknown as ApiParamsWithKeyDeriver, ); - case 'starkNet_executeTxn': + case RpcMethod.ExecuteTxn: return await executeTxn.execute( apiParams.requestParams as unknown as ExecuteTxnParams, ); - case 'starkNet_upgradeAccContract': + case RpcMethod.UpgradeAccContract: apiParams.keyDeriver = await getAddressKeyDeriver(snap); return upgradeAccContract( apiParams as unknown as ApiParamsWithKeyDeriver, ); - case 'starkNet_declareContract': + case RpcMethod.DeclareContract: return await declareContract.execute( apiParams.requestParams as unknown as DeclareContractParams, ); - case 'starkNet_getStarkName': + case RpcMethod.GetStarkName: return await getStarkName(apiParams); - case 'starkNet_getDeploymentData': + case RpcMethod.GetDeploymentData: return await getDeploymentData.execute( apiParams.requestParams as unknown as GetDeploymentDataParams, ); diff --git a/packages/starknet-snap/src/utils/permission.test.ts b/packages/starknet-snap/src/utils/permission.test.ts new file mode 100644 index 00000000..9d6a0f93 --- /dev/null +++ b/packages/starknet-snap/src/utils/permission.test.ts @@ -0,0 +1,59 @@ +import { originPermissions, validateOrigin, RpcMethod } from './permission'; + +describe('validateOrigin', () => { + const walletUIDappPermissions = Array.from( + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + originPermissions.get('https://snaps.consensys.io')!, + ); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const publicPermissions = Array.from(originPermissions.get('*')!); + const restrictedPermissions = [ + RpcMethod.DeployCario0Account, + RpcMethod.ListAccounts, + RpcMethod.GetTransactions, + RpcMethod.UpgradeAccContract, + RpcMethod.GetStarkName, + RpcMethod.ReadContract, + RpcMethod.GetStoredErc20Tokens, + ]; + + it.each(walletUIDappPermissions)( + "pass the validation with a valid Wallet UI Dapp's origin and a whitelisted method. method - %s", + (method: string) => { + expect(() => + validateOrigin('https://snaps.consensys.io', method), + ).not.toThrow(); + }, + ); + + it.each(publicPermissions)( + 'pass the validation with any origin and a whitelisted method. method - %s', + (method: string) => { + expect(() => validateOrigin('https://any.io', method)).not.toThrow(); + }, + ); + + it.each(restrictedPermissions)( + 'throw a `Permission denied` if the method is restricted for public. method - %s', + (method: string) => { + expect(() => validateOrigin('https://any.io', method)).toThrow( + 'Permission denied', + ); + }, + ); + + it('throw a `Permission denied` if the method is not exist.', () => { + expect(() => validateOrigin('https://any.io', 'method_not_exist')).toThrow( + 'Permission denied', + ); + expect(() => + validateOrigin('https://snaps.consensys.io', 'method_not_exist'), + ).toThrow('Permission denied'); + }); + + it('throw a `Origin not found` if the orgin is not given or empty.', () => { + expect(() => validateOrigin('', 'method_not_exist')).toThrow( + 'Origin not found', + ); + }); +}); diff --git a/packages/starknet-snap/src/utils/permission.ts b/packages/starknet-snap/src/utils/permission.ts new file mode 100644 index 00000000..216452ef --- /dev/null +++ b/packages/starknet-snap/src/utils/permission.ts @@ -0,0 +1,104 @@ +import { UnauthorizedError } from '@metamask/snaps-sdk'; + +export enum RpcMethod { + ExtractPublicKey = 'starkNet_extractPublicKey', + GetCurrentNetwork = 'starkNet_getCurrentNetwork', + GetStoredNetworks = 'starkNet_getStoredNetworks', + SwitchNetwork = 'starkNet_switchNetwork', + AddErc20Token = 'starkNet_addErc20Token', + RecoverAccounts = 'starkNet_recoverAccounts', + ExecuteTxn = 'starkNet_executeTxn', + DeclareContract = 'starkNet_declareContract', + GetDeploymentData = 'starkNet_getDeploymentData', + SignMessage = 'starkNet_signMessage', + SignTransaction = 'starkNet_signTransaction', + SignDeclareTransaction = 'starkNet_signDeclareTransaction', + SignDeployAccountTransaction = 'starkNet_signDeployAccountTransaction', + + CreateAccount = 'starkNet_createAccount', + DisplayPrivateKey = 'starkNet_displayPrivateKey', + GetErc20TokenBalance = 'starkNet_getErc20TokenBalance', + GetTransactionStatus = 'starkNet_getTransactionStatus', + EstimateFee = 'starkNet_estimateFee', + VerifySignedMessage = 'starkNet_verifySignedMessage', + DeployCario0Account = 'starkNet_createAccountLegacy', + ListAccounts = 'starkNet_getStoredUserAccounts', + GetTransactions = 'starkNet_getTransactions', + UpgradeAccContract = 'starkNet_upgradeAccContract', + GetStarkName = 'starkNet_getStarkName', + ReadContract = 'starkNet_getValue', + GetStoredErc20Tokens = 'starkNet_getStoredErc20Tokens', +} +// RpcMethod that are allowed to be called by any origin +const publicPermissions = [ + RpcMethod.ExtractPublicKey, + RpcMethod.GetCurrentNetwork, + RpcMethod.GetStoredNetworks, + RpcMethod.SwitchNetwork, + RpcMethod.AddErc20Token, + RpcMethod.RecoverAccounts, + RpcMethod.ExecuteTxn, + RpcMethod.DeclareContract, + RpcMethod.GetDeploymentData, + RpcMethod.SignMessage, + RpcMethod.SignTransaction, + RpcMethod.SignDeclareTransaction, + RpcMethod.SignDeployAccountTransaction, + RpcMethod.CreateAccount, + RpcMethod.DisplayPrivateKey, + RpcMethod.GetErc20TokenBalance, + RpcMethod.GetTransactionStatus, + RpcMethod.EstimateFee, + RpcMethod.VerifySignedMessage, +]; +// RpcMethod that are restricted to be called by wallet UI origins +const walletUIDappPermissions = publicPermissions.concat([ + RpcMethod.DeployCario0Account, + RpcMethod.ListAccounts, + RpcMethod.GetTransactions, + RpcMethod.UpgradeAccContract, + RpcMethod.GetStarkName, + RpcMethod.ReadContract, + RpcMethod.GetStoredErc20Tokens, +]); + +const publicPermissionsSet = new Set(publicPermissions); +const walletUIDappPermissionsSet = new Set(walletUIDappPermissions); + +const walletUIDappOrigins = [ + 'http://localhost:3000', + 'https://snaps.consensys.io', + 'https://dev.snaps.consensys.io', + 'https://staging.snaps.consensys.io', +]; + +export const originPermissions = new Map>([]); +for (const origin of walletUIDappOrigins) { + originPermissions.set(origin, walletUIDappPermissionsSet); +} +originPermissions.set('*', publicPermissionsSet); + +/** + * Validate the origin and method pair. + * If the origin is not found or the method is not allowed, throw an error. + * + * @param origin - The origin of the request. + * @param method - The method of the request. + * @throws {UnauthorizedError} If the origin is not found or the method is not allowed. + */ +export function validateOrigin(origin: string, method: string): void { + if (!origin) { + // eslint-disable-next-line @typescript-eslint/no-throw-literal + throw new UnauthorizedError('Origin not found'); + } + // As public permissions are a subset of wallet UI Dapp permissions, + // If the origin and method pair are not in the wallet UI Dapp permissions, + // then fallback and validate whether it hits the common permission. + if ( + !originPermissions.get(origin)?.has(method) && + !originPermissions.get('*')?.has(method) + ) { + // eslint-disable-next-line @typescript-eslint/no-throw-literal + throw new UnauthorizedError(`Permission denied`); + } +}