Skip to content

Commit

Permalink
Use normalized tool_choice value in prompt -> sdk conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
cephalization committed Feb 6, 2025
1 parent cebb5cc commit 07b1207
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 141 deletions.
38 changes: 38 additions & 0 deletions js/packages/phoenix-client/src/__generated__/api/v1.ts

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions js/packages/phoenix-client/src/prompts/sdks/toAnthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type { Variables, toSDKParamsBase } from "./types";
import { promptMessageToAnthropic } from "../../schemas/llm/messageSchemas";
import { formatPromptMessages } from "../../utils/formatPromptMessages";
import {
AnthropicToolChoice,
phoenixToolChoiceToOpenaiToolChoice,
safelyConvertToolChoiceToProvider,
} from "../../schemas/llm/toolChoiceSchemas";
import {
Expand All @@ -28,9 +28,8 @@ export const toAnthropic = <V extends Variables = Variables>({
variables,
}: ToAnthropicParams<V>): MessageCreateParams | null => {
try {
const { tool_choice: initialToolChoice, ...invocationParameters } =
const invocationParameters =
prompt.invocation_parameters as unknown as Record<string, unknown> & {
tool_choice?: AnthropicToolChoice;
max_tokens: number;
};
// parts of the prompt that can be directly converted to Anthropic params
Expand All @@ -57,19 +56,21 @@ export const toAnthropic = <V extends Variables = Variables>({
promptMessageToAnthropic.parse(message)
) as MessageParam[];

const tools = prompt.tools?.tools.map((tool) => {
let tools = prompt.tools?.tools.map((tool) => {
const openaiDefinition = phoenixToolToOpenAI.parse(tool);
invariant(openaiDefinition, "Tool definition is not valid");
return fromOpenAIToolDefinition({
toolDefinition: openaiDefinition,
targetProvider: "ANTHROPIC",
});
});

tools = (tools?.length ?? 0) > 0 ? tools : undefined;
const tool_choice =
(tools?.length ?? 0) > 0 && initialToolChoice
(tools?.length ?? 0) > 0 && prompt.tools?.tool_choice
? (safelyConvertToolChoiceToProvider({
toolChoice: initialToolChoice,
toolChoice: phoenixToolChoiceToOpenaiToolChoice.parse(
prompt.tools.tool_choice
),
targetProvider: "ANTHROPIC",
}) ?? undefined)
: undefined;
Expand All @@ -78,7 +79,7 @@ export const toAnthropic = <V extends Variables = Variables>({
const completionParams = {
...baseCompletionParams,
messages,
tools: (tools?.length ?? 0) > 0 ? tools : undefined,
tools,
tool_choice,
} satisfies Partial<MessageCreateParams>;

Expand Down
16 changes: 10 additions & 6 deletions js/packages/phoenix-client/src/prompts/sdks/toOpenAI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
} from "openai/resources";
import type { Variables, toSDKParamsBase } from "./types";
import {
phoenixToolChoiceToOpenaiToolChoice,
phoenixToolToOpenAI,
promptMessageToOpenAI,
safelyConvertToolChoiceToProvider,
Expand Down Expand Up @@ -56,18 +57,21 @@ export const toOpenAI = <V extends Variables = Variables>({
promptMessageToOpenAI.parse(message)
);

const tools = prompt.tools?.tools.map((tool) =>
phoenixToolToOpenAI.parse(tool)
);
let tools = prompt.tools?.tools
.map((tool) => phoenixToolToOpenAI.parse(tool))
.filter((tool) => tool !== null);
tools = (tools?.length ?? 0) > 0 ? tools : undefined;

const response_format = prompt.response_format
? phoenixResponseFormatToOpenAI.parse(prompt.response_format)
: undefined;

const tool_choice =
(tools?.length ?? 0) > 0 && "tool_choice" in baseCompletionParams
(tools?.length ?? 0) > 0 && prompt.tools?.tool_choice
? (safelyConvertToolChoiceToProvider({
toolChoice: baseCompletionParams.tool_choice,
toolChoice: phoenixToolChoiceToOpenaiToolChoice.parse(
prompt.tools?.tool_choice
),
targetProvider: "OPENAI",
}) ?? undefined)
: undefined;
Expand All @@ -76,7 +80,7 @@ export const toOpenAI = <V extends Variables = Variables>({
const completionParams = {
...baseCompletionParams,
messages,
tools: (tools?.length ?? 0) > 0 ? tools : undefined,
tools,
tool_choice,
response_format,
} satisfies Partial<ChatCompletionCreateParams>;
Expand Down
61 changes: 55 additions & 6 deletions js/packages/phoenix-client/src/schemas/llm/toolChoiceSchemas.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
import { z } from "zod";
import type { PhoenixModelProvider } from "../../constants";
import { assertUnreachable } from "../../utils/assertUnreachable";
import { isObject } from "../../utils/isObject";

/**
* Phoenix's tool choice schema
*/
export const phoenixToolChoiceSchema = z.union([
z.object({
type: z.literal("none"),
}),
z.object({
type: z.literal("zero-or-more"),
}),
z.object({
type: z.literal("one-or-more"),
}),
z.object({
type: z.literal("specific-function-tool"),
function_name: z.string(),
}),
]);

export type PhoenixToolChoice = z.infer<typeof phoenixToolChoiceSchema>;

/**
* OpenAI's tool choice schema
Expand Down Expand Up @@ -42,10 +64,31 @@ export const anthropicToolChoiceSchema = z.discriminatedUnion("type", [

export type AnthropicToolChoice = z.infer<typeof anthropicToolChoiceSchema>;

/*
*
* Conversion Helpers
*
*/

export const phoenixToolChoiceToOpenaiToolChoice =
phoenixToolChoiceSchema.transform((phoenix): OpenaiToolChoice => {
switch (phoenix.type) {
case "none":
return "none";
case "zero-or-more":
return "auto";
case "one-or-more":
return "required";
case "specific-function-tool":
return { type: "function", function: { name: phoenix.function_name } };
}
});

export const anthropicToolChoiceToOpenaiToolChoice =
anthropicToolChoiceSchema.transform((anthropic): OpenaiToolChoice => {
switch (anthropic.type) {
case "any":
return "required";
case "auto":
return "auto";
case "tool":
Expand All @@ -63,13 +106,19 @@ export const anthropicToolChoiceToOpenaiToolChoice =

export const openAIToolChoiceToAnthropicToolChoice =
openAIToolChoiceSchema.transform((openAI): AnthropicToolChoice => {
if (typeof openAI === "string") {
return { type: "auto" };
if (isObject(openAI)) {
return { type: "tool", name: openAI.function.name };
}
switch (openAI) {
case "auto":
return { type: "auto" };
case "none":
return { type: "auto" };
case "required":
return { type: "any" };
default:
assertUnreachable(openAI);
}
return {
type: "tool",
name: openAI.function.name ?? "",
};
});

export const llmProviderToolChoiceSchema = z.union([
Expand Down
6 changes: 6 additions & 0 deletions js/packages/phoenix-client/src/types/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ export type PromptChatMessage = Extract<
export type PromptTool =
components["schemas"]["PromptToolsV1"]["tools"][number];

/**
* The Phoenix prompt tool choice type from the API.
*/
export type PromptToolChoice = NonNullable<
components["schemas"]["PromptToolsV1"]["tool_choice"]
>;
/**
* The Phoenix prompt output schema type from the API.
*/
Expand Down
59 changes: 59 additions & 0 deletions js/packages/phoenix-client/test/prompts/sdks/data.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import { PromptVersion } from "../../../src/types/prompts";

export const BASE_MOCK_PROMPT_VERSION = {
id: "test",
description: "Test prompt",
model_provider: "openai",
model_name: "gpt-4",
template_type: "CHAT",
template_format: "MUSTACHE",
template: {
version: "chat-template-v1",
messages: [
{
role: "USER",
content: [{ type: "text", text: { text: "Hello" } }],
},
],
},
invocation_parameters: {
temperature: 0.7,
},
} satisfies Partial<PromptVersion>;

export const BASE_MOCK_PROMPT_VERSION_TOOLS = {
tools: {
type: "tools-v1",
tool_choice: { type: "zero-or-more" },
tools: [
{
type: "function-tool-v1",
name: "test",
description: "test function",
schema: {
type: "json-schema-draft-7-object-schema",
json: {
type: "object",
properties: {},
},
},
},
],
},
} satisfies Partial<PromptVersion>;

export const BASE_MOCK_PROMPT_VERSION_RESPONSE_FORMAT = {
response_format: {
type: "response-format-json-schema-v1",
name: "test",
description: "test function",
schema: {
type: "json-schema-draft-7-object-schema",
json: {
type: "object",
properties: {},
},
},
extra_parameters: {},
},
} satisfies Partial<PromptVersion>;
Loading

0 comments on commit 07b1207

Please sign in to comment.