From 9ff8fb82d986340c93d33c9935519beb6c525695 Mon Sep 17 00:00:00 2001 From: Nemi Shah Date: Fri, 29 Sep 2023 14:09:34 +0530 Subject: [PATCH] feat: Add validateAccessToken function to providers (#701) * Add validateAccessToken function to providers to allow verifying the access token received from the providers * Update CHANGELOG * Update version and CHANGELOG * Update CHANGELOG * Refactor based on PR comments * Add tests --- CHANGELOG.md | 6 + .../recipe/thirdparty/providers/custom.js | 7 + .../recipe/thirdparty/providers/github.js | 27 +++ lib/build/recipe/thirdparty/types.d.ts | 15 ++ lib/build/version.d.ts | 2 +- lib/build/version.js | 2 +- lib/ts/recipe/thirdparty/providers/custom.ts | 8 + lib/ts/recipe/thirdparty/providers/github.ts | 32 +++ lib/ts/recipe/thirdparty/types.ts | 15 ++ lib/ts/version.ts | 2 +- package-lock.json | 4 +- package.json | 2 +- test/thirdparty/provider.test.js | 192 ++++++++++++++++++ 13 files changed, 308 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c719054e5..19a135b7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) +## [16.2.0] - 2023-09-29 + +### Changes + +- Added `validateAccessToken` to the configuration for social login providers, this function allows you to verify the access token returned by the social provider. If you are using Github as a provider, there is a default implmentation provided for this function. + ## [16.1.0] - 2023-09-26 - Added `twitter` as a built-in thirdparty provider diff --git a/lib/build/recipe/thirdparty/providers/custom.js b/lib/build/recipe/thirdparty/providers/custom.js index bfb600b48..4d4a4bd2b 100644 --- a/lib/build/recipe/thirdparty/providers/custom.js +++ b/lib/build/recipe/thirdparty/providers/custom.js @@ -278,6 +278,13 @@ function NewProvider(input) { }); } } + if (impl.config.validateAccessToken !== undefined && accessToken !== undefined) { + await impl.config.validateAccessToken({ + accessToken: accessToken, + clientConfig: impl.config, + userContext, + }); + } if (accessToken && impl.config.userInfoEndpoint !== undefined) { const headers = { Authorization: "Bearer " + accessToken, diff --git a/lib/build/recipe/thirdparty/providers/github.js b/lib/build/recipe/thirdparty/providers/github.js index e7fd91dbd..14b22ce21 100644 --- a/lib/build/recipe/thirdparty/providers/github.js +++ b/lib/build/recipe/thirdparty/providers/github.js @@ -53,6 +53,33 @@ function Github(input) { if (input.config.tokenEndpoint === undefined) { input.config.tokenEndpoint = "https://github.com/login/oauth/access_token"; } + if (input.config.validateAccessToken === undefined) { + input.config.validateAccessToken = async ({ accessToken, clientConfig }) => { + const basicAuthToken = Buffer.from( + `${clientConfig.clientId}:${clientConfig.clientSecret === undefined ? "" : clientConfig.clientSecret}` + ).toString("base64"); + const applicationsResponse = await cross_fetch_1.default( + `https://api.github.com/applications/${clientConfig.clientId}/token`, + { + headers: { + Authorization: `Basic ${basicAuthToken}`, + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + access_token: accessToken, + }), + } + ); + if (applicationsResponse.status !== 200) { + throw new Error("Invalid access token"); + } + const body = await applicationsResponse.json(); + if (body.app === undefined || body.app.client_id !== clientConfig.clientId) { + throw new Error("Access token does not belong to your application"); + } + }; + } const oOverride = input.override; input.override = function (originalImplementation) { const oGetConfig = originalImplementation.getConfigForClientType; diff --git a/lib/build/recipe/thirdparty/types.d.ts b/lib/build/recipe/thirdparty/types.d.ts index 7473fc040..ada01a8d4 100644 --- a/lib/build/recipe/thirdparty/types.d.ts +++ b/lib/build/recipe/thirdparty/types.d.ts @@ -70,6 +70,21 @@ declare type CommonProviderConfig = { clientConfig: ProviderConfigForClientType; userContext: any; }) => Promise; + /** + * This function is responsible for validating the access token received from the third party provider. + * This check can include checking the expiry of the access token, checking the audience of the access token, etc. + * + * This function should throw an error if the access token should be considered invalid, or return nothing if it is valid + * + * @param input.accessToken The access token to be validated + * @param input.clientConfig The configuration provided for the third party provider when initialising SuperTokens + * @param input.userContext Refer to https://supertokens.com/docs/thirdparty/advanced-customizations/user-context + */ + validateAccessToken?: (input: { + accessToken: string; + clientConfig: ProviderConfigForClientType; + userContext: any; + }) => Promise; requireEmail?: boolean; generateFakeEmail?: (input: { thirdPartyUserId: string; tenantId: string; userContext: any }) => Promise; }; diff --git a/lib/build/version.d.ts b/lib/build/version.d.ts index 4ad7264a6..b4855d369 100644 --- a/lib/build/version.d.ts +++ b/lib/build/version.d.ts @@ -1,4 +1,4 @@ // @ts-nocheck -export declare const version = "16.1.0"; +export declare const version = "16.2.0"; export declare const cdiSupported: string[]; export declare const dashboardVersion = "0.8"; diff --git a/lib/build/version.js b/lib/build/version.js index ba568973f..f924b34b1 100644 --- a/lib/build/version.js +++ b/lib/build/version.js @@ -15,7 +15,7 @@ exports.dashboardVersion = exports.cdiSupported = exports.version = void 0; * License for the specific language governing permissions and limitations * under the License. */ -exports.version = "16.1.0"; +exports.version = "16.2.0"; exports.cdiSupported = ["4.0"]; // Note: The actual script import for dashboard uses v{DASHBOARD_VERSION} exports.dashboardVersion = "0.8"; diff --git a/lib/ts/recipe/thirdparty/providers/custom.ts b/lib/ts/recipe/thirdparty/providers/custom.ts index 69daa03fc..2564049c2 100644 --- a/lib/ts/recipe/thirdparty/providers/custom.ts +++ b/lib/ts/recipe/thirdparty/providers/custom.ts @@ -305,6 +305,14 @@ export default function NewProvider(input: ProviderInput): TypeProvider { } } + if (impl.config.validateAccessToken !== undefined && accessToken !== undefined) { + await impl.config.validateAccessToken({ + accessToken: accessToken, + clientConfig: impl.config, + userContext, + }); + } + if (accessToken && impl.config.userInfoEndpoint !== undefined) { const headers: { [key: string]: string } = { Authorization: "Bearer " + accessToken, diff --git a/lib/ts/recipe/thirdparty/providers/github.ts b/lib/ts/recipe/thirdparty/providers/github.ts index cc733154f..e9307658d 100644 --- a/lib/ts/recipe/thirdparty/providers/github.ts +++ b/lib/ts/recipe/thirdparty/providers/github.ts @@ -58,6 +58,38 @@ export default function Github(input: ProviderInput): TypeProvider { input.config.tokenEndpoint = "https://github.com/login/oauth/access_token"; } + if (input.config.validateAccessToken === undefined) { + input.config.validateAccessToken = async ({ accessToken, clientConfig }) => { + const basicAuthToken = Buffer.from( + `${clientConfig.clientId}:${clientConfig.clientSecret === undefined ? "" : clientConfig.clientSecret}` + ).toString("base64"); + + const applicationsResponse = await fetch( + `https://api.github.com/applications/${clientConfig.clientId}/token`, + { + headers: { + Authorization: `Basic ${basicAuthToken}`, + "Content-Type": "application/json", + }, + method: "POST", + body: JSON.stringify({ + access_token: accessToken, + }), + } + ); + + if (applicationsResponse.status !== 200) { + throw new Error("Invalid access token"); + } + + const body = await applicationsResponse.json(); + + if (body.app === undefined || body.app.client_id !== clientConfig.clientId) { + throw new Error("Access token does not belong to your application"); + } + }; + } + const oOverride = input.override; input.override = function (originalImplementation) { diff --git a/lib/ts/recipe/thirdparty/types.ts b/lib/ts/recipe/thirdparty/types.ts index f27bed3ed..edbc7106c 100644 --- a/lib/ts/recipe/thirdparty/types.ts +++ b/lib/ts/recipe/thirdparty/types.ts @@ -67,6 +67,21 @@ type CommonProviderConfig = { clientConfig: ProviderConfigForClientType; userContext: any; }) => Promise; + /** + * This function is responsible for validating the access token received from the third party provider. + * This check can include checking the expiry of the access token, checking the audience of the access token, etc. + * + * This function should throw an error if the access token should be considered invalid, or return nothing if it is valid + * + * @param input.accessToken The access token to be validated + * @param input.clientConfig The configuration provided for the third party provider when initialising SuperTokens + * @param input.userContext Refer to https://supertokens.com/docs/thirdparty/advanced-customizations/user-context + */ + validateAccessToken?: (input: { + accessToken: string; + clientConfig: ProviderConfigForClientType; + userContext: any; + }) => Promise; requireEmail?: boolean; generateFakeEmail?: (input: { thirdPartyUserId: string; tenantId: string; userContext: any }) => Promise; }; diff --git a/lib/ts/version.ts b/lib/ts/version.ts index c307a287a..b7ad51705 100644 --- a/lib/ts/version.ts +++ b/lib/ts/version.ts @@ -12,7 +12,7 @@ * License for the specific language governing permissions and limitations * under the License. */ -export const version = "16.1.0"; +export const version = "16.2.0"; export const cdiSupported = ["4.0"]; diff --git a/package-lock.json b/package-lock.json index 11e305f6b..4f03f222c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "supertokens-node", - "version": "16.1.0", + "version": "16.2.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "supertokens-node", - "version": "16.0.0", + "version": "16.2.0", "license": "Apache-2.0", "dependencies": { "content-type": "^1.0.5", diff --git a/package.json b/package.json index 4d99b0097..c45de4c30 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "supertokens-node", - "version": "16.1.0", + "version": "16.2.0", "description": "NodeJS driver for SuperTokens core", "main": "index.js", "scripts": { diff --git a/test/thirdparty/provider.test.js b/test/thirdparty/provider.test.js index e23e8a231..0202e7eac 100644 --- a/test/thirdparty/provider.test.js +++ b/test/thirdparty/provider.test.js @@ -18,8 +18,11 @@ let assert = require("assert"); let { ProcessState } = require("../../lib/build/processState"); let ThirdPartyRecipe = require("../../lib/build/recipe/thirdparty/recipe").default; let ThirdParty = require("../../lib/build/recipe/thirdparty"); +let Session = require("../../lib/build/recipe/session"); let { middleware, errorHandler } = require("../../framework/express"); let nock = require("nock"); +let express = require("express"); +const request = require("supertest"); const privateKey = "-----BEGIN PRIVATE KEY-----\nMIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgu8gXs+XYkqXD6Ala9Sf/iJXzhbwcoG5dMh1OonpdJUmgCgYIKoZIzj0DAQehRANCAASfrvlFbFCYqn3I2zeknYXLwtH30JuOKestDbSfZYxZNMqhF/OzdZFTV0zc5u5s3eN+oCWbnvl0hM+9IW0UlkdA\n-----END PRIVATE KEY-----"; @@ -1052,4 +1055,193 @@ describe(`providerTest: ${printPath("[test/thirdparty/provider.test.js]")}`, fun ); } }); + + it("Test that sign in up fails if validateAccessToken throws", async function () { + const connectionURI = await startST(); + STExpress.init({ + supertokens: { + connectionURI, + }, + appInfo: { + apiDomain: "api.supertokens.io", + appName: "SuperTokens", + websiteDomain: "supertokens.io", + }, + recipeList: [ + ThirdParty.init({ + signInAndUpFeature: { + providers: [ + { + override: (original) => { + return { + ...original, + exchangeAuthCodeForOAuthTokens: async (input) => { + return { + access_token: "wrongaccesstoken", + id_token: "wrongidtoken", + }; + }, + }; + }, + config: { + thirdPartyId: "custom", + clients: [ + { + clientId: "test2", + clientSecret: "test-secret2", + }, + ], + validateAccessToken: async ({ accessToken }) => { + if (accessToken === "wrongaccesstoken") { + throw new Error("Invalid access token"); + } + + return; + }, + }, + }, + ], + }, + }), + ], + }); + + const app = express(); + + app.use(middleware()); + + app.use(errorHandler()); + + // default error handler for app + app.use(function (err, req, res, next) { + res.status(500).send(err.message); + }); + + let response = await new Promise((resolve) => + request(app) + .post("/auth/signinup") + .send({ + thirdPartyId: "custom", + redirectURIInfo: { + redirectURIOnProviderDashboard: "http://127.0.0.1/callback", + redirectURIQueryParams: { + code: "abcdefghj", + }, + }, + }) + .end((err, res) => { + if (err) { + resolve(undefined); + } else { + resolve({ + status: res.status, + message: res.text, + }); + } + }) + ); + + assert.strictEqual(response.status, 500); + assert.strictEqual(response.message, "Invalid access token"); + }); + + it("Test that sign in up works if validateAccessToken does not throw", async function () { + const connectionURI = await startST(); + STExpress.init({ + supertokens: { + connectionURI, + }, + appInfo: { + apiDomain: "api.supertokens.io", + appName: "SuperTokens", + websiteDomain: "supertokens.io", + }, + recipeList: [ + Session.init(), + ThirdParty.init({ + signInAndUpFeature: { + providers: [ + { + override: (original) => { + return { + ...original, + exchangeAuthCodeForOAuthTokens: async (input) => { + return { + access_token: "accesstoken", + id_token: "idtoken", + }; + }, + getUserInfo: async function ({ oAuthTokens }) { + const time = Date.now(); + return { + thirdPartyUserId: "" + time, + email: { + id: `johndoeprovidertest+${time}@supertokens.com`, + isVerified: true, + }, + }; + }, + }; + }, + config: { + thirdPartyId: "custom", + clients: [ + { + clientId: "test2", + clientSecret: "test-secret2", + }, + ], + validateAccessToken: async ({ accessToken }) => { + if (accessToken === "accesstoken") { + return; + } + + throw new Error("Unexpected access token"); + }, + }, + }, + ], + }, + }), + ], + }); + + const app = express(); + + app.use(middleware()); + + app.use(errorHandler()); + + // default error handler for app + app.use(function (err, req, res, next) { + res.status(500).send(err.message); + }); + + let response = await new Promise((resolve) => + request(app) + .post("/auth/signinup") + .send({ + thirdPartyId: "custom", + redirectURIInfo: { + redirectURIOnProviderDashboard: "http://127.0.0.1/callback", + redirectURIQueryParams: { + code: "abcdefghj", + }, + }, + }) + .end((err, res) => { + if (err) { + resolve(undefined); + } else { + resolve({ + status: res.status, + body: res.body, + }); + } + }) + ); + + assert.strictEqual(response.status, 200); + assert.strictEqual(response.body.status, "OK"); + }); });