Skip to content

Commit

Permalink
Merge pull request #6069 from logto-io/gao-org-jit-sso-impl
Browse files Browse the repository at this point in the history
feat(core): organization jit sso
  • Loading branch information
gao-sun authored Jun 21, 2024
2 parents 651a027 + 2cf30d2 commit c51eab1
Show file tree
Hide file tree
Showing 21 changed files with 394 additions and 250 deletions.
3 changes: 2 additions & 1 deletion packages/core/src/libraries/user.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ mockEsm('#src/utils/password.js', () => ({
}));

const { MockQueries } = await import('#src/test-utils/tenant.js');
const { encryptUserPassword, createUserLibrary } = await import('./user.js');
const { createUserLibrary } = await import('./user.js');
const { encryptUserPassword } = await import('./user.utils.js');

const hasUserWithId = jest.fn();
const updateUserById = jest.fn();
Expand Down
172 changes: 68 additions & 104 deletions packages/core/src/libraries/user.ts
Original file line number Diff line number Diff line change
@@ -1,83 +1,22 @@
import type { BindMfa, CreateUser, MfaVerification, Scope, User } from '@logto/schemas';
import { MfaFactor, RoleType, Users, UsersPasswordEncryptionMethod } from '@logto/schemas';
import type { BindMfa, CreateUser, Scope, User } from '@logto/schemas';
import { RoleType, Users, UsersPasswordEncryptionMethod } from '@logto/schemas';
import { generateStandardId, generateStandardShortId } from '@logto/shared';
import { deduplicateByKey, type Nullable } from '@silverhand/essentials';
import { condArray, deduplicateByKey, type Nullable } from '@silverhand/essentials';
import { argon2Verify, bcryptVerify, md5, sha1, sha256 } from 'hash-wasm';
import pRetry from 'p-retry';

import { buildInsertIntoWithPool } from '#src/database/insert-into.js';
import { EnvSet } from '#src/env-set/index.js';
import RequestError from '#src/errors/RequestError/index.js';
import { type JitOrganization } from '#src/queries/organization/email-domains.js';
import OrganizationQueries from '#src/queries/organization/index.js';
import { createUsersRolesQueries } from '#src/queries/users-roles.js';
import type Queries from '#src/tenants/Queries.js';
import assertThat from '#src/utils/assert-that.js';
import { encryptPassword } from '#src/utils/password.js';
import type { OmitAutoSetFields } from '#src/utils/sql.js';

export const encryptUserPassword = async (
password: string
): Promise<{
passwordEncrypted: string;
passwordEncryptionMethod: UsersPasswordEncryptionMethod;
}> => {
const passwordEncryptionMethod = UsersPasswordEncryptionMethod.Argon2i;
const passwordEncrypted = await encryptPassword(password, passwordEncryptionMethod);
import { convertBindMfaToMfaVerification, encryptUserPassword } from './user.utils.js';

return { passwordEncrypted, passwordEncryptionMethod };
};

/**
* Convert bindMfa to mfaVerification, add common fields like "id" and "createdAt"
* and transpile formats like "codes" to "code" for backup code
*/
const converBindMfaToMfaVerification = (bindMfa: BindMfa): MfaVerification => {
const { type } = bindMfa;
const base = {
id: generateStandardId(),
createdAt: new Date().toISOString(),
};

if (type === MfaFactor.BackupCode) {
const { codes } = bindMfa;

return {
...base,
type,
codes: codes.map((code) => ({ code })),
};
}

if (type === MfaFactor.TOTP) {
const { secret } = bindMfa;

return {
...base,
type,
key: secret,
};
}

const { credentialId, counter, publicKey, transports, agent } = bindMfa;
return {
...base,
type,
credentialId,
counter,
publicKey,
transports,
agent,
};
};

export type InsertUserResult = [
User,
{
/** The organizations and organization roles that the user has been provisioned into. */
organizations: readonly JitOrganization[];
},
];
export type InsertUserResult = [User];

export type UserLibrary = ReturnType<typeof createUserLibrary>;

Expand Down Expand Up @@ -143,43 +82,7 @@ export const createUserLibrary = (queries: Queries) => {
);
}

// TODO: If the user's email is not verified, we should not provision the user into any organization.
const provisionOrganizations = async (): Promise<readonly JitOrganization[]> => {
// Just-in-time organization provisioning
const userEmailDomain = data.primaryEmail?.split('@')[1];
if (userEmailDomain) {
const organizationQueries = new OrganizationQueries(connection);
const organizations = await organizationQueries.jit.emailDomains.getJitOrganizations(
userEmailDomain
);

if (organizations.length > 0) {
await organizationQueries.relations.users.insert(
...organizations.map(({ organizationId }) => ({
organizationId,
userId: user.id,
}))
);

const data = organizations.flatMap(({ organizationId, organizationRoleIds }) =>
organizationRoleIds.map((organizationRoleId) => ({
organizationId,
organizationRoleId,
userId: user.id,
}))
);
if (data.length > 0) {
await organizationQueries.relations.rolesUsers.insert(...data);
}

return organizations;
}
}

return [];
};

return [user, { organizations: await provisionOrganizations() }];
return [user];
});
};

Expand Down Expand Up @@ -261,7 +164,7 @@ export const createUserLibrary = (queries: Queries) => {
// TODO @sijie use jsonb array append
const { mfaVerifications } = await findUserById(userId);
await updateUserById(userId, {
mfaVerifications: [...mfaVerifications, converBindMfaToMfaVerification(payload)],
mfaVerifications: [...mfaVerifications, convertBindMfaToMfaVerification(payload)],
});
};

Expand Down Expand Up @@ -338,6 +241,66 @@ export const createUserLibrary = (queries: Queries) => {
const findUserSsoIdentities = async (userId: string) =>
userSsoIdentities.findUserSsoIdentitiesByUserId(userId);

type ProvisionOrganizationsParams =
| {
/** The user ID to provision organizations for. */
userId: string;
/** The user's email to determine JIT organizations. */
email: string;
/** The SSO connector ID to determine JIT organizations. */
ssoConnectorId?: undefined;
}
| {
/** The user ID to provision organizations for. */
userId: string;
/** The user's email to determine JIT organizations. */
email?: undefined;
/** The SSO connector ID to determine JIT organizations. */
ssoConnectorId: string;
};

// TODO: If the user's email is not verified, we should not provision the user into any organization.
/**
* Provision the user with JIT organizations and roles based on the user's email domain and the
* enterprise SSO connector.
*/
const provisionOrganizations = async ({
userId,
email,
ssoConnectorId,
}: ProvisionOrganizationsParams): Promise<readonly JitOrganization[]> => {
const userEmailDomain = email?.split('@')[1];
const jitOrganizations = condArray(
userEmailDomain &&
(await organizations.jit.emailDomains.getJitOrganizations(userEmailDomain)),
ssoConnectorId && (await organizations.jit.ssoConnectors.getJitOrganizations(ssoConnectorId))
);

if (jitOrganizations.length === 0) {
return [];
}

await organizations.relations.users.insert(
...jitOrganizations.map(({ organizationId }) => ({
organizationId,
userId,
}))
);

const data = jitOrganizations.flatMap(({ organizationId, organizationRoleIds }) =>
organizationRoleIds.map((organizationRoleId) => ({
organizationId,
organizationRoleId,
userId,
}))
);
if (data.length > 0) {
await organizations.relations.rolesUsers.insert(...data);
}

return jitOrganizations;
};

return {
generateUserId,
insertUser,
Expand All @@ -349,5 +312,6 @@ export const createUserLibrary = (queries: Queries) => {
verifyUserPassword,
signOutUser,
findUserSsoIdentities,
provisionOrganizations,
};
};
60 changes: 60 additions & 0 deletions packages/core/src/libraries/user.utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import type { BindMfa, MfaVerification } from '@logto/schemas';
import { MfaFactor, UsersPasswordEncryptionMethod } from '@logto/schemas';
import { generateStandardId } from '@logto/shared';

import { encryptPassword } from '#src/utils/password.js';

export const encryptUserPassword = async (
password: string
): Promise<{
passwordEncrypted: string;
passwordEncryptionMethod: UsersPasswordEncryptionMethod;
}> => {
const passwordEncryptionMethod = UsersPasswordEncryptionMethod.Argon2i;
const passwordEncrypted = await encryptPassword(password, passwordEncryptionMethod);

return { passwordEncrypted, passwordEncryptionMethod };
};

/**
* Convert bindMfa to mfaVerification, add common fields like "id" and "createdAt"
* and transpile formats like "codes" to "code" for backup code
*/
export const convertBindMfaToMfaVerification = (bindMfa: BindMfa): MfaVerification => {
const { type } = bindMfa;
const base = {
id: generateStandardId(),
createdAt: new Date().toISOString(),
};

if (type === MfaFactor.BackupCode) {
const { codes } = bindMfa;

return {
...base,
type,
codes: codes.map((code) => ({ code })),
};
}

if (type === MfaFactor.TOTP) {
const { secret } = bindMfa;

return {
...base,
type,
key: secret,
};
}

const { credentialId, counter, publicKey, transports, agent } = bindMfa;
return {
...base,
type,
credentialId,
counter,
publicKey,
transports,
agent,
};
};
6 changes: 5 additions & 1 deletion packages/core/src/queries/organization/email-domains.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ export class EmailDomainQueries {
* Given an email domain, return the organizations and organization roles that need to be
* provisioned.
*/
async getJitOrganizations(emailDomain: string): Promise<readonly JitOrganization[]> {
async getJitOrganizations(emailDomain?: string): Promise<readonly JitOrganization[]> {
if (!emailDomain) {
return [];
}

const { fields } = convertToIdentifiers(OrganizationJitEmailDomains, true);
const organizationJitRoles = convertToIdentifiers(OrganizationJitRoles, true);
return this.pool.any<JitOrganization>(sql`
Expand Down
10 changes: 2 additions & 8 deletions packages/core/src/queries/organization/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import {
OrganizationJitRoles,
OrganizationApplicationRelations,
Applications,
OrganizationJitSsoConnectors,
SsoConnectors,
} from '@logto/schemas';
import { sql, type CommonQueryMethods } from '@silverhand/slonik';

Expand All @@ -36,6 +34,7 @@ import { conditionalSql, convertToIdentifiers } from '#src/utils/sql.js';

import { EmailDomainQueries } from './email-domains.js';
import { RoleUserRelationQueries } from './role-user-relations.js';
import { SsoConnectorQueries } from './sso-connectors.js';
import { UserRelationQueries } from './user-relations.js';

/**
Expand Down Expand Up @@ -311,12 +310,7 @@ export default class OrganizationQueries extends SchemaQueries<
Organizations,
OrganizationRoles
),
ssoConnectors: new TwoRelationsQueries(
this.pool,
OrganizationJitSsoConnectors.table,
Organizations,
SsoConnectors
),
ssoConnectors: new SsoConnectorQueries(this.pool),
};

constructor(pool: CommonQueryMethods) {
Expand Down
45 changes: 45 additions & 0 deletions packages/core/src/queries/organization/sso-connectors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import {
OrganizationJitRoles,
OrganizationJitSsoConnectors,
Organizations,
SsoConnectors,
} from '@logto/schemas';
import { type CommonQueryMethods, sql } from '@silverhand/slonik';

import { TwoRelationsQueries } from '#src/utils/RelationQueries.js';
import { convertToIdentifiers } from '#src/utils/sql.js';

import { type JitOrganization } from './email-domains.js';

const { table, fields } = convertToIdentifiers(OrganizationJitSsoConnectors);

export class SsoConnectorQueries extends TwoRelationsQueries<
typeof Organizations,
typeof SsoConnectors
> {
constructor(pool: CommonQueryMethods) {
super(pool, OrganizationJitSsoConnectors.table, Organizations, SsoConnectors);
}

async getJitOrganizations(ssoConnectorId?: string): Promise<readonly JitOrganization[]> {
if (!ssoConnectorId) {
return [];
}

const { fields } = convertToIdentifiers(OrganizationJitSsoConnectors, true);
const organizationJitRoles = convertToIdentifiers(OrganizationJitRoles, true);
return this.pool.any<JitOrganization>(sql`
select
${fields.organizationId},
array_remove(
array_agg(${organizationJitRoles.fields.organizationRoleId}),
null
) as "organizationRoleIds"
from ${table}
left join ${organizationJitRoles.table}
on ${fields.organizationId} = ${organizationJitRoles.fields.organizationId}
where ${fields.ssoConnectorId} = ${ssoConnectorId}
group by ${fields.organizationId}
`);
}
}
2 changes: 1 addition & 1 deletion packages/core/src/routes-me/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { conditional, pick } from '@silverhand/essentials';
import { literal, object, string } from 'zod';

import RequestError from '#src/errors/RequestError/index.js';
import { encryptUserPassword } from '#src/libraries/user.js';
import { encryptUserPassword } from '#src/libraries/user.utils.js';
import koaGuard from '#src/middleware/koa-guard.js';
import assertThat from '#src/utils/assert-that.js';

Expand Down
Loading

0 comments on commit c51eab1

Please sign in to comment.