Skip to content

Commit

Permalink
fix(oidc): store OIDC user info in local DB (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirGer authored Jan 20, 2024
1 parent 427b4c2 commit d596d55
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 52 deletions.
6 changes: 3 additions & 3 deletions pg.sql
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
set names 'utf8';
set session_replication_role = 'replica';

create table "user" ("id" serial primary key, "created_at" timestamptz(0) not null, "updated_at" timestamptz(0) not null, "email" varchar(255) not null, "password" varchar(255) not null, "first_name" varchar(255) not null, "last_name" varchar(255) not null, "is_admin" bool not null, "photo" bytea null, "company" varchar(255) not null, "card_number" varchar(255) not null, "phone_number" varchar(255) not null);
create table "user" ("id" serial primary key, "created_at" timestamptz(0) not null, "updated_at" timestamptz(0) not null, "email" varchar(255) not null, "password" varchar(255) not null, "first_name" varchar(255) not null, "last_name" varchar(255) not null, "is_admin" bool not null, "photo" bytea null, "company" varchar(255) not null, "card_number" varchar(255) not null, "phone_number" varchar(255) not null, "is_basic" boolean not null);

create table "testimonial" ("id" serial primary key, "created_at" timestamptz(0) not null, "updated_at" timestamptz(0) not null, "name" varchar(255) not null, "title" varchar(255) not null, "message" varchar(255) not null);

create table "product" ("id" serial primary key, "created_at" timestamptz(0) not null default now(), "category" varchar(255) not null, "photo_url" varchar(255) not null, "name" varchar(255) not null, "description" varchar(255) null, "views_count" int DEFAULT 0);

set session_replication_role = 'origin';
--password is admin
INSERT INTO "user" (created_at, updated_at, email, password, first_name, last_name, is_admin, photo, company, card_number, phone_number) VALUES (now(), now(), 'admin', '$2b$10$BBJjmVNNdyEgv7pV/zQR9u/ssIuwZsdDJbowW/Dgp28uws3GmO0Ky', 'admin', 'admin', true, null, 'Brightsec', '1234 5678 9012 3456', '+1 234 567 890');
INSERT INTO "user" (created_at, updated_at, email, password, first_name, last_name, is_admin, photo, company, card_number, phone_number) VALUES (now(), now(), 'user', '$2b$10$edsq4aqzAHnrJu68t8GS2.v0Z7hJSstAo7wBBDmmbpjYGxMMTYpVi', 'user', 'user', false, null, 'Brightsec', '1234 5678 9012 3456', '+1 234 567 890');
INSERT INTO "user" (created_at, updated_at, email, password, first_name, last_name, is_admin, photo, company, card_number, phone_number, is_basic) VALUES (now(), now(), 'admin', '$2b$10$BBJjmVNNdyEgv7pV/zQR9u/ssIuwZsdDJbowW/Dgp28uws3GmO0Ky', 'admin', 'admin', true, null, 'Brightsec', '1234 5678 9012 3456', '+1 234 567 890', true);
INSERT INTO "user" (created_at, updated_at, email, password, first_name, last_name, is_admin, photo, company, card_number, phone_number, is_basic) VALUES (now(), now(), 'user', '$2b$10$edsq4aqzAHnrJu68t8GS2.v0Z7hJSstAo7wBBDmmbpjYGxMMTYpVi', 'user', 'user', false, null, 'Brightsec', '1234 5678 9012 3456', '+1 234 567 890', true);

--insert default products into the table
INSERT INTO "product" ("category", "photo_url", "name", "description") VALUES ('Healing', '/api/file?path=config/products/crystals/amethyst.jpg&type=image/jpg', 'Amethyst', 'a violet variety of quartz');
Expand Down
26 changes: 17 additions & 9 deletions src/auth/auth.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
BadRequestException,
Body,
Controller,
ForbiddenException,
Get,
HttpStatus,
InternalServerErrorException,
Expand Down Expand Up @@ -131,7 +132,7 @@ export class AuthController {
if (req.op === FormMode.OIDC) {
loginData = await this.loginOidc(req);
} else {
loginData = await this.login(req);
loginData = await this.loginBasic(req);
}

const { token, ...loginResponse } = loginData;
Expand Down Expand Up @@ -264,7 +265,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithKIDSqlJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -324,7 +325,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithKIDSqlJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -384,7 +385,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithJKUJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -444,7 +445,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithJWKJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -504,7 +505,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithX5CJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -564,7 +565,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithX5UJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'Authorization',
Expand Down Expand Up @@ -624,7 +625,7 @@ export class AuthController {
@Res({ passthrough: true }) res: FastifyReply,
): Promise<LoginResponse> {
this.logger.debug('Call loginWithHMACJwt');
const profile = await this.login(req);
const profile = await this.loginBasic(req);

res.header(
'authorization',
Expand Down Expand Up @@ -690,7 +691,7 @@ export class AuthController {
}
}

private async login(req: LoginRequest): Promise<LoginData> {
private async loginBasic(req: LoginRequest): Promise<LoginData> {
let user: User;

try {
Expand All @@ -709,6 +710,13 @@ export class AuthController {
});
}

if (!user.isBasic) {
throw new ForbiddenException({
error: 'Invalid authentication method for this user',
location: __filename,
});
}

const token = await this.authService.createToken(
{
user: user.email,
Expand Down
56 changes: 30 additions & 26 deletions src/auth/auth.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { GqlContextType, GqlExecutionContext } from '@nestjs/graphql';
@Injectable()
export class AuthGuard implements CanActivate {
private static readonly AUTH_HEADER = 'authorization';
private static readonly BEARER_PREFIX = 'bearer';
private readonly logger = new Logger(AuthGuard.name);

constructor(
Expand All @@ -25,8 +26,13 @@ export class AuthGuard implements CanActivate {
try {
this.logger.debug('Called canActivate');
const request = this.getRequest(context);
const token = this.extractToken(request);

return !!(await this.verifyToken(request, context));
if (!token) {
return false;
}

return await this.verifyToken(token, context);
} catch (err) {
this.logger.debug(`Failed to validate token: ${err.message}`);
throw new UnauthorizedException({
Expand All @@ -36,48 +42,46 @@ export class AuthGuard implements CanActivate {
}
}

private getRequest(context: ExecutionContext): FastifyRequest {
return context.getType<GqlContextType>() === 'graphql'
? GqlExecutionContext.create(context).getContext().req
: context.switchToHttp().getRequest();
}

private async verifyToken(
request: FastifyRequest,
context: ExecutionContext,
): Promise<boolean> {
private extractToken(request: FastifyRequest): string | undefined {
let token = request.headers[AuthGuard.AUTH_HEADER];

if (!token?.length) {
token = request.cookies[AuthGuard.AUTH_HEADER];
}

if (this.checkIsBearer(token)) {
token = token.substring(7);
token = token.substring(AuthGuard.BEARER_PREFIX.length).trim();
}

if (!token?.length) {
return false;
}
return token?.length ? token : undefined;
}

private getRequest(context: ExecutionContext): FastifyRequest {
return context.getType<GqlContextType>() === 'graphql'
? GqlExecutionContext.create(context).getContext().req
: context.switchToHttp().getRequest();
}

private async verifyToken(
token: string,
context: ExecutionContext,
): Promise<boolean> {
const processorType = this.reflector.get<JwtProcessorType>(
JwTypeMetadataField,
context.getHandler(),
);

return this.authService.validateToken(
token,
processorType ?? JwtProcessorType.BEARER,
);
try {
return await this.authService.validateToken(token, processorType);
} catch (err) {
return this.authService.validateToken(token, JwtProcessorType.BEARER);
}
}

private checkIsBearer(bearer: string): boolean {
if (!bearer || bearer.length < 10) {
return false;
}

const prefix = bearer.substring(0, 7).toLowerCase();

return prefix === 'bearer ';
return (
!!bearer &&
bearer.toLowerCase().startsWith(AuthGuard.BEARER_PREFIX.toLowerCase())
);
}
}
3 changes: 2 additions & 1 deletion src/keycloak/keycloak.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,9 @@ export class KeyCloakService implements OnModuleInit {
);
}

return new Map<string, string>(
const jwks = new Map<string, string>(
data.keys.map((key: JWK & { kid: string }) => [key.kid, jwkToPem(key)]),
);
return jwks;
}
}
3 changes: 3 additions & 0 deletions src/model/user.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ export class User extends Base {

@Property()
phoneNumber: string;

@Property()
isBasic: boolean;
}
7 changes: 6 additions & 1 deletion src/users/api/CreateUserRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { ApiProperty } from '@nestjs/swagger';
import { UserDto } from './UserDto';

export enum SignupMode {
BASIC = 'basic',
OIDC = 'oidc',
}

export class CreateUserRequest extends UserDto {
@ApiProperty()
company: string;
Expand All @@ -10,5 +15,5 @@ export class CreateUserRequest extends UserDto {
phoneNumber: string;
@ApiProperty()
password: string;
op: string;
op: SignupMode;
}
7 changes: 4 additions & 3 deletions src/users/api/UserDto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ export class UserDto {
@Expose({ groups: [FULL_USER_INFO] })
createdAt: Date;

constructor(params: {
[P in keyof UserDto]: UserDto[P];
}) {
@Exclude()
isBasic: boolean;

constructor(params: UserDto) {
Object.assign(this, params);
}
}
44 changes: 36 additions & 8 deletions src/users/users.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import {
ApiTags,
ApiUnauthorizedResponse,
} from '@nestjs/swagger';
import { CreateUserRequest } from './api/CreateUserRequest';
import { CreateUserRequest, SignupMode } from './api/CreateUserRequest';
import { UserDto } from './api/UserDto';
import { LdapQueryHandler } from './ldap.query.handler';
import { UsersService } from './users.service';
Expand Down Expand Up @@ -346,17 +346,18 @@ export class UsersController {
try {
this.logger.debug(`Create a basic user: ${user}`);

const userExists = await this.usersService.findByEmail(user.email);
const userExists = await this.doesUserExist(user);
if (userExists) {
throw new HttpException('User already exists', 409);
throw new HttpException('User already exists', HttpStatus.CONFLICT);
}

return new UserDto(
await this.usersService.createUser(user, user.op === SignupMode.BASIC),
);
} catch (err) {
if (err.status === 404) {
return new UserDto(await this.usersService.createUser(user));
}
throw new HttpException(
err.message ?? 'Something went wrong',
err.status ?? 500,
err.status ?? HttpStatus.INTERNAL_SERVER_ERROR,
);
}
}
Expand All @@ -381,14 +382,24 @@ export class UsersController {
try {
this.logger.debug(`Create a OIDC user: ${user}`);

return new UserDto(
const userExists = await this.doesUserExist(user);

if (userExists) {
throw new HttpException('User already exists', HttpStatus.CONFLICT);
}

const keycloakUser = new UserDto(
await this.keyCloakService.registerUser({
email: user.email,
firstName: user.firstName,
lastName: user.lastName,
password: user.password,
}),
);

this.createUser(user);

return keycloakUser;
} catch (err) {
throw new HttpException(
err.response.data ?? 'Something went wrong',
Expand Down Expand Up @@ -559,4 +570,21 @@ export class UsersController {
).toString(),
).user;
}

private async doesUserExist(user: UserDto): Promise<boolean> {
try {
const userExists = await this.usersService.findByEmail(user.email);
if (userExists) {
return true;
}
} catch (err) {
if (err.status === HttpStatus.NOT_FOUND) {
return false;
}
throw new HttpException(
err.message ?? 'Something went wrong',
err.status ?? HttpStatus.INTERNAL_SERVER_ERROR,
);
}
}
}
3 changes: 2 additions & 1 deletion src/users/users.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export class UsersService {
private readonly usersRepository: EntityRepository<User>,
) {}

async createUser(user: UserDto): Promise<User> {
async createUser(user: UserDto, isBasicUser: boolean = true): Promise<User> {
this.log.debug(`Called createUser`);

const u = new User();
Expand All @@ -35,6 +35,7 @@ export class UsersService {
u.cardNumber = user.cardNumber;
u.phoneNumber = user.phoneNumber;
u.password = await hashPassword(user.password);
u.isBasic = isBasicUser;

await this.usersRepository.persistAndFlush(u);
this.log.debug(`Saved new user`);
Expand Down

0 comments on commit d596d55

Please sign in to comment.