Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support throwing responses in upgrade hook #91

Merged
merged 19 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/adapters/bun.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,18 @@ export default defineWebSocketAdapter<BunAdapter, BunOptions>(
return {
...adapterUtils(peers),
async handleUpgrade(request, server) {
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}
const upgradeOK = server.upgrade(request, {
data: {
server,
request,
} satisfies ContextData,
headers: res?.headers,
headers: upgradeHeaders,
});

if (!upgradeOK) {
return new Response("Upgrade failed", { status: 500 });
}
Expand Down
12 changes: 8 additions & 4 deletions src/adapters/cloudflare-durable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ export default defineWebSocketAdapter<
// placeholder
},
handleDurableUpgrade: async (obj, request) => {
const res = await hooks.callHook("upgrade", request as Request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(
request as Request,
);
if (endResponse) {
return endResponse;
}

const pair = new WebSocketPair();
const client = pair[0];
const server = pair[1];
Expand All @@ -46,11 +49,12 @@ export default defineWebSocketAdapter<
peers.add(peer);
(obj as DurableObjectPub).ctx.acceptWebSocket(server);
await hooks.callHook("open", peer);

// eslint-disable-next-line unicorn/no-null
return new Response(null, {
status: 101,
webSocket: client,
headers: res?.headers,
headers: upgradeHeaders,
});
},
handleDurableMessage: async (obj, ws, message) => {
Expand Down
10 changes: 5 additions & 5 deletions src/adapters/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return {
...adapterUtils(peers),
handleUpgrade: async (request, env, context) => {
const res = await hooks.callHook(
"upgrade",
const { upgradeHeaders, endResponse } = await hooks.upgrade(
request as unknown as Request,
);
if (res instanceof Response) {
return res;
if (endResponse) {
return endResponse as unknown as _cf.Response;
}

const pair = new WebSocketPair();
const client = pair[0];
const server = pair[1];
Expand Down Expand Up @@ -73,7 +73,7 @@ export default defineWebSocketAdapter<CloudflareAdapter, CloudflareOptions>(
return new Response(null, {
status: 101,
webSocket: client,
headers: res?.headers,
headers: upgradeHeaders,
});
},
};
Expand Down
9 changes: 5 additions & 4 deletions src/adapters/deno.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ export default defineWebSocketAdapter<DenoAdapter, DenoOptions>(
return {
...adapterUtils(peers),
handleUpgrade: async (request, info) => {
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}

const upgrade = Deno.upgradeWebSocket(request, {
// @ts-expect-error https://github.com/denoland/deno/pull/22242
headers: res?.headers,
headers: upgradeHeaders,
});
const peer = new DenoPeer({
ws: upgrade.socket,
Expand Down
10 changes: 6 additions & 4 deletions src/adapters/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,14 @@ export default defineWebSocketAdapter<NodeAdapter, NodeOptions>(
...adapterUtils(peers),
handleUpgrade: async (nodeReq, socket, head) => {
const request = new NodeReqProxy(nodeReq);
const res = await hooks.callHook("upgrade", request);
if (res instanceof Response) {
return sendResponse(socket, res);

const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return sendResponse(socket, endResponse);
}

(nodeReq as AugmentedReq)._request = request;
(nodeReq as AugmentedReq)._upgradeHeaders = res?.headers;
(nodeReq as AugmentedReq)._upgradeHeaders = upgradeHeaders;
wss.handleUpgrade(nodeReq, socket, head, (ws) => {
wss.emit("connection", ws, nodeReq);
});
Expand Down
14 changes: 8 additions & 6 deletions src/adapters/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
return {
...adapterUtils(peers),
fetch: async (request: Request) => {
const _res = await hooks.callHook("upgrade", request);
if (_res instanceof Response) {
return _res;
const { upgradeHeaders, endResponse } = await hooks.upgrade(request);
if (endResponse) {
return endResponse;
}

let peer: SSEPeer;
Expand Down Expand Up @@ -73,17 +73,19 @@ export default defineWebSocketAdapter<SSEAdapter, SSEOptions>((opts = {}) => {
"Cache-Control": "no-cache",
Connection: "keep-alive",
};

if (opts.bidir) {
headers["x-crossws-id"] = peer.id;
}
if (_res?.headers) {

if (upgradeHeaders) {
headers = new Headers(headers);
for (const [key, value] of new Headers(_res.headers)) {
for (const [key, value] of new Headers(upgradeHeaders)) {
headers.set(key, value);
}
}

return new Response(peer._sseStream, { ..._res, headers });
return new Response(peer._sseStream, { headers });
},
};
});
Expand Down
33 changes: 19 additions & 14 deletions src/adapters/uws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,18 @@
res.onAborted(() => {
aborted = true;
});
const _res = await hooks.callHook("upgrade", new UWSReqProxy(req));
if (aborted) {
return;
}
if (_res instanceof Response) {
res.writeStatus(`${_res.status} ${_res.statusText}`);
for (const [key, value] of _res.headers) {

const { upgradeHeaders, endResponse } = await hooks.upgrade(
new UWSReqProxy(req),
);
if (endResponse) {
res.writeStatus(`${endResponse.status} ${endResponse.statusText}`);
for (const [key, value] of endResponse.headers) {
res.writeHeader(key, value);
}
if (_res.body) {
for await (const chunk of _res.body) {
if (aborted) {
break;
}
if (endResponse.body) {
for await (const chunk of endResponse.body) {
if (aborted) break;
res.write(chunk);
}
}
Expand All @@ -97,9 +95,16 @@
}
return;
}

if (aborted) {
return;

Check warning on line 100 in src/adapters/uws.ts

View check run for this annotation

Codecov / codecov/patch

src/adapters/uws.ts#L100

Added line #L100 was not covered by tests
}

res.writeStatus("101 Switching Protocols");
if (_res?.headers) {
for (const [key, value] of new Headers(_res.headers)) {
if (upgradeHeaders) {
// prettier-ignore
const headers = upgradeHeaders instanceof Headers ? upgradeHeaders : new Headers(upgradeHeaders);
for (const [key, value] of headers) {
res.writeHeader(key, value);
}
}
Expand Down
47 changes: 40 additions & 7 deletions src/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,32 @@
},
) as Promise<any>;
}

async upgrade(request: UpgradeRequest): Promise<{
upgradeHeaders?: HeadersInit;
endResponse?: Response;
}> {
try {
const res = await this.callHook("upgrade", request);
if (!res) {
return {};
}

Check warning on line 51 in src/hooks.ts

View check run for this annotation

Codecov / codecov/patch

src/hooks.ts#L50-L51

Added lines #L50 - L51 were not covered by tests
if ((res as Response).ok === false) {
return { endResponse: res as Response };
}
if (res.headers) {
return {
upgradeHeaders: res.headers,
};
}
} catch (error) {
if (error instanceof Response) {
return { endResponse: error };
}
throw error;
}
return {};

Check warning on line 66 in src/hooks.ts

View check run for this annotation

Codecov / codecov/patch

src/hooks.ts#L61-L66

Added lines #L61 - L66 were not covered by tests
}
}

// --- types ---
Expand All @@ -60,16 +86,23 @@
...args: ArgsT
) => MaybePromise<RT>;

export type UpgradeRequest =
| Request
| {
url: string;
headers: Headers;
};

export interface Hooks {
/** Upgrading */
/**
*
* @param request
* @throws {Response}
*/
upgrade: (
request:
| Request
| {
url: string;
headers: Headers;
},
) => MaybePromise<Response | ResponseInit | void>;
request: UpgradeRequest,
) => MaybePromise<Response | ResponseInit | undefined>;

/** A message is received */
message: (peer: Peer, message: Message) => MaybePromise<void>;
Expand Down
Loading