Skip to content

Commit

Permalink
feat: add permission boundary (#448)
Browse files Browse the repository at this point in the history
* feat: add permission boundary

* chore: add mock for index test
  • Loading branch information
stanleyyconsensys authored Dec 3, 2024
1 parent 5d569a8 commit 1e80070
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 42 deletions.
5 changes: 5 additions & 0 deletions packages/starknet-snap/src/index.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { onHomePage, onRpcRequest } from '.';
import * as createAccountApi from './createAccount';
import { HomePageController } from './on-home-page';
import * as keyPairUtils from './utils/keyPair';
import * as permissionUtil from './utils/permission';

jest.mock('./utils/logger');

Expand Down Expand Up @@ -41,7 +42,11 @@ describe('onRpcRequest', () => {
expect(createAccountSpy).toHaveBeenCalledTimes(1);
});

// It is a never case, as the permission of each method is checked in the `validateOrigin` function.
// But to increase the coverage, we keep this test case.
it('throws `MethodNotFoundError` if the request method not found', async () => {
jest.spyOn(permissionUtil, 'validateOrigin').mockReturnThis();

await expect(
onRpcRequest({
...createMockRequest(),
Expand Down
75 changes: 33 additions & 42 deletions packages/starknet-snap/src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,14 @@ import type {
import { MethodNotFoundError } from '@metamask/snaps-sdk';
import { Box, Link, Text } from '@metamask/snaps-sdk/jsx';

import { addNetwork } from './addNetwork';
import { Config } from './config';
import { createAccount } from './createAccount';
import { estimateAccDeployFee } from './estimateAccountDeployFee';
import { extractPublicKey } from './extractPublicKey';
import { getCurrentNetwork } from './getCurrentNetwork';
import { getErc20TokenBalance } from './getErc20TokenBalance';
import { getStarkName } from './getStarkName';
import { getStoredErc20Tokens } from './getStoredErc20Tokens';
import { getStoredNetworks } from './getStoredNetworks';
import { getStoredTransactions } from './getStoredTransactions';
import { getStoredUserAccounts } from './getStoredUserAccounts';
import { getTransactions } from './getTransactions';
import { getValue } from './getValue';
Expand Down Expand Up @@ -81,6 +78,7 @@ import { UnknownError } from './utils/exceptions';
import { getAddressKeyDeriver } from './utils/keyPair';
import { acquireLock } from './utils/lock';
import { logger } from './utils/logger';
import { RpcMethod, validateOrigin } from './utils/permission';
import { toJson } from './utils/serializer';
import {
upsertErc20Token,
Expand All @@ -91,12 +89,17 @@ import {
declare const snap;
logger.logLevel = parseInt(Config.logLevel, 10);

export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
export const onRpcRequest: OnRpcRequestHandler = async ({
origin,
request,
}) => {
const requestParams = request?.params as unknown as ApiRequestParams;

logger.log(`${request.method}:\nrequestParams: ${toJson(requestParams)}`);

try {
validateOrigin(origin, request.method);

if (request.method === 'ping') {
logger.log('pong');
return 'pong';
Expand Down Expand Up @@ -142,13 +145,13 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
};

switch (request.method) {
case 'starkNet_createAccount':
case RpcMethod.CreateAccount:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await createAccount(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_createAccountLegacy':
case RpcMethod.DeployCario0Account:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await createAccount(
apiParams as unknown as ApiParamsWithKeyDeriver,
Expand All @@ -157,123 +160,111 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => {
CAIRO_VERSION_LEGACY,
);

case 'starkNet_getStoredUserAccounts':
case RpcMethod.ListAccounts:
return await getStoredUserAccounts(apiParams);

case 'starkNet_displayPrivateKey':
case RpcMethod.DisplayPrivateKey:
return await displayPrivateKey.execute(
apiParams.requestParams as unknown as DisplayPrivateKeyParams,
);

case 'starkNet_extractPublicKey':
case RpcMethod.ExtractPublicKey:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await extractPublicKey(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_signMessage':
case RpcMethod.SignMessage:
return await signMessage.execute(
apiParams.requestParams as unknown as SignMessageParams,
);

case 'starkNet_signTransaction':
case RpcMethod.SignTransaction:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await signTransaction.execute(
apiParams.requestParams as unknown as SignTransactionParams,
);

case 'starkNet_signDeclareTransaction':
case RpcMethod.SignDeclareTransaction:
return await signDeclareTransaction.execute(
apiParams.requestParams as unknown as SignDeclareTransactionParams,
);

case 'starkNet_signDeployAccountTransaction':
case RpcMethod.SignDeployAccountTransaction:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await signDeployAccountTransaction(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_verifySignedMessage':
case RpcMethod.VerifySignedMessage:
return await verifySignature.execute(
apiParams.requestParams as unknown as VerifySignatureParams,
);

case 'starkNet_getErc20TokenBalance':
case RpcMethod.GetErc20TokenBalance:
return await getErc20TokenBalance(apiParams);

case 'starkNet_getTransactionStatus':
case RpcMethod.GetTransactionStatus:
return await getTransactionStatus.execute(
apiParams.requestParams as unknown as GetTransactionStatusParams,
);

case 'starkNet_getValue':
case RpcMethod.ReadContract:
return await getValue(apiParams);

case 'starkNet_estimateFee':
case RpcMethod.EstimateFee:
return await estimateFee.execute(
apiParams.requestParams as unknown as EstimateFeeParams,
);

case 'starkNet_estimateAccountDeployFee':
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await estimateAccDeployFee(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_addErc20Token':
case RpcMethod.AddErc20Token:
return await watchAsset.execute(
apiParams.requestParams as unknown as WatchAssetParams,
);

case 'starkNet_getStoredErc20Tokens':
case RpcMethod.GetStoredErc20Tokens:
return await getStoredErc20Tokens(apiParams);

case 'starkNet_addNetwork':
return await addNetwork(apiParams);

case 'starkNet_switchNetwork':
case RpcMethod.SwitchNetwork:
return await switchNetwork.execute(
apiParams.requestParams as unknown as SwitchNetworkParams,
);

case 'starkNet_getCurrentNetwork':
case RpcMethod.GetCurrentNetwork:
return await getCurrentNetwork(apiParams);

case 'starkNet_getStoredNetworks':
case RpcMethod.GetStoredNetworks:
return await getStoredNetworks(apiParams);

case 'starkNet_getStoredTransactions':
return await getStoredTransactions(apiParams);

case 'starkNet_getTransactions':
case RpcMethod.GetTransactions:
return await getTransactions(apiParams);

case 'starkNet_recoverAccounts':
case RpcMethod.RecoverAccounts:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return await recoverAccounts(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_executeTxn':
case RpcMethod.ExecuteTxn:
return await executeTxn.execute(
apiParams.requestParams as unknown as ExecuteTxnParams,
);

case 'starkNet_upgradeAccContract':
case RpcMethod.UpgradeAccContract:
apiParams.keyDeriver = await getAddressKeyDeriver(snap);
return upgradeAccContract(
apiParams as unknown as ApiParamsWithKeyDeriver,
);

case 'starkNet_declareContract':
case RpcMethod.DeclareContract:
return await declareContract.execute(
apiParams.requestParams as unknown as DeclareContractParams,
);

case 'starkNet_getStarkName':
case RpcMethod.GetStarkName:
return await getStarkName(apiParams);

case 'starkNet_getDeploymentData':
case RpcMethod.GetDeploymentData:
return await getDeploymentData.execute(
apiParams.requestParams as unknown as GetDeploymentDataParams,
);
Expand Down
59 changes: 59 additions & 0 deletions packages/starknet-snap/src/utils/permission.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { originPermissions, validateOrigin, RpcMethod } from './permission';

describe('validateOrigin', () => {
const walletUIDappPermissions = Array.from(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
originPermissions.get('https://snaps.consensys.io')!,
);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const publicPermissions = Array.from(originPermissions.get('*')!);
const restrictedPermissions = [
RpcMethod.DeployCario0Account,
RpcMethod.ListAccounts,
RpcMethod.GetTransactions,
RpcMethod.UpgradeAccContract,
RpcMethod.GetStarkName,
RpcMethod.ReadContract,
RpcMethod.GetStoredErc20Tokens,
];

it.each(walletUIDappPermissions)(
"pass the validation with a valid Wallet UI Dapp's origin and a whitelisted method. method - %s",
(method: string) => {
expect(() =>
validateOrigin('https://snaps.consensys.io', method),
).not.toThrow();
},
);

it.each(publicPermissions)(
'pass the validation with any origin and a whitelisted method. method - %s',
(method: string) => {
expect(() => validateOrigin('https://any.io', method)).not.toThrow();
},
);

it.each(restrictedPermissions)(
'throw a `Permission denied` if the method is restricted for public. method - %s',
(method: string) => {
expect(() => validateOrigin('https://any.io', method)).toThrow(
'Permission denied',
);
},
);

it('throw a `Permission denied` if the method is not exist.', () => {
expect(() => validateOrigin('https://any.io', 'method_not_exist')).toThrow(
'Permission denied',
);
expect(() =>
validateOrigin('https://snaps.consensys.io', 'method_not_exist'),
).toThrow('Permission denied');
});

it('throw a `Origin not found` if the orgin is not given or empty.', () => {
expect(() => validateOrigin('', 'method_not_exist')).toThrow(
'Origin not found',
);
});
});
104 changes: 104 additions & 0 deletions packages/starknet-snap/src/utils/permission.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import { UnauthorizedError } from '@metamask/snaps-sdk';

export enum RpcMethod {
ExtractPublicKey = 'starkNet_extractPublicKey',
GetCurrentNetwork = 'starkNet_getCurrentNetwork',
GetStoredNetworks = 'starkNet_getStoredNetworks',
SwitchNetwork = 'starkNet_switchNetwork',
AddErc20Token = 'starkNet_addErc20Token',
RecoverAccounts = 'starkNet_recoverAccounts',
ExecuteTxn = 'starkNet_executeTxn',
DeclareContract = 'starkNet_declareContract',
GetDeploymentData = 'starkNet_getDeploymentData',
SignMessage = 'starkNet_signMessage',
SignTransaction = 'starkNet_signTransaction',
SignDeclareTransaction = 'starkNet_signDeclareTransaction',
SignDeployAccountTransaction = 'starkNet_signDeployAccountTransaction',

CreateAccount = 'starkNet_createAccount',
DisplayPrivateKey = 'starkNet_displayPrivateKey',
GetErc20TokenBalance = 'starkNet_getErc20TokenBalance',
GetTransactionStatus = 'starkNet_getTransactionStatus',
EstimateFee = 'starkNet_estimateFee',
VerifySignedMessage = 'starkNet_verifySignedMessage',
DeployCario0Account = 'starkNet_createAccountLegacy',
ListAccounts = 'starkNet_getStoredUserAccounts',
GetTransactions = 'starkNet_getTransactions',
UpgradeAccContract = 'starkNet_upgradeAccContract',
GetStarkName = 'starkNet_getStarkName',
ReadContract = 'starkNet_getValue',
GetStoredErc20Tokens = 'starkNet_getStoredErc20Tokens',
}
// RpcMethod that are allowed to be called by any origin
const publicPermissions = [
RpcMethod.ExtractPublicKey,
RpcMethod.GetCurrentNetwork,
RpcMethod.GetStoredNetworks,
RpcMethod.SwitchNetwork,
RpcMethod.AddErc20Token,
RpcMethod.RecoverAccounts,
RpcMethod.ExecuteTxn,
RpcMethod.DeclareContract,
RpcMethod.GetDeploymentData,
RpcMethod.SignMessage,
RpcMethod.SignTransaction,
RpcMethod.SignDeclareTransaction,
RpcMethod.SignDeployAccountTransaction,
RpcMethod.CreateAccount,
RpcMethod.DisplayPrivateKey,
RpcMethod.GetErc20TokenBalance,
RpcMethod.GetTransactionStatus,
RpcMethod.EstimateFee,
RpcMethod.VerifySignedMessage,
];
// RpcMethod that are restricted to be called by wallet UI origins
const walletUIDappPermissions = publicPermissions.concat([
RpcMethod.DeployCario0Account,
RpcMethod.ListAccounts,
RpcMethod.GetTransactions,
RpcMethod.UpgradeAccContract,
RpcMethod.GetStarkName,
RpcMethod.ReadContract,
RpcMethod.GetStoredErc20Tokens,
]);

const publicPermissionsSet = new Set(publicPermissions);
const walletUIDappPermissionsSet = new Set(walletUIDappPermissions);

const walletUIDappOrigins = [
'http://localhost:3000',
'https://snaps.consensys.io',
'https://dev.snaps.consensys.io',
'https://staging.snaps.consensys.io',
];

export const originPermissions = new Map<string, Set<string>>([]);
for (const origin of walletUIDappOrigins) {
originPermissions.set(origin, walletUIDappPermissionsSet);
}
originPermissions.set('*', publicPermissionsSet);

/**
* Validate the origin and method pair.
* If the origin is not found or the method is not allowed, throw an error.
*
* @param origin - The origin of the request.
* @param method - The method of the request.
* @throws {UnauthorizedError} If the origin is not found or the method is not allowed.
*/
export function validateOrigin(origin: string, method: string): void {
if (!origin) {
// eslint-disable-next-line @typescript-eslint/no-throw-literal
throw new UnauthorizedError('Origin not found');
}
// As public permissions are a subset of wallet UI Dapp permissions,
// If the origin and method pair are not in the wallet UI Dapp permissions,
// then fallback and validate whether it hits the common permission.
if (
!originPermissions.get(origin)?.has(method) &&
!originPermissions.get('*')?.has(method)
) {
// eslint-disable-next-line @typescript-eslint/no-throw-literal
throw new UnauthorizedError(`Permission denied`);
}
}

0 comments on commit 1e80070

Please sign in to comment.