diff --git a/src/adapters/bun.ts b/src/adapters/bun.ts index 8a76ce8..a52b26f 100644 --- a/src/adapters/bun.ts +++ b/src/adapters/bun.ts @@ -31,17 +31,18 @@ export default defineWebSocketAdapter( return { ...adapterUtils(peers), async handleUpgrade(request, server) { - const res = await hooks.callHook("upgrade", request); - if (res instanceof Response) { - return res; + const { upgradeHeaders, endResponse } = await hooks.upgrade(request); + if (endResponse) { + return endResponse; } const upgradeOK = server.upgrade(request, { data: { server, request, } satisfies ContextData, - headers: res?.headers, + headers: upgradeHeaders, }); + if (!upgradeOK) { return new Response("Upgrade failed", { status: 500 }); } diff --git a/src/adapters/cloudflare-durable.ts b/src/adapters/cloudflare-durable.ts index 526fa78..b285d08 100644 --- a/src/adapters/cloudflare-durable.ts +++ b/src/adapters/cloudflare-durable.ts @@ -31,10 +31,13 @@ export default defineWebSocketAdapter< // placeholder }, handleDurableUpgrade: async (obj, request) => { - const res = await hooks.callHook("upgrade", request as Request); - if (res instanceof Response) { - return res; + const { upgradeHeaders, endResponse } = await hooks.upgrade( + request as Request, + ); + if (endResponse) { + return endResponse; } + const pair = new WebSocketPair(); const client = pair[0]; const server = pair[1]; @@ -46,11 +49,12 @@ export default defineWebSocketAdapter< peers.add(peer); (obj as DurableObjectPub).ctx.acceptWebSocket(server); await hooks.callHook("open", peer); + // eslint-disable-next-line unicorn/no-null return new Response(null, { status: 101, webSocket: client, - headers: res?.headers, + headers: upgradeHeaders, }); }, handleDurableMessage: async (obj, ws, message) => { diff --git a/src/adapters/cloudflare.ts b/src/adapters/cloudflare.ts index fe4a23e..f00e9b5 100644 --- a/src/adapters/cloudflare.ts +++ b/src/adapters/cloudflare.ts @@ -33,13 +33,13 @@ export default defineWebSocketAdapter( return { ...adapterUtils(peers), handleUpgrade: async (request, env, context) => { - const res = await hooks.callHook( - "upgrade", + const { upgradeHeaders, endResponse } = await hooks.upgrade( request as unknown as Request, ); - if (res instanceof Response) { - return res; + if (endResponse) { + return endResponse as unknown as _cf.Response; } + const pair = new WebSocketPair(); const client = pair[0]; const server = pair[1]; @@ -73,7 +73,7 @@ export default defineWebSocketAdapter( return new Response(null, { status: 101, webSocket: client, - headers: res?.headers, + headers: upgradeHeaders, }); }, }; diff --git a/src/adapters/deno.ts b/src/adapters/deno.ts index 0fd1160..886061e 100644 --- a/src/adapters/deno.ts +++ b/src/adapters/deno.ts @@ -31,13 +31,14 @@ export default defineWebSocketAdapter( return { ...adapterUtils(peers), handleUpgrade: async (request, info) => { - const res = await hooks.callHook("upgrade", request); - if (res instanceof Response) { - return res; + const { upgradeHeaders, endResponse } = await hooks.upgrade(request); + if (endResponse) { + return endResponse; } + const upgrade = Deno.upgradeWebSocket(request, { // @ts-expect-error https://github.com/denoland/deno/pull/22242 - headers: res?.headers, + headers: upgradeHeaders, }); const peer = new DenoPeer({ ws: upgrade.socket, diff --git a/src/adapters/node.ts b/src/adapters/node.ts index 10449d3..fafce73 100644 --- a/src/adapters/node.ts +++ b/src/adapters/node.ts @@ -86,12 +86,14 @@ export default defineWebSocketAdapter( ...adapterUtils(peers), handleUpgrade: async (nodeReq, socket, head) => { const request = new NodeReqProxy(nodeReq); - const res = await hooks.callHook("upgrade", request); - if (res instanceof Response) { - return sendResponse(socket, res); + + const { upgradeHeaders, endResponse } = await hooks.upgrade(request); + if (endResponse) { + return sendResponse(socket, endResponse); } + (nodeReq as AugmentedReq)._request = request; - (nodeReq as AugmentedReq)._upgradeHeaders = res?.headers; + (nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders; wss.handleUpgrade(nodeReq, socket, head, (ws) => { wss.emit("connection", ws, nodeReq); }); diff --git a/src/adapters/sse.ts b/src/adapters/sse.ts index c443f3e..4da1fc5 100644 --- a/src/adapters/sse.ts +++ b/src/adapters/sse.ts @@ -27,9 +27,9 @@ export default defineWebSocketAdapter((opts = {}) => { return { ...adapterUtils(peers), fetch: async (request: Request) => { - const _res = await hooks.callHook("upgrade", request); - if (_res instanceof Response) { - return _res; + const { upgradeHeaders, endResponse } = await hooks.upgrade(request); + if (endResponse) { + return endResponse; } let peer: SSEPeer; @@ -73,17 +73,19 @@ export default defineWebSocketAdapter((opts = {}) => { "Cache-Control": "no-cache", Connection: "keep-alive", }; + if (opts.bidir) { headers["x-crossws-id"] = peer.id; } - if (_res?.headers) { + + if (upgradeHeaders) { headers = new Headers(headers); - for (const [key, value] of new Headers(_res.headers)) { + for (const [key, value] of new Headers(upgradeHeaders)) { headers.set(key, value); } } - return new Response(peer._sseStream, { ..._res, headers }); + return new Response(peer._sseStream, { headers }); }, }; }); diff --git a/src/adapters/uws.ts b/src/adapters/uws.ts index 1b81ea1..8379cef 100644 --- a/src/adapters/uws.ts +++ b/src/adapters/uws.ts @@ -75,20 +75,18 @@ export default defineWebSocketAdapter( res.onAborted(() => { aborted = true; }); - const _res = await hooks.callHook("upgrade", new UWSReqProxy(req)); - if (aborted) { - return; - } - if (_res instanceof Response) { - res.writeStatus(`${_res.status} ${_res.statusText}`); - for (const [key, value] of _res.headers) { + + const { upgradeHeaders, endResponse } = await hooks.upgrade( + new UWSReqProxy(req), + ); + if (endResponse) { + res.writeStatus(`${endResponse.status} ${endResponse.statusText}`); + for (const [key, value] of endResponse.headers) { res.writeHeader(key, value); } - if (_res.body) { - for await (const chunk of _res.body) { - if (aborted) { - break; - } + if (endResponse.body) { + for await (const chunk of endResponse.body) { + if (aborted) break; res.write(chunk); } } @@ -97,9 +95,16 @@ export default defineWebSocketAdapter( } return; } + + if (aborted) { + return; + } + res.writeStatus("101 Switching Protocols"); - if (_res?.headers) { - for (const [key, value] of new Headers(_res.headers)) { + if (upgradeHeaders) { + // prettier-ignore + const headers = upgradeHeaders instanceof Headers ? upgradeHeaders : new Headers(upgradeHeaders); + for (const [key, value] of headers) { res.writeHeader(key, value); } } diff --git a/src/hooks.ts b/src/hooks.ts index dc73f72..dd1c42d 100644 --- a/src/hooks.ts +++ b/src/hooks.ts @@ -39,6 +39,32 @@ export class AdapterHookable { }, ) as Promise; } + + async upgrade(request: UpgradeRequest): Promise<{ + upgradeHeaders?: HeadersInit; + endResponse?: Response; + }> { + try { + const res = await this.callHook("upgrade", request); + if (!res) { + return {}; + } + if ((res as Response).ok === false) { + return { endResponse: res as Response }; + } + if (res.headers) { + return { + upgradeHeaders: res.headers, + }; + } + } catch (error) { + if (error instanceof Response) { + return { endResponse: error }; + } + throw error; + } + return {}; + } } // --- types --- @@ -60,16 +86,23 @@ type HookFn = ( ...args: ArgsT ) => MaybePromise; +export type UpgradeRequest = + | Request + | { + url: string; + headers: Headers; + }; + export interface Hooks { /** Upgrading */ + /** + * + * @param request + * @throws {Response} + */ upgrade: ( - request: - | Request - | { - url: string; - headers: Headers; - }, - ) => MaybePromise; + request: UpgradeRequest, + ) => MaybePromise; /** A message is received */ message: (peer: Peer, message: Message) => MaybePromise;