From 1ab24fec4df98d7d97e1aca48a74252f48ad0920 Mon Sep 17 00:00:00 2001 From: Maddison Hellstrom Date: Thu, 28 Sep 2023 04:08:10 -0700 Subject: [PATCH] feat: support anyOf in function parameters (#13) --- src/functions.ts | 15 ++++ tests/token-counts.test.ts | 154 +++++++++++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) diff --git a/src/functions.ts b/src/functions.ts index a5a5eb5..cd917f3 100644 --- a/src/functions.ts +++ b/src/functions.ts @@ -19,9 +19,14 @@ interface ObjectProp { required?: string[]; } +interface AnyOfProp { + anyOf: Prop[]; +} + type Prop = { description?: string; } & ( + | AnyOfProp | ObjectProp | { type: "string"; @@ -41,6 +46,13 @@ type Prop = { } ); +function isAnyOfProp(prop: Prop): prop is AnyOfProp { + return ( + (prop as AnyOfProp).anyOf !== undefined && + Array.isArray((prop as AnyOfProp).anyOf) + ); +} + // When OpenAI use functions in the prompt, they format them as TypeScript definitions rather than OpenAPI JSON schemas. // This function converts the JSON schemas into TypeScript definitions. export function formatFunctionDefinitions(functions: FunctionDef[]) { @@ -80,6 +92,9 @@ function formatObjectProperties(obj: ObjectProp, indent: number): string { // Format a single property type function formatType(param: Prop, indent: number): string { + if (isAnyOfProp(param)) { + return param.anyOf.map((v) => formatType(v, indent)).join(" | "); + } switch (param.type) { case "string": if (param.enum) { diff --git a/tests/token-counts.test.ts b/tests/token-counts.test.ts index 229cf91..df93e4f 100644 --- a/tests/token-counts.test.ts +++ b/tests/token-counts.test.ts @@ -264,6 +264,160 @@ const TEST_CASES: Example[] = [ ], tokens: 25, }, + { + messages: [ + { + role: "system", + content: "You are an AI assistant trained to foo or bar", + }, + { role: "user", content: "My name is suzie" }, + { + role: "function", + name: "foo", + content: '{"res":{"kind":"user","name":"suzie"}}', + }, + { + role: "user", + content: 'I want to post a message with the text "hello world"', + }, + { + role: "function", + name: "foo", + content: '{"res":{"kind":"post","text":"hello world"}}', + }, + ], + functions: [ + { + name: "foo", + description: "Return the foo or the bar", + parameters: { + type: "object", + properties: { + res: { + anyOf: [ + { + type: "object", + properties: { + kind: { type: "string", const: "post" }, + text: { type: "string" }, + }, + required: ["kind", "text"], + additionalProperties: false, + }, + { + type: "object", + properties: { + kind: { type: "string", const: "user" }, + name: { + type: "string", + enum: ["jane", "suzie", "adam"], + }, + }, + required: ["kind", "name"], + additionalProperties: false, + }, + ], + description: "The foo or the bar", + }, + }, + required: ["res"], + additionalProperties: false, + }, + }, + ], + function_call: { name: "foo" }, + tokens: 158, + }, + { + messages: [ + { role: "system", content: "Hello" }, + { role: "user", content: "Hi there" }, + ], + functions: [ + { + name: "do_stuff", + parameters: { + type: "object", + properties: { + foo: { + anyOf: [ + { + type: "object", + properties: { + kind: { type: "string", const: "a" }, + baz: { type: "boolean" }, + }, + }, + ], + }, + }, + }, + }, + ], + tokens: 52, + }, + { + messages: [ + { role: "system", content: "Hello" }, + { role: "user", content: "Hi there" }, + ], + functions: [ + { + name: "do_stuff", + parameters: { + type: "object", + properties: { + foo: { + anyOf: [ + { + type: "object", + properties: { + kind: { type: "string", const: "a" }, + baz: { type: "boolean" }, + }, + }, + { + type: "object", + properties: { + kind: { type: "string", const: "b" }, + qux: { type: "number" }, + }, + }, + { + type: "object", + properties: { + kind: { type: "string", const: "c" }, + idk: { type: "string" }, + }, + }, + ], + }, + }, + }, + }, + ], + tokens: 80, + }, + { + messages: [ + { role: "system", content: "Hello" }, + { role: "user", content: "Hi there" }, + ], + functions: [ + { + name: "do_stuff", + parameters: { + type: "object", + properties: { + foo: { + anyOf: [{ type: "string", const: "a" }, { type: "number" }], + }, + }, + }, + }, + ], + tokens: 44, + }, { messages: [ { role: "system", content: "Hello" },