diff --git a/.changeset/violet-laws-compare.md b/.changeset/violet-laws-compare.md new file mode 100644 index 000000000..cbf4399e0 --- /dev/null +++ b/.changeset/violet-laws-compare.md @@ -0,0 +1,5 @@ +--- +"synckit": patch +--- + +fix: handle outdated message in channel queue diff --git a/src/index.ts b/src/index.ts index df193d1a5..9fc11e8aa 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ import path from 'node:path' import { fileURLToPath, pathToFileURL } from 'node:url' import { MessageChannel, + MessagePort, type TransferListItem, Worker, parentPort, @@ -19,6 +20,7 @@ import type { AnyAsyncFn, AnyFn, GlobalShim, + MainToWorkerCommandMessage, MainToWorkerMessage, Syncify, ValueOf, @@ -522,36 +524,59 @@ function startWorkerThread>( let nextID = 0 - const syncFn = (...args: Parameters): R => { - const id = nextID++ - - const msg: MainToWorkerMessage> = { id, args } - - worker.postMessage(msg) - - const status = Atomics.wait(sharedBufferView!, 0, 0, timeout) - - // Reset SharedArrayBuffer for next call + const receiveMessageWithId = ( + port: MessagePort, + expectedId: number, + waitingTimeout?: number, + ): WorkerToMainMessage => { + const start = Date.now() + const status = Atomics.wait(sharedBufferView!, 0, 0, waitingTimeout) Atomics.store(sharedBufferView!, 0, 0) - /* istanbul ignore if */ if (!['ok', 'not-equal'].includes(status)) { + const abortMsg: MainToWorkerCommandMessage = { + id: expectedId, + cmd: 'abort', + } + port.postMessage(abortMsg) throw new Error('Internal error: Atomics.wait() failed: ' + status) } - const { - id: id2, - result, - error, - properties, - } = (receiveMessageOnPort(mainPort) as { message: WorkerToMainMessage }) - .message + const { id, ...message } = ( + receiveMessageOnPort(mainPort) as { message: WorkerToMainMessage } + ).message - /* istanbul ignore if */ - if (id !== id2) { - throw new Error(`Internal error: Expected id ${id} but got id ${id2}`) + if (id < expectedId) { + const waitingTime = Date.now() - start + return receiveMessageWithId( + port, + expectedId, + waitingTimeout ? waitingTimeout - waitingTime : undefined, + ) } + if (expectedId !== id) { + throw new Error( + `Internal error: Expected id ${expectedId} but got id ${id}`, + ) + } + + return { id, ...message } + } + + const syncFn = (...args: Parameters): R => { + const id = nextID++ + + const msg: MainToWorkerMessage> = { id, args } + + worker.postMessage(msg) + + const { result, error, properties } = receiveMessageWithId( + mainPort, + id, + timeout, + ) + if (error) { throw Object.assign(error as object, properties) } @@ -587,12 +612,24 @@ export function runAsWorker< ({ id, args }: MainToWorkerMessage>) => { // eslint-disable-next-line @typescript-eslint/no-floating-promises ;(async () => { + let isAborted = false + const handleAbortMessage = (msg: MainToWorkerCommandMessage) => { + if (msg.id === id && msg.cmd === 'abort') { + isAborted = true + } + } + workerPort.on('message', handleAbortMessage) let msg: WorkerToMainMessage try { msg = { id, result: await fn(...args) } } catch (error: unknown) { msg = { id, error, properties: extractProperties(error) } } + workerPort.off('message', handleAbortMessage) + // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition + if (isAborted) { + return + } workerPort.postMessage(msg) Atomics.add(sharedBufferView, 0, 1) Atomics.notify(sharedBufferView, 0) diff --git a/src/types.ts b/src/types.ts index b3fd7516a..5f7556093 100644 --- a/src/types.ts +++ b/src/types.ts @@ -25,6 +25,11 @@ export interface MainToWorkerMessage { args: T } +export interface MainToWorkerCommandMessage { + id: number + cmd: string +} + export interface WorkerData { sharedBuffer: SharedArrayBuffer workerPort: MessagePort diff --git a/test/fn.spec.ts b/test/fn.spec.ts index 3e68e5316..e1c6bd948 100644 --- a/test/fn.spec.ts +++ b/test/fn.spec.ts @@ -3,7 +3,12 @@ import path from 'node:path' import { jest } from '@jest/globals' -import { _dirname, testIf, tsUseEsmSupported } from './helpers.js' +import { + _dirname, + setupReceiveMessageOnPortMock, + testIf, + tsUseEsmSupported, +} from './helpers.js' import type { AsyncWorkerFn } from './types.js' import { createSyncFn } from 'synckit' @@ -12,6 +17,7 @@ const { SYNCKIT_TIMEOUT } = process.env beforeEach(() => { jest.resetModules() + jest.restoreAllMocks() delete process.env.SYNCKIT_GLOBAL_SHIMS @@ -104,6 +110,114 @@ test('timeout', async () => { ) }) +test('subsequent executions after timeout', async () => { + const executionTimeout = 30 + const longRunningTaskDuration = executionTimeout * 10 + process.env.SYNCKIT_TIMEOUT = executionTimeout.toString() + + const { createSyncFn } = await import('synckit') + const syncFn = createSyncFn(workerCjsPath) + + // start an execution in worker that will definitely time out + expect(() => syncFn(1, longRunningTaskDuration)).toThrow() + + // wait for timed out execution to finish inside worker + await new Promise(resolve => setTimeout(resolve, longRunningTaskDuration)) + + // subsequent executions should work correctly + expect(syncFn(2, 1)).toBe(2) + expect(syncFn(3, 1)).toBe(3) +}) + +test('handling of outdated message from worker', async () => { + const executionTimeout = 60 + process.env.SYNCKIT_TIMEOUT = executionTimeout.toString() + const receiveMessageOnPortMock = await setupReceiveMessageOnPortMock() + + jest.spyOn(Atomics, 'wait').mockReturnValue('ok') + + receiveMessageOnPortMock + .mockReturnValueOnce({ message: { id: -1 } }) + .mockReturnValueOnce({ message: { id: 0, result: 1 } }) + + const { createSyncFn } = await import('synckit') + const syncFn = createSyncFn(workerCjsPath) + expect(syncFn(1)).toBe(1) + expect(receiveMessageOnPortMock).toHaveBeenCalledTimes(2) +}) + +test('propagation of undefined timeout', async () => { + delete process.env.SYNCKIT_TIMEOUT + const receiveMessageOnPortMock = await setupReceiveMessageOnPortMock() + + const atomicsWaitSpy = jest.spyOn(Atomics, 'wait').mockReturnValue('ok') + + receiveMessageOnPortMock + .mockReturnValueOnce({ message: { id: -1 } }) + .mockReturnValueOnce({ message: { id: 0, result: 1 } }) + + const { createSyncFn } = await import('synckit') + const syncFn = createSyncFn(workerCjsPath) + expect(syncFn(1)).toBe(1) + expect(receiveMessageOnPortMock).toHaveBeenCalledTimes(2) + + const [firstAtomicsWaitArgs, secondAtomicsWaitArgs] = + atomicsWaitSpy.mock.calls + const [, , , firstAtomicsWaitCallTimeout] = firstAtomicsWaitArgs + const [, , , secondAtomicsWaitCallTimeout] = secondAtomicsWaitArgs + + expect(typeof firstAtomicsWaitCallTimeout).toBe('undefined') + expect(typeof secondAtomicsWaitCallTimeout).toBe('undefined') +}) + +test('reduction of waiting time', async () => { + const synckitTimeout = 60 + process.env.SYNCKIT_TIMEOUT = synckitTimeout.toString() + const receiveMessageOnPortMock = await setupReceiveMessageOnPortMock() + + const atomicsWaitSpy = jest.spyOn(Atomics, 'wait').mockImplementation(() => { + const start = Date.now() + // simulate waiting 10ms for worker to respond + while (Date.now() - start < 10) { + continue + } + + return 'ok' + }) + + receiveMessageOnPortMock + .mockReturnValueOnce({ message: { id: -1 } }) + .mockReturnValueOnce({ message: { id: 0, result: 1 } }) + + const { createSyncFn } = await import('synckit') + const syncFn = createSyncFn(workerCjsPath) + expect(syncFn(1)).toBe(1) + expect(receiveMessageOnPortMock).toHaveBeenCalledTimes(2) + + const [firstAtomicsWaitArgs, secondAtomicsWaitArgs] = + atomicsWaitSpy.mock.calls + const [, , , firstAtomicsWaitCallTimeout] = firstAtomicsWaitArgs + const [, , , secondAtomicsWaitCallTimeout] = secondAtomicsWaitArgs + + expect(typeof firstAtomicsWaitCallTimeout).toBe('number') + expect(firstAtomicsWaitCallTimeout).toBe(synckitTimeout) + expect(typeof secondAtomicsWaitCallTimeout).toBe('number') + expect(secondAtomicsWaitCallTimeout).toBeLessThan(synckitTimeout) +}) + +test('unexpected message from worker', async () => { + jest.spyOn(Atomics, 'wait').mockReturnValue('ok') + + const receiveMessageOnPortMock = await setupReceiveMessageOnPortMock() + receiveMessageOnPortMock.mockReturnValueOnce({ message: { id: 100 } }) + + const { createSyncFn } = await import('synckit') + const syncFn = createSyncFn(workerCjsPath) + expect(() => syncFn(1)).toThrow( + 'Internal error: Expected id 0 but got id 100', + ) +}) + test('globalShims env', async () => { process.env.SYNCKIT_GLOBAL_SHIMS = '1' diff --git a/test/helpers.ts b/test/helpers.ts index 086922403..b1a224bf5 100644 --- a/test/helpers.ts +++ b/test/helpers.ts @@ -1,5 +1,8 @@ import path from 'node:path' import { fileURLToPath } from 'node:url' +import WorkerThreads from 'node:worker_threads' + +import { jest } from '@jest/globals' import { MTS_SUPPORTED_NODE_VERSION } from 'synckit' @@ -13,3 +16,24 @@ export const tsUseEsmSupported = nodeVersion >= MTS_SUPPORTED_NODE_VERSION && nodeVersion <= 18.18 export const testIf = (condition: boolean) => (condition ? it : it.skip) + +type ReceiveMessageOnPortMock = jest.Mock< + typeof WorkerThreads.receiveMessageOnPort +> +export const setupReceiveMessageOnPortMock = + async (): Promise => { + jest.unstable_mockModule('node:worker_threads', () => { + return { + ...WorkerThreads, + receiveMessageOnPort: jest.fn(WorkerThreads.receiveMessageOnPort), + } + }) + + const { receiveMessageOnPort: receiveMessageOnPortMock } = (await import( + 'node:worker_threads' + )) as unknown as { + receiveMessageOnPort: ReceiveMessageOnPortMock + } + + return receiveMessageOnPortMock + }