From 25d11ec3beb6426acb5a38efdc3b30227c6500ac Mon Sep 17 00:00:00 2001 From: Max Bischof <106820326+bischofmax@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:48:03 +0100 Subject: [PATCH 1/2] BC-7906 - Replace WITH_TLDRAW2 with WITH_TLDRAW in ansible conditions (#44) --- ansible/roles/tldraw-server/tasks/main.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ansible/roles/tldraw-server/tasks/main.yml b/ansible/roles/tldraw-server/tasks/main.yml index 954c87c1..260de077 100644 --- a/ansible/roles/tldraw-server/tasks/main.yml +++ b/ansible/roles/tldraw-server/tasks/main.yml @@ -6,7 +6,7 @@ state: "{{ 'present' if WITH_BRANCH_MONGO_DB_MANAGEMENT is defined and WITH_BRANCH_MONGO_DB_MANAGEMENT|bool else 'absent'}}" when: - EXTERNAL_SECRETS_OPERATOR is defined and EXTERNAL_SECRETS_OPERATOR|bool - - WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + - WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - 1password @@ -17,7 +17,7 @@ template: onepassword.yml.j2 when: - ONEPASSWORD_OPERATOR is defined and ONEPASSWORD_OPERATOR|bool - - WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + - WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - 1password @@ -28,7 +28,7 @@ template: configmap.yml.j2 when: - ONEPASSWORD_OPERATOR is defined and ONEPASSWORD_OPERATOR|bool - - WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + - WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - configmap @@ -37,7 +37,7 @@ kubeconfig: ~/.kube/config namespace: "{{ NAMESPACE }}" template: deployment.yml.j2 - state: "{{ 'present' if WITH_TLDRAW2 else 'absent'}}" + state: "{{ 'present' if WITH_TLDRAW else 'absent'}}" tags: - deployment @@ -46,7 +46,7 @@ kubeconfig: ~/.kube/config namespace: "{{ NAMESPACE }}" template: worker-deployment.yml.j2 - state: "{{ 'present' if WITH_TLDRAW2 else 'absent'}}" + state: "{{ 'present' if WITH_TLDRAW else 'absent'}}" tags: - deployment @@ -55,7 +55,7 @@ kubeconfig: ~/.kube/config namespace: "{{ NAMESPACE }}" template: server-svc.yml.j2 - when: WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + when: WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - service @@ -64,7 +64,7 @@ kubeconfig: ~/.kube/config namespace: "{{ NAMESPACE }}" template: pod-monitor.yml.j2 - when: WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + when: WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - prometheus @@ -74,6 +74,6 @@ namespace: "{{ NAMESPACE }}" template: ingress.yml.j2 apply: yes - when: WITH_TLDRAW2 is defined and WITH_TLDRAW2|bool + when: WITH_TLDRAW is defined and WITH_TLDRAW|bool tags: - ingress From 75374c510c23c1bb8fef91c409de00f0b394dd7e Mon Sep 17 00:00:00 2001 From: Cedric Evers <12080057+CeEv@users.noreply.github.com> Date: Tue, 17 Dec 2024 07:24:23 +0100 Subject: [PATCH 2/2] BC-8403 Code improvements (#43) Co-authored-by: SevenWaysDP Co-authored-by: Max Bischof --- .env.test | 3 +- src/apps/tldraw-server.app.ts | 4 +- src/apps/tldraw-worker.app.ts | 10 + .../authorization/authorization.service.ts | 1 - src/infra/logger/logger.spec.ts | 14 +- src/infra/logger/logger.ts | 4 +- src/infra/metrics/index.ts | 4 +- src/infra/redis/index.ts | 2 +- src/infra/redis/interfaces/redis-adapter.ts | 6 +- .../redis/interfaces/stream-message-reply.ts | 14 +- src/infra/redis/ioredis.adapter.spec.ts | 6 +- src/infra/redis/ioredis.adapter.ts | 8 +- src/infra/redis/mapper.ts | 2 +- ....service.spec.ts => redis.factory.spec.ts} | 36 +- .../{redis.service.ts => redis.factory.ts} | 12 +- src/infra/redis/redis.module.ts | 36 +- .../testing/stream-message-reply.factory.ts | 13 + .../testing/stream-messages-reply.factory.ts | 4 +- .../testing/x-auto-claim-response.factory.ts | 14 +- src/infra/storage/storage.service.ts | 8 +- src/infra/y-redis/api.service.spec.ts | 362 ------ src/infra/y-redis/helper.spec.ts | 4 +- src/infra/y-redis/helper.ts | 19 +- .../{y-redis-doc.ts => y-redis-doc-props.ts} | 3 +- src/infra/y-redis/subscriber.service.spec.ts | 243 ++-- src/infra/y-redis/subscriber.service.ts | 97 +- .../testing/stream-message-reply.factory.ts | 14 - .../y-redis/testing/y-redis-doc.factory.ts | 21 + .../y-redis/testing/y-redis-user.factory.ts | 20 + src/infra/y-redis/ws.service.spec.ts | 1075 ----------------- src/infra/y-redis/ws.service.ts | 319 ----- src/infra/y-redis/y-redis-client.module.ts | 31 + src/infra/y-redis/y-redis-doc.factory.ts | 10 + src/infra/y-redis/y-redis-doc.ts | 25 + src/infra/y-redis/y-redis-service.module.ts | 48 + src/infra/y-redis/y-redis-user.factory.ts | 16 + src/infra/y-redis/y-redis-user.ts | 25 + src/infra/y-redis/y-redis.client.spec.ts | 346 ++++++ .../{api.service.ts => y-redis.client.ts} | 110 +- src/infra/y-redis/y-redis.const.ts | 3 + src/infra/y-redis/y-redis.service.spec.ts | 428 +++++++ src/infra/y-redis/y-redis.service.ts | 143 +++ .../server/api/test/tldraw-config.api.spec.ts | 2 +- .../server/api/test/websocket.api.spec.ts | 80 +- .../server/api/websocket.gateway.spec.ts | 174 --- src/modules/server/api/websocket.gateway.ts | 219 +++- src/modules/server/server.const.ts | 3 + src/modules/server/server.module.ts | 10 +- .../service/tldraw-document.service.spec.ts | 38 +- .../server/service/tldraw-document.service.ts | 19 +- src/modules/worker/worker.config.ts | 8 +- src/modules/worker/worker.const.ts | 1 + src/modules/worker/worker.module.ts | 10 +- src/modules/worker/worker.service.spec.ts | 403 +++--- src/modules/worker/worker.service.ts | 277 +++-- 55 files changed, 2183 insertions(+), 2624 deletions(-) rename src/infra/redis/{redis.service.spec.ts => redis.factory.spec.ts} (77%) rename src/infra/redis/{redis.service.ts => redis.factory.ts} (86%) create mode 100644 src/infra/redis/testing/stream-message-reply.factory.ts rename src/infra/{y-redis => redis}/testing/stream-messages-reply.factory.ts (74%) delete mode 100644 src/infra/y-redis/api.service.spec.ts rename src/infra/y-redis/interfaces/{y-redis-doc.ts => y-redis-doc-props.ts} (78%) delete mode 100644 src/infra/y-redis/testing/stream-message-reply.factory.ts create mode 100644 src/infra/y-redis/testing/y-redis-doc.factory.ts create mode 100644 src/infra/y-redis/testing/y-redis-user.factory.ts delete mode 100644 src/infra/y-redis/ws.service.spec.ts delete mode 100644 src/infra/y-redis/ws.service.ts create mode 100644 src/infra/y-redis/y-redis-client.module.ts create mode 100644 src/infra/y-redis/y-redis-doc.factory.ts create mode 100644 src/infra/y-redis/y-redis-doc.ts create mode 100644 src/infra/y-redis/y-redis-service.module.ts create mode 100644 src/infra/y-redis/y-redis-user.factory.ts create mode 100644 src/infra/y-redis/y-redis-user.ts create mode 100644 src/infra/y-redis/y-redis.client.spec.ts rename src/infra/y-redis/{api.service.ts => y-redis.client.ts} (53%) create mode 100644 src/infra/y-redis/y-redis.const.ts create mode 100644 src/infra/y-redis/y-redis.service.spec.ts create mode 100644 src/infra/y-redis/y-redis.service.ts delete mode 100644 src/modules/server/api/websocket.gateway.spec.ts create mode 100644 src/modules/server/server.const.ts create mode 100644 src/modules/worker/worker.const.ts diff --git a/.env.test b/.env.test index aad73219..f8020351 100644 --- a/.env.test +++ b/.env.test @@ -10,6 +10,7 @@ S3_SECRET_KEY=miniouser S3_SSL=false FEATURE_TLDRAW_ENABLED=true -TLDRAW_WEBSOCKET_URL=ws://localhost:3345 +TLDRAW_WEBSOCKET_URL=ws://localhost:3399 +TLDRAW_WEBSOCKET_PORT=3399 X_API_ALLOWED_KEYS=randomString \ No newline at end of file diff --git a/src/apps/tldraw-server.app.ts b/src/apps/tldraw-server.app.ts index c3112c39..79d85de7 100644 --- a/src/apps/tldraw-server.app.ts +++ b/src/apps/tldraw-server.app.ts @@ -29,13 +29,13 @@ async function bootstrap(): Promise { await metricsApp.listen(metricsPort, async () => { const logger = await metricsApp.resolve(Logger); logger.setContext('METRICS'); - logger.log(`Metrics server is running on port ${metricsPort}`); + logger.info(`Metrics server is running on port ${metricsPort}`); }); await nestApp.listen(httpPort, async () => { const logger = await nestApp.resolve(Logger); logger.setContext('TLDRAW'); - logger.log(`Server is running on port ${httpPort}`); + logger.info(`Server is running on port ${httpPort}`); }); } bootstrap(); diff --git a/src/apps/tldraw-worker.app.ts b/src/apps/tldraw-worker.app.ts index 5d2f0e48..e9908cf6 100644 --- a/src/apps/tldraw-worker.app.ts +++ b/src/apps/tldraw-worker.app.ts @@ -1,9 +1,19 @@ import { NestFactory } from '@nestjs/core'; import { WorkerModule } from '../modules/worker/worker.module.js'; +import { WorkerService } from '../modules/worker/worker.service.js'; async function bootstrap(): Promise { const nestApp = await NestFactory.createApplicationContext(WorkerModule); await nestApp.init(); + const workerService = await nestApp.resolve(WorkerService); + + try { + workerService.start(); + } catch (error) { + console.error(error); + workerService.stop(); + process.exit(1); + } } bootstrap(); diff --git a/src/infra/authorization/authorization.service.ts b/src/infra/authorization/authorization.service.ts index 230d7042..380d9407 100644 --- a/src/infra/authorization/authorization.service.ts +++ b/src/infra/authorization/authorization.service.ts @@ -94,7 +94,6 @@ export class AuthorizationService { private createErrorResponsePayload(code: number, reason: string): ResponsePayload { const response = ResponsePayloadBuilder.buildWithError(code, reason); - this.logger.log(`Error: ${code} - ${reason}`); return response; } diff --git a/src/infra/logger/logger.spec.ts b/src/infra/logger/logger.spec.ts index 418037a6..0c9155be 100644 --- a/src/infra/logger/logger.spec.ts +++ b/src/infra/logger/logger.spec.ts @@ -6,7 +6,7 @@ import { RequestLoggingBody } from './interfaces/logger.interface.js'; import { Logger } from './logger.js'; describe('Logger', () => { - let service: Logger; + let logger: Logger; let processStdoutWriteSpy: jest.SpyInstance< boolean, [str: string | Uint8Array, encoding?: BufferEncoding | undefined, cb?: ((err?: Error) => void) | undefined], @@ -30,7 +30,7 @@ describe('Logger', () => { ], }).compile(); - service = await module.resolve(Logger); + logger = await module.resolve(Logger); winstonLogger = module.get(WINSTON_MODULE_PROVIDER); }); @@ -44,10 +44,10 @@ describe('Logger', () => { processStderrWriteSpy.mockRestore(); }); - describe('WHEN log logging', () => { + describe('WHEN info logging', () => { it('should call winstonLogger.info', () => { const error = new Error('custom error'); - service.log(error.message, error.stack); + logger.info(error.message, error.stack); expect(winstonLogger.info).toHaveBeenCalled(); }); }); @@ -55,7 +55,7 @@ describe('Logger', () => { describe('WHEN warn logging', () => { it('should call winstonLogger.warning', () => { const error = new Error('custom error'); - service.warn(error.message, error.stack); + logger.warning(error.message, error.stack); expect(winstonLogger.warning).toHaveBeenCalled(); }); }); @@ -63,7 +63,7 @@ describe('Logger', () => { describe('WHEN debug logging', () => { it('should call winstonLogger.debug', () => { const error = new Error('custom error'); - service.debug(error.message, error.stack); + logger.debug(error.message, error.stack); expect(winstonLogger.debug).toHaveBeenCalled(); }); }); @@ -81,7 +81,7 @@ describe('Logger', () => { }, error, }; - service.http(message, error.stack); + logger.http(message, error.stack); expect(winstonLogger.notice).toHaveBeenCalled(); }); }); diff --git a/src/infra/logger/logger.ts b/src/infra/logger/logger.ts index 76aee069..0b15437c 100644 --- a/src/infra/logger/logger.ts +++ b/src/infra/logger/logger.ts @@ -16,11 +16,11 @@ export class Logger { public constructor(@Inject(WINSTON_MODULE_PROVIDER) private readonly logger: winston.Logger) {} - public log(message: unknown, context?: string): void { + public info(message: unknown, context?: string): void { this.logger.info(this.createMessage(message, context)); } - public warn(message: unknown, context?: string): void { + public warning(message: unknown, context?: string): void { this.logger.warning(this.createMessage(message, context)); } diff --git a/src/infra/metrics/index.ts b/src/infra/metrics/index.ts index 0c575cca..3c2e6576 100644 --- a/src/infra/metrics/index.ts +++ b/src/infra/metrics/index.ts @@ -1,2 +1,2 @@ -export * from './metrics.module.js'; -export * from './metrics.service.js'; +export { MetricsModule } from './metrics.module.js'; +export { MetricsService } from './metrics.service.js'; diff --git a/src/infra/redis/index.ts b/src/infra/redis/index.ts index e6ab38cf..1cbe04b2 100644 --- a/src/infra/redis/index.ts +++ b/src/infra/redis/index.ts @@ -1,2 +1,2 @@ +export * from './redis.factory.js'; export * from './redis.module.js'; -export * from './redis.service.js'; diff --git a/src/infra/redis/interfaces/redis-adapter.ts b/src/infra/redis/interfaces/redis-adapter.ts index 0241eeeb..520c8deb 100644 --- a/src/infra/redis/interfaces/redis-adapter.ts +++ b/src/infra/redis/interfaces/redis-adapter.ts @@ -11,11 +11,11 @@ export interface RedisAdapter { exists(stream: string): Promise; createGroup(): Promise; quit(): Promise; - readStreams(streams: StreamNameClockPair[]): Promise; - readMessagesFromStream(streamName: string): Promise; + readStreams(streams: StreamNameClockPair[]): Promise; + readMessagesFromStream(streamName: string): Promise; reclaimTasks(consumerName: string, redisTaskDebounce: number, tryClaimCount: number): Promise; getDeletedDocEntries(): Promise; - deleteDeleteDocEntry(id: string): Promise; + deleteDeletedDocEntry(id: string): Promise; tryClearTask(task: Task): Promise; tryDeduplicateTask(task: Task, lastId: number, redisMinMessageLifetime: number): Promise; } diff --git a/src/infra/redis/interfaces/stream-message-reply.ts b/src/infra/redis/interfaces/stream-message-reply.ts index 45713040..de07eede 100644 --- a/src/infra/redis/interfaces/stream-message-reply.ts +++ b/src/infra/redis/interfaces/stream-message-reply.ts @@ -1,23 +1,17 @@ import { RedisKey } from 'ioredis'; interface Message { - key: RedisKey; - m?: RedisKey; + m?: Buffer; docName?: string; - compact?: string; + compact?: Buffer; } export interface StreamMessageReply { id: RedisKey; - message: Record; + message: Message; } -export interface StreamMessagesSingleReply { +export interface StreamMessagesReply { name: string; messages: StreamMessageReply[] | null; } - -export type StreamMessagesReply = { - name: string; - messages: StreamMessageReply[] | null; -}[]; diff --git a/src/infra/redis/ioredis.adapter.spec.ts b/src/infra/redis/ioredis.adapter.spec.ts index 7b4c7ce2..8af842a2 100644 --- a/src/infra/redis/ioredis.adapter.spec.ts +++ b/src/infra/redis/ioredis.adapter.spec.ts @@ -194,7 +194,7 @@ describe(IoRedisAdapter.name, () => { await redisAdapter.createGroup(); - expect(logger.log).toHaveBeenCalledWith(error); + expect(logger.info).toHaveBeenCalledWith(error); }); }); @@ -412,7 +412,7 @@ describe(IoRedisAdapter.name, () => { it('should call redis xdel with correct values', async () => { const { id, xdelSpy, expectedProps, redisAdapter } = await setup(); - await redisAdapter.deleteDeleteDocEntry(id); + await redisAdapter.deleteDeletedDocEntry(id); expect(xdelSpy).toHaveBeenCalledWith(...expectedProps); }); @@ -420,7 +420,7 @@ describe(IoRedisAdapter.name, () => { it('should return correct value', async () => { const { id, redisAdapter } = await setup(); - const result = await redisAdapter.deleteDeleteDocEntry(id); + const result = await redisAdapter.deleteDeletedDocEntry(id); expect(result).toBe(1); }); diff --git a/src/infra/redis/ioredis.adapter.ts b/src/infra/redis/ioredis.adapter.ts index d90c1d5a..98a9888c 100644 --- a/src/infra/redis/ioredis.adapter.ts +++ b/src/infra/redis/ioredis.adapter.ts @@ -76,7 +76,7 @@ export class IoRedisAdapter implements RedisAdapter { try { await this.redis.xgroup('CREATE', this.redisWorkerStreamName, this.redisWorkerGroupName, '0', 'MKSTREAM'); } catch (e) { - this.logger.log(e); + this.logger.info(e); // It is okay when the group already exists, so we can ignore this error. if (e.message !== 'BUSYGROUP Consumer Group name already exists') { throw e; @@ -88,7 +88,7 @@ export class IoRedisAdapter implements RedisAdapter { await this.redis.quit(); } - public async readStreams(streams: StreamNameClockPair[]): Promise { + public async readStreams(streams: StreamNameClockPair[]): Promise { const reads = await this.redis.xreadBuffer( 'COUNT', 1000, @@ -104,7 +104,7 @@ export class IoRedisAdapter implements RedisAdapter { return streamReplyRes; } - public async readMessagesFromStream(streamName: string): Promise { + public async readMessagesFromStream(streamName: string): Promise { const reads = await this.redis.xreadBuffer('STREAMS', streamName, '0'); const streamReplyRes = mapToStreamMessagesReply(reads); @@ -140,7 +140,7 @@ export class IoRedisAdapter implements RedisAdapter { return transformedDeletedTasks; } - public deleteDeleteDocEntry(id: string): Promise { + public deleteDeletedDocEntry(id: string): Promise { const result = this.redis.xdel(this.redisDeleteStreamName, id); return result; diff --git a/src/infra/redis/mapper.ts b/src/infra/redis/mapper.ts index 39b2073c..84f18403 100644 --- a/src/infra/redis/mapper.ts +++ b/src/infra/redis/mapper.ts @@ -42,7 +42,7 @@ export function mapToStreamMessagesReplies(messages: XItems | unknown): StreamMe return result; } -export function mapToStreamMessagesReply(streamReply: XReadBufferReply | unknown): StreamMessagesReply { +export function mapToStreamMessagesReply(streamReply: XReadBufferReply | unknown): StreamMessagesReply[] { if (streamReply === null) { return []; } diff --git a/src/infra/redis/redis.service.spec.ts b/src/infra/redis/redis.factory.spec.ts similarity index 77% rename from src/infra/redis/redis.service.spec.ts rename to src/infra/redis/redis.factory.spec.ts index 0c7a1da1..25c6da60 100644 --- a/src/infra/redis/redis.service.spec.ts +++ b/src/infra/redis/redis.factory.spec.ts @@ -4,7 +4,7 @@ import * as util from 'util'; import { Logger } from '../logger/index.js'; import { IoRedisAdapter } from './ioredis.adapter.js'; import { RedisConfig } from './redis.config.js'; -import { RedisService } from './redis.service.js'; +import { RedisFactory } from './redis.factory.js'; jest.mock('ioredis', () => { return { @@ -14,7 +14,7 @@ jest.mock('ioredis', () => { jest.mock('./ioredis.adapter.js'); -describe('Redis Service', () => { +describe(RedisFactory.name, () => { beforeEach(() => { jest.resetAllMocks(); }); @@ -50,7 +50,7 @@ describe('Redis Service', () => { const constructorSpy = jest.spyOn(Redis.prototype, 'constructor'); const logger = createMock(); - const service = new RedisService(config, logger); + const factory = new RedisFactory(config, logger); const expectedProps = { sentinels: [ @@ -62,29 +62,29 @@ describe('Redis Service', () => { name: 'sentinelName', }; - return { resolveSrv, sentinelServiceName, service, constructorSpy, expectedProps }; + return { resolveSrv, sentinelServiceName, factory, constructorSpy, expectedProps }; }; it('calls resolveSrv', async () => { - const { resolveSrv, sentinelServiceName, service } = setup(); + const { resolveSrv, sentinelServiceName, factory } = setup(); - await service.createRedisInstance(); + await factory.createRedisInstance(); expect(resolveSrv).toHaveBeenLastCalledWith(sentinelServiceName); }); it('create new Redis instance with correctly props', async () => { - const { service, constructorSpy, expectedProps } = setup(); + const { factory, constructorSpy, expectedProps } = setup(); - await service.createRedisInstance(); + await factory.createRedisInstance(); expect(constructorSpy).toHaveBeenCalledWith(expectedProps); }); it('creates a new Redis instance', async () => { - const { service } = setup(); + const { factory } = setup(); - const redisInstance = await service.createRedisInstance(); + const redisInstance = await factory.createRedisInstance(); expect(redisInstance).toBeInstanceOf(IoRedisAdapter); }); @@ -105,33 +105,33 @@ describe('Redis Service', () => { const constructorSpy = jest.spyOn(Redis.prototype, 'constructor'); const logger = createMock(); - const service = new RedisService(config, logger); + const factory = new RedisFactory(config, logger); const expectedProps = redisUrl; - return { resolveSrv, service, redisMock, constructorSpy, expectedProps }; + return { resolveSrv, factory, redisMock, constructorSpy, expectedProps }; }; it('calls resolveSrv', async () => { - const { resolveSrv, service } = setup(); + const { resolveSrv, factory } = setup(); - await service.createRedisInstance(); + await factory.createRedisInstance(); expect(resolveSrv).not.toHaveBeenCalled(); }); it('create new Redis instance with correctly props', async () => { - const { service, constructorSpy, expectedProps } = setup(); + const { factory, constructorSpy, expectedProps } = setup(); - await service.createRedisInstance(); + await factory.createRedisInstance(); expect(constructorSpy).toHaveBeenCalledWith(expectedProps); }); it('creates a new Redis instance', async () => { - const { service } = setup(); + const { factory } = setup(); - const redisInstance = await service.createRedisInstance(); + const redisInstance = await factory.createRedisInstance(); expect(redisInstance).toBeInstanceOf(IoRedisAdapter); }); diff --git a/src/infra/redis/redis.service.ts b/src/infra/redis/redis.factory.ts similarity index 86% rename from src/infra/redis/redis.service.ts rename to src/infra/redis/redis.factory.ts index 095ceba8..07c2822d 100644 --- a/src/infra/redis/redis.service.ts +++ b/src/infra/redis/redis.factory.ts @@ -1,4 +1,3 @@ -import { Injectable } from '@nestjs/common'; import * as dns from 'dns'; import { Redis } from 'ioredis'; import * as util from 'util'; @@ -7,14 +6,11 @@ import { RedisAdapter } from './interfaces/index.js'; import { IoRedisAdapter } from './ioredis.adapter.js'; import { RedisConfig } from './redis.config.js'; -@Injectable() -export class RedisService { +export class RedisFactory { public constructor( private readonly config: RedisConfig, private readonly logger: Logger, - ) { - this.logger.setContext(RedisService.name); - } + ) {} public async createRedisInstance(): Promise { let redisInstance: Redis; @@ -39,7 +35,7 @@ export class RedisService { const sentinelName = this.config.REDIS_SENTINEL_NAME; const sentinelPassword = this.config.REDIS_SENTINEL_PASSWORD; const sentinels = await this.discoverSentinelHosts(); - this.logger.log(`Discovered sentinels: ${JSON.stringify(sentinels)}`); + this.logger.info(`Discovered sentinels: ${JSON.stringify(sentinels)}`); const redisInstance = new Redis({ sentinels, @@ -63,7 +59,7 @@ export class RedisService { return hosts; } catch (err) { - this.logger.log('Error during service discovery:', err); + this.logger.info('Error during service discovery:', err); throw err; } } diff --git a/src/infra/redis/redis.module.ts b/src/infra/redis/redis.module.ts index 3f64d58d..bc8fb909 100644 --- a/src/infra/redis/redis.module.ts +++ b/src/infra/redis/redis.module.ts @@ -1,12 +1,32 @@ -import { Module } from '@nestjs/common'; +import { DynamicModule, Module } from '@nestjs/common'; import { ConfigurationModule } from '../configuration/configuration.module.js'; +import { Logger } from '../logger/logger.js'; import { LoggerModule } from '../logger/logger.module.js'; +import { RedisAdapter } from './interfaces/redis-adapter.js'; import { RedisConfig } from './redis.config.js'; -import { RedisService } from './redis.service.js'; +import { RedisFactory } from './redis.factory.js'; -@Module({ - imports: [LoggerModule, ConfigurationModule.register(RedisConfig)], - providers: [RedisService], - exports: [RedisService], -}) -export class RedisModule {} +@Module({}) +export class RedisModule { + public static registerFor(token: string): DynamicModule { + return { + module: RedisModule, + imports: [LoggerModule, ConfigurationModule.register(RedisConfig)], + providers: [ + { + provide: token, + useFactory: async (config: RedisConfig, logger: Logger): Promise => { + logger.setContext(`${RedisFactory.name} - ${token}`); + + const redisFactory = new RedisFactory(config, logger); + const redisAdapter = await redisFactory.createRedisInstance(); + + return redisAdapter; + }, + inject: [RedisConfig, Logger], + }, + ], + exports: [token], + }; + } +} diff --git a/src/infra/redis/testing/stream-message-reply.factory.ts b/src/infra/redis/testing/stream-message-reply.factory.ts new file mode 100644 index 00000000..d9262b66 --- /dev/null +++ b/src/infra/redis/testing/stream-message-reply.factory.ts @@ -0,0 +1,13 @@ +import { Factory } from 'fishery'; +import { StreamMessageReply } from '../interfaces/index.js'; + +export const streamMessageReplyFactory = Factory.define(({ sequence }) => { + return { + id: `redis-id-${sequence}`, + message: { + m: Buffer.from(`message-${sequence}-2`), + docName: `prefix:room:room:docid-${sequence.toString()}`, + compact: Buffer.from(`prefix:room:room:docid-${sequence.toString()}`), + }, + }; +}); diff --git a/src/infra/y-redis/testing/stream-messages-reply.factory.ts b/src/infra/redis/testing/stream-messages-reply.factory.ts similarity index 74% rename from src/infra/y-redis/testing/stream-messages-reply.factory.ts rename to src/infra/redis/testing/stream-messages-reply.factory.ts index 0a603631..a56fdeac 100644 --- a/src/infra/y-redis/testing/stream-messages-reply.factory.ts +++ b/src/infra/redis/testing/stream-messages-reply.factory.ts @@ -1,8 +1,8 @@ import { Factory } from 'fishery'; -import { StreamMessagesReply } from '../../../infra/redis/interfaces/index.js'; +import { StreamMessagesReply } from '../interfaces/index.js'; import { streamMessageReplyFactory } from './stream-message-reply.factory.js'; -export const streamMessagesReplyFactory = Factory.define(({ sequence }) => { +export const streamMessagesReplyFactory = Factory.define(({ sequence }) => { return [ { name: `prefix:room:roomid-${sequence}:docid`, diff --git a/src/infra/redis/testing/x-auto-claim-response.factory.ts b/src/infra/redis/testing/x-auto-claim-response.factory.ts index f0a7245d..a5458e70 100644 --- a/src/infra/redis/testing/x-auto-claim-response.factory.ts +++ b/src/infra/redis/testing/x-auto-claim-response.factory.ts @@ -1,5 +1,5 @@ import { Factory } from 'fishery'; -import { StreamMessageReply, XAutoClaimResponse } from '../interfaces/index.js'; +import { XAutoClaimResponse } from '../interfaces/index.js'; export const xAutoClaimResponseFactory = Factory.define(({ sequence }) => { return { @@ -7,15 +7,3 @@ export const xAutoClaimResponseFactory = Factory.define(({ s messages: [], }; }); - -export const streamMessageReplyFactory = Factory.define(({ sequence }) => { - return { - id: sequence.toString(), - message: { - key: sequence.toString(), - m: sequence.toString(), - docName: `prefix:room:room:docid-${sequence.toString()}`, - compact: `prefix:room:room:docid-${sequence.toString()}`, - }, - }; -}); diff --git a/src/infra/storage/storage.service.ts b/src/infra/storage/storage.service.ts index a3e477b9..6f7d71c2 100644 --- a/src/infra/storage/storage.service.ts +++ b/src/infra/storage/storage.service.ts @@ -1,6 +1,6 @@ import { Injectable, OnModuleInit } from '@nestjs/common'; -import { Client } from 'minio'; import { randomUUID } from 'crypto'; +import { Client } from 'minio'; import { Stream } from 'stream'; import * as Y from 'yjs'; import { Logger } from '../logger/index.js'; @@ -42,14 +42,14 @@ export class StorageService implements DocumentStorage, OnModuleInit { } public async retrieveDoc(room: string, docname: string): Promise<{ doc: Uint8Array; references: string[] } | null> { - this.logger.log('retrieving doc room=' + room + ' docname=' + docname); + this.logger.info('retrieving doc room=' + room + ' docname=' + docname); const objNames = await this.client .listObjectsV2(this.config.S3_BUCKET, encodeS3ObjectName(room, docname), true) .toArray(); const references: string[] = objNames.map((obj) => obj.name); - this.logger.log('retrieved doc room=' + room + ' docname=' + docname + ' refs=' + JSON.stringify(references)); + this.logger.info('retrieved doc room=' + room + ' docname=' + docname + ' refs=' + JSON.stringify(references)); if (references.length === 0) { return null; @@ -67,7 +67,7 @@ export class StorageService implements DocumentStorage, OnModuleInit { }), ); updates = updates.filter((update) => update != null); - this.logger.log('retrieved doc room=' + room + ' docname=' + docname + ' updatesLen=' + updates.length); + this.logger.info('retrieved doc room=' + room + ' docname=' + docname + ' updatesLen=' + updates.length); return { doc: Y.mergeUpdatesV2(updates), references }; } diff --git a/src/infra/y-redis/api.service.spec.ts b/src/infra/y-redis/api.service.spec.ts deleted file mode 100644 index 0b6f2f77..00000000 --- a/src/infra/y-redis/api.service.spec.ts +++ /dev/null @@ -1,362 +0,0 @@ -import { createMock } from '@golevelup/ts-jest'; -import * as Awareness from 'y-protocols/awareness'; -import * as Y from 'yjs'; -import { RedisService } from '../../infra/redis/redis.service.js'; -import { RedisAdapter } from '../redis/interfaces/index.js'; -import { Api, createApiClient, handleMessageUpdates } from './api.service.js'; -import * as helper from './helper.js'; -import * as protocol from './protocol.js'; -import { DocumentStorage } from './storage.js'; -import { streamMessagesReplyFactory } from './testing/stream-messages-reply.factory.js'; -import { yRedisMessageFactory } from './testing/y-redis-message.factory.js'; - -describe(Api.name, () => { - const setupApi = () => { - const store = createMock(); - const redis = createMock({ - redisPrefix: 'prefix', - }); - const api = new Api(store, redis); - - return { store, redis, api }; - }; - - afterEach(() => { - jest.restoreAllMocks(); - }); - - describe('getMessages', () => { - describe('when streams is empty', () => { - it('should return empty array', async () => { - const { api } = setupApi(); - - const result = await api.getMessages([]); - - expect(result).toEqual([]); - }); - }); - - describe('when streams is not empty', () => { - const setup = () => { - const { redis, api } = setupApi(); - - const m = streamMessagesReplyFactory.build(); - redis.readStreams.mockResolvedValueOnce(m); - - const props = [ - { - key: 'stream1', - id: '1', - }, - ]; - const spyMergeMessages = jest.spyOn(protocol, 'mergeMessages').mockReturnValueOnce([]); - - const { name, messages } = m[0]; - // @ts-ignore - const lastId = messages[messages.length - 1].id; - - const expectedResult = [ - { - lastId, - messages: [], - stream: name, - }, - ]; - - const expectedMessages = messages?.map((message) => message.message.m).filter((m) => m != null); - - return { redis, api, spyMergeMessages, expectedResult, expectedMessages, props }; - }; - - it('should call redis.readStreams with correctly params', async () => { - const { api, redis, props } = setup(); - - await api.getMessages(props); - - expect(redis.readStreams).toHaveBeenCalledTimes(1); - expect(redis.readStreams).toHaveBeenCalledWith(props); - }); - - it('should call protocol.mergeMessages with correctly values', async () => { - const { api, spyMergeMessages, expectedMessages, props } = setup(); - - await api.getMessages(props); - - expect(spyMergeMessages).toHaveBeenCalledTimes(1); - expect(spyMergeMessages).toHaveBeenCalledWith(expectedMessages); - }); - - it('should return expected messages', async () => { - const { api, expectedResult, props } = setup(); - - const result = await api.getMessages(props); - - expect(result).toEqual(expectedResult); - }); - }); - }); - - describe('addMessage', () => { - describe('when m is a sync step 2 message', () => { - const setup = () => { - const { api, redis } = setupApi(); - - const room = 'room'; - const docid = 'docid'; - const message = Buffer.from([protocol.messageSync, protocol.messageSyncStep2]); - - const props = { room, docid, message }; - - return { api, redis, props }; - }; - it('should return a promise', async () => { - const { api, redis, props } = setup(); - - const result = await api.addMessage(props.room, props.docid, props.message); - - expect(result).toBeUndefined(); - expect(redis.addMessage).not.toHaveBeenCalled(); - }); - }); - - describe('when m is not a sync step 2 message', () => { - const setup = () => { - const { api, redis } = setupApi(); - - const room = 'room'; - const docid = 'docid'; - const message = Buffer.from([protocol.messageSync, protocol.messageSyncUpdate]); - - const props = { room, docid, message }; - - return { api, redis, props }; - }; - - it('should call redis.addMessage with correctly params', async () => { - const { api, redis, props } = setup(); - - await api.addMessage(props.room, props.docid, props.message); - - expect(redis.addMessage).toHaveBeenCalledTimes(1); - expect(redis.addMessage).toHaveBeenCalledWith('prefix:room:room:docid', props.message); - }); - }); - - describe('when m is correctly message', () => { - const setup = () => { - const { api } = setupApi(); - - const room = 'room'; - const docid = 'docid'; - const message = Buffer.from([protocol.messageSync, protocol.messageSyncStep2, 0x54, 0x45, 0x53, 0x54]); - - const props = { room, docid, message }; - - return { api, props }; - }; - it('should set correctly protocol type', async () => { - const { api, props } = setup(); - - await api.addMessage(props.room, props.docid, props.message); - - expect(props.message[1]).toEqual(protocol.messageSyncUpdate); - }); - }); - }); - - describe('getStateVector', () => { - const setup = () => { - const { api, store } = setupApi(); - - const room = 'room'; - const docid = 'docid'; - - const props = { room, docid }; - - return { api, store, props }; - }; - - it('should call store.retrieveStateVector with correctly params', async () => { - const { api, store, props } = setup(); - const { room, docid } = props; - - await api.getStateVector(room, docid); - - expect(store.retrieveStateVector).toHaveBeenCalledTimes(1); - expect(store.retrieveStateVector).toHaveBeenCalledWith(room, docid); - }); - }); - - describe('getDoc', () => { - const setup = () => { - const { api, store, redis } = setupApi(); - const spyComputeRedisRoomStreamName = jest.spyOn(helper, 'computeRedisRoomStreamName'); - const spyExtractMessagesFromStreamReply = jest.spyOn(helper, 'extractMessagesFromStreamReply'); - - const ydoc = new Y.Doc(); - const doc = Y.encodeStateAsUpdateV2(ydoc); - const streamReply = streamMessagesReplyFactory.build(); - redis.readMessagesFromStream.mockResolvedValueOnce(streamReply); - store.retrieveDoc.mockResolvedValueOnce({ doc, references: [] }); - - const room = 'roomid-1'; - const docid = 'docid'; - - const props = { room, docid }; - - return { - api, - store, - redis, - props, - spyComputeRedisRoomStreamName, - spyExtractMessagesFromStreamReply, - streamReply, - }; - }; - - it('should call computeRedisRoomStreamName with correctly params', async () => { - const { api, props, spyComputeRedisRoomStreamName } = setup(); - const { room, docid } = props; - - const result = await api.getDoc(room, docid); - result.awareness.destroy(); - - expect(spyComputeRedisRoomStreamName).toHaveBeenCalledWith(room, docid, 'prefix'); - }); - - it('should call redis.readMessagesFromStream with correctly params', async () => { - const { api, props, redis } = setup(); - const { room, docid } = props; - - const result = await api.getDoc(room, docid); - result.awareness.destroy(); - - expect(redis.readMessagesFromStream).toHaveBeenCalledTimes(1); - expect(redis.readMessagesFromStream).toHaveBeenCalledWith('prefix:room:roomid-1:docid'); - }); - - it('should call extractMessagesFromStreamReply with correctly params', async () => { - const { api, props, spyExtractMessagesFromStreamReply, streamReply } = setup(); - const { room, docid } = props; - - const result = await api.getDoc(room, docid); - result.awareness.destroy(); - - expect(spyExtractMessagesFromStreamReply).toHaveBeenCalledWith(streamReply, 'prefix'); - }); - - it('should return expected result', async () => { - const { api, props } = setup(); - const { room, docid } = props; - - const result = await api.getDoc(room, docid); - result.awareness.destroy(); - - expect(result).toBeDefined(); - expect(result).toEqual(expect.objectContaining({ ydoc: expect.any(Y.Doc) })); - }); - }); - - describe('destroy', () => { - const setup = () => { - const { api, redis } = setupApi(); - - return { api, redis }; - }; - - it('should set _destroyed to true', () => { - const { api } = setup(); - - api.destroy(); - - expect(api._destroyed).toBeTruthy(); - }); - - it('should call store.destroy with correctly params', async () => { - const { api, redis } = setup(); - - await api.destroy(); - - expect(redis.quit).toHaveBeenCalledTimes(1); - }); - }); -}); - -describe('handleMessageUpdates', () => { - describe('when a message is messageSyncUpdate', () => { - const setup = () => { - const ydoc = new Y.Doc(); - const awareness = createMock(); - const message = Buffer.from([protocol.messageSync, protocol.messageSyncUpdate, 0x54, 0x45, 0x53, 0x54]); - - const messages = yRedisMessageFactory.build({ messages: [message] }); - - const spyApplyUpdate = jest.spyOn(Y, 'applyUpdate'); - spyApplyUpdate.mockReturnValueOnce(undefined); - - return { ydoc, awareness, messages, spyApplyUpdate }; - }; - - it('should call Y.applyUpdate with correctly params', () => { - const { ydoc, awareness, messages, spyApplyUpdate } = setup(); - - handleMessageUpdates(messages, ydoc, awareness); - - expect(spyApplyUpdate).toHaveBeenCalledWith(ydoc, expect.anything()); - }); - }); - - describe('when a message is messageSyncAwareness', () => { - const setup = () => { - const ydoc = new Y.Doc(); - const awareness = createMock(); - const message = Buffer.from([protocol.messageAwareness, 0x54, 0x45, 0x53, 0x54]); - - const messages = yRedisMessageFactory.build({ messages: [message] }); - - const spyApplyAwarenessUpdate = jest.spyOn(Awareness, 'applyAwarenessUpdate'); - spyApplyAwarenessUpdate.mockReturnValueOnce(undefined); - - return { ydoc, awareness, messages, spyApplyAwarenessUpdate }; - }; - - it('should call Y.applyAwarenessUpdate with correctly params', () => { - const { ydoc, awareness, messages, spyApplyAwarenessUpdate } = setup(); - - handleMessageUpdates(messages, ydoc, awareness); - - expect(spyApplyAwarenessUpdate).toHaveBeenCalledWith(awareness, expect.anything(), null); - }); - }); -}); - -describe('createApiClient', () => { - const setup = () => { - const store = createMock(); - const redisService = createMock(); - const redisInstance = createMock(); - const apiInstance = createMock({ - redis: redisInstance, - }); - - return { store, redisService, redisInstance, apiInstance }; - }; - - it('should call createRedisInstance.createRedisInstance', async () => { - const { store, redisService } = setup(); - - await createApiClient(store, redisService); - - expect(redisService.createRedisInstance).toHaveBeenCalledTimes(1); - }); - - it('should return an instance of Api', async () => { - const { store, redisService } = setup(); - - const result = await createApiClient(store, redisService); - - expect(result).toBeDefined(); - expect(result).toBeInstanceOf(Api); - }); -}); diff --git a/src/infra/y-redis/helper.spec.ts b/src/infra/y-redis/helper.spec.ts index 3a53ef87..5e277062 100644 --- a/src/infra/y-redis/helper.spec.ts +++ b/src/infra/y-redis/helper.spec.ts @@ -1,10 +1,10 @@ +import { streamMessagesReplyFactory } from '../redis/testing/stream-messages-reply.factory.js'; import { computeRedisRoomStreamName, decodeRedisRoomStreamName, extractMessagesFromStreamReply, isSmallerRedisId, } from './helper.js'; -import { streamMessagesReplyFactory } from './testing/stream-messages-reply.factory.js'; describe('helper', () => { describe('isSmallerRedisId', () => { @@ -116,7 +116,7 @@ describe('helper', () => { 'docid', { lastId: 'redis-id-2', - messages: ['message-1-2', 'message-2-2'], + messages: [Buffer.from('message-1-2'), Buffer.from('message-2-2')], }, ], ]), diff --git a/src/infra/y-redis/helper.ts b/src/infra/y-redis/helper.ts index 68aee03d..5c943bb4 100644 --- a/src/infra/y-redis/helper.ts +++ b/src/infra/y-redis/helper.ts @@ -1,11 +1,7 @@ import { RedisKey } from 'ioredis'; import { array, map } from 'lib0'; import { TypeGuard } from '../../infra/redis/guards/type.guard.js'; -import { - StreamMessageReply, - StreamMessagesReply, - StreamMessagesSingleReply, -} from '../../infra/redis/interfaces/index.js'; +import { StreamMessageReply, StreamMessagesReply } from '../../infra/redis/interfaces/index.js'; import { YRedisMessage } from './interfaces/stream-message.js'; /* This file contains the implementation of the functions, @@ -27,10 +23,11 @@ export const isSmallerRedisId = (a: string, b: string): boolean => { export const computeRedisRoomStreamName = (room: string, docid: string, prefix: string): string => `${prefix}:room:${encodeURIComponent(room)}:${encodeURIComponent(docid)}`; -export const decodeRedisRoomStreamName = ( - rediskey: string, - expectedPrefix: string, -): { room: string; docid: string } => { +export interface RoomStreamInfos { + room: string; + docid: string; +} +export const decodeRedisRoomStreamName = (rediskey: string, expectedPrefix: string): RoomStreamInfos => { const match = /^(.*):room:(.*):(.*)$/.exec(rediskey); if (match == null || match[1] !== expectedPrefix) { throw new Error( @@ -41,7 +38,7 @@ export const decodeRedisRoomStreamName = ( return { room: decodeURIComponent(match[2]), docid: decodeURIComponent(match[3]) }; }; -const getIdFromLastStreamMessageReply = (docStreamReplay: StreamMessagesSingleReply): RedisKey | undefined => { +const getIdFromLastStreamMessageReply = (docStreamReplay: StreamMessagesReply): RedisKey | undefined => { let id = undefined; if (TypeGuard.isArrayWithElements(docStreamReplay.messages)) { id = array.last(docStreamReplay.messages).id; @@ -58,7 +55,7 @@ const castRedisKeyToUnit8Array = (redisKey: RedisKey): Uint8Array => { }; export const extractMessagesFromStreamReply = ( - streamReply: StreamMessagesReply, + streamReply: StreamMessagesReply[], prefix: string, ): Map> => { const messages = new Map>(); diff --git a/src/infra/y-redis/interfaces/y-redis-doc.ts b/src/infra/y-redis/interfaces/y-redis-doc-props.ts similarity index 78% rename from src/infra/y-redis/interfaces/y-redis-doc.ts rename to src/infra/y-redis/interfaces/y-redis-doc-props.ts index 0fa2cf47..9b5cc96e 100644 --- a/src/infra/y-redis/interfaces/y-redis-doc.ts +++ b/src/infra/y-redis/interfaces/y-redis-doc-props.ts @@ -1,10 +1,11 @@ import { Awareness } from 'y-protocols/awareness'; import { Doc } from 'yjs'; -export interface YRedisDoc { +export interface YRedisDocProps { ydoc: Doc; awareness: Awareness; redisLastId: string; storeReferences: string[] | null; docChanged: boolean; + streamName: string; } diff --git a/src/infra/y-redis/subscriber.service.spec.ts b/src/infra/y-redis/subscriber.service.spec.ts index 2bed3175..6e731660 100644 --- a/src/infra/y-redis/subscriber.service.spec.ts +++ b/src/infra/y-redis/subscriber.service.spec.ts @@ -1,178 +1,123 @@ -import { createMock } from '@golevelup/ts-jest'; -import { RedisService } from '../../infra/redis/redis.service.js'; -import { Api } from './api.service.js'; -import { DocumentStorage } from './storage.js'; -import * as subscriberService from './subscriber.service.js'; +import { createMock, DeepMocked } from '@golevelup/ts-jest'; +import { Test, TestingModule } from '@nestjs/testing'; +import { Logger } from '../logger/logger.js'; +import { SubscriberService } from './subscriber.service.js'; import { yRedisMessageFactory } from './testing/y-redis-message.factory.js'; +import { YRedisClient } from './y-redis.client.js'; describe('SubscriberService', () => { - describe('run', () => { - let callCount = 0; - - beforeEach(() => { - callCount = 0; - }); - - const setup = () => { - const subscriber = createMock({ - run: jest.fn(), - }); - - Object.defineProperty(subscriberService, 'running', { - get: () => { - if (callCount === 0) { - callCount++; - - return true; - } - - return false; - }, - }); - - return { subscriber }; - }; - - it('should call subscriber.run', async () => { - const { subscriber } = setup(); - - await subscriberService.run(subscriber); - - expect(subscriber.run).toHaveBeenCalled(); - }); - }); - - describe('createSubscriber', () => { - const setup = () => { - const store = createMock(); - const createRedisInstance = createMock(); - - const runSpy = jest.spyOn(subscriberService, 'run').mockResolvedValue(); - - return { store, createRedisInstance, runSpy }; - }; - - it('should call run', async () => { - const { store, createRedisInstance, runSpy } = setup(); - - await subscriberService.createSubscriber(store, createRedisInstance); - - expect(runSpy).toHaveBeenCalled(); + describe(SubscriberService.name, () => { + let module: TestingModule; + let service: SubscriberService; + let yRedisClient: DeepMocked; + + beforeEach(async () => { + module = await Test.createTestingModule({ + providers: [ + SubscriberService, + { + provide: YRedisClient, + useValue: createMock(), + }, + { + provide: Logger, + useValue: createMock(), + }, + ], + }).compile(); + + service = module.get(SubscriberService); + yRedisClient = module.get(YRedisClient); }); - it('should return subscriber', async () => { - const { store, createRedisInstance } = setup(); - - const subscriber = await subscriberService.createSubscriber(store, createRedisInstance); - - expect(subscriber).toBeDefined(); - expect(subscriber).toBeInstanceOf(subscriberService.Subscriber); + afterEach(() => { + jest.restoreAllMocks(); }); - }); - - describe(subscriberService.Subscriber.name, () => { - const setup = () => { - const api = createMock(); - const subscriber = new subscriberService.Subscriber(api); - - return { subscriber, api }; - }; it('should be defined', () => { - const { subscriber } = setup(); - - expect(subscriber).toBeDefined(); + expect(service).toBeDefined(); }); describe('ensureSubId', () => { it('should update nextId when id is smaller', () => { - const { subscriber } = setup(); const id = '1'; const stream = 'test'; - subscriber.subscribers.set(stream, { fs: new Set(), id: '2', nextId: null }); + service.subscribers.set(stream, { fs: new Set(), id: '2', nextId: null }); - subscriber.ensureSubId(stream, id); + service.ensureSubId(stream, id); - expect(subscriber.subscribers.get(stream)?.nextId).toEqual(id); + expect(service.subscribers.get(stream)?.nextId).toEqual(id); }); it('should not update nextId when id is not smaller', () => { - const { subscriber } = setup(); const id = '3'; const stream = 'test'; - subscriber.subscribers.set(stream, { fs: new Set(), id: '2', nextId: null }); + service.subscribers.set(stream, { fs: new Set(), id: '2', nextId: null }); - subscriber.ensureSubId(stream, id); + service.ensureSubId(stream, id); - expect(subscriber.subscribers.get(stream)?.nextId).toBeNull(); + expect(service.subscribers.get(stream)?.nextId).toBeNull(); }); }); describe('subscribe', () => { describe('when stream is not present', () => { it('should add stream to subscribers', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.size).toEqual(1); + expect(service.subscribers.size).toEqual(1); }); it('should add subscription handler to stream', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.get('test')?.fs.size).toEqual(1); + expect(service.subscribers.get('test')?.fs.size).toEqual(1); }); - it('should have many subscriber', () => { - const { subscriber } = setup(); + it('should have two subscribers', () => { const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.size).toEqual(1); - subscriber.subscribe('test1', subscriptionHandler); - expect(subscriber.subscribers.size).toEqual(2); + expect(service.subscribers.size).toEqual(1); + service.subscribe('test1', subscriptionHandler); + expect(service.subscribers.size).toEqual(2); }); it('should add stream to subscribers with next id as null', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.get('test')?.nextId).toBeNull(); + expect(service.subscribers.get('test')?.nextId).toBeNull(); }); it('should add stream to subscribers with id as 0', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.get('test')?.id).toEqual('0'); + expect(service.subscribers.get('test')?.id).toEqual('0'); }); it('should add stream to subscribers with subscription handler', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); - expect(subscriber.subscribers.get('test')?.fs.has(subscriptionHandler)).toBeTruthy(); + expect(service.subscribers.get('test')?.fs.has(subscriptionHandler)).toBeTruthy(); }); it('should return correctly result', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - const result = subscriber.subscribe('test', subscriptionHandler); + const result = service.subscribe('test', subscriptionHandler); expect(result).toEqual({ redisId: '0' }); }); @@ -182,55 +127,68 @@ describe('SubscriberService', () => { describe('unsubscribe', () => { describe('when stream is present', () => { it('should remove just once subscription handler from stream', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); const subscriptionHandler1 = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); - subscriber.subscribe('test', subscriptionHandler1); - subscriber.unsubscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler1); + service.unsubscribe('test', subscriptionHandler); - expect(subscriber.subscribers.get('test')?.fs.size).toEqual(1); + expect(service.subscribers.get('test')?.fs.size).toEqual(1); }); it('should remove stream from subscribers when fs size is 0', () => { - const { subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); - subscriber.unsubscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); + service.unsubscribe('test', subscriptionHandler); - expect(subscriber.subscribers.size).toEqual(0); + expect(service.subscribers.size).toEqual(0); + }); + }); + }); + + describe('status', () => { + describe('when running is true', () => { + it('should return true', () => { + expect(service.status()).toBeTruthy(); + }); + }); + + describe('when running is false', () => { + it('should return false', () => { + service.stop(); + + expect(service.status()).toBeFalsy(); }); }); }); describe('destroy', () => { it('should call client destroy', async () => { - const { subscriber, api } = setup(); - - await subscriber.destroy(); + await service.onModuleDestroy(); - expect(api.destroy).toHaveBeenCalled(); + expect(yRedisClient.destroy).toHaveBeenCalled(); }); }); describe('run', () => { const setupRun = () => { - const { api, subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); + const messages = yRedisMessageFactory.build({ stream: 'test' }); + yRedisClient.getMessages.mockResolvedValue([messages]); - return { api, subscriber, subscriptionHandler }; + return { subscriptionHandler }; }; it('should call client getMessages', async () => { - const { api, subscriber } = setupRun(); + setupRun(); - await subscriber.run(); + await service.run(); - expect(api.getMessages).toHaveBeenCalledWith( + expect(yRedisClient.getMessages).toHaveBeenCalledWith( expect.arrayContaining([ { key: expect.any(String), @@ -241,22 +199,22 @@ describe('SubscriberService', () => { }); it('should call subscription handler', async () => { - const { api, subscriber } = setupRun(); + setupRun(); const messages = yRedisMessageFactory.buildList(3, { stream: 'test' }); - const spyGetSubscribers = jest.spyOn(subscriber.subscribers, 'get'); - api.getMessages.mockResolvedValueOnce(messages); + const spyGetSubscribers = jest.spyOn(service.subscribers, 'get'); + yRedisClient.getMessages.mockResolvedValueOnce(messages); - await subscriber.run(); + await service.run(); expect(spyGetSubscribers).toHaveBeenCalledTimes(3); }); it('should call subscription handler', async () => { - const { api, subscriber, subscriptionHandler } = setupRun(); + const { subscriptionHandler } = setupRun(); const messages = yRedisMessageFactory.buildList(3, { stream: 'test' }); - api.getMessages.mockResolvedValue(messages); + yRedisClient.getMessages.mockResolvedValue(messages); - await subscriber.run(); + await service.run(); expect(subscriptionHandler).toHaveBeenCalledWith(messages[0].stream, messages[0].messages); expect(subscriptionHandler).toHaveBeenCalledWith(messages[1].stream, messages[1].messages); @@ -264,25 +222,24 @@ describe('SubscriberService', () => { }); it('should skip subscription handler', async () => { - const { api, subscriber, subscriptionHandler } = setupRun(); + const { subscriptionHandler } = setupRun(); const messages = yRedisMessageFactory.buildList(3, { stream: 'skip' }); - api.getMessages.mockResolvedValue(messages); + yRedisClient.getMessages.mockResolvedValue(messages); - await subscriber.run(); + await service.run(); expect(subscriptionHandler).not.toHaveBeenCalled(); }); describe('when nextId is not null', () => { const setupRun = () => { - const { api, subscriber } = setup(); const subscriptionHandler = jest.fn(); - subscriber.subscribe('test', subscriptionHandler); + service.subscribe('test', subscriptionHandler); const messages = yRedisMessageFactory.build({ stream: 'test' }); - api.getMessages.mockResolvedValue([messages]); + yRedisClient.getMessages.mockResolvedValue([messages]); - const testSubscriber = subscriber.subscribers.get('test'); + const testSubscriber = service.subscribers.get('test'); if (testSubscriber) { testSubscriber.nextId = '1'; } @@ -292,13 +249,13 @@ describe('SubscriberService', () => { id: '1', }; - return { api, subscriber, testSubscriber, expectedMessages }; + return { yRedisClient, service, testSubscriber, expectedMessages }; }; it('should set id and nextId ', async () => { - const { subscriber, testSubscriber, expectedMessages } = setupRun(); + const { service, testSubscriber, expectedMessages } = setupRun(); - await subscriber.run(); + await service.run(); expect(testSubscriber).toEqual(expect.objectContaining(expectedMessages)); }); diff --git a/src/infra/y-redis/subscriber.service.ts b/src/infra/y-redis/subscriber.service.ts index 8614e648..04467fb7 100644 --- a/src/infra/y-redis/subscriber.service.ts +++ b/src/infra/y-redis/subscriber.service.ts @@ -5,43 +5,50 @@ The original code from the `y-redis` repository is licensed under the AGPL-3.0 license. https://github.com/yjs/y-redis */ +import { Injectable, OnModuleDestroy } from '@nestjs/common'; import * as map from 'lib0/map'; -import { RedisService } from '../redis/redis.service.js'; -import { Api, createApiClient } from './api.service.js'; +import { Logger } from '../../infra/logger/logger.js'; +import { StreamNameClockPair } from '../../infra/redis/interfaces/stream-name-clock-pair.js'; import { isSmallerRedisId } from './helper.js'; -import { DocumentStorage } from './storage.js'; +import { YRedisClient } from './y-redis.client.js'; -export const running = true; - -export const run = async (subscriber: Subscriber): Promise => { - while (running) { - await subscriber.run(); - } -}; - -type SubscriptionHandler = (stream: string, message: Uint8Array[]) => void; +export type SubscriptionHandler = (stream: string, message: Uint8Array[]) => void; interface Subscriptions { fs: Set; id: string; nextId?: string | null; } -export const createSubscriber = async ( - store: DocumentStorage, - createRedisInstance: RedisService, -): Promise => { - const client = await createApiClient(store, createRedisInstance); - const subscriber = new Subscriber(client); - // Here we are not using an "await", as it would block further execution - // of our code, as the subscriber.run() is an infinite loop. - run(subscriber); - - return subscriber; -}; - -export class Subscriber { + +@Injectable() +export class SubscriberService implements OnModuleDestroy { + private running = true; public readonly subscribers = new Map(); - public constructor(private readonly client: Api) {} + public constructor( + private readonly yRedisClient: YRedisClient, + private readonly logger: Logger, + ) { + this.logger.setContext(SubscriberService.name); + } + + public async start(): Promise { + this.running = true; + this.logger.info(`Start sync messages process`); + + while (this.running) { + const streams = await this.run(); + await this.waitIfStreamsEmpty(streams); + } + } + + public stop(): void { + this.running = false; + this.logger.info(`Ended sync messages process`); + } + + public status(): boolean { + return this.running; + } public ensureSubId(stream: string, id: string): void { const sub = this.subscribers.get(stream); @@ -69,14 +76,29 @@ export class Subscriber { } } - public async destroy(): Promise { - await this.client.destroy(); + public async onModuleDestroy(): Promise { + this.stop(); + await this.yRedisClient.destroy(); + } + + private async waitIfStreamsEmpty(streams: StreamNameClockPair[], waitInMs = 50): Promise { + if (streams.length === 0) { + await new Promise((resolve) => setTimeout(resolve, waitInMs)); + } + } + + public async run(): Promise { + const streams = this.getSubscriberStreams(); + + if (streams.length > 0) { + await this.publishMessages(streams); + } + + return streams; } - public async run(): Promise { - const messages = await this.client.getMessages( - Array.from(this.subscribers.entries()).map(([stream, s]) => ({ key: stream, id: s.id })), - ); + private async publishMessages(streams: StreamNameClockPair[]): Promise { + const messages = await this.yRedisClient.getMessages(streams); for (const message of messages) { const sub = this.subscribers.get(message.stream); @@ -86,7 +108,14 @@ export class Subscriber { sub.id = sub.nextId; sub.nextId = null; } - sub.fs.forEach((f) => f(message.stream, message.messages)); + sub.fs.forEach((subscriberCallback) => subscriberCallback(message.stream, message.messages)); } } + + private getSubscriberStreams(): StreamNameClockPair[] { + const subscribers = Array.from(this.subscribers.entries()); + const streams = subscribers.map(([stream, s]) => ({ key: stream, id: s.id })); + + return streams; + } } diff --git a/src/infra/y-redis/testing/stream-message-reply.factory.ts b/src/infra/y-redis/testing/stream-message-reply.factory.ts deleted file mode 100644 index de02ab31..00000000 --- a/src/infra/y-redis/testing/stream-message-reply.factory.ts +++ /dev/null @@ -1,14 +0,0 @@ -import { Factory } from 'fishery'; -import { StreamMessageReply } from '../../../infra/redis/interfaces/index.js'; - -export const streamMessageReplyFactory = Factory.define(({ sequence }) => { - return { - id: `redis-id-${sequence}`, - message: { - key: `redis-key-${sequence}-1`, - m: `message-${sequence}-2`, - docName: `doc-name-${sequence}`, - compact: `compact-${sequence}`, - }, - }; -}); diff --git a/src/infra/y-redis/testing/y-redis-doc.factory.ts b/src/infra/y-redis/testing/y-redis-doc.factory.ts new file mode 100644 index 00000000..110a137a --- /dev/null +++ b/src/infra/y-redis/testing/y-redis-doc.factory.ts @@ -0,0 +1,21 @@ +import { createMock } from '@golevelup/ts-jest'; +import { Factory } from 'fishery'; +import { Awareness } from 'y-protocols/awareness'; +import { Doc } from 'yjs'; +import { YRedisDoc } from '../y-redis-doc.js'; + +export const yRedisDocFactory = Factory.define(({ sequence }) => { + return { + ydoc: createMock(), + awareness: createMock({ + destroy: () => { + return; + }, + }), + redisLastId: `last-id-${sequence}`, + storeReferences: null, + docChanged: false, + streamName: `prefix:room:roomid:docid`, + getAwarenessStateSize: (): number => 0, + }; +}); diff --git a/src/infra/y-redis/testing/y-redis-user.factory.ts b/src/infra/y-redis/testing/y-redis-user.factory.ts new file mode 100644 index 00000000..7763583e --- /dev/null +++ b/src/infra/y-redis/testing/y-redis-user.factory.ts @@ -0,0 +1,20 @@ +import { Factory } from 'fishery'; +import { YRedisUser } from '../y-redis-user.js'; + +export const yRedisUserFactory = Factory.define(({ sequence }) => { + const error = null; + + return { + initialRedisSubId: '0', + room: `room-${sequence}`, + hasWriteAccess: false, + userid: `userid-${sequence}`, + error, + subs: new Set(), + id: sequence, + awarenessId: sequence, + awarenessLastClock: 0, + awarenessLastUpdated: new Date(), + isClosed: false, + }; +}); diff --git a/src/infra/y-redis/ws.service.spec.ts b/src/infra/y-redis/ws.service.spec.ts deleted file mode 100644 index a5a2cb2d..00000000 --- a/src/infra/y-redis/ws.service.spec.ts +++ /dev/null @@ -1,1075 +0,0 @@ -import { createMock } from '@golevelup/ts-jest'; -import { encoding } from 'lib0'; -import * as uws from 'uWebSockets.js'; -import { Awareness } from 'y-protocols/awareness.js'; -import * as Y from 'yjs'; -import { RedisService } from '../redis/redis.service.js'; -import * as apiClass from './api.service.js'; -import { Api } from './api.service.js'; -import { computeRedisRoomStreamName } from './helper.js'; -import * as protocol from './protocol.js'; -import { DocumentStorage } from './storage.js'; -import * as subscriberClass from './subscriber.service.js'; -import { Subscriber } from './subscriber.service.js'; -import { - closeCallback, - messageCallback, - openCallback, - registerYWebsocketServer, - upgradeCallback, - User, - YWebsocketServer, -} from './ws.service.js'; - -describe('ws service', () => { - beforeEach(() => { - jest.resetAllMocks(); - }); - - const buildUpdate = (props: { - messageType: number; - length: number; - numberOfUpdates: number; - awarenessId: number; - lastClock: number; - }): Buffer => { - const { messageType, length, numberOfUpdates, awarenessId, lastClock } = props; - const encoder = encoding.createEncoder(); - encoding.writeVarUint(encoder, messageType); // - encoding.writeVarUint(encoder, length); // Length of update - encoding.writeVarUint(encoder, numberOfUpdates); // Number of awareness updates - encoding.writeVarUint(encoder, awarenessId); // Awareness id - encoding.writeVarUint(encoder, lastClock); // Lasclocl - - return Buffer.from(encoding.toUint8Array(encoder)); - }; - - describe('registerYWebsocketServer', () => { - const setup = () => { - const app = createMock(); - const pattern = 'pattern'; - const store = createMock(); - const checkAuth = jest.fn(); - const options = {}; - const createRedisInstance = createMock(); - const client = createMock(); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - const subscriber = createMock(); - jest.spyOn(subscriberClass, 'createSubscriber').mockResolvedValueOnce(subscriber); - - return { app, pattern, store, checkAuth, options, createRedisInstance, subscriber, client }; - }; - - it('returns YWebsocketServer', async () => { - const { app, pattern, store, checkAuth, options, createRedisInstance } = setup(); - - const result = await registerYWebsocketServer(app, pattern, store, checkAuth, options, createRedisInstance); - - expect(result).toEqual(expect.any(YWebsocketServer)); - }); - - describe('yWebsocketServer.destroy', () => { - it('should destroy client and subscriber', async () => { - const { app, pattern, store, checkAuth, options, createRedisInstance, subscriber, client } = setup(); - - const server = await registerYWebsocketServer(app, pattern, store, checkAuth, options, createRedisInstance); - server.destroy(); - - expect(subscriber.destroy).toHaveBeenCalledTimes(1); - expect(client.destroy).toHaveBeenCalledTimes(1); - }); - }); - }); - - describe('upgradeCallback', () => { - describe('when aborted is emitted from response', () => { - it('should return', async () => { - const res = createMock(); - const req = createMock(); - const context = createMock(); - const checkAuth = jest.fn(); - - await upgradeCallback(res, req, context, checkAuth); - - res.aborted(); - - expect(res.upgrade).not.toHaveBeenCalled(); - }); - }); - - describe('when checkAuth rejects', () => { - it('should return 500 Internal Server Error', async () => { - const res = createMock(); - const req = createMock(); - const context = createMock(); - const checkAuth = jest.fn().mockRejectedValue(new Error('error')); - res.writeStatus.mockImplementationOnce(() => res); - - await upgradeCallback(res, req, context, checkAuth); - - expect(res.cork).toHaveBeenCalledTimes(1); - expect(res.cork).toHaveBeenCalledWith(expect.any(Function)); - res.cork.mock.calls[0][0](); - expect(res.writeStatus).toHaveBeenCalledWith('500 Internal Server Error'); - expect(res.end).toHaveBeenCalledWith('Internal Server Error'); - }); - }); - - describe('when checkAuth resolves ', () => { - describe('when connection is not aborted', () => { - it('should upgrade the connection', async () => { - const res = createMock(); - const req = createMock(); - const context = createMock(); - const checkAuth = jest.fn().mockResolvedValue({ hasWriteAccess: true, room: 'room', userid: 'userid' }); - - await upgradeCallback(res, req, context, checkAuth); - - expect(res.cork).toHaveBeenCalledTimes(1); - expect(res.cork).toHaveBeenCalledWith(expect.any(Function)); - res.cork.mock.calls[0][0](); - expect(res.upgrade).toHaveBeenCalledWith( - expect.objectContaining({ - awarenessId: null, - awarenessLastClock: 0, - error: null, - hasWriteAccess: true, - id: 0, - initialRedisSubId: '0', - isClosed: false, - room: 'room', - userid: 'userid', - }), - req.getHeader('sec-websocket-key'), - req.getHeader('sec-websocket-protocol'), - req.getHeader('sec-websocket-extensions'), - context, - ); - }); - }); - - describe('when connection is aborted', () => { - it('should not upgrade the connection', async () => { - const res = createMock(); - const req = createMock(); - const context = createMock(); - const checkAuth = jest.fn().mockImplementationOnce(async () => { - res.onAborted.mock.calls[0][0](); - - return await Promise.resolve({ hasWriteAccess: true, room: 'room', userid: 'userid' }); - }); - - await upgradeCallback(res, req, context, checkAuth); - - expect(res.cork).not.toHaveBeenCalled(); - }); - }); - }); - }); - - describe('openCallback', () => { - const buildParams = () => { - const ws = createMock>(); - const subscriber = createMock(); - const client = createMock({ redisPrefix: 'prefix' }); - const redisMessageSubscriber = jest.fn(); - const openWsCallback = jest.fn(); - const initDocCallback = jest.fn(); - - return { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback }; - }; - - it('should call getUserData', async () => { - const { ws, subscriber, client, redisMessageSubscriber } = buildParams(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(ws.getUserData).toHaveBeenCalledTimes(1); - }); - - describe('when user has error property', () => { - const setup = () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = buildParams(); - - const code = 111; - const reason = 'reason'; - const user = createMock({ error: { code, reason } }); - ws.getUserData.mockReturnValue(user); - - return { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback, code, reason }; - }; - - it('should call ws.end', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback, code, reason } = - setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(ws.end).toHaveBeenCalledWith(code, reason); - }); - - it('should not call openWsCallback', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(openWsCallback).not.toHaveBeenCalled(); - }); - }); - - describe('when users room property is null', () => { - const setup = () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = buildParams(); - - const user = createMock({ room: null, error: null }); - ws.getUserData.mockReturnValue(user); - - return { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback }; - }; - - it('should call ws.end', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(ws.end).toHaveBeenCalledWith(1008); - }); - - it('should not call openWsCallback', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(openWsCallback).not.toHaveBeenCalled(); - }); - }); - - describe('when users userid property is null', () => { - const setup = () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = buildParams(); - - const user = createMock({ userid: null, error: null }); - ws.getUserData.mockReturnValue(user); - - return { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback }; - }; - - it('should call ws.end', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(ws.end).toHaveBeenCalledWith(1008); - }); - - it('should not call openWsCallback', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(openWsCallback).not.toHaveBeenCalled(); - }); - }); - - describe('when user has room and no error and is not closed', () => { - const setup = () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = buildParams(); - - const user = createMock({ error: null, room: 'room', isClosed: false }); - ws.getUserData.mockReturnValue(user); - - const redisStream = computeRedisRoomStreamName(user.room ?? '', 'index', client.redisPrefix); - - return { ws, subscriber, client, user, redisStream, redisMessageSubscriber, openWsCallback, initDocCallback }; - }; - - it('should call openWsCallback', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(openWsCallback).toHaveBeenCalledWith(ws); - }); - - it('should add stream to user subscriptions', async () => { - const { ws, subscriber, client, user, redisStream, redisMessageSubscriber, openWsCallback, initDocCallback } = - setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(user.subs.add).toHaveBeenCalledWith(redisStream); - }); - - it('should subscribe ws to stream', async () => { - const { ws, subscriber, client, redisStream, redisMessageSubscriber, openWsCallback, initDocCallback } = - setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(ws.subscribe).toHaveBeenCalledWith(redisStream); - }); - - it('should subscribe subscriber to stream', async () => { - const { ws, subscriber, client, redisStream, redisMessageSubscriber, openWsCallback, initDocCallback } = - setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(subscriber.subscribe).toHaveBeenCalledWith(redisStream, redisMessageSubscriber); - }); - - it('should get doc from client', async () => { - const { ws, subscriber, client, user, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(client.getDoc).toHaveBeenCalledWith(user.room, 'index'); - }); - - describe('when getDoc rejects', () => { - it('should call ws.end', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - client.getDoc.mockRejectedValue(new Error('error')); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(ws.end).toHaveBeenCalledWith(1011); - }); - }); - - describe('when getDoc resolves with ydoc.store.clients.size 0', () => { - it('should call initDocCallback', async () => { - const { ws, subscriber, client, user, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - const ydoc = createMock({ store: { clients: { size: 0 } } }); - const awareness = createMock(); - client.getDoc.mockResolvedValueOnce({ - ydoc, - awareness, - redisLastId: '0', - storeReferences: [], - docChanged: true, - }); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(initDocCallback).toHaveBeenCalledWith(user.room, 'index', client); - }); - - describe('when ydoc.awareness.states.size > 0', () => { - it('should call ws.cork', async () => { - const { ws, subscriber, client, redisMessageSubscriber } = setup(); - const ydoc = createMock({ store: { clients: { size: 0 } } }); - const awareness = createMock({ states: new Map([['1', '1']]) }); - client.getDoc.mockResolvedValueOnce({ - ydoc, - awareness, - redisLastId: '0', - storeReferences: [], - docChanged: true, - }); - const encodedArray = new Uint8Array([1, 2, 3]); - jest.spyOn(protocol, 'encodeSyncStep1').mockReturnValueOnce(encodedArray); - jest.spyOn(Y, 'encodeStateVector').mockReturnValueOnce(encodedArray); - jest.spyOn(protocol, 'encodeSyncStep2').mockReturnValueOnce(encodedArray); - jest.spyOn(Y, 'encodeStateAsUpdate').mockReturnValueOnce(encodedArray); - jest.spyOn(protocol, 'encodeAwarenessUpdate').mockReturnValueOnce(encodedArray); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(ws.cork).toHaveBeenCalledTimes(1); - expect(ws.cork).toHaveBeenCalledWith(expect.any(Function)); - ws.cork.mock.calls[0][0](); - expect(ws.send).toHaveBeenNthCalledWith(1, encodedArray, true, false); - expect(ws.send).toHaveBeenNthCalledWith(2, encodedArray, true, true); - expect(ws.send).toHaveBeenNthCalledWith(3, encodedArray, true, true); - }); - }); - - describe('when ydoc.awareness.states.size = 0', () => { - it('should call ws.cork', async () => { - const { ws, subscriber, client, redisMessageSubscriber } = setup(); - const ydoc = createMock({ store: { clients: { size: 0 } } }); - const awareness = createMock({ states: new Map([]) }); - client.getDoc.mockResolvedValueOnce({ - ydoc, - awareness, - redisLastId: '0', - storeReferences: [], - docChanged: true, - }); - const encodedArray = new Uint8Array([1, 2, 3]); - jest.spyOn(protocol, 'encodeSyncStep1').mockReturnValueOnce(encodedArray); - jest.spyOn(Y, 'encodeStateVector').mockReturnValueOnce(encodedArray); - jest.spyOn(protocol, 'encodeSyncStep2').mockReturnValueOnce(encodedArray); - jest.spyOn(Y, 'encodeStateAsUpdate').mockReturnValueOnce(encodedArray); - jest.spyOn(protocol, 'encodeAwarenessUpdate').mockReturnValueOnce(encodedArray); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(ws.cork).toHaveBeenCalledTimes(1); - expect(ws.cork).toHaveBeenCalledWith(expect.any(Function)); - ws.cork.mock.calls[0][0](); - expect(ws.send).toHaveBeenCalledTimes(2); - expect(ws.send).toHaveBeenNthCalledWith(1, encodedArray, true, false); - expect(ws.send).toHaveBeenNthCalledWith(2, encodedArray, true, true); - }); - }); - - describe('when lastId is smaller than initial redis id', () => { - it('should call subscriber.ensureSubId', async () => { - const { ws, subscriber, client, user, redisMessageSubscriber } = setup(); - const ydoc = createMock({ store: { clients: { size: 0 } } }); - const awareness = createMock(); - client.getDoc.mockResolvedValueOnce({ - ydoc, - awareness, - redisLastId: '0-1', - storeReferences: [], - docChanged: true, - }); - subscriber.subscribe.mockReturnValueOnce({ redisId: '1-2' }); - const redisStream = computeRedisRoomStreamName(user.room ?? '', 'index', client.redisPrefix); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(subscriber.ensureSubId).toHaveBeenCalledWith(redisStream, '0-1'); - }); - }); - - describe('when lastId is bigger than initial redis id', () => { - it('should call subscriber.ensureSubId', async () => { - const { ws, subscriber, client, redisMessageSubscriber } = setup(); - const ydoc = createMock({ store: { clients: { size: 0 } } }); - const awareness = createMock(); - client.getDoc.mockResolvedValueOnce({ - ydoc, - awareness, - redisLastId: '2-1', - storeReferences: [], - docChanged: true, - }); - subscriber.subscribe.mockReturnValueOnce({ redisId: '1-2' }); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(subscriber.ensureSubId).not.toHaveBeenCalled(); - }); - }); - }); - - describe('when getDoc resolves with ydoc.store.clients.size > 0', () => { - it('should call ws.end', async () => { - const { ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback } = setup(); - const ydoc = createMock({ store: { clients: { size: 1 } } }); - const awareness = createMock(); - client.getDoc.mockResolvedValue({ - ydoc, - awareness, - redisLastId: '0', - storeReferences: [], - docChanged: true, - }); - - await openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback); - - expect(initDocCallback).not.toHaveBeenCalled(); - }); - }); - }); - - describe('when user is closed', () => { - const setup = () => { - const { ws, subscriber, client, redisMessageSubscriber } = buildParams(); - - const user = createMock({ error: null, room: 'room', isClosed: true }); - ws.getUserData.mockReturnValue(user); - - return { ws, subscriber, client, redisMessageSubscriber }; - }; - - it('should not call ws.cork', async () => { - const { ws, subscriber, client, redisMessageSubscriber } = setup(); - - await openCallback(ws, subscriber, client, redisMessageSubscriber); - - expect(ws.cork).not.toHaveBeenCalled(); - }); - }); - }); - - describe('messageCallback', () => { - const buildParams = () => { - const ws = createMock>(); - const client = createMock({ redisPrefix: 'prefix' }); - - return { ws, client }; - }; - - describe('when user has write access', () => { - describe('when user has room', () => { - describe('when error is thrown', () => { - const setup = () => { - const { ws, client } = buildParams(); - - ws.getUserData.mockImplementationOnce(() => { - throw new Error('error'); - }); - const messageBuffer = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer }; - }; - - it('should not pass the error and call ws.end', () => { - const { ws, client, messageBuffer } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(ws.end).toHaveBeenCalledWith(1011); - }); - }); - - describe('when message is awareness update and users awarenessid is null', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(75); - expect(user.awarenessLastClock).toBe(76); - }); - - it('should call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - - describe('when message is awareness update and users awarenessid is messages awarenessid', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: 75, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(75); - expect(user.awarenessLastClock).toBe(76); - }); - - it('should call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - - describe('when message is sync update', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageSync, - length: protocol.messageSyncUpdate, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - - describe('when message is sync step 2 update', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageSync, - length: protocol.messageSyncStep2, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - - describe('when message is sync step 1 update', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageSync, - length: protocol.messageSyncStep1, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should not call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).not.toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - - describe('when message is of unknown type', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: 999, - length: 999, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should not call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).not.toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - }); - - describe('when user has no room', () => { - describe('when message is awareness update', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: true, - room: null, - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should not call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).not.toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - }); - }); - - describe('when user has no write access', () => { - describe('when user has room', () => { - describe('when message is awareness update', () => { - const setup = () => { - const { ws, client } = buildParams(); - const user = createMock({ - hasWriteAccess: false, - room: 'room', - awarenessId: null, - awarenessLastClock: 99, - }); - ws.getUserData.mockReturnValueOnce(user); - const messageBuffer = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - - return { ws, client, messageBuffer, user }; - }; - - it('should not update users awarenessId and awarenessLastClock', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(user.awarenessId).toBe(null); - expect(user.awarenessLastClock).toBe(99); - }); - - it('should not call addMessage', () => { - const { ws, client, messageBuffer, user } = setup(); - - messageCallback(ws, messageBuffer, client); - - expect(client.addMessage).not.toHaveBeenCalledWith(user.room, 'index', messageBuffer); - }); - }); - }); - }); - }); - - describe('closeCallback', () => { - const buildParams = () => { - const ws = createMock>(); - const client = createMock({ redisPrefix: 'prefix' }); - const app = createMock(); - const subscriber = createMock(); - - return { ws, client, app, subscriber }; - }; - - describe('when user has room', () => { - describe('when user has awarenessId', () => { - describe('when error is thrown', () => { - const setup = () => { - const { ws, client, app, subscriber } = buildParams(); - - ws.getUserData.mockImplementationOnce(() => { - throw new Error('error'); - }); - const code = 0; - const message = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - const redisMessageSubscriber = jest.fn(); - - return { ws, client, app, code, subscriber, message, redisMessageSubscriber }; - }; - - it('should not pass error', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - }); - }); - - describe('when app has 0 subscribers', () => { - const setup = () => { - const { ws, client, app, subscriber } = buildParams(); - app.numSubscribers.mockReturnValue(0); - - const user = createMock({ - room: 'room', - awarenessId: 22, - awarenessLastClock: 1, - subs: new Set(['topic1', 'topic2']), - }); - ws.getUserData.mockReturnValueOnce(user); - const code = 0; - const message = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - const redisMessageSubscriber = jest.fn(); - const closeWsCallback = jest.fn(); - - return { ws, client, app, code, subscriber, message, redisMessageSubscriber, user, closeWsCallback }; - }; - - it('should call addMessage', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber, user } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(client.addMessage).toHaveBeenCalledWith( - user.room, - 'index', - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - Buffer.from(protocol.encodeAwarenessUserDisconnected(user.awarenessId!, user.awarenessLastClock)), - ); - }); - - it('should set users isClosed to true', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber, user } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(user.isClosed).toBe(true); - }); - - it('should call closeWsCallback', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber, closeWsCallback } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber, closeWsCallback); - - expect(closeWsCallback).toHaveBeenCalledWith(ws, code, message); - }); - - it('should call subscriber.unsubscribe for every topic of user', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(subscriber.unsubscribe).toHaveBeenNthCalledWith(1, 'topic1', redisMessageSubscriber); - expect(subscriber.unsubscribe).toHaveBeenNthCalledWith(2, 'topic2', redisMessageSubscriber); - }); - }); - - describe('when app has 1 subscriber', () => { - const setup = () => { - const { ws, client, app, subscriber } = buildParams(); - app.numSubscribers.mockReturnValue(1); - const user = createMock({ - room: 'room', - awarenessId: 22, - awarenessLastClock: 1, - subs: new Set(['topic1', 'topic2']), - }); - ws.getUserData.mockReturnValueOnce(user); - const code = 0; - const message = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - const redisMessageSubscriber = jest.fn(); - - return { ws, client, app, code, subscriber, message, redisMessageSubscriber }; - }; - - it('should not call addMessage', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(subscriber.unsubscribe).not.toHaveBeenCalled(); - }); - }); - }); - - describe('when user has no awarenessId', () => { - describe('when app has 0 subscribers', () => { - const setup = () => { - const { ws, client, app, subscriber } = buildParams(); - app.numSubscribers.mockReturnValueOnce(0); - const user = createMock({ - room: 'room', - awarenessId: null, - awarenessLastClock: 1, - subs: new Set(['topic1', 'topic2']), - }); - ws.getUserData.mockReturnValueOnce(user); - const code = 0; - const message = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - const redisMessageSubscriber = jest.fn(); - - return { ws, client, app, code, subscriber, message, redisMessageSubscriber }; - }; - - it('should not call addMessage', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(client.addMessage).not.toHaveBeenCalled(); - }); - }); - }); - }); - - describe('when user has no room', () => { - const setup = () => { - const { ws, client, app, subscriber } = buildParams(); - app.numSubscribers.mockReturnValue(0); - - const user = createMock({ - room: null, - awarenessId: 22, - awarenessLastClock: 1, - subs: new Set(['topic1', 'topic2']), - }); - ws.getUserData.mockReturnValueOnce(user); - const code = 0; - const message = buildUpdate({ - messageType: protocol.messageAwareness, - length: 0, - numberOfUpdates: 1, - awarenessId: 75, - lastClock: 76, - }); - const redisMessageSubscriber = jest.fn(); - const closeWsCallback = jest.fn(); - - return { ws, client, app, code, subscriber, message, redisMessageSubscriber, user, closeWsCallback }; - }; - - it('should not call addMessage', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber); - - expect(client.addMessage).not.toHaveBeenCalled(); - }); - - it('should not call closeWsCallback', () => { - const { app, ws, client, subscriber, code, message, redisMessageSubscriber, closeWsCallback } = setup(); - - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber, closeWsCallback); - - expect(closeWsCallback).not.toHaveBeenCalled(); - }); - }); - }); -}); diff --git a/src/infra/y-redis/ws.service.ts b/src/infra/y-redis/ws.service.ts deleted file mode 100644 index 5b1f6467..00000000 --- a/src/infra/y-redis/ws.service.ts +++ /dev/null @@ -1,319 +0,0 @@ -/* This file contains the implementation of the functions, - which was copied from the y-redis repository. - Adopting this code allows us to integrate proven and - optimized logic into the current project. - The original code from the `y-redis` repository is licensed under the AGPL-3.0 license. - https://github.com/yjs/y-redis -*/ -/* eslint-disable max-classes-per-file */ -import * as array from 'lib0/array'; -import * as decoding from 'lib0/decoding'; -import * as encoding from 'lib0/encoding'; -import * as promise from 'lib0/promise'; -import * as uws from 'uWebSockets.js'; -import * as Y from 'yjs'; -import { ResponsePayload } from '../authorization/interfaces/response.payload.js'; -import { RedisService } from '../redis/redis.service.js'; -import { Api, createApiClient } from './api.service.js'; -import { computeRedisRoomStreamName, isSmallerRedisId } from './helper.js'; -import * as protocol from './protocol.js'; -import { DocumentStorage } from './storage.js'; -import { createSubscriber, Subscriber } from './subscriber.service.js'; - -/** - * how to sync - * receive sync-step 1 - * // @todo y-websocket should only accept updates after receiving sync-step 2 - * redisId = ws.sub(conn) - * {doc,redisDocLastId} = api.getdoc() - * compute sync-step 2 - * if (redisId > redisDocLastId) { - * subscriber.ensureId(redisDocLastId) - * } - */ - -export class YWebsocketServer { - public constructor( - public readonly app: uws.TemplatedApp, - public readonly client: Api, - public readonly subscriber: Subscriber, - ) {} - - public async destroy(): Promise { - this.subscriber.destroy(); - await this.client.destroy(); - } -} - -let _idCnt = 0; - -export class User { - public subs: Set; - public id: number; - public awarenessId: number | null; - public awarenessLastClock: number; - public isClosed: boolean; - public initialRedisSubId: string; - - public constructor( - public readonly room: string | null, - public readonly hasWriteAccess: boolean, - /** - * Identifies the User globally. - * Note that several clients can have the same userid (e.g. if a user opened several browser - * windows) - */ - public readonly userid: string | null, - public readonly error: Partial | null = null, - ) { - this.initialRedisSubId = '0'; - this.subs = new Set(); - /** - * This is just an identifier to keep track of the user for logging purposes. - */ - this.id = _idCnt++; - this.awarenessId = null; - this.awarenessLastClock = 0; - this.isClosed = false; - } -} - -export const upgradeCallback = async ( - res: uws.HttpResponse, - req: uws.HttpRequest, - context: uws.us_socket_context_t, - checkAuth: (req: uws.HttpRequest) => Promise, -): Promise => { - try { - const headerWsKey = req.getHeader('sec-websocket-key'); - const headerWsProtocol = req.getHeader('sec-websocket-protocol'); - const headerWsExtensions = req.getHeader('sec-websocket-extensions'); - let aborted = false; - res.onAborted(() => { - aborted = true; - }); - - const { hasWriteAccess, room, userid, error } = await checkAuth(req); - if (aborted) return; - res.cork(() => { - res.upgrade( - new User(room, hasWriteAccess, userid, error), - headerWsKey, - headerWsProtocol, - headerWsExtensions, - context, - ); - }); - } catch (error) { - res.cork(() => { - res.writeStatus('500 Internal Server Error').end('Internal Server Error'); - }); - console.error(error); - } -}; - -export const openCallback = async ( - ws: uws.WebSocket, - subscriber: Subscriber, - client: Api, - redisMessageSubscriber: (stream: string, messages: Uint8Array[]) => void, - openWsCallback?: (ws: uws.WebSocket) => void, - initDocCallback?: (room: string, docname: string, client: Api) => void, -): Promise => { - try { - const user = ws.getUserData(); - if (user.error != null) { - const { code, reason } = user.error; - ws.end(code, reason); - - return; - } - - if (user.room === null || user.userid === null) { - ws.end(1008); - - return; - } - - if (openWsCallback) { - openWsCallback(ws); - } - const stream = computeRedisRoomStreamName(user.room, 'index', client.redisPrefix); - user.subs.add(stream); - ws.subscribe(stream); - user.initialRedisSubId = subscriber.subscribe(stream, redisMessageSubscriber).redisId; - const indexDoc = await client.getDoc(user.room, 'index'); - if (indexDoc.ydoc.store.clients.size === 0) { - if (initDocCallback) { - initDocCallback(user.room, 'index', client); - } - } - if (user.isClosed) return; - ws.cork(() => { - ws.send(protocol.encodeSyncStep1(Y.encodeStateVector(indexDoc.ydoc)), true, false); - ws.send(protocol.encodeSyncStep2(Y.encodeStateAsUpdate(indexDoc.ydoc)), true, true); - if (indexDoc.awareness.states.size > 0) { - ws.send( - protocol.encodeAwarenessUpdate(indexDoc.awareness, array.from(indexDoc.awareness.states.keys())), - true, - true, - ); - } - }); - - // awareness is destroyed here to avoid memory leaks, see: https://github.com/yjs/y-redis/issues/24 - indexDoc.awareness.destroy(); - - if (isSmallerRedisId(indexDoc.redisLastId, user.initialRedisSubId)) { - // our subscription is newer than the content that we received from the api - // need to renew subscription id and make sure that we catch the latest content. - subscriber.ensureSubId(stream, indexDoc.redisLastId); - } - } catch (error) { - console.error(error); - ws.end(1011); - } -}; - -const isAwarenessUpdate = (message: Buffer): boolean => message[0] === protocol.messageAwareness; -const isSyncUpdateOrSyncStep2OrAwarenessUpdate = (message: Buffer): boolean => - (message[0] === protocol.messageSync && - (message[1] === protocol.messageSyncUpdate || message[1] === protocol.messageSyncStep2)) || - isAwarenessUpdate(message); - -export const messageCallback = (ws: uws.WebSocket, messageBuffer: ArrayBuffer, client: Api): void => { - try { - const user = ws.getUserData(); - // don't read any messages from users without write access - if (!user.hasWriteAccess || !user.room) return; - // It is important to copy the data here - const message = Buffer.from(messageBuffer.slice(0, messageBuffer.byteLength)); - - if ( - // filter out messages that we simply want to propagate to all clients - // sync update or sync step 2 - // awareness update - isSyncUpdateOrSyncStep2OrAwarenessUpdate(message) - ) { - if (isAwarenessUpdate(message)) { - const decoder = decoding.createDecoder(message); - decoding.readVarUint(decoder); // read message type - decoding.readVarUint(decoder); // read length of awareness update - const alen = decoding.readVarUint(decoder); // number of awareness updates - const awId = decoding.readVarUint(decoder); - if (alen === 1 && (user.awarenessId === null || user.awarenessId === awId)) { - // only update awareness if len=1 - user.awarenessId = awId; - user.awarenessLastClock = decoding.readVarUint(decoder); - } - } - client.addMessage(user.room, 'index', message); - } else if (message[0] === protocol.messageSync && message[1] === protocol.messageSyncStep1) { - // sync step 1 - // can be safely ignored because we send the full initial state at the beginning - } else { - console.error('Unexpected message type', message); - } - } catch (error) { - console.error(error); - ws.end(1011); - } -}; - -export const closeCallback = ( - app: uws.TemplatedApp, - ws: uws.WebSocket, - client: Api, - subscriber: Subscriber, - code: number, - message: ArrayBuffer, - redisMessageSubscriber: (stream: string, messages: Uint8Array[]) => void, - closeWsCallback?: (ws: uws.WebSocket, code: number, message: ArrayBuffer) => void, -): void => { - try { - const user = ws.getUserData(); - if (!user.room) return; - - user.awarenessId && - client.addMessage( - user.room, - 'index', - Buffer.from(protocol.encodeAwarenessUserDisconnected(user.awarenessId, user.awarenessLastClock)), - ); - user.isClosed = true; - - if (closeWsCallback) { - closeWsCallback(ws, code, message); - } - user.subs.forEach((topic) => { - if (app.numSubscribers(topic) === 0) { - subscriber.unsubscribe(topic, redisMessageSubscriber); - } - }); - } catch (error) { - console.error(error); - } -}; - -/** - * @param {uws.TemplatedApp} app - * @param {uws.RecognizedString} pattern - * @param {import('./storage.js').AbstractStorage} store - * @param {function(uws.HttpRequest): Promise} checkAuth - * @param {Object} conf - * @param {string} [conf.redisPrefix] - * @param {(room:string,docname:string,client:api.Api)=>void} [conf.initDocCallback] - this is called when a doc is - * accessed, but it doesn't exist. You could populate the doc here. However, this function could be - * called several times, until some content exists. So you need to handle concurrent calls. - * @param {(ws:uws.WebSocket)=>void} [conf.openWsCallback] - called when a websocket connection is opened - * @param {(ws:uws.WebSocket,code:number,message:ArrayBuffer)=>void} [conf.closeWsCallback] - called when a websocket connection is closed - * @param {() => Promise} createRedisInstance - */ -export const registerYWebsocketServer = async ( - app: uws.TemplatedApp, - pattern: string, - store: DocumentStorage, - checkAuth: (req: uws.HttpRequest) => Promise, - options: { - initDocCallback?: (room: string, docname: string, client: Api) => void; - openWsCallback?: (ws: uws.WebSocket) => void; - closeWsCallback?: (ws: uws.WebSocket, code: number, message: ArrayBuffer) => void; - }, - createRedisInstance: RedisService, -): Promise => { - const { initDocCallback, openWsCallback, closeWsCallback } = options; - const [client, subscriber] = await promise.all([ - createApiClient(store, createRedisInstance), - createSubscriber(store, createRedisInstance), - ]); - - const redisMessageSubscriber = (stream: string, messages: Uint8Array[]): void => { - if (app.numSubscribers(stream) === 0) { - subscriber.unsubscribe(stream, redisMessageSubscriber); - } - const message = - messages.length === 1 - ? messages[0] - : encoding.encode((encoder) => - messages.forEach((message) => { - encoding.writeUint8Array(encoder, message); - }), - ); - app.publish(stream, message, true, false); - }; - - app.ws(pattern, { - compression: uws.SHARED_COMPRESSOR, - maxPayloadLength: 100 * 1024 * 1024, - idleTimeout: 60, - sendPingsAutomatically: true, - upgrade: (res, req, context) => upgradeCallback(res, req, context, checkAuth), - open: (ws: uws.WebSocket) => - openCallback(ws, subscriber, client, redisMessageSubscriber, openWsCallback, initDocCallback), - message: (ws, messageBuffer) => messageCallback(ws, messageBuffer, client), - close: (ws, code, message) => - closeCallback(app, ws, client, subscriber, code, message, redisMessageSubscriber, closeWsCallback), - }); - - return new YWebsocketServer(app, client, subscriber); -}; diff --git a/src/infra/y-redis/y-redis-client.module.ts b/src/infra/y-redis/y-redis-client.module.ts new file mode 100644 index 00000000..db999b28 --- /dev/null +++ b/src/infra/y-redis/y-redis-client.module.ts @@ -0,0 +1,31 @@ +import { DynamicModule, Module } from '@nestjs/common'; +import { Logger } from '../logger/logger.js'; +import { LoggerModule } from '../logger/logger.module.js'; +import { RedisAdapter } from '../redis/interfaces/redis-adapter.js'; +import { RedisModule } from '../redis/redis.module.js'; +import { StorageModule } from '../storage/storage.module.js'; +import { StorageService } from '../storage/storage.service.js'; +import { YRedisClient } from './y-redis.client.js'; +import { REDIS_FOR_API } from './y-redis.const.js'; + +@Module({}) +export class YRedisClientModule { + public static register(): DynamicModule { + return { + module: YRedisClientModule, + imports: [RedisModule.registerFor(REDIS_FOR_API), StorageModule, LoggerModule], + providers: [ + { + provide: YRedisClient, + useFactory: (redisAdapter: RedisAdapter, storageService: StorageService, logger: Logger): YRedisClient => { + const yRedisClient = new YRedisClient(storageService, redisAdapter, logger); + + return yRedisClient; + }, + inject: [REDIS_FOR_API, StorageService, Logger], + }, + ], + exports: [YRedisClient], + }; + } +} diff --git a/src/infra/y-redis/y-redis-doc.factory.ts b/src/infra/y-redis/y-redis-doc.factory.ts new file mode 100644 index 00000000..67d60b10 --- /dev/null +++ b/src/infra/y-redis/y-redis-doc.factory.ts @@ -0,0 +1,10 @@ +import { YRedisDocProps } from './interfaces/y-redis-doc-props.js'; +import { YRedisDoc } from './y-redis-doc.js'; + +export class YRedisDocFactory { + public static build(props: YRedisDocProps): YRedisDoc { + const yRedisDoc = new YRedisDoc(props); + + return yRedisDoc; + } +} diff --git a/src/infra/y-redis/y-redis-doc.ts b/src/infra/y-redis/y-redis-doc.ts new file mode 100644 index 00000000..7b736fd1 --- /dev/null +++ b/src/infra/y-redis/y-redis-doc.ts @@ -0,0 +1,25 @@ +import { Awareness } from 'y-protocols/awareness'; +import { Doc } from 'yjs'; +import { YRedisDocProps } from './interfaces/y-redis-doc-props.js'; + +export class YRedisDoc { + public readonly ydoc: Doc; + public readonly awareness: Awareness; + public readonly redisLastId: string; + public readonly storeReferences: string[] | null; + public readonly docChanged: boolean; + public readonly streamName: string; + + public constructor(props: YRedisDocProps) { + this.ydoc = props.ydoc; + this.awareness = props.awareness; + this.redisLastId = props.redisLastId; + this.storeReferences = props.storeReferences; + this.docChanged = props.docChanged; + this.streamName = props.streamName; + } + + public getAwarenessStateSize(): number { + return this.awareness.getStates().size; + } +} diff --git a/src/infra/y-redis/y-redis-service.module.ts b/src/infra/y-redis/y-redis-service.module.ts new file mode 100644 index 00000000..ae326809 --- /dev/null +++ b/src/infra/y-redis/y-redis-service.module.ts @@ -0,0 +1,48 @@ +import { DynamicModule, Module } from '@nestjs/common'; +import { Logger } from '../logger/logger.js'; +import { LoggerModule } from '../logger/logger.module.js'; +import { RedisAdapter } from '../redis/interfaces/redis-adapter.js'; +import { RedisModule } from '../redis/redis.module.js'; +import { StorageModule } from '../storage/storage.module.js'; +import { StorageService } from '../storage/storage.service.js'; +import { SubscriberService } from './subscriber.service.js'; +import { YRedisClient } from './y-redis.client.js'; +import { API_FOR_SUBSCRIBER, REDIS_FOR_API, REDIS_FOR_SUBSCRIBER } from './y-redis.const.js'; +import { YRedisService } from './y-redis.service.js'; + +@Module({}) +export class YRedisServiceModule { + public static register(): DynamicModule { + return { + module: YRedisServiceModule, + imports: [ + RedisModule.registerFor(REDIS_FOR_SUBSCRIBER), + RedisModule.registerFor(REDIS_FOR_API), + StorageModule, + LoggerModule, + ], + providers: [ + YRedisService, + { + provide: API_FOR_SUBSCRIBER, + useFactory: (redisAdapter: RedisAdapter, storageService: StorageService, logger: Logger): YRedisClient => { + const yRedisClient = new YRedisClient(storageService, redisAdapter, logger); + + return yRedisClient; + }, + inject: [REDIS_FOR_SUBSCRIBER, StorageService, Logger], + }, + { + provide: SubscriberService, + useFactory: (yRedisClient: YRedisClient, logger: Logger): SubscriberService => { + const subscriber = new SubscriberService(yRedisClient, logger); + + return subscriber; + }, + inject: [API_FOR_SUBSCRIBER, Logger], + }, + ], + exports: [YRedisService], + }; + } +} diff --git a/src/infra/y-redis/y-redis-user.factory.ts b/src/infra/y-redis/y-redis-user.factory.ts new file mode 100644 index 00000000..b60f0385 --- /dev/null +++ b/src/infra/y-redis/y-redis-user.factory.ts @@ -0,0 +1,16 @@ +import { YRedisUser } from './y-redis-user.js'; + +interface YRedisUserProps { + room: string | null; + hasWriteAccess: boolean; + userid: string | null; + error: Partial | null; +} + +export class YRedisUserFactory { + public static build(props: YRedisUserProps): YRedisUser { + const user = new YRedisUser(props.room, props.hasWriteAccess, props.userid, props.error); + + return user; + } +} diff --git a/src/infra/y-redis/y-redis-user.ts b/src/infra/y-redis/y-redis-user.ts new file mode 100644 index 00000000..21c4cc70 --- /dev/null +++ b/src/infra/y-redis/y-redis-user.ts @@ -0,0 +1,25 @@ +export class YRedisUser { + public subs: Set; + public awarenessId: number | null; + public awarenessLastClock: number; + public isClosed: boolean; + public initialRedisSubId: string; + + public constructor( + public readonly room: string | null, + public readonly hasWriteAccess: boolean, + /** + * Identifies the User globally. + * Note that several clients can have the same userid (e.g. if a user opened several browser + * windows) + */ + public readonly userid: string | null, + public readonly error: Partial | null = null, + ) { + this.initialRedisSubId = '0'; + this.subs = new Set(); + this.awarenessId = null; + this.awarenessLastClock = 0; + this.isClosed = false; + } +} diff --git a/src/infra/y-redis/y-redis.client.spec.ts b/src/infra/y-redis/y-redis.client.spec.ts new file mode 100644 index 00000000..48904a36 --- /dev/null +++ b/src/infra/y-redis/y-redis.client.spec.ts @@ -0,0 +1,346 @@ +import { createMock, DeepMocked } from '@golevelup/ts-jest'; +import { Test, TestingModule } from '@nestjs/testing'; +import * as Awareness from 'y-protocols/awareness'; +import * as Y from 'yjs'; +import { Doc, encodeStateAsUpdateV2 } from 'yjs'; +import { Logger } from '../logger/logger.js'; +import { RedisAdapter } from '../redis/interfaces/index.js'; +import { IoRedisAdapter } from '../redis/ioredis.adapter.js'; +import { streamMessagesReplyFactory } from '../redis/testing/stream-messages-reply.factory.js'; +import { StorageService } from '../storage/storage.service.js'; +import * as helper from './helper.js'; +import * as protocol from './protocol.js'; +import { DocumentStorage } from './storage.js'; +import { yRedisMessageFactory } from './testing/y-redis-message.factory.js'; +import { YRedisClient } from './y-redis.client.js'; + +describe(YRedisClient.name, () => { + let module: TestingModule; + let redis: DeepMocked; + let store: DeepMocked; + let yRedisClient: YRedisClient; + + beforeEach(async () => { + module = await Test.createTestingModule({ + providers: [ + { + provide: YRedisClient, + useFactory: (redisAdapter: RedisAdapter, storageService: StorageService, logger: Logger): YRedisClient => { + const yRedisClient = new YRedisClient(storageService, redisAdapter, logger); + + return yRedisClient; + }, + inject: [IoRedisAdapter, StorageService, Logger], + }, + { + provide: StorageService, + useValue: createMock(), + }, + { + provide: IoRedisAdapter, + useValue: createMock({ + redisPrefix: 'prefix', + }), + }, + { + provide: Logger, + useValue: createMock(), + }, + ], + }).compile(); + + redis = module.get(IoRedisAdapter); + store = module.get(StorageService); + yRedisClient = module.get(YRedisClient); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe('getMessages', () => { + describe('when streams is empty', () => { + it('should return empty array', async () => { + const result = await yRedisClient.getMessages([]); + + expect(result).toEqual([]); + }); + }); + + describe('when streams is not empty', () => { + const setup = () => { + const m = streamMessagesReplyFactory.build(); + redis.readStreams.mockResolvedValueOnce(m); + + const props = [ + { + key: 'stream1', + id: '1', + }, + ]; + const spyMergeMessages = jest.spyOn(protocol, 'mergeMessages').mockReturnValueOnce([]); + + const { name, messages } = m[0]; + // @ts-ignore + const lastId = messages[messages.length - 1].id; + + const expectedResult = [ + { + lastId, + messages: [], + stream: name, + }, + ]; + + const expectedMessages = messages?.map((message) => message.message.m).filter((m) => m != null); + + return { spyMergeMessages, expectedResult, expectedMessages, props }; + }; + + it('should call redis.readStreams with correct params', async () => { + const { props } = setup(); + + await yRedisClient.getMessages(props); + + expect(redis.readStreams).toHaveBeenCalledTimes(1); + expect(redis.readStreams).toHaveBeenCalledWith(props); + }); + + it('should call protocol.mergeMessages with correct values', async () => { + const { spyMergeMessages, expectedMessages, props } = setup(); + + await yRedisClient.getMessages(props); + + expect(spyMergeMessages).toHaveBeenCalledTimes(1); + expect(spyMergeMessages).toHaveBeenCalledWith(expectedMessages); + }); + + it('should return expected messages', async () => { + const { expectedResult, props } = setup(); + + const result = await yRedisClient.getMessages(props); + + expect(result).toEqual(expectedResult); + }); + }); + }); + + describe('addMessage', () => { + describe('when m is a sync step 2 message', () => { + const setup = () => { + const room = 'room'; + const docid = 'docid'; + const message = Buffer.from([protocol.messageSync, protocol.messageSyncStep2]); + + const props = { room, docid, message }; + + return { props }; + }; + it('should return a promise', async () => { + const { props } = setup(); + + const result = await yRedisClient.addMessage(props.room, props.docid, props.message); + + expect(result).toBeUndefined(); + expect(redis.addMessage).not.toHaveBeenCalled(); + }); + }); + + describe('when m is not a sync step 2 message', () => { + const setup = () => { + const room = 'room'; + const docid = 'docid'; + const message = Buffer.from([protocol.messageSync, protocol.messageSyncUpdate]); + + const props = { room, docid, message }; + + return { props }; + }; + + it('should call redis.addMessage with correct params', async () => { + const { props } = setup(); + + await yRedisClient.addMessage(props.room, props.docid, props.message); + + expect(redis.addMessage).toHaveBeenCalledTimes(1); + expect(redis.addMessage).toHaveBeenCalledWith('prefix:room:room:docid', props.message); + }); + }); + + describe('when m is a correct message', () => { + const setup = () => { + const room = 'room'; + const docid = 'docid'; + const message = Buffer.from([protocol.messageSync, protocol.messageSyncStep2, 0x54, 0x45, 0x53, 0x54]); + + const props = { room, docid, message }; + + return { props }; + }; + it('should set correctly protocol type', async () => { + const { props } = setup(); + + await yRedisClient.addMessage(props.room, props.docid, props.message); + + expect(props.message[1]).toEqual(protocol.messageSyncUpdate); + }); + }); + }); + + describe('getStateVector', () => { + const setup = () => { + const room = 'room'; + const docid = 'docid'; + + const props = { room, docid }; + + return { props }; + }; + + it('should call store.retrieveStateVector with correct params', async () => { + const { props } = setup(); + const { room, docid } = props; + + await yRedisClient.getStateVector(room, docid); + + expect(store.retrieveStateVector).toHaveBeenCalledTimes(1); + expect(store.retrieveStateVector).toHaveBeenCalledWith(room, docid); + }); + }); + + describe('getDoc', () => { + const setup = () => { + const spyComputeRedisRoomStreamName = jest.spyOn(helper, 'computeRedisRoomStreamName'); + const spyExtractMessagesFromStreamReply = jest.spyOn(helper, 'extractMessagesFromStreamReply'); + + const ydoc = new Doc(); + const doc = encodeStateAsUpdateV2(ydoc); + const streamReply = streamMessagesReplyFactory.build(); + redis.readMessagesFromStream.mockResolvedValueOnce(streamReply); + store.retrieveDoc.mockResolvedValueOnce({ doc, references: [] }); + + const room = 'roomid-1'; + const docid = 'docid'; + + const props = { room, docid }; + + return { + props, + spyComputeRedisRoomStreamName, + spyExtractMessagesFromStreamReply, + streamReply, + }; + }; + + it('should call computeRedisRoomStreamName with correct params', async () => { + const { props, spyComputeRedisRoomStreamName } = setup(); + const { room, docid } = props; + + const result = await yRedisClient.getDoc(room, docid); + result.awareness.destroy(); + + expect(spyComputeRedisRoomStreamName).toHaveBeenCalledWith(room, docid, 'prefix'); + }); + + it('should call redis.readMessagesFromStream with correct params', async () => { + const { props } = setup(); + const { room, docid } = props; + + const result = await yRedisClient.getDoc(room, docid); + result.awareness.destroy(); + + expect(redis.readMessagesFromStream).toHaveBeenCalledTimes(1); + expect(redis.readMessagesFromStream).toHaveBeenCalledWith('prefix:room:roomid-1:docid'); + }); + + it('should call extractMessagesFromStreamReply with correct params', async () => { + const { props, spyExtractMessagesFromStreamReply, streamReply } = setup(); + const { room, docid } = props; + + const result = await yRedisClient.getDoc(room, docid); + result.awareness.destroy(); + + expect(spyExtractMessagesFromStreamReply).toHaveBeenCalledWith(streamReply, 'prefix'); + }); + + it('should return expected result', async () => { + const { props } = setup(); + const { room, docid } = props; + + const result = await yRedisClient.getDoc(room, docid); + result.awareness.destroy(); + + expect(result).toBeDefined(); + expect(result).toEqual(expect.objectContaining({ ydoc: expect.any(Doc) })); + }); + + it('should return awarenessStateSize', async () => { + const { props } = setup(); + const { room, docid } = props; + + const result = await yRedisClient.getDoc(room, docid); + result.awareness.states.set(0, new Map()); + + expect(result.getAwarenessStateSize()).toBe(1); + result.awareness.destroy(); + }); + }); + + describe('destroy', () => { + it('should call store.destroy with correct params', async () => { + await yRedisClient.destroy(); + + expect(redis.quit).toHaveBeenCalledTimes(1); + }); + }); + + describe('handleMessageUpdates', () => { + describe('when a message is messageSyncUpdate', () => { + const setup = () => { + const ydoc = new Doc(); + const awareness = createMock(); + const message = Buffer.from([protocol.messageSync, protocol.messageSyncUpdate, 0x54, 0x45, 0x53, 0x54]); + + const messages = yRedisMessageFactory.build({ messages: [message] }); + + const spyApplyUpdate = jest.spyOn(Y, 'applyUpdate'); + spyApplyUpdate.mockReturnValueOnce(undefined); + + return { ydoc, awareness, messages, spyApplyUpdate }; + }; + + it('should call Y.applyUpdate with correct params', () => { + const { ydoc, awareness, messages, spyApplyUpdate } = setup(); + + // @ts-ignore it is private method + yRedisClient.handleMessageUpdates(messages, ydoc, awareness); + + expect(spyApplyUpdate).toHaveBeenCalledWith(ydoc, expect.anything()); + }); + }); + + describe('when a message is messageSyncAwareness', () => { + const setup = () => { + const ydoc = new Doc(); + const awareness = createMock(); + const message = Buffer.from([protocol.messageAwareness, 0x54, 0x45, 0x53, 0x54]); + + const messages = yRedisMessageFactory.build({ messages: [message] }); + + const spyApplyAwarenessUpdate = jest.spyOn(Awareness, 'applyAwarenessUpdate'); + spyApplyAwarenessUpdate.mockReturnValueOnce(undefined); + + return { ydoc, awareness, messages, spyApplyAwarenessUpdate }; + }; + + it('should call Y.applyAwarenessUpdate with correct params', () => { + const { ydoc, awareness, messages, spyApplyAwarenessUpdate } = setup(); + + // @ts-ignore it is private method + yRedisClient.handleMessageUpdates(messages, ydoc, awareness); + + expect(spyApplyAwarenessUpdate).toHaveBeenCalledWith(awareness, expect.anything(), null); + }); + }); + }); +}); diff --git a/src/infra/y-redis/api.service.ts b/src/infra/y-redis/y-redis.client.ts similarity index 53% rename from src/infra/y-redis/api.service.ts rename to src/infra/y-redis/y-redis.client.ts index 8af2d418..752a287a 100644 --- a/src/infra/y-redis/api.service.ts +++ b/src/infra/y-redis/y-redis.client.ts @@ -1,73 +1,44 @@ +import { Injectable, OnModuleInit } from '@nestjs/common'; import { array, decoding, promise } from 'lib0'; import { applyAwarenessUpdate, Awareness } from 'y-protocols/awareness'; import { applyUpdate, applyUpdateV2, Doc } from 'yjs'; +import { Logger } from '../logger/logger.js'; import { MetricsService } from '../metrics/metrics.service.js'; -import { RedisAdapter, StreamNameClockPair } from '../redis/interfaces/index.js'; -import { RedisService } from '../redis/redis.service.js'; +import { RedisAdapter, StreamMessageReply, StreamNameClockPair } from '../redis/interfaces/index.js'; import { computeRedisRoomStreamName, extractMessagesFromStreamReply } from './helper.js'; import { YRedisMessage } from './interfaces/stream-message.js'; -import { YRedisDoc } from './interfaces/y-redis-doc.js'; import * as protocol from './protocol.js'; import { DocumentStorage } from './storage.js'; +import { YRedisDocFactory } from './y-redis-doc.factory.js'; +import { YRedisDoc } from './y-redis-doc.js'; -export const handleMessageUpdates = (docMessages: YRedisMessage | null, ydoc: Doc, awareness: Awareness): void => { - docMessages?.messages.forEach((m) => { - const decoder = decoding.createDecoder(m); - const messageType = decoding.readVarUint(decoder); - switch (messageType) { - case protocol.messageSync: { - // The methode readVarUnit work with pointer, that increase by each execution. The second execution get the second value. - const syncType = decoding.readVarUint(decoder); - if (syncType === protocol.messageSyncUpdate) { - applyUpdate(ydoc, decoding.readVarUint8Array(decoder)); - } - break; - } - case protocol.messageAwareness: { - applyAwarenessUpdate(awareness, decoding.readVarUint8Array(decoder), null); - break; - } - } - }); -}; - -export const createApiClient = async (store: DocumentStorage, createRedisInstance: RedisService): Promise => { - const a = new Api(store, await createRedisInstance.createRedisInstance()); - - await a.redis.createGroup(); - - return a; -}; - -export class Api { +@Injectable() +export class YRedisClient implements OnModuleInit { public readonly redisPrefix: string; - public _destroyed; public constructor( - private readonly store: DocumentStorage, + private readonly storage: DocumentStorage, public readonly redis: RedisAdapter, + private readonly logger: Logger, ) { - this.store = store; + this.logger.setContext(YRedisClient.name); this.redisPrefix = redis.redisPrefix; - this._destroyed = false; } - public async getMessages(streams: StreamNameClockPair[]): Promise { - if (streams.length === 0) { - await promise.wait(50); - - return []; - } + public async onModuleInit(): Promise { + await this.redis.createGroup(); + } + public async getMessages(streams: StreamNameClockPair[]): Promise { const streamReplyRes = await this.redis.readStreams(streams); const res: YRedisMessage[] = []; streamReplyRes?.forEach((stream) => { + const messages = this.extractMessages(stream.messages); res.push({ stream: stream.name.toString(), - // @ts-ignore - messages: protocol.mergeMessages(stream.messages.map((message) => message.message.m).filter((m) => m != null)), + messages: protocol.mergeMessages(messages), lastId: stream.messages ? array.last(stream.messages).id.toString() : '', }); }); @@ -89,20 +60,20 @@ export class Api { } public getStateVector(room: string, docid = '/'): Promise { - return this.store.retrieveStateVector(room, docid); + return this.storage.retrieveStateVector(room, docid); } public async getDoc(room: string, docid: string): Promise { const end = MetricsService.methodDurationHistogram.startTimer(); let docChanged = false; - const roomComputed = computeRedisRoomStreamName(room, docid, this.redisPrefix); - const streamReply = await this.redis.readMessagesFromStream(roomComputed); + const streamName = computeRedisRoomStreamName(room, docid, this.redisPrefix); + const streamReply = await this.redis.readMessagesFromStream(streamName); const ms = extractMessagesFromStreamReply(streamReply, this.redisPrefix); const docMessages = ms.get(room)?.get(docid) ?? null; - const docstate = await this.store.retrieveDoc(room, docid); + const docstate = await this.storage.retrieveDoc(room, docid); const ydoc = new Doc(); const awareness = new Awareness(ydoc); @@ -117,28 +88,59 @@ export class Api { }); ydoc.transact(() => { - handleMessageUpdates(docMessages, ydoc, awareness); + this.handleMessageUpdates(docMessages, ydoc, awareness); }); end(); - const response = { + const response = YRedisDocFactory.build({ ydoc, awareness, redisLastId: docMessages?.lastId.toString() ?? '0', storeReferences: docstate?.references ?? null, docChanged, - }; + streamName, + }); if (ydoc.store.pendingStructs !== null) { - console.warn(`Document ${room} has pending structs ${JSON.stringify(ydoc.store.pendingStructs)}.`); + this.logger.warning(`Document ${room} has pending structs ${JSON.stringify(ydoc.store.pendingStructs)}.`); } return response; } public async destroy(): Promise { - this._destroyed = true; await this.redis.quit(); } + + private handleMessageUpdates(docMessages: YRedisMessage | null, ydoc: Doc, awareness: Awareness): void { + docMessages?.messages.forEach((m) => { + const decoder = decoding.createDecoder(m); + const messageType = decoding.readVarUint(decoder); + switch (messageType) { + case protocol.messageSync: { + // The methode readVarUnit works with a pointer, that increases by each execution. The second execution gets the second value. + const syncType = decoding.readVarUint(decoder); + if (syncType === protocol.messageSyncUpdate) { + applyUpdate(ydoc, decoding.readVarUint8Array(decoder)); + } + break; + } + case protocol.messageAwareness: { + applyAwarenessUpdate(awareness, decoding.readVarUint8Array(decoder), null); + break; + } + } + }); + } + + private extractMessages(messages: StreamMessageReply[] | null): Buffer[] { + if (messages === null) { + return []; + } + + const filteredMessages = messages.map((message) => message.message.m).filter((m) => m != null); + + return filteredMessages; + } } diff --git a/src/infra/y-redis/y-redis.const.ts b/src/infra/y-redis/y-redis.const.ts new file mode 100644 index 00000000..9c1b5af2 --- /dev/null +++ b/src/infra/y-redis/y-redis.const.ts @@ -0,0 +1,3 @@ +export const API_FOR_SUBSCRIBER = 'API_FOR_SUBSCRIBER'; +export const REDIS_FOR_SUBSCRIBER = 'REDIS_FOR_SUBSCRIBER'; +export const REDIS_FOR_API = 'REDIS_FOR_API'; diff --git a/src/infra/y-redis/y-redis.service.spec.ts b/src/infra/y-redis/y-redis.service.spec.ts new file mode 100644 index 00000000..475f676c --- /dev/null +++ b/src/infra/y-redis/y-redis.service.spec.ts @@ -0,0 +1,428 @@ +import { createMock } from '@golevelup/ts-jest'; +import { Test, TestingModule } from '@nestjs/testing'; +import { encoding } from 'lib0'; +import { Awareness } from 'y-protocols/awareness'; +import { Doc, encodeStateAsUpdate, encodeStateVector } from 'yjs'; +import * as protocol from './protocol.js'; +import { SubscriberService } from './subscriber.service.js'; +import { yRedisDocFactory } from './testing/y-redis-doc.factory.js'; +import { yRedisUserFactory } from './testing/y-redis-user.factory.js'; +import { YRedisService } from './y-redis.service.js'; + +const buildUpdate = (props: { + messageType: number; + length: number; + numberOfUpdates: number; + awarenessId: number; + lastClock: number; +}): Buffer => { + const { messageType, length, numberOfUpdates, awarenessId, lastClock } = props; + const encoder = encoding.createEncoder(); + encoding.writeVarUint(encoder, messageType); + encoding.writeVarUint(encoder, length); + encoding.writeVarUint(encoder, numberOfUpdates); + encoding.writeVarUint(encoder, awarenessId); + encoding.writeVarUint(encoder, lastClock); + + return Buffer.from(encoding.toUint8Array(encoder)); +}; + +describe(YRedisService.name, () => { + let module: TestingModule; + let yRedisService: YRedisService; + let subscriberService: SubscriberService; + + beforeEach(async () => { + module = await Test.createTestingModule({ + providers: [ + YRedisService, + { + provide: SubscriberService, + useValue: createMock(), + }, + ], + }).compile(); + + yRedisService = module.get(YRedisService); + subscriberService = module.get(SubscriberService); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('should be defined', () => { + expect(yRedisService).toBeDefined(); + }); + + describe('SubscriberService wrap methods', () => { + const setup = () => { + const stream = 'test'; + const callback = jest.fn(); + + return { stream, callback }; + }; + + it('should call subscriberService.start', () => { + yRedisService.start(); + + expect(subscriberService.start).toHaveBeenCalled(); + }); + + it('should call subscriberService.subscribe', () => { + const { stream, callback } = setup(); + + yRedisService.subscribe(stream, callback); + + expect(subscriberService.subscribe).toHaveBeenCalledWith(stream, callback); + }); + + it('should call subscriberService.unsubscribe', () => { + const { stream, callback } = setup(); + + yRedisService.unsubscribe(stream, callback); + + expect(subscriberService.unsubscribe).toHaveBeenCalledWith(stream, callback); + }); + + describe('when isSmallerRedisId returns true', () => { + const setup = () => { + const yRedisDoc = yRedisDocFactory.build({ redisLastId: '0' }); + const yRedisUser = yRedisUserFactory.build({ initialRedisSubId: '1' }); + + return { yRedisDoc, yRedisUser }; + }; + + it('should call subscriberService.ensureSubId', () => { + const { yRedisDoc, yRedisUser } = setup(); + + yRedisService.ensureLatestContentSubscription(yRedisDoc, yRedisUser); + + expect(subscriberService.ensureSubId).toHaveBeenCalledWith(yRedisDoc.streamName, yRedisDoc.redisLastId); + }); + }); + + describe('when isSmallerRedisId returns false', () => { + const setup = () => { + const yRedisDoc = yRedisDocFactory.build({ redisLastId: '1' }); + const yRedisUser = yRedisUserFactory.build({ initialRedisSubId: '0' }); + + return { yRedisDoc, yRedisUser }; + }; + + it('should call subscriberService.ensureSubId', () => { + const { yRedisDoc, yRedisUser } = setup(); + + yRedisService.ensureLatestContentSubscription(yRedisDoc, yRedisUser); + + expect(subscriberService.ensureSubId).not.toHaveBeenCalledWith(yRedisDoc.streamName, yRedisDoc.redisLastId); + }); + }); + }); + + describe('filterMessageForPropagation', () => { + describe('when message is awareness update and users awarenessid is null', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: null, + awarenessLastClock: 99, + }); + const messageBuffer = buildUpdate({ + messageType: protocol.messageAwareness, + length: 0, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should update users awarenessId and awarenessLastClock', () => { + const { messageBuffer, user } = setup(); + + yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(user.awarenessId).toBe(75); + expect(user.awarenessLastClock).toBe(76); + }); + + it('should return message', () => { + const { messageBuffer, user } = setup(); + + const result = yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(result).toEqual(Buffer.from(messageBuffer)); + }); + }); + + describe('when message is awareness update and users awarenessid is messages awarenessid', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: 75, + awarenessLastClock: 99, + }); + + const messageBuffer = buildUpdate({ + messageType: protocol.messageAwareness, + length: 0, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should update users awarenessId and awarenessLastClock', () => { + const { messageBuffer, user } = setup(); + + yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(user.awarenessId).toBe(75); + expect(user.awarenessLastClock).toBe(76); + }); + }); + + describe('when message is sync update', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: null, + awarenessLastClock: 99, + }); + + const messageBuffer = buildUpdate({ + messageType: protocol.messageSync, + length: protocol.messageSyncUpdate, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should not update users awarenessId and awarenessLastClock', () => { + const { messageBuffer, user } = setup(); + + yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(user.awarenessId).toBe(null); + expect(user.awarenessLastClock).toBe(99); + }); + }); + + describe('when message is sync step 2 update', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: null, + awarenessLastClock: 99, + }); + + const messageBuffer = buildUpdate({ + messageType: protocol.messageSync, + length: protocol.messageSyncStep2, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should not update users awarenessId and awarenessLastClock', () => { + const { messageBuffer, user } = setup(); + + yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(user.awarenessId).toBe(null); + expect(user.awarenessLastClock).toBe(99); + }); + }); + + describe('when message is sync step 1 update', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: null, + awarenessLastClock: 99, + }); + + const messageBuffer = buildUpdate({ + messageType: protocol.messageSync, + length: protocol.messageSyncStep1, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should not update users awarenessId and awarenessLastClock', () => { + const { messageBuffer, user } = setup(); + + yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(user.awarenessId).toBe(null); + expect(user.awarenessLastClock).toBe(99); + }); + + it('should return null', () => { + const { messageBuffer, user } = setup(); + + const result = yRedisService.filterMessageForPropagation(messageBuffer, user); + + expect(result).toBe(null); + }); + }); + + describe('when message is of unknown type', () => { + const setup = () => { + const user = yRedisUserFactory.build({ + awarenessId: null, + awarenessLastClock: 99, + }); + + const messageBuffer = buildUpdate({ + messageType: 999, + length: 999, + numberOfUpdates: 1, + awarenessId: 75, + lastClock: 76, + }); + + return { messageBuffer, user }; + }; + + it('should throw an error', () => { + const { messageBuffer, user } = setup(); + + expect(() => yRedisService.filterMessageForPropagation(messageBuffer, user)).toThrow( + `Unexpected message type ${messageBuffer}`, + ); + }); + }); + }); + + describe('createAwarenessUserDisconnectedMessage', () => { + describe('when awarenessId is null', () => { + it('should throw an error', () => { + const user = yRedisUserFactory.build({ awarenessId: null }); + + expect(() => yRedisService.createAwarenessUserDisconnectedMessage(user)).toThrow( + 'Missing awarenessId in YRedisUser.', + ); + }); + }); + + describe('when awarenessId is not null', () => { + const setup = () => { + const user = yRedisUserFactory.build({ awarenessId: 1 }); + + const expected = Buffer.from(protocol.encodeAwarenessUserDisconnected(1, 0)); + + return { user, expected }; + }; + + it('should return a buffer', () => { + const { user, expected } = setup(); + const result = yRedisService.createAwarenessUserDisconnectedMessage(user); + + expect(result).toEqual(expected); + }); + }); + + describe('encodeSyncStep1StateVectorMessage', () => { + const setup = () => { + const yDoc = new Doc(); + + const expected = protocol.encodeSyncStep1(encodeStateVector(yDoc)); + + return { yDoc, expected }; + }; + + it('should return a buffer', () => { + const { yDoc, expected } = setup(); + const result = yRedisService.encodeSyncStep1StateVectorMessage(yDoc); + + expect(result).toEqual(expected); + }); + }); + + describe('encodeSyncStep2StateAsUpdateMessage', () => { + const setup = () => { + const yDoc = new Doc(); + + const expected = protocol.encodeSyncStep2(encodeStateAsUpdate(yDoc)); + + return { yDoc, expected }; + }; + + it('should return a buffer', () => { + const { yDoc, expected } = setup(); + const result = yRedisService.encodeSyncStep2StateAsUpdateMessage(yDoc); + + expect(result).toEqual(expected); + }); + }); + + describe('encodeAwarenessUpdateMessage', () => { + const setup = () => { + const awareness = new Awareness(new Doc()); + + const expected = protocol.encodeAwarenessUpdate(awareness, Array.from(awareness.states.keys())); + + return { awareness, expected }; + }; + + it('should return a buffer', () => { + const { awareness, expected } = setup(); + const result = yRedisService.encodeAwarenessUpdateMessage(awareness); + + expect(result).toEqual(expected); + + awareness.destroy(); + }); + }); + + describe('mergeMessagesToMessage', () => { + describe('when messages length is 1', () => { + const setup = () => { + const messages = [Buffer.from('hello')]; + + const expected = messages[0]; + + return { messages, expected }; + }; + + it('should return a buffer', () => { + const { messages, expected } = setup(); + const result = yRedisService.mergeMessagesToMessage(messages); + + expect(result).toEqual(expected); + }); + }); + + describe('when messages length greater than 1', () => { + const setup = () => { + const messages = [Buffer.from('hello'), Buffer.from('world')]; + + const expected = encoding.encode((encoder) => { + messages.forEach((message) => { + encoding.writeUint8Array(encoder, message); + }); + }); + + return { messages, expected }; + }; + + it('should return a buffer', () => { + const { messages, expected } = setup(); + const result = yRedisService.mergeMessagesToMessage(messages); + + expect(result).toEqual(expected); + }); + }); + }); + }); +}); diff --git a/src/infra/y-redis/y-redis.service.ts b/src/infra/y-redis/y-redis.service.ts new file mode 100644 index 00000000..890ae886 --- /dev/null +++ b/src/infra/y-redis/y-redis.service.ts @@ -0,0 +1,143 @@ +import { Injectable } from '@nestjs/common'; +import { encoding } from 'lib0'; +import * as array from 'lib0/array'; +import * as decoding from 'lib0/decoding'; +import { Awareness } from 'y-protocols/awareness.js'; +import { Doc, encodeStateAsUpdate, encodeStateVector } from 'yjs'; +import { isSmallerRedisId } from './helper.js'; +import * as protocol from './protocol.js'; +import { SubscriberService, SubscriptionHandler } from './subscriber.service.js'; +import { YRedisDoc } from './y-redis-doc.js'; +import { YRedisUser } from './y-redis-user.js'; + +@Injectable() +export class YRedisService { + public constructor(private readonly subscriberService: SubscriberService) {} + + // subscriber wrappers + public async start(): Promise { + await this.subscriberService.start(); + } + + public subscribe(stream: string, callback: SubscriptionHandler): { redisId: string } { + const { redisId } = this.subscriberService.subscribe(stream, callback); + + return { redisId }; + } + + public unsubscribe(stream: string, callback: SubscriptionHandler): void { + this.subscriberService.unsubscribe(stream, callback); + } + + public ensureLatestContentSubscription(yRedisDoc: YRedisDoc, yRedisUser: YRedisUser): void { + if (isSmallerRedisId(yRedisDoc.redisLastId, yRedisUser.initialRedisSubId)) { + // our subscription is newer than the content that we received from the y-redis-client + // need to renew subscription id and make sure that we catch the latest content. + this.subscriberService.ensureSubId(yRedisDoc.streamName, yRedisDoc.redisLastId); + } + } + + // state helper + public filterMessageForPropagation(messageBuffer: ArrayBuffer, yRedisUser: YRedisUser): Buffer | null { + const messageBufferCopy = this.copyMessageBuffer(messageBuffer); + const message = Buffer.from(messageBufferCopy); + + if (this.isSyncUpdateAndSyncStep2(message)) { + return message; + } + + if (this.isAwarenessUpdate(message)) { + this.updateUserAwareness(message, yRedisUser); + + return message; + } + + if (this.isSyncMessageStep1(message)) { + // can be safely ignored because we send the full initial state at the beginning + return null; + } + + throw new Error(`Unexpected message type ${message}`); + } + + public createAwarenessUserDisconnectedMessage(yRedisUser: YRedisUser): Buffer { + if (!yRedisUser.awarenessId) { + throw new Error('Missing awarenessId in YRedisUser.'); + } + + const awarenessMessage = Buffer.from( + protocol.encodeAwarenessUserDisconnected(yRedisUser.awarenessId, yRedisUser.awarenessLastClock), + ); + + return awarenessMessage; + } + + public encodeSyncStep1StateVectorMessage(yDoc: Doc): Uint8Array { + const message = protocol.encodeSyncStep1(encodeStateVector(yDoc)); + + return message; + } + + public encodeSyncStep2StateAsUpdateMessage(ydoc: Doc): Uint8Array { + const message = protocol.encodeSyncStep2(encodeStateAsUpdate(ydoc)); + + return message; + } + + public encodeAwarenessUpdateMessage(awareness: Awareness): Uint8Array { + const message = protocol.encodeAwarenessUpdate(awareness, array.from(awareness.states.keys())); + + return message; + } + + public mergeMessagesToMessage(messages: Uint8Array[]): Uint8Array { + const mergedMessage = messages.length === 1 ? messages[0] : this.useEncodingToMergeMessages(messages); + + return mergedMessage; + } + + // private + private useEncodingToMergeMessages(messages: Uint8Array[]): Uint8Array { + const mergedMessage = encoding.encode((encoder) => + messages.forEach((message) => { + encoding.writeUint8Array(encoder, message); + }), + ); + + return mergedMessage; + } + + private copyMessageBuffer(messageBuffer: ArrayBuffer): ArrayBuffer { + const messageBufferCopy = messageBuffer.slice(0, messageBuffer.byteLength); + + return messageBufferCopy; + } + + private isSyncMessageStep1(message: Buffer): boolean { + return message[0] === protocol.messageSync && message[1] === protocol.messageSyncStep1; + } + + private isSyncUpdateAndSyncStep2(message: Buffer): boolean { + return ( + message[0] === protocol.messageSync && + (message[1] === protocol.messageSyncUpdate || message[1] === protocol.messageSyncStep2) + ); + } + + private isAwarenessUpdate(message: Buffer): boolean { + return message[0] === protocol.messageAwareness; + } + + private updateUserAwareness(message: Buffer, yRedisUser: YRedisUser): void { + const decoder = decoding.createDecoder(message); + decoding.readVarUint(decoder); // read message type + decoding.readVarUint(decoder); // read length of awareness update + const alen = decoding.readVarUint(decoder); // number of awareness updates + const awId = decoding.readVarUint(decoder); + if (alen === 1 && (yRedisUser.awarenessId === null || yRedisUser.awarenessId === awId)) { + // only update awareness if len=1 + yRedisUser.awarenessId = awId; + yRedisUser.awarenessLastClock = decoding.readVarUint(decoder); + } + } +} diff --git a/src/modules/server/api/test/tldraw-config.api.spec.ts b/src/modules/server/api/test/tldraw-config.api.spec.ts index 7f294f0c..848f6840 100644 --- a/src/modules/server/api/test/tldraw-config.api.spec.ts +++ b/src/modules/server/api/test/tldraw-config.api.spec.ts @@ -31,7 +31,7 @@ describe('Tldraw-Config Api Test', () => { TLDRAW_ASSETS_ALLOWED_MIME_TYPES_LIST: ['image/png', 'image/jpeg', 'image/gif', 'image/svg+xml'], TLDRAW_ASSETS_ENABLED: true, TLDRAW_ASSETS_MAX_SIZE_BYTES: 10485760, - TLDRAW_WEBSOCKET_URL: 'ws://localhost:3345', + TLDRAW_WEBSOCKET_URL: 'ws://localhost:3399', }); }); }); diff --git a/src/modules/server/api/test/websocket.api.spec.ts b/src/modules/server/api/test/websocket.api.spec.ts index a809b8f8..10094965 100644 --- a/src/modules/server/api/test/websocket.api.spec.ts +++ b/src/modules/server/api/test/websocket.api.spec.ts @@ -9,11 +9,12 @@ import { Doc, encodeStateAsUpdateV2 } from 'yjs'; import { ResponsePayloadBuilder } from '../../../../infra//authorization/response.builder.js'; import { AuthorizationService } from '../../../../infra/authorization/authorization.service.js'; import { ServerModule } from '../../server.module.js'; +import { TldrawServerConfig } from '../../tldraw-server.config.js'; describe('Websocket Api Test', () => { let app: INestApplication; let authorizationService: DeepMocked; - const prefix = 'y'; + let tldrawServerConfig: TldrawServerConfig; beforeAll(async () => { const moduleFixture = await Test.createTestingModule({ @@ -25,8 +26,8 @@ describe('Websocket Api Test', () => { app = moduleFixture.createNestApplication(); await app.init(); - authorizationService = await app.resolve(AuthorizationService); + tldrawServerConfig = await app.resolve(TldrawServerConfig); }); afterAll(async () => { @@ -35,7 +36,8 @@ describe('Websocket Api Test', () => { const createWsClient = (room: string) => { const ydoc = new Doc(); - const serverUrl = 'ws://localhost:3345'; + const serverUrl = tldrawServerConfig.TLDRAW_WEBSOCKET_URL; + const prefix = 'y'; const provider = new WebsocketProvider(serverUrl, prefix + '-' + room, ydoc, { // @ts-ignore WebSocketPolyfill: WebSocket, @@ -66,8 +68,7 @@ describe('Websocket Api Test', () => { describe('when clients have permission for room', () => { describe('when two clients connect to the same doc before any changes', () => { const setup = () => { - const randomString = Math.random().toString(36).substring(7); - const room = randomString; + const room = Math.random().toString(36).substring(7); authorizationService.hasPermission.mockResolvedValueOnce({ hasWriteAccess: true, @@ -225,5 +226,74 @@ describe('Websocket Api Test', () => { expect(error.code).toBe(4401); }); }); + + describe('when client connects and has not a room', () => { + const setup = () => { + const randomString = Math.random().toString(36).substring(7); + const room = randomString; + + const response = ResponsePayloadBuilder.build(null, 'userId'); + authorizationService.hasPermission.mockResolvedValue(response); + + const { ydoc: client1Doc, provider } = createWsClient(room); + + return { client1Doc, provider }; + }; + + it('syncs doc changes of first client to second client', async () => { + const { provider } = setup(); + + let error: CloseEvent; + if (provider.ws) { + provider.ws.onclose = (event: Event) => { + error = event as CloseEvent; + }; + } + + await promise.until(0, () => { + return error as unknown as boolean; + }); + + // @ts-ignore + expect(error.reason).toBe('Missing room or userid'); + // @ts-ignore + expect(error.code).toBe(1008); + }); + }); }); + + /*describe('when openCallback catch an error', () => { + const setup = () => { + const randomString = Math.random().toString(36).substring(7); + const room = randomString; + + const response = ResponsePayloadBuilder.build(room, 'userId'); + authorizationService.hasPermission.mockResolvedValue(response); + + const { ydoc: client1Doc, provider } = createWsClient(room); + + return { client1Doc, provider }; + }; + + it('syncs doc changes of first client to second client', async () => { + const { provider } = setup(); + + let error: CloseEvent; + if (provider.ws) { + provider.ws.onclose = (event: Event) => { + error = event as CloseEvent; + }; + } + //spyOn(provider.ws, 'end').and.callThrough(); + + await promise.until(0, () => { + return error as unknown as boolean; + }); + + // @ts-ignore + //expect(error.reason).toBe('Internal Server Error'); + // @ts-ignore + expect(error.code).toBe(1011); + }); + });*/ }); diff --git a/src/modules/server/api/websocket.gateway.spec.ts b/src/modules/server/api/websocket.gateway.spec.ts deleted file mode 100644 index 6845d86f..00000000 --- a/src/modules/server/api/websocket.gateway.spec.ts +++ /dev/null @@ -1,174 +0,0 @@ -import { createMock, DeepMocked } from '@golevelup/ts-jest'; -import { Test, TestingModule } from '@nestjs/testing'; -import { TemplatedApp } from 'uWebSockets.js'; -import { AuthorizationService } from '../../../infra/authorization/authorization.service.js'; -import { Logger } from '../../../infra/logger/logger.js'; -import { MetricsService } from '../../../infra/metrics/metrics.service.js'; -import { IoRedisAdapter } from '../../../infra/redis/ioredis.adapter.js'; -import { RedisService } from '../../../infra/redis/redis.service.js'; -import { StorageService } from '../../../infra/storage/storage.service.js'; -import * as WsService from '../../../infra/y-redis/ws.service.js'; -import { registerYWebsocketServer } from '../../../infra/y-redis/ws.service.js'; -import { TldrawServerConfig } from '../tldraw-server.config.js'; -import { WebsocketGateway } from './websocket.gateway.js'; - -describe(WebsocketGateway.name, () => { - let service: WebsocketGateway; - let storageService: StorageService; - let redisService: DeepMocked; - let webSocketServer: DeepMocked; - let logger: DeepMocked; - - beforeAll(async () => { - const module: TestingModule = await Test.createTestingModule({ - providers: [ - WebsocketGateway, - { - provide: 'UWS', - useValue: createMock(), - }, - { - provide: StorageService, - useValue: createMock(), - }, - { - provide: AuthorizationService, - useValue: createMock(), - }, - { - provide: RedisService, - useValue: createMock(), - }, - { - provide: Logger, - useValue: createMock(), - }, - { - provide: TldrawServerConfig, - useValue: { - TLDRAW_WEBSOCKET_PATH: 'tests', - TLDRAW_WEBSOCKET_PORT: 3345, - }, - }, - ], - }).compile(); - - service = await module.resolve(WebsocketGateway); - storageService = module.get(StorageService); - redisService = module.get(RedisService); - webSocketServer = module.get('UWS'); - logger = module.get(Logger); - }); - - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('should be defined', () => { - expect(service).toBeDefined(); - }); - - describe('onModuleInit', () => { - const setup = () => { - const yWebsocketServer = createMock(); - jest.spyOn(WsService, 'registerYWebsocketServer').mockResolvedValueOnce(yWebsocketServer); - - const redisAdapter: DeepMocked = createMock(); - - return { redisAdapter }; - }; - - it('should call registerYWebsocketServer', async () => { - setup(); - - await service.onModuleInit(); - - expect(registerYWebsocketServer).toHaveBeenCalledWith( - webSocketServer, - 'tests/:room', - storageService, - expect.any(Function), - { - openWsCallback: expect.any(Function), - closeWsCallback: expect.any(Function), - }, - redisService, - ); - }); - - it('should increment openConnectionsGauge on openWsCallback', async () => { - setup(); - const openConnectionsGaugeIncSpy = jest.spyOn(MetricsService.openConnectionsGauge, 'inc'); - - await service.onModuleInit(); - - const openWsCallback = (registerYWebsocketServer as jest.Mock).mock.calls[0][4].openWsCallback; - openWsCallback(); - - expect(openConnectionsGaugeIncSpy).toHaveBeenCalled(); - }); - - it('should decrement openConnectionsGauge on closeWsCallback', async () => { - setup(); - const openConnectionsGaugeDecSpy = jest.spyOn(MetricsService.openConnectionsGauge, 'dec'); - - await service.onModuleInit(); - const closeWsCallback = (registerYWebsocketServer as jest.Mock).mock.calls[0][4].closeWsCallback; - closeWsCallback(); - - expect(openConnectionsGaugeDecSpy).toHaveBeenCalled(); - }); - - it('should call webSocketServer.listen', async () => { - setup(); - await service.onModuleInit(); - - expect(webSocketServer.listen).toHaveBeenCalledWith(3345, expect.any(Function)); - }); - - it('should call redisAdapter.subscribeToDeleteChannel', async () => { - const { redisAdapter } = setup(); - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapter); - - await service.onModuleInit(); - - expect(redisService.createRedisInstance).toHaveBeenCalled(); - expect(redisAdapter.subscribeToDeleteChannel).toHaveBeenCalledWith(expect.any(Function)); - }); - - it('should call webSocketServer.publish', async () => { - const { redisAdapter } = setup(); - - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapter); - redisAdapter.subscribeToDeleteChannel.mockImplementation((cb) => cb('test')); - - await service.onModuleInit(); - - expect(webSocketServer.publish).toHaveBeenCalledWith('test', 'action:delete'); - }); - - it('should log if webSocketServer.listen return true', async () => { - setup(); - // @ts-ignore - webSocketServer.listen.mockImplementationOnce((_, cb) => cb(true)); - - await service.onModuleInit(); - - expect(logger.log).toHaveBeenCalledWith('Websocket Server is running on port 3345'); - }); - }); - - describe('onModuleDestroy', () => { - const setup = () => { - const yWebsocketServer = createMock(); - jest.spyOn(WsService, 'registerYWebsocketServer').mockResolvedValueOnce(yWebsocketServer); - }; - - it('should call webSocketServer.close', () => { - setup(); - service.onModuleDestroy(); - - expect(webSocketServer.close).toHaveBeenCalled(); - }); - }); -}); diff --git a/src/modules/server/api/websocket.gateway.ts b/src/modules/server/api/websocket.gateway.ts index cfc76cc2..bf45ce8c 100644 --- a/src/modules/server/api/websocket.gateway.ts +++ b/src/modules/server/api/websocket.gateway.ts @@ -1,22 +1,46 @@ import { Inject, Injectable, OnModuleDestroy, OnModuleInit } from '@nestjs/common'; -import { TemplatedApp } from 'uWebSockets.js'; +import { + HttpRequest, + HttpResponse, + SHARED_COMPRESSOR, + TemplatedApp, + us_socket_context_t, + WebSocket, +} from 'uWebSockets.js'; import { AuthorizationService } from '../../../infra/authorization/authorization.service.js'; import { Logger } from '../../../infra/logger/index.js'; import { MetricsService } from '../../../infra/metrics/metrics.service.js'; -import { RedisService } from '../../../infra/redis/redis.service.js'; -import { StorageService } from '../../../infra/storage/storage.service.js'; -import { registerYWebsocketServer } from '../../../infra/y-redis/ws.service.js'; +import { RedisAdapter } from '../../../infra/redis/interfaces/redis-adapter.js'; +import { YRedisDoc } from '../../../infra/y-redis/y-redis-doc.js'; +import { YRedisUserFactory } from '../../../infra/y-redis/y-redis-user.factory.js'; +import { YRedisUser } from '../../../infra/y-redis/y-redis-user.js'; +import { YRedisClient } from '../../../infra/y-redis/y-redis.client.js'; +import { YRedisService } from '../../../infra/y-redis/y-redis.service.js'; +import { REDIS_FOR_SUBSCRIBE_OF_DELETION, UWS } from '../server.const.js'; import { TldrawServerConfig } from '../tldraw-server.config.js'; -export const UWS = 'UWS'; +interface RequestHeaderInfos { + headerWsExtensions: string; + headerWsKey: string; + headerWsProtocol: string; +} + +// https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent/code +enum WebSocketErrorCodes { + InternalError = 1011, + PolicyViolation = 1008, + TldrawPolicyViolation = 4401, + TldrawInternalError = 4500, +} @Injectable() export class WebsocketGateway implements OnModuleInit, OnModuleDestroy { public constructor( @Inject(UWS) private readonly webSocketServer: TemplatedApp, - private readonly storageService: StorageService, + private readonly yRedisService: YRedisService, + private readonly yRedisClient: YRedisClient, private readonly authorizationService: AuthorizationService, - private readonly redisService: RedisService, + @Inject(REDIS_FOR_SUBSCRIBE_OF_DELETION) private readonly redisAdapter: RedisAdapter, private readonly config: TldrawServerConfig, private readonly logger: Logger, ) { @@ -27,31 +51,176 @@ export class WebsocketGateway implements OnModuleInit, OnModuleDestroy { this.webSocketServer.close(); } - public async onModuleInit(): Promise { - const wsPathPrefix = this.config.TLDRAW_WEBSOCKET_PATH; - const wsPort = this.config.TLDRAW_WEBSOCKET_PORT; + public onModuleInit(): void { + this.yRedisService.start(); - await registerYWebsocketServer( - this.webSocketServer, - `${wsPathPrefix}/:room`, - this.storageService, - this.authorizationService.hasPermission.bind(this.authorizationService), - { - openWsCallback: () => MetricsService.openConnectionsGauge.inc(), - closeWsCallback: () => MetricsService.openConnectionsGauge.dec(), - }, - this.redisService, - ); + this.webSocketServer.ws(`${this.config.TLDRAW_WEBSOCKET_PATH}/:room`, { + compression: SHARED_COMPRESSOR, + maxPayloadLength: 100 * 1024 * 1024, + idleTimeout: 60, + sendPingsAutomatically: true, + upgrade: (res, req, context) => this.upgradeCallback(res, req, context), + open: (ws: WebSocket) => this.openCallback(ws), + message: (ws, messageBuffer) => this.messageCallback(ws, messageBuffer), + close: (ws) => this.closeCallback(ws), + }); - this.webSocketServer.listen(wsPort, (t) => { + this.webSocketServer.listen(this.config.TLDRAW_WEBSOCKET_PORT, (t) => { if (t) { - this.logger.log(`Websocket Server is running on port ${wsPort}`); + this.logger.info(`Websocket Server is running on port ${this.config.TLDRAW_WEBSOCKET_PORT}`); } }); - const redisAdapter = await this.redisService.createRedisInstance(); - redisAdapter.subscribeToDeleteChannel((message: string) => { + this.redisAdapter.subscribeToDeleteChannel((message: string) => { this.webSocketServer.publish(message, 'action:delete'); }); } + + private async upgradeCallback(res: HttpResponse, req: HttpRequest, context: us_socket_context_t): Promise { + try { + let aborted = false; + const { headerWsKey, headerWsProtocol, headerWsExtensions } = this.extractHeaderInfos(req); + + res.onAborted(() => { + aborted = true; + }); + + const authPayload = await this.authorizationService.hasPermission(req); + if (aborted) return; + + res.cork(() => { + const yRedisUser = YRedisUserFactory.build(authPayload); + res.upgrade(yRedisUser, headerWsKey, headerWsProtocol, headerWsExtensions, context); + }); + } catch (error) { + res.cork(() => { + res.writeStatus('500 Internal Server Error').end('Internal Server Error'); + }); + this.logger.warning(error); + } + } + + private extractHeaderInfos(req: HttpRequest): RequestHeaderInfos { + const headerWsKey = req.getHeader('sec-websocket-key'); + const headerWsProtocol = req.getHeader('sec-websocket-protocol'); + const headerWsExtensions = req.getHeader('sec-websocket-extensions'); + + return { + headerWsExtensions, + headerWsKey, + headerWsProtocol, + }; + } + + private async openCallback(ws: WebSocket): Promise { + try { + const user = ws.getUserData(); + if (user.error != null) { + const { code: authorizationRequestErrorCode, reason } = user.error; + this.logger.warning(`Error: ${authorizationRequestErrorCode} - ${reason}`); + ws.end(authorizationRequestErrorCode, reason); + + return; + } + + if (user.room === null || user.userid === null) { + ws.end(WebSocketErrorCodes.PolicyViolation, 'Missing room or userid'); + + return; + } + + MetricsService.openConnectionsGauge.inc(); + + const yRedisDoc = await this.yRedisClient.getDoc(user.room, 'index'); + user.subs.add(yRedisDoc.streamName); + ws.subscribe(yRedisDoc.streamName); + + const { redisId } = this.yRedisService.subscribe(yRedisDoc.streamName, this.redisMessageSubscriber); + user.initialRedisSubId = redisId; + + if (user.isClosed) return; + + ws.cork(() => { + ws.send(this.yRedisService.encodeSyncStep1StateVectorMessage(yRedisDoc.ydoc), true, false); + ws.send(this.yRedisService.encodeSyncStep2StateAsUpdateMessage(yRedisDoc.ydoc), true, true); + if (yRedisDoc.getAwarenessStateSize() > 0) { + ws.send(this.yRedisService.encodeAwarenessUpdateMessage(yRedisDoc.awareness), true, true); + } + }); + + this.destroyAwarenessToAvoidMemoryLeak(yRedisDoc); + + this.yRedisService.ensureLatestContentSubscription(yRedisDoc, user); + } catch (error) { + this.logger.warning(error); + ws.end(WebSocketErrorCodes.InternalError, 'Internal Server Error'); + } + } + + private destroyAwarenessToAvoidMemoryLeak(indexDoc: YRedisDoc): void { + // @see: https://github.com/yjs/y-redis/issues/24 + indexDoc.awareness.destroy(); + } + + private readonly redisMessageSubscriber = (stream: string, messages: Uint8Array[]): void => { + if (!this.isSubscriberAvailable(stream)) { + this.yRedisService.unsubscribe(stream, this.redisMessageSubscriber); + } + + const message = this.yRedisService.mergeMessagesToMessage(messages); + this.webSocketServer.publish(stream, message, true, false); + }; + + private isSubscriberAvailable(stream: string): boolean { + return this.webSocketServer.numSubscribers(stream) > 0; + } + + private messageCallback(ws: WebSocket, messageBuffer: ArrayBuffer): void { + try { + const user = ws.getUserData(); + + if (!user.hasWriteAccess || !user.room) { + ws.end(WebSocketErrorCodes.TldrawPolicyViolation, 'User has no write access or room is missing'); + + return; + } + + const message = this.yRedisService.filterMessageForPropagation(messageBuffer, user); + + if (message) { + this.yRedisClient.addMessage(user.room, 'index', message); + } + } catch (error) { + this.logger.warning(error); + ws.end(WebSocketErrorCodes.InternalError); + } + } + + private closeCallback(ws: WebSocket): void { + try { + const user = ws.getUserData(); + if (!user.room) return; + + if (user.awarenessId) { + const awarenessMessage = this.yRedisService.createAwarenessUserDisconnectedMessage(user); + this.yRedisClient.addMessage(user.room, 'index', awarenessMessage); + } + + this.unsubscribeUser(user); + + MetricsService.openConnectionsGauge.dec(); + } catch (error) { + this.logger.warning(error); + ws.end(WebSocketErrorCodes.InternalError); + } + } + + private unsubscribeUser(user: YRedisUser): void { + user.isClosed = true; + user.subs.forEach((topic) => { + if (this.webSocketServer.numSubscribers(topic) === 0) { + this.yRedisService.unsubscribe(topic, this.redisMessageSubscriber); + } + }); + } } diff --git a/src/modules/server/server.const.ts b/src/modules/server/server.const.ts new file mode 100644 index 00000000..0da74a97 --- /dev/null +++ b/src/modules/server/server.const.ts @@ -0,0 +1,3 @@ +export const REDIS_FOR_DELETION = 'REDIS_FOR_DELETION'; +export const REDIS_FOR_SUBSCRIBE_OF_DELETION = 'REDIS_FOR_SUBSCRIBE_OF_DELETION'; +export const UWS = 'UWS'; diff --git a/src/modules/server/server.module.ts b/src/modules/server/server.module.ts index d14b6698..563a2c87 100644 --- a/src/modules/server/server.module.ts +++ b/src/modules/server/server.module.ts @@ -7,16 +7,22 @@ import { ConfigurationModule } from '../../infra/configuration/configuration.mod import { LoggerModule } from '../../infra/logger/logger.module.js'; import { RedisModule } from '../../infra/redis/index.js'; import { StorageModule } from '../../infra/storage/storage.module.js'; +import { YRedisClientModule } from '../../infra/y-redis/y-redis-client.module.js'; +import { YRedisServiceModule } from '../../infra/y-redis/y-redis-service.module.js'; import { TldrawConfigController } from './api/tldraw-confg.controller.js'; import { TldrawDocumentController } from './api/tldraw-document.controller.js'; -import { UWS, WebsocketGateway } from './api/websocket.gateway.js'; +import { WebsocketGateway } from './api/websocket.gateway.js'; +import { REDIS_FOR_DELETION, REDIS_FOR_SUBSCRIBE_OF_DELETION, UWS } from './server.const.js'; import { TldrawDocumentService } from './service/tldraw-document.service.js'; import { TldrawServerConfig } from './tldraw-server.config.js'; @Module({ imports: [ ConfigurationModule.register(TldrawServerConfig), - RedisModule, + YRedisClientModule.register(), + YRedisServiceModule.register(), + RedisModule.registerFor(REDIS_FOR_DELETION), + RedisModule.registerFor(REDIS_FOR_SUBSCRIBE_OF_DELETION), StorageModule, AuthorizationModule, LoggerModule, diff --git a/src/modules/server/service/tldraw-document.service.spec.ts b/src/modules/server/service/tldraw-document.service.spec.ts index 1200d290..0d59383f 100644 --- a/src/modules/server/service/tldraw-document.service.spec.ts +++ b/src/modules/server/service/tldraw-document.service.spec.ts @@ -1,30 +1,35 @@ import { createMock, DeepMocked } from '@golevelup/ts-jest'; import { Test } from '@nestjs/testing'; import { TemplatedApp } from 'uWebSockets.js'; -import { RedisService } from '../../../infra/redis/index.js'; -import { IoRedisAdapter } from '../../../infra/redis/ioredis.adapter.js'; +import { RedisFactory } from '../../../infra/redis/index.js'; +import { RedisAdapter } from '../../../infra/redis/interfaces/redis-adapter.js'; +import { REDIS_FOR_DELETION, UWS } from '../server.const.js'; import { TldrawDocumentService } from './tldraw-document.service.js'; describe('Tldraw-Document Service', () => { let service: TldrawDocumentService; let webSocketServer: TemplatedApp; - let redisService: DeepMocked; + let redisAdapter: DeepMocked; beforeAll(async () => { const moduleFixture = await Test.createTestingModule({ providers: [ TldrawDocumentService, { - provide: RedisService, - useValue: createMock(), + provide: RedisFactory, + useValue: createMock(), + }, + { provide: UWS, useValue: createMock() }, + { + provide: REDIS_FOR_DELETION, + useValue: createMock({ redisPrefix: 'y' }), }, - { provide: 'UWS', useValue: createMock() }, ], }).compile(); service = moduleFixture.get(TldrawDocumentService); - webSocketServer = moduleFixture.get('UWS'); - redisService = moduleFixture.get(RedisService); + webSocketServer = moduleFixture.get(UWS); + redisAdapter = moduleFixture.get(REDIS_FOR_DELETION); }); describe('when redis and storage service returns successfully', () => { @@ -33,32 +38,23 @@ describe('Tldraw-Document Service', () => { const docName = `y:room:${parentId}:index`; const expectedMessage = 'action:delete'; - const redisInstance = createMock({ - redisPrefix: 'y', - }); - - redisService.createRedisInstance.mockResolvedValueOnce(redisInstance); - - return { parentId, docName, expectedMessage, redisInstance }; + return { parentId, docName, expectedMessage }; }; it('should call webSocketServer.publish', async () => { const { parentId, docName, expectedMessage } = setup(); - await service.onModuleInit(); await service.deleteByDocName(parentId); expect(webSocketServer.publish).toHaveBeenCalledWith(docName, expectedMessage); }); - it('should call redisInstance.markToDeleteByDocName', async () => { - const { parentId, docName, redisInstance } = setup(); - - await service.onModuleInit(); + it('should call redisAdapter.markToDeleteByDocName', async () => { + const { parentId, docName } = setup(); await service.deleteByDocName(parentId); - expect(redisInstance.markToDeleteByDocName).toHaveBeenCalledWith(docName); + expect(redisAdapter.markToDeleteByDocName).toHaveBeenCalledWith(docName); }); }); }); diff --git a/src/modules/server/service/tldraw-document.service.ts b/src/modules/server/service/tldraw-document.service.ts index 681ed7e7..d5bdab4e 100644 --- a/src/modules/server/service/tldraw-document.service.ts +++ b/src/modules/server/service/tldraw-document.service.ts @@ -1,29 +1,22 @@ -import { Inject, Injectable, OnModuleInit } from '@nestjs/common'; +import { Inject, Injectable } from '@nestjs/common'; import { TemplatedApp } from 'uWebSockets.js'; -import { RedisService } from '../../../infra/redis/index.js'; import { RedisAdapter } from '../../../infra/redis/interfaces/index.js'; import { computeRedisRoomStreamName } from '../../../infra/y-redis/helper.js'; -const UWS = 'UWS'; +import { REDIS_FOR_DELETION, UWS } from '../server.const.js'; @Injectable() -export class TldrawDocumentService implements OnModuleInit { - private redisInstance!: RedisAdapter; - +export class TldrawDocumentService { public constructor( @Inject(UWS) private readonly webSocketServer: TemplatedApp, - private readonly redisService: RedisService, + @Inject(REDIS_FOR_DELETION) private readonly redisAdapter: RedisAdapter, ) {} - public async onModuleInit(): Promise { - this.redisInstance = await this.redisService.createRedisInstance(); - } - public async deleteByDocName(parentId: string): Promise { - const redisPrefix = this.redisInstance.redisPrefix; + const redisPrefix = this.redisAdapter.redisPrefix; const docName = computeRedisRoomStreamName(parentId, 'index', redisPrefix); this.webSocketServer.publish(docName, 'action:delete'); - await this.redisInstance.markToDeleteByDocName(docName); + await this.redisAdapter.markToDeleteByDocName(docName); } } diff --git a/src/modules/worker/worker.config.ts b/src/modules/worker/worker.config.ts index ff419623..34eb7391 100644 --- a/src/modules/worker/worker.config.ts +++ b/src/modules/worker/worker.config.ts @@ -1,5 +1,5 @@ import { Transform } from 'class-transformer'; -import { IsNumber } from 'class-validator'; +import { IsNumber, IsPositive } from 'class-validator'; export class WorkerConfig { /** @@ -13,10 +13,16 @@ export class WorkerConfig { * Minimum lifetime of y* update messages in redis streams. */ @IsNumber() + @IsPositive() @Transform(({ value }) => parseInt(value)) public WORKER_MIN_MESSAGE_LIFETIME = 60000; @IsNumber() @Transform(({ value }) => parseInt(value)) public WORKER_TRY_CLAIM_COUNT = 5; + + @IsNumber() + @IsPositive() + @Transform(({ value }) => parseInt(value)) + public WORKER_IDLE_BREAK_MS = 1000; } diff --git a/src/modules/worker/worker.const.ts b/src/modules/worker/worker.const.ts new file mode 100644 index 00000000..a70fe4eb --- /dev/null +++ b/src/modules/worker/worker.const.ts @@ -0,0 +1 @@ +export const REDIS_FOR_WORKER = 'REDIS_FOR_WORKER'; diff --git a/src/modules/worker/worker.module.ts b/src/modules/worker/worker.module.ts index a27c2295..0ca6ec0b 100644 --- a/src/modules/worker/worker.module.ts +++ b/src/modules/worker/worker.module.ts @@ -3,11 +3,19 @@ import { ConfigurationModule } from '../../infra/configuration/configuration.mod import { LoggerModule } from '../../infra/logger/logger.module.js'; import { RedisModule } from '../../infra/redis/redis.module.js'; import { StorageModule } from '../../infra/storage/storage.module.js'; +import { YRedisClientModule } from '../../infra/y-redis/y-redis-client.module.js'; import { WorkerConfig } from './worker.config.js'; +import { REDIS_FOR_WORKER } from './worker.const.js'; import { WorkerService } from './worker.service.js'; @Module({ - imports: [ConfigurationModule.register(WorkerConfig), RedisModule, StorageModule, LoggerModule], + imports: [ + ConfigurationModule.register(WorkerConfig), + RedisModule.registerFor(REDIS_FOR_WORKER), + StorageModule, + LoggerModule, + YRedisClientModule.register(), + ], providers: [WorkerService], }) export class WorkerModule {} diff --git a/src/modules/worker/worker.service.spec.ts b/src/modules/worker/worker.service.spec.ts index 9dcd9bc2..5ab31cd8 100644 --- a/src/modules/worker/worker.service.spec.ts +++ b/src/modules/worker/worker.service.spec.ts @@ -1,35 +1,50 @@ import { createMock, DeepMocked } from '@golevelup/ts-jest'; import { Test, TestingModule } from '@nestjs/testing'; -import { Awareness } from 'y-protocols/awareness.js'; +import { Awareness } from 'y-protocols/awareness'; import { Doc } from 'yjs'; import { Logger } from '../../infra/logger/logger.js'; -import { RedisAdapter } from '../../infra/redis/interfaces/index.js'; -import { RedisService } from '../../infra/redis/redis.service.js'; -import { - streamMessageReplyFactory, - xAutoClaimResponseFactory, -} from '../../infra/redis/testing/x-auto-claim-response.factory.js'; +import { RedisAdapter, StreamMessageReply } from '../../infra/redis/interfaces/index.js'; +import { streamMessageReplyFactory } from '../../infra/redis/testing/stream-message-reply.factory.js'; +import { xAutoClaimResponseFactory } from '../../infra/redis/testing/x-auto-claim-response.factory.js'; import { StorageService } from '../../infra/storage/storage.service.js'; -import * as apiClass from '../../infra/y-redis/api.service.js'; -import { Api } from '../../infra/y-redis/api.service.js'; +import { yRedisDocFactory } from '../../infra/y-redis/testing/y-redis-doc.factory.js'; +import { YRedisClient } from '../../infra/y-redis/y-redis.client.js'; import { WorkerConfig } from './worker.config.js'; +import { REDIS_FOR_WORKER } from './worker.const.js'; import { WorkerService } from './worker.service.js'; +const mapStreamMessageRepliesToTask = (streamMessageReplies: StreamMessageReply[]) => { + const tasks = streamMessageReplies.map((message) => ({ + stream: message.message.compact?.toString(), + id: message.id.toString(), + })); + + return tasks; +}; + describe(WorkerService.name, () => { + let module: TestingModule; let service: WorkerService; - let redisService: DeepMocked; + let redisAdapter: DeepMocked; + let yRedisClient: DeepMocked; + let storageService: DeepMocked; beforeAll(async () => { - const module: TestingModule = await Test.createTestingModule({ + // TODO: should we start the app as api-test for this job? + module = await Test.createTestingModule({ providers: [ WorkerService, + { + provide: YRedisClient, + useValue: createMock(), + }, { provide: StorageService, useValue: createMock(), }, { - provide: RedisService, - useValue: createMock(), + provide: REDIS_FOR_WORKER, + useValue: createMock({ redisPrefix: 'prefix' }), }, { provide: Logger, @@ -41,124 +56,124 @@ describe(WorkerService.name, () => { WORKER_TRY_CLAIM_COUNT: 1, WORKER_TASK_DEBOUNCE: 1, WORKER_MIN_MESSAGE_LIFETIME: 1, + WORKER_IDLE_BREAK_MS: 1, }, }, ], }).compile(); service = await module.resolve(WorkerService); - redisService = module.get(RedisService); + redisAdapter = module.get(REDIS_FOR_WORKER); + yRedisClient = module.get(YRedisClient); + storageService = module.get(StorageService); }); afterEach(() => { jest.restoreAllMocks(); + service.stop(); + }); + + afterAll(async () => { + await module.close(); }); it('should be defined', () => { expect(service).toBeDefined(); }); - describe('onModuleInit', () => { - describe('when _destroyed is false', () => { + describe('job', () => { + describe('when new service instance is running', () => { const setup = () => { - const client = createMock({ _destroyed: false }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - - const error = new Error('Error to break while loop!'); + service.start(); - const consumeWorkerQueueSpy = jest.spyOn(service, 'consumeWorkerQueue').mockRejectedValueOnce(error); + const spy = jest.spyOn(service, 'consumeWorkerQueue'); - return { error, consumeWorkerQueueSpy }; + return { spy }; }; + it('and stop is called, it should be stopped', () => { + setup(); + + service.stop(); + + expect(service.status()).toBe(false); + }); + + it('and start is called, it should run', () => { + setup(); + + service.start(); + + expect(service.status()).toBe(true); + }); + it('should call consumeWorkerQueue', async () => { - const { error, consumeWorkerQueueSpy } = setup(); + const { spy } = setup(); - await expect(service.onModuleInit()).rejects.toThrow(error); + await new Promise((resolve) => setTimeout(resolve, 10)); - expect(consumeWorkerQueueSpy).toHaveBeenCalled(); + expect(spy).toHaveBeenCalled(); }); }); - describe('when _destroyed is true', () => { + describe('when new service instance is not running', () => { const setup = () => { - const client = createMock({ _destroyed: true }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); + service.stop(); + const spy = jest.spyOn(service, 'consumeWorkerQueue'); - const consumeWorkerQueueSpy = jest.spyOn(service, 'consumeWorkerQueue').mockResolvedValueOnce([]); - - return { consumeWorkerQueueSpy }; + return { spy }; }; - it('should call not consumeWorkerQueue', async () => { - const { consumeWorkerQueueSpy } = setup(); + it('and start is called, it should be started', () => { + setup(); - await service.onModuleInit(); + service.start(); - expect(consumeWorkerQueueSpy).not.toHaveBeenCalled(); + expect(service.status()).toBe(true); }); - }); - }); - describe('consumeWorkerQueue', () => { - describe('when there are no tasks', () => { - const setup = async () => { - const client = createMock({ _destroyed: true }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); + it('and stop is called, it should stopped', () => { + setup(); - await service.onModuleInit(); - }; + service.stop(); + + expect(service.status()).toBe(false); + }); - it('should return an empty array', async () => { - await setup(); + it('should not call consumeWorkerQueue', async () => { + const { spy } = setup(); - const result = await service.consumeWorkerQueue(); + await new Promise((resolve) => setTimeout(resolve, 10)); - expect(result).toEqual([]); + expect(spy).not.toHaveBeenCalled(); }); }); + }); + describe('consumeWorkerQueue', () => { describe('when there are tasks', () => { describe('when stream length is 0', () => { describe('when deletedDocEntries is empty', () => { - const setup = async () => { - const awareness = createMock(); - const client = createMock({ _destroyed: true }); - client.getDoc.mockResolvedValue({ - ydoc: createMock(), - awareness, - redisLastId: '0', - storeReferences: null, - docChanged: false, - }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - - const streamMessageReply1 = streamMessageReplyFactory.build(); - const streamMessageReply2 = streamMessageReplyFactory.build(); - const streamMessageReply3 = streamMessageReplyFactory.build(); + const setup = () => { + const yRedisDocMock = yRedisDocFactory.build(); + yRedisClient.getDoc.mockResolvedValue(yRedisDocMock); + const streamMessageReplys = streamMessageReplyFactory.buildList(3); const reclaimedTasks = xAutoClaimResponseFactory.build(); - reclaimedTasks.messages = [streamMessageReply1, streamMessageReply2, streamMessageReply3]; - - const redisAdapterMock = createMock({ redisPrefix: 'prefix' }); - redisAdapterMock.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); - redisAdapterMock.getDeletedDocEntries.mockResolvedValueOnce([]); - redisAdapterMock.tryClearTask.mockResolvedValueOnce(0).mockResolvedValueOnce(0).mockResolvedValueOnce(0); + reclaimedTasks.messages = streamMessageReplys; + const deletedDocEntries: StreamMessageReply[] = []; - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapterMock); + redisAdapter.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); + redisAdapter.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); + redisAdapter.tryClearTask.mockResolvedValueOnce(0).mockResolvedValueOnce(0).mockResolvedValueOnce(0); - const expectedTasks = reclaimedTasks.messages.map((m) => ({ - stream: m.message.compact.toString(), - id: m?.id.toString(), - })); - - await service.onModuleInit(); + const expectedTasks = mapStreamMessageRepliesToTask(reclaimedTasks.messages); return { expectedTasks }; }; it('should return an array of tasks', async () => { - const { expectedTasks } = await setup(); + const { expectedTasks } = setup(); const result = await service.consumeWorkerQueue(); @@ -167,46 +182,26 @@ describe(WorkerService.name, () => { }); describe('when deletedDocEntries contains element', () => { - const setup = async () => { - const awareness = createMock(); - const client = createMock({ _destroyed: true }); - client.getDoc.mockResolvedValue({ - ydoc: createMock(), - awareness, - redisLastId: '0', - storeReferences: null, - docChanged: false, - }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - - const streamMessageReply1 = streamMessageReplyFactory.build(); - const streamMessageReply2 = streamMessageReplyFactory.build(); - const streamMessageReply3 = streamMessageReplyFactory.build(); + const setup = () => { + const yRedisDocMock = yRedisDocFactory.build(); + yRedisClient.getDoc.mockResolvedValue(yRedisDocMock); + const streamMessageReplys = streamMessageReplyFactory.buildList(3); const reclaimedTasks = xAutoClaimResponseFactory.build(); - reclaimedTasks.messages = [streamMessageReply1, streamMessageReply2, streamMessageReply3]; - - const deletedDocEntries = [streamMessageReply2]; - - const redisAdapterMock = createMock({ redisPrefix: 'prefix' }); - redisAdapterMock.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); - redisAdapterMock.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); - redisAdapterMock.tryClearTask.mockResolvedValueOnce(0).mockResolvedValueOnce(0).mockResolvedValueOnce(0); - - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapterMock); + reclaimedTasks.messages = streamMessageReplys; + const deletedDocEntries = [streamMessageReplys[2]]; - const expectedTasks = reclaimedTasks.messages.map((m) => ({ - stream: m.message.compact.toString(), - id: m?.id.toString(), - })); + redisAdapter.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); + redisAdapter.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); + redisAdapter.tryClearTask.mockResolvedValueOnce(0).mockResolvedValueOnce(0).mockResolvedValueOnce(0); - await service.onModuleInit(); + const expectedTasks = mapStreamMessageRepliesToTask(reclaimedTasks.messages); return { expectedTasks }; }; it('should return an array of tasks', async () => { - const { expectedTasks } = await setup(); + const { expectedTasks } = setup(); const result = await service.consumeWorkerQueue(); @@ -217,31 +212,18 @@ describe(WorkerService.name, () => { describe('when stream length is not 0', () => { describe('when docChanged is false', () => { - const setup = async () => { - const awareness = createMock(); - const client = createMock({ _destroyed: true }); - client.getDoc.mockResolvedValue({ - ydoc: createMock(), - awareness, - redisLastId: '0', - storeReferences: null, - docChanged: false, - }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - - const streamMessageReply1 = streamMessageReplyFactory.build(); - const streamMessageReply2 = streamMessageReplyFactory.build(); - const streamMessageReply3 = streamMessageReplyFactory.build(); + const setup = () => { + const yRedisDocMock = yRedisDocFactory.build(); + yRedisClient.getDoc.mockResolvedValue(yRedisDocMock); + const streamMessageReplys = streamMessageReplyFactory.buildList(3); const reclaimedTasks = xAutoClaimResponseFactory.build(); - reclaimedTasks.messages = [streamMessageReply1, streamMessageReply2, streamMessageReply3]; - - const deletedDocEntries = [streamMessageReply2]; + reclaimedTasks.messages = streamMessageReplys; + const deletedDocEntries = [streamMessageReplys[2]]; - const redisAdapterMock = createMock({ redisPrefix: 'prefix' }); - redisAdapterMock.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); - redisAdapterMock.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); - redisAdapterMock.tryClearTask + redisAdapter.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); + redisAdapter.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); + redisAdapter.tryClearTask .mockImplementationOnce(async (task) => { return await Promise.resolve(task.stream.length); }) @@ -252,20 +234,13 @@ describe(WorkerService.name, () => { return await Promise.resolve(task.stream.length); }); - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapterMock); - - const expectedTasks = reclaimedTasks.messages.map((m) => ({ - stream: m.message.compact.toString(), - id: m?.id.toString(), - })); - - await service.onModuleInit(); + const expectedTasks = mapStreamMessageRepliesToTask(reclaimedTasks.messages); return { expectedTasks }; }; it('should return an array of tasks', async () => { - const { expectedTasks } = await setup(); + const { expectedTasks } = setup(); const result = await service.consumeWorkerQueue(); @@ -274,61 +249,105 @@ describe(WorkerService.name, () => { }); describe('when docChanged is true', () => { - const setup = async () => { - const awareness = createMock(); - const client = createMock({ _destroyed: true }); - client.getDoc.mockResolvedValue({ - ydoc: createMock(), - awareness, - redisLastId: '0', - storeReferences: null, - docChanged: true, + describe('when storeReferences is null', () => { + const setup = () => { + const yRedisDocMock = yRedisDocFactory.build(); + yRedisClient.getDoc.mockResolvedValue(yRedisDocMock); + + const streamMessageReplys = streamMessageReplyFactory.buildList(3); + const reclaimedTasks = xAutoClaimResponseFactory.build(); + reclaimedTasks.messages = streamMessageReplys; + const deletedDocEntries = [streamMessageReplys[2]]; + + redisAdapter.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); + redisAdapter.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); + redisAdapter.tryClearTask + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }) + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }) + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }); + + const expectedTasks = mapStreamMessageRepliesToTask(reclaimedTasks.messages); + + return { expectedTasks }; + }; + + it('should return an array of tasks', async () => { + const { expectedTasks } = setup(); + + const result = await service.consumeWorkerQueue(); + + expect(result).toEqual(expectedTasks); }); - jest.spyOn(apiClass, 'createApiClient').mockResolvedValueOnce(client); - - const streamMessageReply1 = streamMessageReplyFactory.build(); - const streamMessageReply2 = streamMessageReplyFactory.build(); - const streamMessageReply3 = streamMessageReplyFactory.build(); - - const reclaimedTasks = xAutoClaimResponseFactory.build(); - reclaimedTasks.messages = [streamMessageReply1, streamMessageReply2, streamMessageReply3]; - - const deletedDocEntries = [streamMessageReply2]; - - const redisAdapterMock = createMock({ redisPrefix: 'prefix' }); - redisAdapterMock.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); - redisAdapterMock.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); - redisAdapterMock.tryClearTask - .mockImplementationOnce(async (task) => { - return await Promise.resolve(task.stream.length); - }) - .mockImplementationOnce(async (task) => { - return await Promise.resolve(task.stream.length); - }) - .mockImplementationOnce(async (task) => { - return await Promise.resolve(task.stream.length); - }); - - redisService.createRedisInstance.mockResolvedValueOnce(redisAdapterMock); - - const expectedTasks = reclaimedTasks.messages.map((m) => ({ - stream: m.message.compact.toString(), - id: m?.id.toString(), - })); - - await service.onModuleInit(); - - return { expectedTasks }; - }; - - it('should return an array of tasks', async () => { - // docChanged = true; - - const { expectedTasks } = await setup(); - - const result = await service.consumeWorkerQueue(); + }); - expect(result).toEqual(expectedTasks); + describe('when storeReferences is defined', () => { + const setup = () => { + const storeReferences = ['storeReference1', 'storeReference2']; + const yRedisDocMock = { + ydoc: createMock(), + awareness: createMock(), + redisLastId: '0', + storeReferences, + docChanged: true, + streamName: '', + getAwarenessStateSize: () => 1, + }; + yRedisClient.getDoc.mockResolvedValue(yRedisDocMock); + + const streamMessageReplys = streamMessageReplyFactory.buildList(3); + const reclaimedTasks = xAutoClaimResponseFactory.build(); + reclaimedTasks.messages = streamMessageReplys; + const deletedDocEntries = [streamMessageReplys[2]]; + + redisAdapter.reclaimTasks.mockResolvedValueOnce(reclaimedTasks); + redisAdapter.getDeletedDocEntries.mockResolvedValueOnce(deletedDocEntries); + redisAdapter.tryClearTask + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }) + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }) + .mockImplementationOnce(async (task) => { + return await Promise.resolve(task.stream.length); + }); + + const expectedTasks = mapStreamMessageRepliesToTask(reclaimedTasks.messages); + + return { expectedTasks, storeReferences }; + }; + + it('should return an array of tasks and delete references', async () => { + const { expectedTasks, storeReferences } = setup(); + + const result = await service.consumeWorkerQueue(); + + expect(result).toEqual(expectedTasks); + expect(storageService.deleteReferences).toHaveBeenNthCalledWith( + 1, + 'room', + expect.stringContaining('docid-'), + storeReferences, + ); + expect(storageService.deleteReferences).toHaveBeenNthCalledWith( + 2, + 'room', + expect.stringContaining('docid-'), + storeReferences, + ); + expect(storageService.deleteReferences).toHaveBeenNthCalledWith( + 3, + 'room', + expect.stringContaining('docid-'), + storeReferences, + ); + }); }); }); }); diff --git a/src/modules/worker/worker.service.ts b/src/modules/worker/worker.service.ts index abc1792a..3fc7d431 100644 --- a/src/modules/worker/worker.service.ts +++ b/src/modules/worker/worker.service.ts @@ -1,119 +1,228 @@ -import { Injectable, OnModuleInit } from '@nestjs/common'; +import { Inject, Injectable, OnModuleDestroy } from '@nestjs/common'; import { randomUUID } from 'crypto'; +import { RedisKey } from 'ioredis'; import { Logger } from '../../infra/logger/index.js'; -import { RedisAdapter, Task } from '../../infra/redis/interfaces/index.js'; -import { RedisService } from '../../infra/redis/redis.service.js'; +import { RedisAdapter, StreamMessageReply, Task, XAutoClaimResponse } from '../../infra/redis/interfaces/index.js'; import { StorageService } from '../../infra/storage/storage.service.js'; -import { Api, createApiClient } from '../../infra/y-redis/api.service.js'; -import { decodeRedisRoomStreamName } from '../../infra/y-redis/helper.js'; +import { decodeRedisRoomStreamName, RoomStreamInfos } from '../../infra/y-redis/helper.js'; +import { YRedisDoc } from '../../infra/y-redis/y-redis-doc.js'; +import { YRedisClient } from '../../infra/y-redis/y-redis.client.js'; import { WorkerConfig } from './worker.config.js'; +import { REDIS_FOR_WORKER } from './worker.const.js'; + +interface Job { + status(): boolean; + start(): void; + stop(): void; +} @Injectable() -export class WorkerService implements OnModuleInit { - private client!: Api; +export class WorkerService implements Job, OnModuleDestroy { private readonly consumerId = randomUUID(); - private redis!: RedisAdapter; + private running = true; public constructor( private readonly storageService: StorageService, - private readonly redisService: RedisService, + @Inject(REDIS_FOR_WORKER) private readonly redis: RedisAdapter, private readonly logger: Logger, private readonly config: WorkerConfig, + private readonly yRedisClient: YRedisClient, ) { this.logger.setContext(WorkerService.name); } - public async onModuleInit(): Promise { - this.client = await createApiClient(this.storageService, this.redisService); - this.redis = await this.redisService.createRedisInstance(); + public async onModuleDestroy(): Promise { + this.stop(); + await this.yRedisClient.destroy(); + } - this.logger.log(`Created worker process ${this.consumerId}`); - while (!this.client._destroyed) { - await this.consumeWorkerQueue(); + public async start(): Promise { + this.running = true; + + while (this.running) { + const tasks = await this.consumeWorkerQueue(); + await this.waitIfNoOpenTask(tasks, this.config.WORKER_IDLE_BREAK_MS); } - this.logger.log(`Ended worker process ${this.consumerId}`); + + this.logger.info(`Start worker process ${this.consumerId}`); + } + + public stop(): void { + this.running = false; + this.logger.info(`Ended worker process ${this.consumerId}`); + } + + public status(): boolean { + return this.running; } public async consumeWorkerQueue(): Promise { - const tryClaimCount = this.config.WORKER_TRY_CLAIM_COUNT; - const taskDebounce = this.config.WORKER_TASK_DEBOUNCE; - const minMessageLifetime = this.config.WORKER_MIN_MESSAGE_LIFETIME; - const tasks: Task[] = []; + const reclaimedTasks = await this.reclaimTasksInRedis(); + const tasks = this.mapReclaimTaskToTask(reclaimedTasks); - const reclaimedTasks = await this.redis.reclaimTasks(this.consumerId, taskDebounce, tryClaimCount); + const promises = tasks.map((task: Task) => this.processTask(task)); + await Promise.all(promises); - const deletedDocEntries = await this.redis.getDeletedDocEntries(); - const deletedDocNames = deletedDocEntries?.map((entry) => { - return entry.message.docName; - }); + return tasks; + } + + private async processTask(task: Task): Promise { + const [deletedDocEntries, streamLength] = await Promise.all([ + this.redis.getDeletedDocEntries(), + this.redis.tryClearTask(task), + ]); + + try { + if (this.streamIsEmpty(streamLength)) { + this.removingRecurringTaskFromQueue(task, deletedDocEntries); + } else { + await this.processUpdateChanges(deletedDocEntries, task); + } + } catch (error: unknown) { + this.logger.warning({ error, deletedDocEntries, task, message: 'processTask' }); + } + } + + private async processUpdateChanges(deletedDocEntries: StreamMessageReply[], task: Task): Promise { + this.logger.info('requesting doc from store'); + const roomStreamInfos = decodeRedisRoomStreamName(task.stream.toString(), this.redis.redisPrefix); + const yRedisDoc = await this.yRedisClient.getDoc(roomStreamInfos.room, roomStreamInfos.docid); + + this.destroyAwarenessToAvoidMemoryLeaks(yRedisDoc); + this.logDoc(yRedisDoc); + const lastId = this.determineLastId(yRedisDoc, task); + + const deletedDocNames = this.extractDocNamesFromStreamMessageReply(deletedDocEntries); + if (this.docChangedButNotDeleted(yRedisDoc, deletedDocNames, task)) { + await this.storageService.persistDoc(roomStreamInfos.room, roomStreamInfos.docid, yRedisDoc.ydoc); + } + + await Promise.all([ + this.redis.tryDeduplicateTask(task, lastId, this.config.WORKER_MIN_MESSAGE_LIFETIME), + this.deleteStorageReferencesIfExist(yRedisDoc, roomStreamInfos), + ]); + + this.logStream(task, lastId - this.config.WORKER_MIN_MESSAGE_LIFETIME); + } + + private async waitIfNoOpenTask(tasks: Task[], waitInMs: number): Promise { + if (tasks.length === 0) { + this.logger.info(`No tasks available, pausing... ${JSON.stringify({ tasks })}`); + await new Promise((resolve) => setTimeout(resolve, waitInMs)); + } + } + + private async reclaimTasksInRedis(): Promise { + const reclaimedTasks = await this.redis.reclaimTasks( + this.consumerId, + this.config.WORKER_TASK_DEBOUNCE, + this.config.WORKER_TRY_CLAIM_COUNT, + ); + + return reclaimedTasks; + } + + private destroyAwarenessToAvoidMemoryLeaks(yRedisDoc: YRedisDoc): void { + // @see: https://github.com/yjs/y-redis/issues/24 + yRedisDoc.awareness.destroy(); + } + + private async removingRecurringTaskFromQueue(task: Task, deletedDocEntries: StreamMessageReply[]): Promise { + this.logger.info( + `Stream still empty, removing recurring task from queue ${JSON.stringify({ stream: task.stream })}`, + ); + + const deleteEntryId = deletedDocEntries.find((entry) => entry.message.docName === task.stream)?.id.toString(); + + if (deleteEntryId) { + const roomStreamInfos = decodeRedisRoomStreamName(task.stream.toString(), this.redis.redisPrefix); + await Promise.all([ + this.redis.deleteDeletedDocEntry(deleteEntryId), + this.storageService.deleteDocument(roomStreamInfos.room, roomStreamInfos.docid), + ]); + } + } + private deleteStorageReferencesIfExist(yRedisDoc: YRedisDoc, roomStreamInfos: RoomStreamInfos): Promise { + let promise = Promise.resolve(); + + if (this.isDocumentChangedAndReferencesAvaible(yRedisDoc)) { + const storeReferences = this.castToStringArray(yRedisDoc.storeReferences); + promise = this.storageService.deleteReferences(roomStreamInfos.room, roomStreamInfos.docid, storeReferences); + } + + return promise; + } + + // helper + private mapReclaimTaskToTask(reclaimedTasks: XAutoClaimResponse): Task[] { + const tasks: Task[] = []; reclaimedTasks.messages?.forEach((m) => { const stream = m?.message.compact; stream && tasks.push({ stream: stream.toString(), id: m?.id.toString() }); }); - if (tasks.length === 0) { - this.logger.log(`No tasks available, pausing... ${JSON.stringify({ tasks })}`); - await new Promise((resolve) => setTimeout(resolve, 1000)); - return []; + if (tasks.length > 0) { + this.logger.info(`Accepted tasks ${JSON.stringify({ tasks })}`); + } + + return tasks; + } + + private determineLastId(yRedisDoc: YRedisDoc, task: Task): number { + const lastId = Math.max(parseInt(yRedisDoc.redisLastId.split('-')[0]), parseInt(task.id.split('-')[0])); + + return lastId; + } + + private extractDocNamesFromStreamMessageReply(docEntries: StreamMessageReply[]): string[] { + const docNames = docEntries + .map((entry) => { + return entry.message.docName; + }) + .filter((docName) => docName !== undefined); + + return docNames; + } + + private castToStringArray(input: string[] | null): string[] { + if (input) { + return input; } - this.logger.log(`Accepted tasks ${JSON.stringify({ tasks })}`); - - await Promise.all( - tasks.map(async (task) => { - const streamlen = await this.redis.tryClearTask(task); - const { room, docid } = decodeRedisRoomStreamName(task.stream.toString(), this.redis.redisPrefix); - - if (streamlen === 0) { - this.logger.log( - `Stream still empty, removing recurring task from queue ${JSON.stringify({ stream: task.stream })}`, - ); - - const deleteEntryId = deletedDocEntries.find((entry) => entry.message.docName === task.stream)?.id.toString(); - - if (deleteEntryId) { - this.redis.deleteDeleteDocEntry(deleteEntryId); - this.storageService.deleteDocument(room, docid); - } - } else { - // @todo, make sure that awareness by this.getDoc is eventually destroyed, or doesn't - // register a timeout anymore - this.logger.log('requesting doc from store'); - const { ydoc, storeReferences, redisLastId, docChanged, awareness } = await this.client.getDoc(room, docid); - - // awareness is destroyed here to avoid memory leaks, see: https://github.com/yjs/y-redis/issues/24 - awareness.destroy(); - this.logger.log( - 'retrieved doc from store. redisLastId=' + redisLastId + ' storeRefs=' + JSON.stringify(storeReferences), - ); - const lastId = Math.max(parseInt(redisLastId.split('-')[0]), parseInt(task.id.split('-')[0])); - if (docChanged) { - this.logger.log('persisting doc'); - if (!deletedDocNames.includes(task.stream)) { - await this.storageService.persistDoc(room, docid, ydoc); - } - } - - await Promise.all([ - storeReferences && docChanged - ? this.storageService.deleteReferences(room, docid, storeReferences) - : Promise.resolve(), - this.redis.tryDeduplicateTask(task, lastId, minMessageLifetime), - ]); - - this.logger.log( - `Compacted stream - ${JSON.stringify({ - stream: task.stream, - taskId: task.id, - newLastId: lastId - minMessageLifetime, - })}`, - ); - } - }), + throw new Error(`Input ${input} can not be castet to string[].`); + } + + private docChangedButNotDeleted(yRedisDoc: YRedisDoc, deletedDocNames: RedisKey[], task: Task): boolean { + return yRedisDoc.docChanged && !deletedDocNames.includes(task.stream); + } + + private isDocumentChangedAndReferencesAvaible(yRedisDoc: YRedisDoc): boolean { + return yRedisDoc.storeReferences !== null && yRedisDoc.docChanged === true; + } + + private streamIsEmpty(streamLength: number): boolean { + return streamLength === 0; + } + + // logs + private logDoc(yRedisDoc: YRedisDoc): void { + this.logger.info( + 'retrieved doc from store. redisLastId=' + + yRedisDoc.redisLastId + + ' storeRefs=' + + JSON.stringify(yRedisDoc.storeReferences), ); + } - return tasks; + private logStream(task: Task, newLastId: number): void { + this.logger.info( + `Compacted stream + ${JSON.stringify({ + stream: task.stream, + taskId: task.id, + newLastId, + })}`, + ); } }