Skip to content

Commit

Permalink
Merge branch 'main' into stream-support
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Jan 3, 2024
2 parents b9feb0d + 1021cdc commit 2887a17
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 17 deletions.
30 changes: 19 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,32 @@ import OpenAI from "openai";

const UserSchema = z.object({
age: z.number(),
name: z.string().refine((name) => name.includes(" "), {
message: "Name must contain a space",
}),
});
name: z.string().refine(name => name.includes(" "), {
message: "Name must contain a space"
})
})

type User = z.infer<typeof UserSchema>;
type User = z.infer<typeof UserSchema>

const client = instruct.patch({
client: OpenAI(process.env.OPENAI_API_KEY, process.env.OPENAI_ORG_ID),
mode: instruct.MODES.TOOLS,
});
const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS" // or TOOLS or MD_JSON or JSON_SCHEMA or JSON
})

const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",
response_model: UserSchema,
max_retries: 3,
});
max_retries: 3
})

assert(user.age === 30)
assert(user.name === "Jason Liu")
```

Or if it makes more sense to you, you can use the builder pattern:
Expand Down
50 changes: 50 additions & 0 deletions examples/classification/multi_prediction/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import assert from "assert"
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

enum MULTI_CLASSIFICATION_LABELS {
"BILLING" = "billing",
"GENERAL_QUERY" = "general_query",
"HARDWARE" = "hardware"
}

const MultiClassificationSchema = z.object({
predicted_labels: z.array(z.nativeEnum(MULTI_CLASSIFICATION_LABELS))
})

type MultiClassification = z.infer<typeof MultiClassificationSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const createClassification = async (data: string): Promise<MultiClassification | undefined> => {
const classification: MultiClassification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }],
model: "gpt-3.5-turbo",
response_model: MultiClassificationSchema,
max_retries: 3
})

return classification || undefined
}

const classification = await createClassification(
"My account is locked and I can't access my billing info. Phone is also broken"
)
// OUTPUT: { predicted_labels: [ 'billing', 'hardware' ] }

console.log({ classification })

assert(
classification.predicted_labels.includes(MULTI_CLASSIFICATION_LABELS.BILLING) &&
classification.predicted_labels.includes(MULTI_CLASSIFICATION_LABELS.HARDWARE),
`Expected ${classification.predicted_labels} to include ${MULTI_CLASSIFICATION_LABELS.BILLING} and ${MULTI_CLASSIFICATION_LABELS.HARDWARE}`
)
48 changes: 48 additions & 0 deletions examples/classification/simple_prediction/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import assert from "assert"
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

enum CLASSIFICATION_LABELS {
"SPAM" = "SPAM",
"NOT_SPAM" = "NOT_SPAM"
}

const SimpleClassificationSchema = z.object({
class_label: z.nativeEnum(CLASSIFICATION_LABELS)
})

type SimpleClassification = z.infer<typeof SimpleClassificationSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const createClassification = async (data: string): Promise<SimpleClassification | undefined> => {
const classification: SimpleClassification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following text: ${data}` }],
model: "gpt-3.5-turbo",
response_model: SimpleClassificationSchema,
max_retries: 3
})

return classification || undefined
}

const classification = await createClassification(
"Hello there I'm a nigerian prince and I want to give you money"
)
// OUTPUT: { class_label: 'SPAM' }

console.log({ classification })

assert(
classification.class_label === CLASSIFICATION_LABELS.SPAM,
`Expected ${classification.class_label} to be ${CLASSIFICATION_LABELS.SPAM}`
)
2 changes: 1 addition & 1 deletion src/constants/modes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ export const MODE = {
JSON: "JSON",
MD_JSON: "MD_JSON",
JSON_SCHEMA: "JSON_SCHEMA"
}
} as const

export type MODE = keyof typeof MODE
20 changes: 16 additions & 4 deletions src/oai/params.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ export function OAIBuildFunctionParams(definition, params) {
}

export function OAIBuildToolFunctionParams(definition, params) {
const { name, ...definitionParams } = definition

return {
...params,
tool_choice: {
type: "function",
function: { name: definition.name }
function: { name }
},
tools: [...(params?.tools ?? []), definition]
tools: [
{
type: "function",
function: {
name,
parameters: definitionParams
}
},
...(params?.tools ?? [])
]
}
}

Expand All @@ -27,7 +38,8 @@ export function OAIBuildMessageBasedParams(definition, params, mode) {
response_format: { type: "json_object" }
},
[MODE.JSON_SCHEMA]: {
response_format: { type: "json_object", schema: definition }
//TODO: not sure what is different about this mode - the OAI sdk doesnt accept a schema here
response_format: { type: "json_object" }
}
}

Expand All @@ -39,7 +51,7 @@ export function OAIBuildMessageBasedParams(definition, params, mode) {
messages: [
...(params?.messages ?? []),
{
role: "SYSTEM",
role: "system",
content: `
Given a user prompt, you will return fully valid JSON based on the following description and schema.
You will return no other prose. You will take into account the descriptions for each paramater within the schema
Expand Down
1 change: 0 additions & 1 deletion src/oai/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ export function OAIResponseFnArgsParser(
| OpenAI.Chat.Completions.ChatCompletion
) {
const parsedData = typeof data === "string" ? JSON.parse(data) : data

const text =
parsedData.choices?.[0].delta?.function_call?.arguments ??
parsedData.choices?.[0]?.message?.function_call?.arguments ??
Expand Down
62 changes: 62 additions & 0 deletions tests/mode.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import Instructor from "@/instructor"
import { describe, expect, test } from "bun:test"
import OpenAI from "openai"
import { z } from "zod"

import { MODE } from "@/constants/modes"

const models_latest = ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]
const models_old = ["gpt-3.5-turbo", "gpt-4"]

const createTestCases = (): { model: string; mode: MODE }[] => {
const { FUNCTIONS, ...rest } = MODE
const modes = Object.values(rest)

return [
...models_latest.flatMap(model => modes.map(mode => ({ model, mode }))),
...models_old.flatMap(model => ({ model, mode: FUNCTIONS }))
]
}

const UserSchema = z.object({
age: z.number(),
name: z.string().refine(name => name.includes(" "), {
message: "Name must contain a space"
})
})

type User = z.infer<typeof UserSchema>

async function extractUser(model: string, mode: MODE) {
const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: mode
})

const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: model,
response_model: UserSchema,
max_retries: 3
})

return user
}

describe("Modes", async () => {
const testCases = createTestCases()

for await (const { model, mode } of testCases) {
test(`Should return extracted name and age for model ${model} and mode ${mode}`, async () => {
const user = await extractUser(model, mode)

expect(user.name).toEqual("Jason Liu")
expect(user.age).toEqual(30)
})
}
})

0 comments on commit 2887a17

Please sign in to comment.