Skip to content

Commit

Permalink
Add validateAccessToken function to providers to allow verifying the …
Browse files Browse the repository at this point in the history
…access token received from the providers
  • Loading branch information
nkshah2 committed Sep 25, 2023
1 parent 7b98039 commit f662e27
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 0 deletions.
12 changes: 12 additions & 0 deletions lib/build/recipe/thirdparty/providers/custom.js
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,18 @@ function NewProvider(input) {
);
rawUserInfoFromProvider.fromUserInfoAPI = userInfoFromAccessToken;
}
/**
* This is intentionally not part of the above if block. This is because the user may want to validate the access
* token payload even if the user info API has not been provided by the provider. In this case they would get an
* empty object and they can fail if they always expect a non-empty object.
*/
if (impl.config.validateAccessToken !== undefined) {
await impl.config.validateAccessToken({
accessToken: accessToken,
clientConfig: impl.config,
userContext,
});
}
const userInfoResult = getSupertokensUserInfoResultFromRawUserInfo(impl.config, rawUserInfoFromProvider);
return {
thirdPartyUserId: userInfoResult.thirdPartyUserId,
Expand Down
27 changes: 27 additions & 0 deletions lib/build/recipe/thirdparty/providers/github.js
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,33 @@ function Github(input) {
if (input.config.tokenEndpoint === undefined) {
input.config.tokenEndpoint = "https://github.com/login/oauth/access_token";
}
if (input.config.validateAccessToken === undefined) {
input.config.validateAccessToken = async ({ accessToken, clientConfig }) => {
const basicAuthToken = Buffer.from(
`${clientConfig.clientId}:${clientConfig.clientSecret === undefined ? "" : clientConfig.clientSecret}`
).toString("base64");
const applicationsResponse = await cross_fetch_1.default(
`https://api.github.com/applications/${clientConfig.clientId}/token`,
{
headers: {
Authorization: `Basic ${basicAuthToken}`,
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify({
access_token: accessToken,
}),
}
);
if (applicationsResponse.status !== 200) {
throw new Error("Invalid access token");
}
const body = await applicationsResponse.json();
if (body.app === undefined || body.app.client_id !== clientConfig.clientId) {
throw new Error("Access token does not belong to your application");
}
};
}
const oOverride = input.override;
input.override = function (originalImplementation) {
const oGetConfig = originalImplementation.getConfigForClientType;
Expand Down
7 changes: 7 additions & 0 deletions lib/build/recipe/thirdparty/providers/google.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ function Google(input) {
{ included_grant_scopes: "true", access_type: "offline" },
input.config.authorizationEndpointQueryParams
);
// if (input.config.validateAccessToken === undefined) {
// input.config.validateAccessToken = async ({ accessTokenPayload, clientConfig }) => {
// if (accessTokenPayload.aud !== clientConfig.clientId) {
// throw Error("accessTokenPayload.aud does not match clientId");
// }
// };
// }
const oOverride = input.override;
input.override = function (originalImplementation) {
const oGetConfig = originalImplementation.getConfigForClientType;
Expand Down
15 changes: 15 additions & 0 deletions lib/build/recipe/thirdparty/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ declare type CommonProviderConfig = {
clientConfig: ProviderConfigForClientType;
userContext: any;
}) => Promise<void>;
/**
* This function is responsible for validating the access token received from the third party provider.
* This check can include checking the expiry of the access token, checking the audience of the access token, etc.
*
* This function should throw an error if the access token should be considered invalid, or return nothing if it is valid
*
* @param input.accessToken The access token to be validated
* @param input.clientConfig The configuration provided for the third party provider when initialising SuperTokens
* @param input.userContext Refer to https://supertokens.com/docs/thirdparty/advanced-customizations/user-context
*/
validateAccessToken?: (input: {
accessToken: string;
clientConfig: ProviderConfigForClientType;
userContext: any;
}) => Promise<void>;
requireEmail?: boolean;
generateFakeEmail?: (input: { thirdPartyUserId: string; tenantId: string; userContext: any }) => Promise<string>;
};
Expand Down
13 changes: 13 additions & 0 deletions lib/ts/recipe/thirdparty/providers/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,19 @@ export default function NewProvider(input: ProviderInput): TypeProvider {
rawUserInfoFromProvider.fromUserInfoAPI = userInfoFromAccessToken;
}

/**
* This is intentionally not part of the above if block. This is because the user may want to validate the access
* token payload even if the user info API has not been provided by the provider. In this case they would get an
* empty object and they can fail if they always expect a non-empty object.
*/
if (impl.config.validateAccessToken !== undefined) {
await impl.config.validateAccessToken({
accessToken: accessToken,
clientConfig: impl.config,
userContext,
});
}

const userInfoResult = getSupertokensUserInfoResultFromRawUserInfo(impl.config, rawUserInfoFromProvider);

return {
Expand Down
32 changes: 32 additions & 0 deletions lib/ts/recipe/thirdparty/providers/github.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,38 @@ export default function Github(input: ProviderInput): TypeProvider {
input.config.tokenEndpoint = "https://github.com/login/oauth/access_token";
}

if (input.config.validateAccessToken === undefined) {
input.config.validateAccessToken = async ({ accessToken, clientConfig }) => {
const basicAuthToken = Buffer.from(
`${clientConfig.clientId}:${clientConfig.clientSecret === undefined ? "" : clientConfig.clientSecret}`
).toString("base64");

const applicationsResponse = await fetch(
`https://api.github.com/applications/${clientConfig.clientId}/token`,
{
headers: {
Authorization: `Basic ${basicAuthToken}`,
"Content-Type": "application/json",
},
method: "POST",
body: JSON.stringify({
access_token: accessToken,
}),
}
);

if (applicationsResponse.status !== 200) {
throw new Error("Invalid access token");
}

const body = await applicationsResponse.json();

if (body.app === undefined || body.app.client_id !== clientConfig.clientId) {
throw new Error("Access token does not belong to your application");
}
};
}

const oOverride = input.override;

input.override = function (originalImplementation) {
Expand Down
8 changes: 8 additions & 0 deletions lib/ts/recipe/thirdparty/providers/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ export default function Google(input: ProviderInput): TypeProvider {
...input.config.authorizationEndpointQueryParams,
};

// if (input.config.validateAccessToken === undefined) {
// input.config.validateAccessToken = async ({ accessTokenPayload, clientConfig }) => {
// if (accessTokenPayload.aud !== clientConfig.clientId) {
// throw Error("accessTokenPayload.aud does not match clientId");
// }
// };
// }

const oOverride = input.override;

input.override = function (originalImplementation) {
Expand Down
15 changes: 15 additions & 0 deletions lib/ts/recipe/thirdparty/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ type CommonProviderConfig = {
clientConfig: ProviderConfigForClientType;
userContext: any;
}) => Promise<void>;
/**
* This function is responsible for validating the access token received from the third party provider.
* This check can include checking the expiry of the access token, checking the audience of the access token, etc.
*
* This function should throw an error if the access token should be considered invalid, or return nothing if it is valid
*
* @param input.accessToken The access token to be validated
* @param input.clientConfig The configuration provided for the third party provider when initialising SuperTokens
* @param input.userContext Refer to https://supertokens.com/docs/thirdparty/advanced-customizations/user-context
*/
validateAccessToken?: (input: {
accessToken: string;
clientConfig: ProviderConfigForClientType;
userContext: any;
}) => Promise<void>;
requireEmail?: boolean;
generateFakeEmail?: (input: { thirdPartyUserId: string; tenantId: string; userContext: any }) => Promise<string>;
};
Expand Down

0 comments on commit f662e27

Please sign in to comment.