Skip to content

Commit

Permalink
Merge pull request #18 from njk112/feat/examples-clasification
Browse files Browse the repository at this point in the history
Feat/examples classification
  • Loading branch information
jxnl authored Jan 3, 2024
2 parents d2b773e + 9c31962 commit 1021cdc
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
50 changes: 50 additions & 0 deletions examples/classification/multi_prediction/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import assert from "assert"
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

enum MULTI_CLASSIFICATION_LABELS {
"BILLING" = "billing",
"GENERAL_QUERY" = "general_query",
"HARDWARE" = "hardware"
}

const MultiClassificationSchema = z.object({
predicted_labels: z.array(z.nativeEnum(MULTI_CLASSIFICATION_LABELS))
})

type MultiClassification = z.infer<typeof MultiClassificationSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const createClassification = async (data: string): Promise<MultiClassification | undefined> => {
const classification: MultiClassification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }],
model: "gpt-3.5-turbo",
response_model: MultiClassificationSchema,
max_retries: 3
})

return classification || undefined
}

const classification = await createClassification(
"My account is locked and I can't access my billing info. Phone is also broken"
)
// OUTPUT: { predicted_labels: [ 'billing', 'hardware' ] }

console.log({ classification })

assert(
classification.predicted_labels.includes(MULTI_CLASSIFICATION_LABELS.BILLING) &&
classification.predicted_labels.includes(MULTI_CLASSIFICATION_LABELS.HARDWARE),
`Expected ${classification.predicted_labels} to include ${MULTI_CLASSIFICATION_LABELS.BILLING} and ${MULTI_CLASSIFICATION_LABELS.HARDWARE}`
)
48 changes: 48 additions & 0 deletions examples/classification/simple_prediction/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import assert from "assert"
import Instructor from "@/instructor"
import OpenAI from "openai"
import { z } from "zod"

enum CLASSIFICATION_LABELS {
"SPAM" = "SPAM",
"NOT_SPAM" = "NOT_SPAM"
}

const SimpleClassificationSchema = z.object({
class_label: z.nativeEnum(CLASSIFICATION_LABELS)
})

type SimpleClassification = z.infer<typeof SimpleClassificationSchema>

const oai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY ?? undefined,
organization: process.env.OPENAI_ORG_ID ?? undefined
})

const client = Instructor({
client: oai,
mode: "FUNCTIONS"
})

const createClassification = async (data: string): Promise<SimpleClassification | undefined> => {
const classification: SimpleClassification = await client.chat.completions.create({
messages: [{ role: "user", content: `"Classify the following text: ${data}` }],
model: "gpt-3.5-turbo",
response_model: SimpleClassificationSchema,
max_retries: 3
})

return classification || undefined
}

const classification = await createClassification(
"Hello there I'm a nigerian prince and I want to give you money"
)
// OUTPUT: { class_label: 'SPAM' }

console.log({ classification })

assert(
classification.class_label === CLASSIFICATION_LABELS.SPAM,
`Expected ${classification.class_label} to be ${CLASSIFICATION_LABELS.SPAM}`
)

0 comments on commit 1021cdc

Please sign in to comment.