diff --git a/index.js b/index.js index b29ab88..d9ad233 100644 --- a/index.js +++ b/index.js @@ -5,6 +5,7 @@ const hyperid = require('hyperid') const { getGlobalDispatcher, setGlobalDispatcher } = require('undici') const { threadId, MessageChannel, parentPort } = require('worker_threads') const inject = require('light-my-request') +const Hooks = require('./lib/hooks') const kAddress = Symbol('undici-thread-interceptor.address') @@ -14,6 +15,7 @@ function createThreadInterceptor (opts) { const forwarded = new Map() const nextId = hyperid() const domain = opts?.domain + const hooks = new Hooks(opts) let timeout = opts?.timeout if (timeout === true) { @@ -59,12 +61,14 @@ function createThreadInterceptor (opts) { delete newOpts.dispatcher + hooks.fireOnRequest(newOpts) + if (newOpts.body?.[Symbol.asyncIterator]) { collectBodyAndDispatch(newOpts, handler).then(() => { port.postMessage({ type: 'request', id, opts: newOpts, threadId }) }, (err) => { clearTimeout(handle) - + hooks.fireOnError(err) handler.onError(err) }) } else { @@ -85,9 +89,11 @@ function createThreadInterceptor (opts) { clearTimeout(handle) if (err) { + hooks.fireOnError(err) handler.onError(err) return } + hooks.fireOnResponse(res) const headers = [] for (const [key, value] of Object.entries(res.headers)) { diff --git a/lib/hooks.js b/lib/hooks.js new file mode 100644 index 0000000..b6b1fcd --- /dev/null +++ b/lib/hooks.js @@ -0,0 +1,46 @@ +'use strict' + +const supportedHooks = ['onRequest', 'onResponse', 'onError'] + +class Hooks { + onRequest = [] + onResponse = [] + onError = [] + + constructor (opts) { + for (const hook of supportedHooks) { + const value = opts?.[hook] + if (value) { + const hooks = Array.isArray(value) ? value : [value] + this.#validate(hooks) + this[`${hook}`].push(...hooks) + } + } + } + + #validate (hooks) { + for (const hook of hooks) { + if (typeof hook !== 'function') throw new Error(`Expected a function, got ${typeof hook}`) + } + } + + async run (hooks, ...args) { + for (const fn of hooks) { + await fn(...args) + } + } + + fireOnRequest (...args) { + return this.run(this.onRequest, ...args) + } + + fireOnResponse (...args) { + return this.run(this.onResponse, ...args) + } + + fireOnError (...args) { + return this.run(this.onError, ...args) + } +} + +module.exports = Hooks diff --git a/test/hooks.test.js b/test/hooks.test.js new file mode 100644 index 0000000..70ec0c8 --- /dev/null +++ b/test/hooks.test.js @@ -0,0 +1,208 @@ +'use strict' + +const { test } = require('node:test') +const { deepStrictEqual, strictEqual } = require('node:assert') +const { join } = require('node:path') +const { Worker } = require('node:worker_threads') +const { createThreadInterceptor } = require('../') +const { Agent, request } = require('undici') + +test('hooks - onRequest', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'worker1.js')) + t.after(() => worker.terminate()) + let hookCalled = null + + const interceptor = createThreadInterceptor({ + domain: '.local', + onRequest: (opts) => { + hookCalled = opts + } + }) + interceptor.route('myserver', worker) + + const agent = new Agent().compose(interceptor) + + const { statusCode } = await request('http://myserver.local', { + dispatcher: agent, + }) + + strictEqual(statusCode, 200) + deepStrictEqual(hookCalled, { + headers: { + host: 'myserver.local', + }, + method: 'GET', + origin: 'http://myserver.local', + path: '/' + }) +}) + +test('hooks - multiple onRequests', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'worker1.js')) + t.after(() => worker.terminate()) + const hookCalled = [] + + const firstHook = (opts) => { + hookCalled.push({ first: opts }) + } + + const secondHook = (opts) => { + hookCalled.push({ second: opts }) + } + + const interceptor = createThreadInterceptor({ + domain: '.local', + onRequest: [firstHook, secondHook] + }) + interceptor.route('myserver', worker) + + const agent = new Agent().compose(interceptor) + + const { statusCode } = await request('http://myserver.local', { + dispatcher: agent, + }) + + strictEqual(statusCode, 200) + deepStrictEqual(hookCalled, [ + { + first: { + headers: { + host: 'myserver.local', + }, + method: 'GET', + origin: 'http://myserver.local', + path: '/' + } + }, { + second: { + headers: { + host: 'myserver.local', + }, + method: 'GET', + origin: 'http://myserver.local', + path: '/' + } + } + ]) +}) + +test('hooks - onResponse', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'worker1.js')) + t.after(() => worker.terminate()) + let hookCalled = null + + const interceptor = createThreadInterceptor({ + domain: '.local', + onResponse: (opts) => { + hookCalled = Buffer.from(opts.rawPayload).toString() + } + }) + interceptor.route('myserver', worker) + + const agent = new Agent().compose(interceptor) + const { statusCode } = await request('http://myserver.local', { + dispatcher: agent, + }) + + strictEqual(statusCode, 200) + deepStrictEqual(hookCalled, '{"hello":"world"}') +}) + +test('hooks - multiple onResponses', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'worker1.js')) + t.after(() => worker.terminate()) + const hookCalled = [] + + const onResponse1 = (opts) => { + hookCalled.push({ res1: Buffer.from(opts.rawPayload).toString() }) + } + + const onResponse2 = (opts) => { + hookCalled.push({ res2: Buffer.from(opts.rawPayload).toString() }) + } + + const interceptor = createThreadInterceptor({ + domain: '.local', + onResponse: [onResponse1, onResponse2] + }) + interceptor.route('myserver', worker) + + const agent = new Agent().compose(interceptor) + const { statusCode } = await request('http://myserver.local', { + dispatcher: agent, + }) + + strictEqual(statusCode, 200) + deepStrictEqual(hookCalled, [{ res1: '{"hello":"world"}' }, { res2: '{"hello":"world"}' }]) +}) + +test('hooks - onError', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'error.js')) + t.after(() => worker.terminate()) + let hookCalled = null + + const interceptor = createThreadInterceptor({ + domain: '.local', + onError: (error) => { + hookCalled = error + } + }) + interceptor.route('myserver', worker) + + try { + const agent = new Agent().compose(interceptor) + await request('http://myserver.local', { + dispatcher: agent, + }) + throw new Error('should not be here') + } catch (err) { + strictEqual(err.message, 'kaboom') + deepStrictEqual(hookCalled.message, 'kaboom') + } +}) + +test('hooks - multiple onErrors', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'error.js')) + t.after(() => worker.terminate()) + const hookCalled = [] + + const onError1 = (error) => { + hookCalled.push({ error1: error.message }) + } + + const onError2 = (error) => { + hookCalled.push({ error2: error.message }) + } + + const interceptor = createThreadInterceptor({ + domain: '.local', + onError: [onError1, onError2] + }) + interceptor.route('myserver', worker) + + try { + const agent = new Agent().compose(interceptor) + await request('http://myserver.local', { + dispatcher: agent, + }) + throw new Error('should not be here') + } catch (err) { + strictEqual(err.message, 'kaboom') + deepStrictEqual(hookCalled, [{ error1: 'kaboom' }, { error2: 'kaboom' }]) + } +}) + +test('hooks - should throw if handler not a function', async (t) => { + const worker = new Worker(join(__dirname, 'fixtures', 'worker1.js')) + t.after(() => worker.terminate()) + + try { + createThreadInterceptor({ + domain: '.local', + onResponse: 'nor a function', + }) + throw new Error('should not be here') + } catch (err) { + strictEqual(err.message, 'Expected a function, got string') + } +})