Skip to content

Commit

Permalink
proura脱却
Browse files Browse the repository at this point in the history
  • Loading branch information
eatski committed Aug 12, 2024
1 parent c5e6fc8 commit 23b7882
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 88 deletions.
39 changes: 39 additions & 0 deletions src/common/util/type.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
export type ToUnionWithField<T, F extends keyof T> = _ToUnionWithField<
T,
F,
T[F]
>;

type _ToUnionWithField<T, F extends keyof T, FV extends T[F]> = FV extends never
? never
: T & Record<F, FV>;

type ObjectWithNullableField = {
name: string;
message: string | null;
};
type Unionized = ToUnionWithField<ObjectWithNullableField, "message">;

const expectNonError: (
x: Unionized,
) => x is { message: string } & ObjectWithNullableField = (hoge: Unionized) => {
return typeof hoge.message === "string";
};

// @ts-expect-error
const expectError: (
x: ObjectWithNullableField,
) => x is { message: string } & ObjectWithNullableField = (
hoge: ObjectWithNullableField,
) => {
return typeof hoge.message === "string";
};

expectNonError({
message: "hoge",
name: "",
});
expectError({
message: "hoge",
name: "",
});
15 changes: 0 additions & 15 deletions src/server/model/story.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,3 @@ export type QuestionExample = z.infer<typeof questionExample>;
export type QuestionExampleWithCustomMessage = QuestionExample & {
customMessage: string;
};

export const filterWithCustomMessage = (
examples: QuestionExample[],
): QuestionExampleWithCustomMessage[] => {
const filterd: QuestionExampleWithCustomMessage[] = [];
for (const example of examples) {
if (example.customMessage) {
filterd.push({
...example,
customMessage: example.customMessage,
});
}
}
return filterd;
};
149 changes: 76 additions & 73 deletions src/server/services/question/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import { ToUnionWithField } from "@/common/util/type";
import DataLoader from "dataloader";
import {
type ABTestingVariant,
AB_TESTING_VARIANTS,
} from "../../../common/abtesting";
import { calculateEuclideanDistance } from "../../../libs/math";
import { openai } from "../../../libs/openai";
import { prepareProura } from "../../../libs/proura";
import type {
QuestionExample,
QuestionExampleWithCustomMessage,
Expand All @@ -24,92 +24,95 @@ export const askQuestion = async (
story: Story,
abPromise: Promise<ABTestingVariant>,
) => {
const proura = prepareProura();
const embeddingsDataLoader = new DataLoader((texts: readonly string[]) => {
return openai.embeddings
.create({
model: "text-embedding-ada-002",
input: [...texts],
})
.then((res) => res.data);
});
const { answer, hitQuestionExample } = await proura
.add("questionEmbedding", async () => {
return embeddingsDataLoader.load(question);
})
.add("sortedExamplesWithDistance", async (dependsOn) => {
const embeddings = await embeddingsDataLoader.loadMany(
story.questionExamples.map(({ question }) => question),
.then((res) =>
res.data.map(({ index, embedding }) => {
const text = texts[index];
if (text === undefined) {
throw new Error("index out of range");
}
return {
input: text,
embedding: embedding,
};
}),
);
const questionEmbedding = await dependsOn("questionEmbedding");
const result: {
example: QuestionExample;
distance: number;
}[] = [];
embeddings.forEach((embedding, index) => {
const example = story.questionExamples[index];
if (!example) {
throw new Error("index out of range");
}
if (embedding instanceof Error) {
console.error(embedding);
return;
}
const distance = calculateEuclideanDistance(
questionEmbedding.embedding,
embedding.embedding,
);
result.push({
example,
distance,
});
});
const questonEmbedding = embeddingsDataLoader.load(question);
const examplesEmbeddings = embeddingsDataLoader.loadMany(
story.questionExamples.map(({ question }) => question),
);

const sortedExamplesWithDistance = examplesEmbeddings.then(
async (embeddings) => {
return questonEmbedding.then((questonEmbedding) => {
return embeddings
.map((item) => {
if (item instanceof Error) {
console.error(item);
return null;
}
const distance = calculateEuclideanDistance(
questonEmbedding.embedding,
item.embedding,
);
const example = story.questionExamples.find(
({ question }) => question === item.input,
);
if (!example) {
throw new Error("example not found");
}
return {
example,
distance,
};
})
.filter((e) => e !== null)
.toSorted((a, b) => a.distance - b.distance);
});
result.sort((a, b) => a.distance - b.distance);
return result;
},
);

const answer = sortedExamplesWithDistance
.then((examples) => {
return ["True", "False", "Unknown"]
.map((answer) => {
return examples.find(({ example }) => example.answer === answer);
})
.filter((e) => e !== undefined);
})
.add("answer", async (dependsOn) => {
const examples = await dependsOn("sortedExamplesWithDistance");
const pickedFewExamples: typeof examples = [];
["True", "False", "Unknown"].forEach((answer) => {
const example = examples.find(
({ example }) => example.answer === answer,
);
example && pickedFewExamples.push(example);
});
.then((pickedFewExamples) => {
const inputStory = {
quiz: story.quiz,
truth: story.truth,
questionExamples: pickedFewExamples.map(({ example }) => example),
};
const ab = await abPromise;
const selected = abTestVarToQuestionToAI[ab];
return selected(inputStory, question);
})
.add("hitQuestionExample", async (dependsOn) => {
const answer = await dependsOn("answer");
const examples = await dependsOn("sortedExamplesWithDistance");
const recur = ([
head,
...tail
]: typeof examples): QuestionExampleWithCustomMessage | null => {
if (!head) {
return null;
}
const { example, distance } = head;
if (example.customMessage && distance < 0.3) {
return {
...example,
customMessage: example.customMessage,
};
}
return recur(tail);
};
const hit = recur(examples);
return hit?.answer === answer ? hit : null;
})
.exec();
return abPromise.then((ab) => {
const selected = abTestVarToQuestionToAI[ab];
return selected(inputStory, question);
});
});

const hitQuestionExample = sortedExamplesWithDistance.then((examples) => {
return answer.then((answer) => {
return examples
.filter(({ example }) => example.answer === answer)
.filter(({ distance }) => distance < 0.3)
.map(({ example }) => example)
.find<QuestionExampleWithCustomMessage>(
(example: ToUnionWithField<QuestionExample, "customMessage">) =>
typeof example.customMessage === "string",
);
});
});

return {
answer,
hitQuestionExample,
answer: await answer,
hitQuestionExample: (await hitQuestionExample) || null,
};
};

0 comments on commit 23b7882

Please sign in to comment.