diff --git a/bun.lockb b/bun.lockb index 72ad010..f7cdf3c 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/src/actions.ts b/src/actions.ts index 02d13e8..bdd2e10 100644 --- a/src/actions.ts +++ b/src/actions.ts @@ -10,6 +10,8 @@ import { customOctokit } from "./octokit"; import { sanitizeMetadata } from "./util"; import { verifySignature } from "./signature"; import { KERNEL_PUBLIC_KEY } from "./constants"; +import { jsonType } from "./types/util"; +import { commandCallSchema } from "./types/command"; config(); @@ -18,21 +20,24 @@ interface Options { postCommentOnError?: boolean; settingsSchema?: TAnySchema; envSchema?: TAnySchema; + commandSchema?: TAnySchema; kernelPublicKey?: string; + disableSignatureVerification?: boolean; // only use for local development } const inputSchema = T.Object({ stateId: T.String(), eventName: T.String(), - eventPayload: T.String(), + eventPayload: jsonType(T.Record(T.String(), T.Any())), + command: jsonType(commandCallSchema), authToken: T.String(), - settings: T.String(), + settings: jsonType(T.Record(T.String(), T.Any())), ref: T.String(), signature: T.String(), }); -export async function createActionsPlugin( - handler: (context: Context) => Promise | undefined>, +export async function createActionsPlugin( + handler: (context: Context) => Promise | undefined>, options?: Options ) { const pluginOptions = { @@ -40,7 +45,9 @@ export async function createActionsPlugin = { + let command: TCommand | null = null; + if (inputs.command && pluginOptions.commandSchema) { + try { + command = Value.Decode(pluginOptions.commandSchema, Value.Default(pluginOptions.commandSchema, inputs.command)); + } catch (e) { + console.dir(...Value.Errors(pluginOptions.commandSchema, inputs.command), { depth: null }); + throw e; + } + } else if (inputs.command) { + command = inputs.command as TCommand; + } + + const context: Context = { eventName: inputs.eventName as TSupportedEvents, - payload: JSON.parse(inputs.eventPayload), + payload: inputs.eventPayload, + command: command, octokit: new customOctokit({ auth: inputs.authToken }), config: config, env: env, diff --git a/src/context.ts b/src/context.ts index 8502b0d..47b6933 100644 --- a/src/context.ts +++ b/src/context.ts @@ -2,11 +2,12 @@ import { EmitterWebhookEvent as WebhookEvent, EmitterWebhookEventName as Webhook import { Logs } from "@ubiquity-os/ubiquity-os-logger"; import { customOctokit } from "./octokit"; -export interface Context { +export interface Context { eventName: TSupportedEvents; payload: { [K in TSupportedEvents]: K extends WebhookEventName ? WebhookEvent : never; }[TSupportedEvents]["payload"]; + command: TCommand | null; octokit: InstanceType; config: TConfig; env: TEnv; diff --git a/src/server.ts b/src/server.ts index 8e1a050..de5521c 100644 --- a/src/server.ts +++ b/src/server.ts @@ -18,6 +18,7 @@ interface Options { postCommentOnError?: boolean; settingsSchema?: TAnySchema; envSchema?: TAnySchema; + commandSchema?: TAnySchema; bypassSignatureVerification?: boolean; } @@ -25,20 +26,15 @@ const inputSchema = T.Object({ stateId: T.String(), eventName: T.String(), eventPayload: T.Record(T.String(), T.Any()), + command: T.Union([T.Null(), T.Object({ name: T.String(), parameters: T.Unknown() })]), authToken: T.String(), settings: T.Record(T.String(), T.Any()), ref: T.String(), signature: T.String(), - bypassSignatureVerification: T.Optional( - T.Boolean({ - default: false, - description: "Bypass signature verification (caution: only use this if you know what you're doing)", - }) - ), }); -export function createPlugin( - handler: (context: Context) => Promise | undefined>, +export function createPlugin( + handler: (context: Context) => Promise | undefined>, manifest: Manifest, options?: Options ) { @@ -48,6 +44,8 @@ export function createPlugin = { + let command: TCommand | null = null; + if (inputs.command && pluginOptions.commandSchema) { + try { + command = Value.Decode(pluginOptions.commandSchema, Value.Default(pluginOptions.commandSchema, inputs.command)); + } catch (e) { + console.dir(...Value.Errors(pluginOptions.commandSchema, inputs.command), { depth: null }); + throw e; + } + } else if (inputs.command) { + command = inputs.command as TCommand; + } + + const context: Context = { eventName: inputs.eventName as TSupportedEvents, payload: inputs.eventPayload, + command: command, octokit: new customOctokit({ auth: inputs.authToken }), config: config, env: env, diff --git a/src/signature.ts b/src/signature.ts index 52dfb10..ce6d85b 100644 --- a/src/signature.ts +++ b/src/signature.ts @@ -5,6 +5,7 @@ interface Inputs { authToken: unknown; settings: unknown; ref: unknown; + command: unknown; } export async function verifySignature(publicKeyPem: string, inputs: Inputs, signature: string) { @@ -16,6 +17,7 @@ export async function verifySignature(publicKeyPem: string, inputs: Inputs, sign settings: inputs.settings, authToken: inputs.authToken, ref: inputs.ref, + command: inputs.command, }; const pemContents = publicKeyPem.replace("-----BEGIN PUBLIC KEY-----", "").replace("-----END PUBLIC KEY-----", "").trim(); const binaryDer = Uint8Array.from(atob(pemContents), (c) => c.charCodeAt(0)); diff --git a/src/types/command.ts b/src/types/command.ts new file mode 100644 index 0000000..92ad081 --- /dev/null +++ b/src/types/command.ts @@ -0,0 +1,5 @@ +import { StaticDecode, Type as T } from "@sinclair/typebox"; + +export const commandCallSchema = T.Union([T.Null(), T.Object({ name: T.String(), parameters: T.Unknown() })]); + +export type CommandCall = StaticDecode; diff --git a/src/types/manifest.ts b/src/types/manifest.ts index 56330b5..7cb8ad5 100644 --- a/src/types/manifest.ts +++ b/src/types/manifest.ts @@ -4,16 +4,19 @@ import { emitterEventNames } from "@octokit/webhooks"; export const runEvent = T.Union(emitterEventNames.map((o) => T.Literal(o))); export const commandSchema = T.Object({ + name: T.String({ minLength: 1 }), description: T.String({ minLength: 1 }), "ubiquity:example": T.String({ minLength: 1 }), + parameters: T.Optional(T.Record(T.String(), T.Any())), }); export const manifestSchema = T.Object({ name: T.String({ minLength: 1 }), description: T.Optional(T.String({ default: "" })), - commands: T.Optional(T.Record(T.String(), commandSchema, { default: {} })), + commands: T.Optional(T.Array(commandSchema, { default: [] })), "ubiquity:listeners": T.Optional(T.Array(runEvent, { default: [] })), configuration: T.Optional(T.Record(T.String(), T.Any(), { default: {} })), + skipBotEvents: T.Optional(T.Boolean({ default: true })), }); export type Manifest = Static; diff --git a/src/types/util.ts b/src/types/util.ts new file mode 100644 index 0000000..8aa6275 --- /dev/null +++ b/src/types/util.ts @@ -0,0 +1,11 @@ +import { Type, TAnySchema } from "@sinclair/typebox"; +import { Value } from "@sinclair/typebox/value"; + +export function jsonType(type: TSchema) { + return Type.Transform(Type.String()) + .Decode((value) => { + const parsed = JSON.parse(value); + return Value.Decode(type, Value.Default(type, parsed)); + }) + .Encode((value) => JSON.stringify(value)); +} diff --git a/tests/sdk.test.ts b/tests/sdk.test.ts index 0f4e6df..8b80197 100644 --- a/tests/sdk.test.ts +++ b/tests/sdk.test.ts @@ -7,6 +7,7 @@ import { createPlugin } from "../src/server"; import { signPayload } from "../src/signature"; import { server } from "./__mocks__/node"; import issueCommented from "./__mocks__/requests/issue-comment-post.json"; +import { CommandCall } from "../src/types/command"; const { publicKey, privateKey } = crypto.generateKeyPairSync("rsa", { modulusLength: 2048, @@ -29,7 +30,15 @@ const sdkOctokitImportPath = "../src/octokit"; const githubActionImportPath = "@actions/github"; const githubCoreImportPath = "@actions/core"; -async function getWorkerInputs(stateId: string, eventName: string, eventPayload: object, settings: object, authToken: string, ref: string) { +async function getWorkerInputs( + stateId: string, + eventName: string, + eventPayload: object, + settings: object, + authToken: string, + ref: string, + command: CommandCall | null +) { const inputs = { stateId, eventName, @@ -37,6 +46,7 @@ async function getWorkerInputs(stateId: string, eventName: string, eventPayload: settings, authToken, ref, + command, }; const signature = await signPayload(JSON.stringify(inputs), privateKey); @@ -46,7 +56,15 @@ async function getWorkerInputs(stateId: string, eventName: string, eventPayload: }; } -async function getWorkflowInputs(stateId: string, eventName: string, eventPayload: object, settings: object, authToken: string, ref: string) { +async function getWorkflowInputs( + stateId: string, + eventName: string, + eventPayload: object, + settings: object, + authToken: string, + ref: string, + command: CommandCall | null +) { const inputs = { stateId, eventName, @@ -54,6 +72,7 @@ async function getWorkflowInputs(stateId: string, eventName: string, eventPayloa settings: JSON.stringify(settings), authToken, ref, + command: JSON.stringify(command), }; const signature = await signPayload(JSON.stringify(inputs), privateKey); @@ -71,6 +90,7 @@ const app = createPlugin( return { success: true, event: context.eventName, + command: context.command, }; }, { name: "test" }, @@ -111,13 +131,13 @@ describe("SDK worker tests", () => { expect(res.status).toEqual(400); }); it("Should deny POST request with invalid signature", async () => { - const inputs = getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: false }, "test", ""); + const inputs = await getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: false }, "test", "", null); const res = await app.request("/", { headers: { "content-type": "application/json", }, - body: JSON.stringify({ ...(await inputs), signature: "invalid_signature" }), + body: JSON.stringify({ ...inputs, signature: "invalid_signature" }), method: "POST", }); expect(res.status).toEqual(400); @@ -151,13 +171,13 @@ describe("SDK worker tests", () => { { kernelPublicKey: publicKey } ); - const inputs = getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: true }, "test", ""); + const inputs = await getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: true }, "test", "", null); const res = await app.request("/", { headers: { "content-type": "application/json", }, - body: JSON.stringify(await inputs), + body: JSON.stringify(inputs), method: "POST", }); expect(res.status).toEqual(500); @@ -178,18 +198,24 @@ describe("SDK worker tests", () => { }); }); it("Should accept correct request", async () => { - const inputs = getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: false }, "test", ""); + const inputs = await getWorkerInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, { shouldFail: false }, "test", "", { + name: "test", + parameters: { param1: "test" }, + }); const res = await app.request("/", { headers: { "content-type": "application/json", }, - body: JSON.stringify(await inputs), + body: JSON.stringify(inputs), method: "POST", }); expect(res.status).toEqual(200); const result = await res.json(); - expect(result).toEqual({ stateId: "stateId", output: { success: true, event: issueCommented.eventName } }); + expect(result).toEqual({ + stateId: "stateId", + output: { success: true, event: issueCommented.eventName, command: { name: "test", parameters: { param1: "test" } } }, + }); }); }); @@ -201,8 +227,10 @@ describe("SDK actions tests", () => { }; it("Should accept correct request", async () => { - const inputs = getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", ""); - const githubInputs = await inputs; + const githubInputs = await getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", "", { + name: "test", + parameters: { param1: "test" }, + }); jest.unstable_mockModule(githubActionImportPath, () => ({ context: { runId: "1", @@ -238,27 +266,28 @@ describe("SDK actions tests", () => { async (context: Context) => { return { event: context.eventName, + command: context.command, }; }, { kernelPublicKey: publicKey, } ); + const expectedResult = { event: issueCommented.eventName, command: { name: "test", parameters: { param1: "test" } } }; expect(setFailed).not.toHaveBeenCalled(); - expect(setOutput).toHaveBeenCalledWith("result", { event: issueCommented.eventName }); + expect(setOutput).toHaveBeenCalledWith("result", expectedResult); expect(createDispatchEvent).toHaveBeenCalledWith({ event_type: "return-data-to-ubiquity-os-kernel", owner: repo.owner, repo: repo.repo, client_payload: { state_id: "stateId", - output: JSON.stringify({ event: issueCommented.eventName }), + output: JSON.stringify(expectedResult), }, }); }); it("Should deny invalid signature", async () => { - const inputs = getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", ""); - const githubInputs = await inputs; + const githubInputs = await getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", "", null); jest.unstable_mockModule("@actions/github", () => ({ context: { @@ -294,8 +323,7 @@ describe("SDK actions tests", () => { expect(setOutput).not.toHaveBeenCalled(); }); it("Should accept inputs in different order", async () => { - const inputs = getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", ""); - const githubInputs = await inputs; + const githubInputs = await getWorkflowInputs("stateId", issueCommentedEvent.eventName, issueCommentedEvent.eventPayload, {}, "test_token", "", null); jest.unstable_mockModule(githubActionImportPath, () => ({ context: { @@ -309,6 +337,7 @@ describe("SDK actions tests", () => { ref: githubInputs.ref, authToken: githubInputs.authToken, stateId: githubInputs.stateId, + command: githubInputs.command, eventPayload: githubInputs.eventPayload, }, },