diff --git a/src/common/util/type.ts b/src/common/util/type.ts new file mode 100644 index 00000000..ed34d385 --- /dev/null +++ b/src/common/util/type.ts @@ -0,0 +1,39 @@ +export type ToUnionWithField = _ToUnionWithField< + T, + F, + T[F] +>; + +type _ToUnionWithField = FV extends never + ? never + : T & Record; + +type ObjectWithNullableField = { + name: string; + message: string | null; +}; +type Unionized = ToUnionWithField; + +const expectNonError: ( + x: Unionized, +) => x is { message: string } & ObjectWithNullableField = (hoge: Unionized) => { + return typeof hoge.message === "string"; +}; + +// @ts-expect-error +const expectError: ( + x: ObjectWithNullableField, +) => x is { message: string } & ObjectWithNullableField = ( + hoge: ObjectWithNullableField, +) => { + return typeof hoge.message === "string"; +}; + +expectNonError({ + message: "hoge", + name: "", +}); +expectError({ + message: "hoge", + name: "", +}); diff --git a/src/server/model/story.ts b/src/server/model/story.ts index fd057257..d1f723ae 100644 --- a/src/server/model/story.ts +++ b/src/server/model/story.ts @@ -92,18 +92,3 @@ export type QuestionExample = z.infer; export type QuestionExampleWithCustomMessage = QuestionExample & { customMessage: string; }; - -export const filterWithCustomMessage = ( - examples: QuestionExample[], -): QuestionExampleWithCustomMessage[] => { - const filterd: QuestionExampleWithCustomMessage[] = []; - for (const example of examples) { - if (example.customMessage) { - filterd.push({ - ...example, - customMessage: example.customMessage, - }); - } - } - return filterd; -}; diff --git a/src/server/services/question/index.ts b/src/server/services/question/index.ts index c0de3599..a80158dd 100644 --- a/src/server/services/question/index.ts +++ b/src/server/services/question/index.ts @@ -1,3 +1,4 @@ +import { ToUnionWithField } from "@/common/util/type"; import DataLoader from "dataloader"; import { type ABTestingVariant, @@ -5,7 +6,6 @@ import { } from "../../../common/abtesting"; import { calculateEuclideanDistance } from "../../../libs/math"; import { openai } from "../../../libs/openai"; -import { prepareProura } from "../../../libs/proura"; import type { QuestionExample, QuestionExampleWithCustomMessage, @@ -24,92 +24,95 @@ export const askQuestion = async ( story: Story, abPromise: Promise, ) => { - const proura = prepareProura(); const embeddingsDataLoader = new DataLoader((texts: readonly string[]) => { return openai.embeddings .create({ model: "text-embedding-ada-002", input: [...texts], }) - .then((res) => res.data); - }); - const { answer, hitQuestionExample } = await proura - .add("questionEmbedding", async () => { - return embeddingsDataLoader.load(question); - }) - .add("sortedExamplesWithDistance", async (dependsOn) => { - const embeddings = await embeddingsDataLoader.loadMany( - story.questionExamples.map(({ question }) => question), + .then((res) => + res.data.map(({ index, embedding }) => { + const text = texts[index]; + if (text === undefined) { + throw new Error("index out of range"); + } + return { + input: text, + embedding: embedding, + }; + }), ); - const questionEmbedding = await dependsOn("questionEmbedding"); - const result: { - example: QuestionExample; - distance: number; - }[] = []; - embeddings.forEach((embedding, index) => { - const example = story.questionExamples[index]; - if (!example) { - throw new Error("index out of range"); - } - if (embedding instanceof Error) { - console.error(embedding); - return; - } - const distance = calculateEuclideanDistance( - questionEmbedding.embedding, - embedding.embedding, - ); - result.push({ - example, - distance, - }); + }); + const questonEmbedding = embeddingsDataLoader.load(question); + const examplesEmbeddings = embeddingsDataLoader.loadMany( + story.questionExamples.map(({ question }) => question), + ); + + const sortedExamplesWithDistance = examplesEmbeddings.then( + async (embeddings) => { + return questonEmbedding.then((questonEmbedding) => { + return embeddings + .map((item) => { + if (item instanceof Error) { + console.error(item); + return null; + } + const distance = calculateEuclideanDistance( + questonEmbedding.embedding, + item.embedding, + ); + const example = story.questionExamples.find( + ({ question }) => question === item.input, + ); + if (!example) { + throw new Error("example not found"); + } + return { + example, + distance, + }; + }) + .filter((e) => e !== null) + .toSorted((a, b) => a.distance - b.distance); }); - result.sort((a, b) => a.distance - b.distance); - return result; + }, + ); + + const answer = sortedExamplesWithDistance + .then((examples) => { + return ["True", "False", "Unknown"] + .map((answer) => { + return examples.find(({ example }) => example.answer === answer); + }) + .filter((e) => e !== undefined); }) - .add("answer", async (dependsOn) => { - const examples = await dependsOn("sortedExamplesWithDistance"); - const pickedFewExamples: typeof examples = []; - ["True", "False", "Unknown"].forEach((answer) => { - const example = examples.find( - ({ example }) => example.answer === answer, - ); - example && pickedFewExamples.push(example); - }); + .then((pickedFewExamples) => { const inputStory = { quiz: story.quiz, truth: story.truth, questionExamples: pickedFewExamples.map(({ example }) => example), }; - const ab = await abPromise; - const selected = abTestVarToQuestionToAI[ab]; - return selected(inputStory, question); - }) - .add("hitQuestionExample", async (dependsOn) => { - const answer = await dependsOn("answer"); - const examples = await dependsOn("sortedExamplesWithDistance"); - const recur = ([ - head, - ...tail - ]: typeof examples): QuestionExampleWithCustomMessage | null => { - if (!head) { - return null; - } - const { example, distance } = head; - if (example.customMessage && distance < 0.3) { - return { - ...example, - customMessage: example.customMessage, - }; - } - return recur(tail); - }; - const hit = recur(examples); - return hit?.answer === answer ? hit : null; - }) - .exec(); + return abPromise.then((ab) => { + const selected = abTestVarToQuestionToAI[ab]; + return selected(inputStory, question); + }); + }); + + const hitQuestionExample = sortedExamplesWithDistance.then((examples) => { + return answer.then((answer) => { + return examples + .filter(({ example }) => example.answer === answer) + .filter(({ distance }) => distance < 0.3) + .map(({ example }) => example) + .find( + (example: ToUnionWithField) => + typeof example.customMessage === "string", + ); + }); + }); + return { - answer, - hitQuestionExample, + answer: await answer, + hitQuestionExample: (await hitQuestionExample) || null, }; };