From 03cad34c3b791b4a76cd5bd977982a456dec3bec Mon Sep 17 00:00:00 2001 From: Roger Yang <80478925+RogerHYang@users.noreply.github.com> Date: Wed, 12 Feb 2025 12:21:17 -0800 Subject: [PATCH] feat(prompts): POST method for prompts endpoint (#6347) --- .../src/__generated__/api/v1.ts | 94 ++- packages/phoenix-client/pyproject.toml | 2 +- .../scripts/codegen/transform.py | 226 +++++++- .../client/__generated__/v1/.gitignore | 2 + .../client/__generated__/v1/__init__.py | 41 +- .../phoenix/client/__generated__/v1/models.py | 546 ++++++++++++++++++ .../client/helpers/sdk/anthropic/messages.py | 7 +- .../google_generativeai/generate_content.py | 7 +- .../phoenix/client/helpers/sdk/openai/chat.py | 75 ++- .../client/resources/prompts/__init__.py | 55 +- .../src/phoenix/client/utils/config.py | 7 + .../src/phoenix/client/utils/prompt.py | 10 +- .../client/utils/template_formatters.py | 4 +- .../tests/canary/sdk/openai/test_chat.py | 94 ++- schemas/openapi.json | 270 ++++++++- src/phoenix/db/models.py | 44 +- src/phoenix/server/api/openapi/schema.py | 1 + src/phoenix/server/api/routers/v1/models.py | 5 +- src/phoenix/server/api/routers/v1/prompts.py | 85 ++- tests/integration/prompts/test_prompts.py | 25 +- tox.ini | 28 +- .../internal/prompts/hallucination_eval.ipynb | 123 +++- 22 files changed, 1632 insertions(+), 119 deletions(-) create mode 100644 packages/phoenix-client/src/phoenix/client/__generated__/v1/.gitignore create mode 100644 packages/phoenix-client/src/phoenix/client/__generated__/v1/models.py diff --git a/js/packages/phoenix-client/src/__generated__/api/v1.ts b/js/packages/phoenix-client/src/__generated__/api/v1.ts index 436d6d16b1..da29799009 100644 --- a/js/packages/phoenix-client/src/__generated__/api/v1.ts +++ b/js/packages/phoenix-client/src/__generated__/api/v1.ts @@ -221,7 +221,8 @@ export interface paths { /** Get all prompts */ get: operations["getPrompts"]; put?: never; - post?: never; + /** Create a prompt version */ + post: operations["postPromptVersion"]; delete?: never; options?: never; head?: never; @@ -346,6 +347,15 @@ export interface components { CreateExperimentResponseBody: { data: components["schemas"]["Experiment"]; }; + /** CreatePromptRequestBody */ + CreatePromptRequestBody: { + prompt: components["schemas"]["PromptData"]; + version: components["schemas"]["PromptVersionData"]; + }; + /** CreatePromptResponseBody */ + CreatePromptResponseBody: { + data: components["schemas"]["PromptVersion"]; + }; /** Dataset */ Dataset: { /** Id */ @@ -492,6 +502,8 @@ export interface components { /** Detail */ detail?: components["schemas"]["ValidationError"][]; }; + /** Identifier */ + Identifier: string; /** InsertedSpanAnnotation */ InsertedSpanAnnotation: { /** @@ -549,14 +561,13 @@ export interface components { ModelProvider: "OPENAI" | "AZURE_OPENAI" | "ANTHROPIC" | "GEMINI"; /** Prompt */ Prompt: { + name: components["schemas"]["Identifier"]; + /** Description */ + description?: string | null; + /** Source Prompt Id */ + source_prompt_id?: string | null; /** Id */ id: string; - /** Source Prompt Id */ - source_prompt_id: string | null; - /** Name */ - name: string; - /** Description */ - description: string | null; }; /** PromptAnthropicInvocationParameters */ PromptAnthropicInvocationParameters: { @@ -617,6 +628,14 @@ export interface components { /** Messages */ messages: components["schemas"]["PromptMessage"][]; }; + /** PromptData */ + PromptData: { + name: components["schemas"]["Identifier"]; + /** Description */ + description?: string | null; + /** Source Prompt Id */ + source_prompt_id?: string | null; + }; /** PromptFunctionTool */ PromptFunctionTool: { /** @@ -785,10 +804,27 @@ export interface components { }; /** PromptVersion */ PromptVersion: { + /** Description */ + description?: string | null; + model_provider: components["schemas"]["ModelProvider"]; + /** Model Name */ + model_name: string; + /** Template */ + template: components["schemas"]["PromptChatTemplate"] | components["schemas"]["PromptStringTemplate"]; + template_type: components["schemas"]["PromptTemplateType"]; + template_format: components["schemas"]["PromptTemplateFormat"]; + /** Invocation Parameters */ + invocation_parameters: components["schemas"]["PromptOpenAIInvocationParameters"] | components["schemas"]["PromptAzureOpenAIInvocationParameters"] | components["schemas"]["PromptAnthropicInvocationParameters"] | components["schemas"]["PromptGeminiInvocationParameters"]; + tools?: components["schemas"]["PromptTools"] | null; + /** Response Format */ + response_format?: components["schemas"]["PromptResponseFormatJSONSchema"] | null; /** Id */ id: string; + }; + /** PromptVersionData */ + PromptVersionData: { /** Description */ - description: string; + description?: string | null; model_provider: components["schemas"]["ModelProvider"]; /** Model Name */ model_name: string; @@ -1719,6 +1755,48 @@ export interface operations { }; }; }; + postPromptVersion: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["CreatePromptRequestBody"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["CreatePromptResponseBody"]; + }; + }; + /** @description Forbidden */ + 403: { + headers: { + [name: string]: unknown; + }; + content: { + "text/plain": string; + }; + }; + /** @description Unprocessable Entity */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "text/plain": string; + }; + }; + }; + }; listPromptVersions: { parameters: { query?: { diff --git a/packages/phoenix-client/pyproject.toml b/packages/phoenix-client/pyproject.toml index 194b295794..e9ed68a411 100644 --- a/packages/phoenix-client/pyproject.toml +++ b/packages/phoenix-client/pyproject.toml @@ -74,7 +74,7 @@ force-single-line = false [tool.ruff.lint.per-file-ignores] "*.ipynb" = ["E402", "E501"] -"src/phoenix/client/__generated__/v1/__init__.py" = ["F401", "E501"] +"src/phoenix/client/__generated__/v1/*.py" = ["E501"] [tool.ruff.format] line-ending = "native" diff --git a/packages/phoenix-client/scripts/codegen/transform.py b/packages/phoenix-client/scripts/codegen/transform.py index 379e0ef530..bc153a78ea 100644 --- a/packages/phoenix-client/scripts/codegen/transform.py +++ b/packages/phoenix-client/scripts/codegen/transform.py @@ -1,5 +1,7 @@ import ast import sys +from pathlib import Path +from typing import Callable class ConvertDataClassToTypedDict(ast.NodeTransformer): @@ -77,13 +79,195 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: return node -if __name__ == "__main__": - if len(sys.argv) != 2: - print("Usage: python transform.py ") - sys.exit(1) - file_path = sys.argv[1] - with open(file_path, "r") as file: - code = file.read() +class PydanticModels(ast.NodeTransformer): + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + # Recursively transform all child nodes in the class body. + new_body = [self.visit(child) for child in node.body] + + # Check if the class inherits from BaseModel. + # We look at the bases of the class, filtering out those that are simple names. + base_names = [base.id for base in node.bases if isinstance(base, ast.Name)] + if "BaseModel" not in base_names: + # This class is not a Pydantic model; return it unchanged. + return node + + # Check if a model_config assignment already exists in the class body. + has_model_config = any( + isinstance(stmt, ast.Assign) + and any( + isinstance(target, ast.Name) and target.id == "model_config" + for target in stmt.targets + ) + for stmt in new_body + ) + + if not has_model_config: + # Create an assignment: + # model_config = ConfigDict(strict=True, validate_assignment=True) + model_config_assign = ast.Assign( + targets=[ast.Name(id="model_config", ctx=ast.Store())], + value=ast.Call( + func=ast.Name(id="ConfigDict", ctx=ast.Load()), + args=[], + keywords=[ + ast.keyword(arg="strict", value=ast.Constant(value=True)), + ast.keyword(arg="validate_assignment", value=ast.Constant(value=True)), + ], + ), + ) + # Insert the new assignment at the beginning of the class body. + new_body.insert(0, model_config_assign) + + # Return a new ClassDef node with the updated body. + # We preserve the original bases, keywords, and decorators. + return ast.ClassDef( + name=node.name, + bases=node.bases, + keywords=node.keywords, + body=new_body, + decorator_list=node.decorator_list, + ) + + def visit_Call(self, node) -> ast.Call: + if isinstance(node.func, ast.Name) and node.func.id == "ConfigDict": + kwargs = {kw.arg: kw for kw in node.keywords} + kwargs["extra"] = ast.keyword(arg="extra", value=ast.Constant(value="forbid")) + kwargs["strict"] = ast.keyword(arg="strict", value=ast.Constant(value=True)) + kwargs["validate_assignment"] = ast.keyword( + arg="validate_assignment", value=ast.Constant(value=True) + ) + return ast.Call( + func=node.func, + args=node.args, + keywords=list(kwargs.values()), + ) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: + """ + Transform annotated assignments that use a Sequence with a discriminator. + + Specifically, convert annotations of the form: + + Annotated[Sequence[T], Field(discriminator="type", ...)] + + into: + + Annotated[ + Sequence[Annotated[T, Field(discriminator="type")]], + Field(... without discriminator ...) + ] + + This transformation is needed to avoid runtime errors with discriminated unions in Pydantic. + + Note: T is usually a Union[...], but being one isn't required. + """ + # Check that the annotation is an Annotated[...] type. + if not ( + isinstance(node.annotation, ast.Subscript) + and isinstance(node.annotation.value, ast.Name) + and node.annotation.value.id == "Annotated" + ): + return node + + # Ensure the Annotated[...] has at least two parts. + if not ( + isinstance(node.annotation.slice, ast.Tuple) and len(node.annotation.slice.elts) >= 2 + ): + return node + + # Unpack the two parts: the type part and the Field part. + annotated_parts = node.annotation.slice.elts + seq_annotation = annotated_parts[0] # Expecting Sequence[T] + field_call = annotated_parts[1] # Expecting Field(...) + + # Check that the type part is Sequence[T]. + if not ( + isinstance(seq_annotation, ast.Subscript) + and isinstance(seq_annotation.value, ast.Name) + and seq_annotation.value.id == "Sequence" + ): + return node + + # Check that the second part is a Field(...) call. + if not ( + isinstance(field_call, ast.Call) + and isinstance(field_call.func, ast.Name) + and field_call.func.id == "Field" + ): + return node + + # Verify that the Field call includes a discriminator with the value "type". + has_discriminator = any( + kw.arg == "discriminator" + and isinstance(kw.value, ast.Constant) + and kw.value.value == "type" + for kw in field_call.keywords + ) + if not has_discriminator: + return node + + # --- Begin transformation --- + + # Extract the inner type T from Sequence[T]. + inner_type = seq_annotation.slice + + # Create a Field call for the inner Annotated with only the discriminator. + inner_field = ast.Call( + func=ast.Name(id="Field"), + args=[], + keywords=[ast.keyword(arg="discriminator", value=ast.Constant(value="type"))], + ) + + # Build the inner Annotated[T, Field(discriminator="type")]. + inner_annotated = ast.Subscript( + value=ast.Name(id="Annotated"), + slice=ast.Tuple(elts=[inner_type, inner_field]), + ) + + # Wrap the inner Annotated in a Sequence, i.e. Sequence[Annotated[T, Field(...)]] + new_seq_annotation = ast.Subscript( + value=ast.Name(id="Sequence"), + slice=ast.Tuple(elts=[inner_annotated]), + ) + + # Prepare a new outer Field by copying all keywords except the discriminator. + outer_field_keywords = [kw for kw in field_call.keywords if kw.arg != "discriminator"] + + # If there are extra keywords, create an outer Field call. + if outer_field_keywords: + outer_field = ast.Call( + func=ast.Name(id="Field"), + args=[], + keywords=outer_field_keywords, + ) + # Combine the new Sequence and the outer Field in a new Annotated. + new_annotation = ast.Subscript( + value=ast.Name(id="Annotated"), + slice=ast.Tuple(elts=[new_seq_annotation, outer_field]), + ) + else: + # Otherwise, just use the new Sequence annotation. + new_annotation = new_seq_annotation + + # Return a new annotated assignment with the transformed annotation. + return ast.AnnAssign( + target=node.target, + annotation=new_annotation, + value=node.value, + simple=node.simple, + ) + + +def _transform_pydantic(code: str) -> ast.AST: + parsed = ast.parse(code) + transformed = PydanticModels().visit(parsed) + return transformed + + +def _transform_dataclass( + code: str, +) -> ast.AST: parsed = ast.parse(code) for i, node in enumerate(parsed.body): if isinstance(node, ast.ClassDef): @@ -101,8 +285,28 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: ) break transformed = ConvertDataClassToTypedDict().visit(parsed) + return transformed + + +def _rewrite( + path: Path, + in_file: str, + out_file: str, + transform: Callable[[str], ast.AST], +) -> None: + with open(path / in_file, "r") as f: + code = f.read() + transformed = ast.fix_missing_locations(transform(code)) unparsed = ast.unparse(transformed) - with open(file_path, "w") as file: - file.write("# pyright: reportUnusedImport=false\n") - file.write('"""Do not edit"""\n\n') - file.write(unparsed) + with open(path / out_file, "w") as f: + f.write('"""Do not edit"""\n\n') + f.write(unparsed) + + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python transform.py ") + sys.exit(1) + path = Path(sys.argv[1]) + _rewrite(path, ".pydantic.txt", "models.py", _transform_pydantic) + _rewrite(path, ".dataclass.txt", "__init__.py", _transform_dataclass) diff --git a/packages/phoenix-client/src/phoenix/client/__generated__/v1/.gitignore b/packages/phoenix-client/src/phoenix/client/__generated__/v1/.gitignore new file mode 100644 index 0000000000..a51ab9efd8 --- /dev/null +++ b/packages/phoenix-client/src/phoenix/client/__generated__/v1/.gitignore @@ -0,0 +1,2 @@ +/.pydantic.txt +/.dataclass.txt diff --git a/packages/phoenix-client/src/phoenix/client/__generated__/v1/__init__.py b/packages/phoenix-client/src/phoenix/client/__generated__/v1/__init__.py index e950473a12..44e2716105 100644 --- a/packages/phoenix-client/src/phoenix/client/__generated__/v1/__init__.py +++ b/packages/phoenix-client/src/phoenix/client/__generated__/v1/__init__.py @@ -1,4 +1,3 @@ -# pyright: reportUnusedImport=false """Do not edit""" from __future__ import annotations @@ -103,10 +102,10 @@ class ListExperimentsResponseBody(TypedDict): class Prompt(TypedDict): - id: str - source_prompt_id: Optional[str] name: str - description: Optional[str] + id: str + description: NotRequired[str] + source_prompt_id: NotRequired[str] class PromptAnthropicInvocationParametersContent(TypedDict): @@ -126,6 +125,12 @@ class PromptAzureOpenAIInvocationParametersContent(TypedDict): reasoning_effort: NotRequired[Literal["low", "medium", "high"]] +class PromptData(TypedDict): + name: str + description: NotRequired[str] + source_prompt_id: NotRequired[str] + + class PromptFunctionTool(TypedDict): name: str type: Literal["function-tool"] @@ -312,8 +317,24 @@ class PromptChatTemplate(TypedDict): class PromptVersion(TypedDict): + model_provider: Literal["OPENAI", "AZURE_OPENAI", "ANTHROPIC", "GEMINI"] + model_name: str + template: Union[PromptChatTemplate, PromptStringTemplate] + template_type: Literal["STR", "CHAT"] + template_format: Literal["MUSTACHE", "FSTRING", "NONE"] + invocation_parameters: Union[ + PromptOpenAIInvocationParameters, + PromptAzureOpenAIInvocationParameters, + PromptAnthropicInvocationParameters, + PromptGeminiInvocationParameters, + ] id: str - description: str + description: NotRequired[str] + tools: NotRequired[PromptTools] + response_format: NotRequired[PromptResponseFormatJSONSchema] + + +class PromptVersionData(TypedDict): model_provider: Literal["OPENAI", "AZURE_OPENAI", "ANTHROPIC", "GEMINI"] model_name: str template: Union[PromptChatTemplate, PromptStringTemplate] @@ -325,10 +346,20 @@ class PromptVersion(TypedDict): PromptAnthropicInvocationParameters, PromptGeminiInvocationParameters, ] + description: NotRequired[str] tools: NotRequired[PromptTools] response_format: NotRequired[PromptResponseFormatJSONSchema] +class CreatePromptRequestBody(TypedDict): + prompt: PromptData + version: PromptVersionData + + +class CreatePromptResponseBody(TypedDict): + data: PromptVersion + + class GetPromptResponseBody(TypedDict): data: PromptVersion diff --git a/packages/phoenix-client/src/phoenix/client/__generated__/v1/models.py b/packages/phoenix-client/src/phoenix/client/__generated__/v1/models.py new file mode 100644 index 0000000000..10f71fcb94 --- /dev/null +++ b/packages/phoenix-client/src/phoenix/client/__generated__/v1/models.py @@ -0,0 +1,546 @@ +"""Do not edit""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Literal, Mapping, Optional, Sequence, Union + +from pydantic import BaseModel, ConfigDict, Field + + +class CreateExperimentRequestBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + name: Annotated[ + Optional[str], + Field( + description="Name of the experiment (if omitted, a random name will be generated)", + title="Name", + ), + ] = None + description: Annotated[ + Optional[str], + Field(description="An optional description of the experiment", title="Description"), + ] = None + metadata: Annotated[ + Optional[Mapping[str, Any]], + Field(description="Metadata for the experiment", title="Metadata"), + ] = None + version_id: Annotated[ + Optional[str], + Field( + description="ID of the dataset version over which the experiment will be run (if omitted, the latest version will be used)", + title="Version Id", + ), + ] = None + repetitions: Annotated[ + Optional[int], + Field( + description="Number of times the experiment should be repeated for each example", + title="Repetitions", + ), + ] = 1 + + +class Dataset(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + id: Annotated[str, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + metadata: Annotated[Mapping[str, Any], Field(title="Metadata")] + created_at: Annotated[datetime, Field(title="Created At")] + updated_at: Annotated[datetime, Field(title="Updated At")] + + +class DatasetExample(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + id: Annotated[str, Field(title="Id")] + input: Annotated[Mapping[str, Any], Field(title="Input")] + output: Annotated[Mapping[str, Any], Field(title="Output")] + metadata: Annotated[Mapping[str, Any], Field(title="Metadata")] + updated_at: Annotated[datetime, Field(title="Updated At")] + + +class DatasetVersion(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + version_id: Annotated[str, Field(title="Version Id")] + description: Annotated[Optional[str], Field(title="Description")] = None + metadata: Annotated[Mapping[str, Any], Field(title="Metadata")] + created_at: Annotated[datetime, Field(title="Created At")] + + +class DatasetWithExampleCount(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + id: Annotated[str, Field(title="Id")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + metadata: Annotated[Mapping[str, Any], Field(title="Metadata")] + created_at: Annotated[datetime, Field(title="Created At")] + updated_at: Annotated[datetime, Field(title="Updated At")] + example_count: Annotated[int, Field(title="Example Count")] + + +class Experiment(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + id: Annotated[str, Field(description="The ID of the experiment", title="Id")] + dataset_id: Annotated[ + str, + Field( + description="The ID of the dataset associated with the experiment", title="Dataset Id" + ), + ] + dataset_version_id: Annotated[ + str, + Field( + description="The ID of the dataset version associated with the experiment", + title="Dataset Version Id", + ), + ] + repetitions: Annotated[ + int, Field(description="Number of times the experiment is repeated", title="Repetitions") + ] + metadata: Annotated[ + Mapping[str, Any], Field(description="Metadata of the experiment", title="Metadata") + ] + project_name: Annotated[ + Optional[str], + Field( + description="The name of the project associated with the experiment", + title="Project Name", + ), + ] = None + created_at: Annotated[ + datetime, Field(description="The creation timestamp of the experiment", title="Created At") + ] + updated_at: Annotated[ + datetime, + Field(description="The last update timestamp of the experiment", title="Updated At"), + ] + + +class GetDatasetResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: DatasetWithExampleCount + + +class GetExperimentResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Experiment + + +class InsertedSpanAnnotation(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + id: Annotated[str, Field(description="The ID of the inserted span annotation", title="Id")] + + +class JSONSchemaDraft7ObjectSchema(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + type: Annotated[Literal["json-schema-draft-7-object-schema"], Field(title="Type")] + json_: Annotated[Mapping[str, Any], Field(alias="json", title="Json")] + + +class ListDatasetExamplesData(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + dataset_id: Annotated[str, Field(title="Dataset Id")] + version_id: Annotated[str, Field(title="Version Id")] + examples: Annotated[Sequence[DatasetExample], Field(title="Examples")] + + +class ListDatasetExamplesResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: ListDatasetExamplesData + + +class ListDatasetVersionsResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[DatasetVersion], Field(title="Data")] + next_cursor: Annotated[Optional[str], Field(title="Next Cursor")] = None + + +class ListDatasetsResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[Dataset], Field(title="Data")] + next_cursor: Annotated[Optional[str], Field(title="Next Cursor")] = None + + +class ListExperimentsResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[Experiment], Field(title="Data")] + + +class Prompt(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + name: Annotated[str, Field(pattern="^[a-z0-9]([_a-z0-9-]*[a-z0-9])?$", title="Identifier")] + description: Annotated[Optional[str], Field(title="Description")] = None + source_prompt_id: Annotated[Optional[str], Field(title="Source Prompt Id")] = None + id: Annotated[str, Field(title="Id")] + + +class PromptAnthropicInvocationParametersContent(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + max_tokens: Annotated[int, Field(title="Max Tokens")] + temperature: Annotated[Optional[float], Field(title="Temperature")] = None + top_p: Annotated[Optional[float], Field(title="Top P")] = None + stop_sequences: Annotated[Optional[Sequence[str]], Field(title="Stop Sequences")] = None + + +class PromptAzureOpenAIInvocationParametersContent(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + temperature: Annotated[Optional[float], Field(title="Temperature")] = None + max_tokens: Annotated[Optional[int], Field(title="Max Tokens")] = None + frequency_penalty: Annotated[Optional[float], Field(title="Frequency Penalty")] = None + presence_penalty: Annotated[Optional[float], Field(title="Presence Penalty")] = None + top_p: Annotated[Optional[float], Field(title="Top P")] = None + seed: Annotated[Optional[int], Field(title="Seed")] = None + reasoning_effort: Annotated[ + Optional[Literal["low", "medium", "high"]], Field(title="Reasoning Effort") + ] = None + + +class PromptData(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + name: Annotated[str, Field(pattern="^[a-z0-9]([_a-z0-9-]*[a-z0-9])?$", title="Identifier")] + description: Annotated[Optional[str], Field(title="Description")] = None + source_prompt_id: Annotated[Optional[str], Field(title="Source Prompt Id")] = None + + +class PromptFunctionTool(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["function-tool"], Field(title="Type")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + schema_: Annotated[ + Optional[JSONSchemaDraft7ObjectSchema], + Field(alias="schema", discriminator="type", title="Schema"), + ] = None + extra_parameters: Annotated[Optional[Mapping[str, Any]], Field(title="Extra Parameters")] = None + + +class PromptGeminiInvocationParametersContent(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + temperature: Annotated[Optional[float], Field(title="Temperature")] = None + max_output_tokens: Annotated[Optional[int], Field(title="Max Output Tokens")] = None + stop_sequences: Annotated[Optional[Sequence[str]], Field(title="Stop Sequences")] = None + presence_penalty: Annotated[Optional[float], Field(title="Presence Penalty")] = None + frequency_penalty: Annotated[Optional[float], Field(title="Frequency Penalty")] = None + top_p: Annotated[Optional[float], Field(title="Top P")] = None + top_k: Annotated[Optional[int], Field(title="Top K")] = None + + +class PromptOpenAIInvocationParametersContent(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + temperature: Annotated[Optional[float], Field(title="Temperature")] = None + max_tokens: Annotated[Optional[int], Field(title="Max Tokens")] = None + frequency_penalty: Annotated[Optional[float], Field(title="Frequency Penalty")] = None + presence_penalty: Annotated[Optional[float], Field(title="Presence Penalty")] = None + top_p: Annotated[Optional[float], Field(title="Top P")] = None + seed: Annotated[Optional[int], Field(title="Seed")] = None + reasoning_effort: Annotated[ + Optional[Literal["low", "medium", "high"]], Field(title="Reasoning Effort") + ] = None + + +class PromptResponseFormatJSONSchema(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["response-format-json-schema"], Field(title="Type")] + name: Annotated[str, Field(title="Name")] + description: Annotated[Optional[str], Field(title="Description")] = None + schema_: Annotated[ + JSONSchemaDraft7ObjectSchema, Field(alias="schema", discriminator="type", title="Schema") + ] + extra_parameters: Annotated[Mapping[str, Any], Field(title="Extra Parameters")] + + +class PromptStringTemplate(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["string"], Field(title="Type")] + template: Annotated[str, Field(title="Template")] + + +class PromptToolChoiceNone(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["none"], Field(title="Type")] + + +class PromptToolChoiceOneOrMore(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["one-or-more"], Field(title="Type")] + + +class PromptToolChoiceSpecificFunctionTool(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["specific-function-tool"], Field(title="Type")] + function_name: Annotated[str, Field(title="Function Name")] + + +class PromptToolChoiceZeroOrMore(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["zero-or-more"], Field(title="Type")] + + +class PromptTools(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["tools"], Field(title="Type")] + tools: Annotated[ + Sequence[Annotated[PromptFunctionTool, Field(discriminator="type")],], + Field(min_length=1, title="Tools"), + ] + tool_choice: Annotated[ + Optional[ + Union[ + PromptToolChoiceNone, + PromptToolChoiceZeroOrMore, + PromptToolChoiceOneOrMore, + PromptToolChoiceSpecificFunctionTool, + ] + ], + Field(discriminator="type", title="Tool Choice"), + ] = None + disable_parallel_tool_calls: Annotated[ + Optional[bool], Field(title="Disable Parallel Tool Calls") + ] = None + + +class SpanAnnotationResult(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + label: Annotated[ + Optional[str], Field(description="The label assigned by the annotation", title="Label") + ] = None + score: Annotated[ + Optional[float], Field(description="The score assigned by the annotation", title="Score") + ] = None + explanation: Annotated[ + Optional[str], + Field(description="Explanation of the annotation result", title="Explanation"), + ] = None + + +class TextContentValue(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + text: Annotated[str, Field(title="Text")] + + +class ToolCallFunction(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + type: Annotated[Literal["function"], Field(title="Type")] + name: Annotated[str, Field(title="Name")] + arguments: Annotated[str, Field(title="Arguments")] + + +class ToolResultContentValue(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + tool_call_id: Annotated[str, Field(title="Tool Call Id")] + result: Annotated[ + Optional[Union[bool, int, float, str, Mapping[str, Any], Sequence[Any]]], + Field(title="Result"), + ] = None + + +class UploadDatasetData(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + dataset_id: Annotated[str, Field(title="Dataset Id")] + + +class UploadDatasetResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: UploadDatasetData + + +class ValidationError(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + loc: Annotated[Sequence[Union[str, int]], Field(title="Location")] + msg: Annotated[str, Field(title="Message")] + type: Annotated[str, Field(title="Error Type")] + + +class AnnotateSpansResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[InsertedSpanAnnotation], Field(title="Data")] + + +class CreateExperimentResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Experiment + + +class GetPromptsResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[Prompt], Field(title="Data")] + + +class HTTPValidationError(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + detail: Annotated[Optional[Sequence[ValidationError]], Field(title="Detail")] = None + + +class PromptAnthropicInvocationParameters(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["anthropic"], Field(title="Type")] + anthropic: PromptAnthropicInvocationParametersContent + + +class PromptAzureOpenAIInvocationParameters(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["azure_openai"], Field(title="Type")] + azure_openai: PromptAzureOpenAIInvocationParametersContent + + +class PromptGeminiInvocationParameters(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["gemini"], Field(title="Type")] + gemini: PromptGeminiInvocationParametersContent + + +class PromptOpenAIInvocationParameters(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["openai"], Field(title="Type")] + openai: PromptOpenAIInvocationParametersContent + + +class SpanAnnotation(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + span_id: Annotated[ + str, Field(description="OpenTelemetry Span ID (hex format w/o 0x prefix)", title="Span Id") + ] + name: Annotated[str, Field(description="The name of the annotation", title="Name")] + annotator_kind: Annotated[ + Literal["LLM", "HUMAN"], + Field(description="The kind of annotator used for the annotation", title="Annotator Kind"), + ] + result: Annotated[ + Optional[SpanAnnotationResult], Field(description="The result of the annotation") + ] = None + metadata: Annotated[ + Optional[Mapping[str, Any]], + Field(description="Metadata for the annotation", title="Metadata"), + ] = None + + +class TextContentPart(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["text"], Field(title="Type")] + text: TextContentValue + + +class ToolCallContentValue(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + tool_call_id: Annotated[str, Field(title="Tool Call Id")] + tool_call: ToolCallFunction + + +class ToolResultContentPart(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["tool_result"], Field(title="Type")] + tool_result: ToolResultContentValue + + +class AnnotateSpansRequestBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[SpanAnnotation], Field(title="Data")] + + +class ToolCallContentPart(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["tool_call"], Field(title="Type")] + tool_call: ToolCallContentValue + + +class PromptMessage(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + role: Annotated[Literal["USER", "SYSTEM", "AI", "TOOL"], Field(title="PromptMessageRole")] + content: Annotated[ + Sequence[ + Annotated[ + Union[TextContentPart, ToolCallContentPart, ToolResultContentPart], + Field(discriminator="type"), + ], + ], + Field(min_length=1, title="Content"), + ] + + +class PromptChatTemplate(BaseModel): + model_config = ConfigDict(extra="forbid", strict=True, validate_assignment=True) + type: Annotated[Literal["chat"], Field(title="Type")] + messages: Annotated[Sequence[PromptMessage], Field(title="Messages")] + + +class PromptVersion(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + description: Annotated[Optional[str], Field(title="Description")] = None + model_provider: Annotated[ + Literal["OPENAI", "AZURE_OPENAI", "ANTHROPIC", "GEMINI"], Field(title="ModelProvider") + ] + model_name: Annotated[str, Field(title="Model Name")] + template: Annotated[ + Union[PromptChatTemplate, PromptStringTemplate], + Field(discriminator="type", title="Template"), + ] + template_type: Annotated[Literal["STR", "CHAT"], Field(title="PromptTemplateType")] + template_format: Annotated[ + Literal["MUSTACHE", "FSTRING", "NONE"], Field(title="PromptTemplateFormat") + ] + invocation_parameters: Annotated[ + Union[ + PromptOpenAIInvocationParameters, + PromptAzureOpenAIInvocationParameters, + PromptAnthropicInvocationParameters, + PromptGeminiInvocationParameters, + ], + Field(discriminator="type", title="Invocation Parameters"), + ] + tools: Optional[PromptTools] = None + response_format: Annotated[ + Optional[PromptResponseFormatJSONSchema], Field(title="Response Format") + ] = None + id: Annotated[str, Field(title="Id")] + + +class PromptVersionData(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + description: Annotated[Optional[str], Field(title="Description")] = None + model_provider: Annotated[ + Literal["OPENAI", "AZURE_OPENAI", "ANTHROPIC", "GEMINI"], Field(title="ModelProvider") + ] + model_name: Annotated[str, Field(title="Model Name")] + template: Annotated[ + Union[PromptChatTemplate, PromptStringTemplate], + Field(discriminator="type", title="Template"), + ] + template_type: Annotated[Literal["STR", "CHAT"], Field(title="PromptTemplateType")] + template_format: Annotated[ + Literal["MUSTACHE", "FSTRING", "NONE"], Field(title="PromptTemplateFormat") + ] + invocation_parameters: Annotated[ + Union[ + PromptOpenAIInvocationParameters, + PromptAzureOpenAIInvocationParameters, + PromptAnthropicInvocationParameters, + PromptGeminiInvocationParameters, + ], + Field(discriminator="type", title="Invocation Parameters"), + ] + tools: Optional[PromptTools] = None + response_format: Annotated[ + Optional[PromptResponseFormatJSONSchema], Field(title="Response Format") + ] = None + + +class CreatePromptRequestBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + prompt: PromptData + version: PromptVersionData + + +class CreatePromptResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: PromptVersion + + +class GetPromptResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: PromptVersion + + +class GetPromptVersionsResponseBody(BaseModel): + model_config = ConfigDict(strict=True, validate_assignment=True) + data: Annotated[Sequence[PromptVersion], Field(title="Data")] diff --git a/packages/phoenix-client/src/phoenix/client/helpers/sdk/anthropic/messages.py b/packages/phoenix-client/src/phoenix/client/helpers/sdk/anthropic/messages.py index fa84663a92..441e4cdea8 100644 --- a/packages/phoenix-client/src/phoenix/client/helpers/sdk/anthropic/messages.py +++ b/packages/phoenix-client/src/phoenix/client/helpers/sdk/anthropic/messages.py @@ -57,7 +57,7 @@ v1.ToolResultContentPart, ] - def _(obj: v1.PromptVersion) -> None: + def _(obj: v1.PromptVersionData) -> None: messages, kwargs = to_chat_messages_and_kwargs(obj) Anthropic().messages.create(messages=messages, **kwargs) @@ -85,12 +85,11 @@ class _ModelKwargs(_ToolKwargs, TypedDict, total=False): def to_chat_messages_and_kwargs( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, /, *, variables: Mapping[str, str] = MappingProxyType({}), formatter: Optional[TemplateFormatter] = None, - **_: Any, ) -> tuple[list[MessageParam], _ModelKwargs]: formatter = formatter or to_formatter(obj) assert formatter is not None @@ -124,7 +123,7 @@ def to_chat_messages_and_kwargs( class _ModelKwargsConversion: @staticmethod def to_anthropic( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, ) -> _ModelKwargs: parameters: v1.PromptAnthropicInvocationParametersContent = ( obj["invocation_parameters"]["anthropic"] diff --git a/packages/phoenix-client/src/phoenix/client/helpers/sdk/google_generativeai/generate_content.py b/packages/phoenix-client/src/phoenix/client/helpers/sdk/google_generativeai/generate_content.py index 3f7bd04dda..386b16a5d9 100644 --- a/packages/phoenix-client/src/phoenix/client/helpers/sdk/google_generativeai/generate_content.py +++ b/packages/phoenix-client/src/phoenix/client/helpers/sdk/google_generativeai/generate_content.py @@ -32,7 +32,7 @@ v1.ToolResultContentPart, ] - def _(obj: v1.PromptVersion) -> None: + def _(obj: v1.PromptVersionData) -> None: messages, kwargs = to_chat_messages_and_kwargs(obj) GenerativeModel(**kwargs) _: Iterable[protos.Content] = messages @@ -57,12 +57,11 @@ class _ModelKwargs(_ToolKwargs, TypedDict, total=False): def to_chat_messages_and_kwargs( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, /, *, variables: Mapping[str, str] = MappingProxyType({}), formatter: Optional[TemplateFormatter] = None, - **_: Any, ) -> tuple[list[protos.Content], _ModelKwargs]: formatter = formatter or to_formatter(obj) assert formatter is not None @@ -93,7 +92,7 @@ def to_chat_messages_and_kwargs( def _to_model_kwargs( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, /, ) -> _ModelKwargs: invocation_parameters: v1.PromptGeminiInvocationParametersContent = ( diff --git a/packages/phoenix-client/src/phoenix/client/helpers/sdk/openai/chat.py b/packages/phoenix-client/src/phoenix/client/helpers/sdk/openai/chat.py index 659ef0de64..e319c7da23 100644 --- a/packages/phoenix-client/src/phoenix/client/helpers/sdk/openai/chat.py +++ b/packages/phoenix-client/src/phoenix/client/helpers/sdk/openai/chat.py @@ -44,7 +44,10 @@ ) from openai.types.chat.chat_completion_assistant_message_param import ContentArrayOfContentPart from openai.types.chat.chat_completion_named_tool_choice_param import Function - from openai.types.chat.completion_create_params import ResponseFormat + from openai.types.chat.completion_create_params import ( + CompletionCreateParamsBase, + ResponseFormat, + ) from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema from openai.types.shared_params.response_format_json_schema import JSONSchema @@ -52,6 +55,9 @@ def _(obj: v1.PromptVersion) -> None: messages, kwargs = to_chat_messages_and_kwargs(obj) OpenAI().chat.completions.create(messages=messages, **kwargs) + def __(obj: CompletionCreateParamsBase) -> None: + create_prompt_version_from_openai_chat(obj) + class _ToolKwargs(TypedDict, total=False): parallel_tool_calls: bool @@ -62,7 +68,7 @@ class _ToolKwargs(TypedDict, total=False): class _ModelKwargs(_ToolKwargs, TypedDict, total=False): model: Required[str] frequency_penalty: float - max_tokens: int + max_completion_tokens: int presence_penalty: float reasoning_effort: ChatCompletionReasoningEffort response_format: ResponseFormat @@ -79,19 +85,70 @@ class _ModelKwargs(_ToolKwargs, TypedDict, total=False): ] __all__ = [ + "create_prompt_version_from_openai_chat", "to_chat_messages_and_kwargs", ] logger = logging.getLogger(__name__) +def create_prompt_version_from_openai_chat( + obj: CompletionCreateParamsBase, + /, + *, + description: Optional[str] = None, + template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "MUSTACHE", +) -> v1.PromptVersionData: + messages: list[ChatCompletionMessageParam] = list(obj["messages"]) + template = v1.PromptChatTemplate( + type="chat", + messages=[_MessageConversion.from_openai(m) for m in messages], + ) + params = v1.PromptOpenAIInvocationParametersContent() + if "max_completion_tokens" in obj and obj["max_completion_tokens"] is not None: + params["max_tokens"] = int(obj["max_completion_tokens"]) + if "temperature" in obj and obj["temperature"] is not None: + params["temperature"] = float(obj["temperature"]) + if "top_p" in obj and obj["top_p"] is not None: + params["top_p"] = float(obj["top_p"]) + if "presence_penalty" in obj and obj["presence_penalty"] is not None: + params["presence_penalty"] = float(obj["presence_penalty"]) + if "frequency_penalty" in obj and obj["frequency_penalty"] is not None: + params["frequency_penalty"] = float(obj["frequency_penalty"]) + if "seed" in obj and obj["seed"] is not None: + params["seed"] = int(obj["seed"]) + if "reasoning_effort" in obj and obj["reasoning_effort"] is not None: + params["reasoning_effort"] = obj["reasoning_effort"] + ans = v1.PromptVersionData( + model_provider="OPENAI", + model_name=obj["model"], + template=template, + template_type="CHAT", + template_format=template_format, + invocation_parameters=v1.PromptOpenAIInvocationParameters(type="openai", openai=params), + ) + tool_kwargs: _ToolKwargs = {} + if "tools" in obj: + tool_kwargs["tools"] = list(obj["tools"]) + if "tool_choice" in obj: + tool_kwargs["tool_choice"] = obj["tool_choice"] + if "parallel_tool_calls" in obj: + tool_kwargs["parallel_tool_calls"] = obj["parallel_tool_calls"] + if (tools := _ToolKwargsConversion.from_openai(tool_kwargs)) is not None: + ans["tools"] = tools + if "response_format" in obj: + ans["response_format"] = _ResponseFormatConversion.from_openai(obj["response_format"]) + if description: + ans["description"] = description + return ans + + def to_chat_messages_and_kwargs( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, /, *, variables: Mapping[str, str] = MappingProxyType({}), formatter: Optional[TemplateFormatter] = None, - **_: Any, ) -> tuple[list[ChatCompletionMessageParam], _ModelKwargs]: return ( list(_to_chat_completion_messages(obj, variables, formatter)), @@ -100,7 +157,7 @@ def to_chat_messages_and_kwargs( def _to_model_kwargs( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, ) -> _ModelKwargs: parameters: v1.PromptOpenAIInvocationParametersContent = ( obj["invocation_parameters"]["openai"] @@ -110,6 +167,8 @@ def _to_model_kwargs( ans: _ModelKwargs = { "model": obj["model_name"], } + if "max_tokens" in parameters: + ans["max_completion_tokens"] = parameters["max_tokens"] if "temperature" in parameters: ans["temperature"] = parameters["temperature"] if "top_p" in parameters: @@ -133,14 +192,14 @@ def _to_model_kwargs( if "response_format" in obj: response_format = obj["response_format"] if response_format["type"] == "response-format-json-schema": - ans["response_format"] = _ResponseFormatJSONSchemaConversion.to_openai(response_format) + ans["response_format"] = _ResponseFormatConversion.to_openai(response_format) elif TYPE_CHECKING: assert_never(response_format) return ans def _to_chat_completion_messages( - obj: v1.PromptVersion, + obj: v1.PromptVersionData, variables: Mapping[str, str], formatter: Optional[TemplateFormatter] = None, /, @@ -291,7 +350,7 @@ def from_openai( return function -class _ResponseFormatJSONSchemaConversion: +class _ResponseFormatConversion: @staticmethod def to_openai( obj: v1.PromptResponseFormatJSONSchema, diff --git a/packages/phoenix-client/src/phoenix/client/resources/prompts/__init__.py b/packages/phoenix-client/src/phoenix/client/resources/prompts/__init__.py index edb190a190..75a87eda9e 100644 --- a/packages/phoenix-client/src/phoenix/client/resources/prompts/__init__.py +++ b/packages/phoenix-client/src/phoenix/client/resources/prompts/__init__.py @@ -3,7 +3,8 @@ import httpx -from phoenix.client.__generated__.v1 import GetPromptResponseBody, PromptVersion +from phoenix.client.__generated__ import v1 +from phoenix.client.utils.config import _PYDANTIC_VERSION # pyright: ignore[reportPrivateUsage] class Prompts: @@ -16,11 +17,55 @@ def get( prompt_version_id: Optional[str] = None, prompt_identifier: Optional[str] = None, tag: Optional[str] = None, - ) -> PromptVersion: + ) -> v1.PromptVersion: url = _url(prompt_version_id, prompt_identifier, tag) response = self._client.get(url) response.raise_for_status() - return cast(GetPromptResponseBody, response.json())["data"] + return cast(v1.GetPromptResponseBody, response.json())["data"] + + def create( + self, + *, + version: v1.PromptVersion, + name: str, + prompt_description: Optional[str] = None, + ) -> v1.PromptVersion: + """ + Creates a new version for the prompt under the name specified. The prompt will + be created if it doesn't already exist. + + Args: + version (v1.PromptVersion): The version of the prompt to create. + name (str): The identifier for the prompt. It can contain alphanumeric + characters, hyphens and underscores, but must begin with an + alphanumeric character. + prompt_description (Optional[str]): An optional description for the prompt. + If prompt already exists, this value is ignored by the server. + + Returns: + v1.PromptVersion: The created prompt version data. + """ + url = "v1/prompts" + prompt = v1.PromptData(name=name) + if prompt_description: + prompt["description"] = prompt_description + if _PYDANTIC_VERSION.startswith("2"): + import phoenix.client.__generated__.v1.models as m1 + + json_ = cast( + v1.CreatePromptRequestBody, + m1.CreatePromptRequestBody.model_validate( + { + "prompt": {"name": name, "description": prompt_description}, + "version": version, + } + ).model_dump(exclude_unset=True, exclude_defaults=True, by_alias=True), + ) + else: + json_ = v1.CreatePromptRequestBody(prompt=prompt, version=version) + response = self._client.post(url=url, json=json_) + response.raise_for_status() + return cast(v1.CreatePromptResponseBody, response.json())["data"] class AsyncPrompts: @@ -33,11 +78,11 @@ async def get( prompt_version_id: Optional[str] = None, prompt_identifier: Optional[str] = None, tag: Optional[str] = None, - ) -> PromptVersion: + ) -> v1.PromptVersion: url = _url(prompt_version_id, prompt_identifier, tag) response = await self._client.get(url) response.raise_for_status() - return cast(GetPromptResponseBody, response.json())["data"] + return cast(v1.GetPromptResponseBody, response.json())["data"] def _url( diff --git a/packages/phoenix-client/src/phoenix/client/utils/config.py b/packages/phoenix-client/src/phoenix/client/utils/config.py index 73744e856a..1ff09dac67 100644 --- a/packages/phoenix-client/src/phoenix/client/utils/config.py +++ b/packages/phoenix-client/src/phoenix/client/utils/config.py @@ -1,4 +1,5 @@ import os +from importlib.metadata import version from typing import Optional from phoenix.client.constants import ( @@ -68,3 +69,9 @@ def get_base_url() -> str: host = "127.0.0.1" base_url = get_env_collector_endpoint() or f"http://{host}:{get_env_port()}" return base_url if base_url.endswith("/") else base_url + "/" + + +try: + _PYDANTIC_VERSION = version("pydantic") +except Exception: + _PYDANTIC_VERSION = "" # pyright: ignore[reportConstantRedefinition] diff --git a/packages/phoenix-client/src/phoenix/client/utils/prompt.py b/packages/phoenix-client/src/phoenix/client/utils/prompt.py index e64a8a68f4..ebce4b9186 100644 --- a/packages/phoenix-client/src/phoenix/client/utils/prompt.py +++ b/packages/phoenix-client/src/phoenix/client/utils/prompt.py @@ -13,10 +13,18 @@ to_chat_messages_and_kwargs as to_messages_google_generativeai, # pyright: ignore[reportUnknownVariableType] ) from phoenix.client.helpers.sdk.openai.chat import ( - to_chat_messages_and_kwargs as to_messages_openai, # pyright: ignore[reportUnknownVariableType] + create_prompt_version_from_openai_chat, # pyright: ignore[reportUnknownVariableType] +) +from phoenix.client.helpers.sdk.openai.chat import ( + to_chat_messages_and_kwargs as to_messages_openai, ) from phoenix.client.utils.template_formatters import TemplateFormatter +__all__ = [ + "to_chat_messages_and_kwargs", + "create_prompt_version_from_openai_chat", +] + SDK: TypeAlias = Literal[ "anthropic", # https://pypi.org/project/anthropic/ "google_generativeai", # https://pypi.org/project/google-generativeai/ diff --git a/packages/phoenix-client/src/phoenix/client/utils/template_formatters.py b/packages/phoenix-client/src/phoenix/client/utils/template_formatters.py index 380b5cbadc..3c840c2e6c 100644 --- a/packages/phoenix-client/src/phoenix/client/utils/template_formatters.py +++ b/packages/phoenix-client/src/phoenix/client/utils/template_formatters.py @@ -7,7 +7,7 @@ from typing_extensions import assert_never -from phoenix.client.__generated__.v1 import PromptVersion +from phoenix.client.__generated__ import v1 class TemplateFormatter(Protocol): @@ -142,7 +142,7 @@ class TemplateFormatterError(Exception): NO_OP_FORMATTER = NoOpFormatterBase() -def to_formatter(obj: PromptVersion) -> BaseTemplateFormatter: +def to_formatter(obj: v1.PromptVersionData) -> BaseTemplateFormatter: if ( "template_format" not in obj or not obj["template_format"] diff --git a/packages/phoenix-client/tests/canary/sdk/openai/test_chat.py b/packages/phoenix-client/tests/canary/sdk/openai/test_chat.py index 13fb016759..19cd7e6727 100644 --- a/packages/phoenix-client/tests/canary/sdk/openai/test_chat.py +++ b/packages/phoenix-client/tests/canary/sdk/openai/test_chat.py @@ -2,12 +2,15 @@ import json from enum import Enum +from random import randint, random +from secrets import token_hex from typing import Any, Iterable, Mapping, Optional, Union, cast import pytest from deepdiff.diff import DeepDiff from faker import Faker from openai.lib._parsing import type_to_response_format_param +from openai.lib._tools import pydantic_function_tool from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionContentPartParam, @@ -21,7 +24,7 @@ ) from openai.types.chat.chat_completion_assistant_message_param import ContentArrayOfContentPart from openai.types.chat.chat_completion_message_tool_call_param import Function -from openai.types.chat.completion_create_params import ResponseFormat +from openai.types.chat.completion_create_params import CompletionCreateParamsBase, ResponseFormat from openai.types.shared_params import FunctionDefinition from pydantic import BaseModel, create_model @@ -29,11 +32,13 @@ from phoenix.client.helpers.sdk.openai.chat import ( _FunctionToolConversion, _MessageConversion, - _ResponseFormatJSONSchemaConversion, + _ResponseFormatConversion, _TextContentPartConversion, _ToolCallContentPartConversion, _ToolKwargs, _ToolKwargsConversion, + create_prompt_version_from_openai_chat, + to_chat_messages_and_kwargs, ) from phoenix.client.utils.template_formatters import NO_OP_FORMATTER @@ -212,6 +217,24 @@ def test_formatter(self) -> None: assert ans["text"] == formatter.format(x["text"]["text"], variables=variables) +class _GetWeather(BaseModel): + city: str + + +class _GetPopulation(BaseModel): + country: str + year: Optional[int] = None + + +_TOOLS = [ + cast( + ChatCompletionToolParam, + json.loads(json.dumps(pydantic_function_tool(t))), + ) + for t in cast(Iterable[type[BaseModel]], [_GetWeather, _GetPopulation]) +] + + class TestToolCallContentPartConversion: def test_round_trip(self) -> None: obj: ChatCompletionMessageToolCallParam = _tool_call() @@ -255,8 +278,8 @@ class TestResponseFormatJSONSchemaConversion: ) def test_round_trip(self, type_: type[BaseModel]) -> None: obj = cast(ResponseFormat, type_to_response_format_param(type_)) - x: v1.PromptResponseFormatJSONSchema = _ResponseFormatJSONSchemaConversion.from_openai(obj) - new_obj = _ResponseFormatJSONSchemaConversion.to_openai(x) + x: v1.PromptResponseFormatJSONSchema = _ResponseFormatConversion.from_openai(obj) + new_obj = _ResponseFormatConversion.to_openai(x) assert not DeepDiff(obj, new_obj) @@ -318,6 +341,69 @@ def test_round_trip(self, obj: _ToolKwargs) -> None: assert not DeepDiff(obj, new_obj) +class TestCompletionCreateParamsBase: + @pytest.mark.parametrize( + "obj", + [ + CompletionCreateParamsBase( + model=token_hex(8), + messages=[ + { + "role": "system", + "content": "You will be provided with statements, and your task is" + "to convert them to standard English.", + }, + {"role": "user", "content": "{{ statement }}"}, + ], + temperature=random(), + max_completion_tokens=randint(1, 256), + top_p=random(), + ), + CompletionCreateParamsBase( + model=token_hex(8), + messages=[ + { + "role": "system", + "content": "You are a UI generator. Convert the user input into a UI.", + }, + { + "role": "user", + "content": "Make a form for {{ feature }}.", + }, + ], + response_format=cast( + ResponseFormat, + type_to_response_format_param( + create_model("Response", ui=(_UI, ...)), + ), + ), + temperature=random(), + max_completion_tokens=randint(1, 256), + top_p=random(), + ), + CompletionCreateParamsBase( + model=token_hex(8), + messages=[ + { + "role": "user", + "content": "What is the latest population estimate for {{ location }}?", + } + ], + tools=_TOOLS, + tool_choice="required", + temperature=random(), + max_completion_tokens=randint(1, 256), + top_p=random(), + ), + ], + ) + def test_round_trip(self, obj: CompletionCreateParamsBase) -> None: + pv: v1.PromptVersionData = create_prompt_version_from_openai_chat(obj) + messages, kwargs = to_chat_messages_and_kwargs(pv, formatter=NO_OP_FORMATTER) + new_obj = CompletionCreateParamsBase(messages=messages, **kwargs) # type: ignore[typeddict-item] + assert not DeepDiff(obj, new_obj) + + class _MockFormatter: def format(self, _: str, /, *, variables: Mapping[str, str]) -> str: return json.dumps(variables) diff --git a/schemas/openapi.json b/schemas/openapi.json index a904a0c27a..5a0cbd3803 100644 --- a/schemas/openapi.json +++ b/schemas/openapi.json @@ -1264,6 +1264,55 @@ "description": "Unprocessable Entity" } } + }, + "post": { + "tags": [ + "prompts" + ], + "summary": "Create a prompt version", + "operationId": "postPromptVersion", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreatePromptRequestBody" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreatePromptResponseBody" + } + } + } + }, + "403": { + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + }, + "description": "Forbidden" + }, + "422": { + "content": { + "text/plain": { + "schema": { + "type": "string" + } + } + }, + "description": "Unprocessable Entity" + } + } } }, "/v1/prompts/{prompt_identifier}/versions": { @@ -1662,6 +1711,34 @@ ], "title": "CreateExperimentResponseBody" }, + "CreatePromptRequestBody": { + "properties": { + "prompt": { + "$ref": "#/components/schemas/PromptData" + }, + "version": { + "$ref": "#/components/schemas/PromptVersionData" + } + }, + "type": "object", + "required": [ + "prompt", + "version" + ], + "title": "CreatePromptRequestBody" + }, + "CreatePromptResponseBody": { + "properties": { + "data": { + "$ref": "#/components/schemas/PromptVersion" + } + }, + "type": "object", + "required": [ + "data" + ], + "title": "CreatePromptResponseBody" + }, "Dataset": { "properties": { "id": { @@ -1977,6 +2054,11 @@ "type": "object", "title": "HTTPValidationError" }, + "Identifier": { + "type": "string", + "pattern": "^[a-z0-9]([_a-z0-9-]*[a-z0-9])?$", + "title": "Identifier" + }, "InsertedSpanAnnotation": { "properties": { "id": { @@ -2132,11 +2214,10 @@ }, "Prompt": { "properties": { - "id": { - "type": "string", - "title": "Id" + "name": { + "$ref": "#/components/schemas/Identifier" }, - "source_prompt_id": { + "description": { "anyOf": [ { "type": "string" @@ -2145,13 +2226,9 @@ "type": "null" } ], - "title": "Source Prompt Id" - }, - "name": { - "type": "string", - "title": "Name" + "title": "Description" }, - "description": { + "source_prompt_id": { "anyOf": [ { "type": "string" @@ -2160,15 +2237,17 @@ "type": "null" } ], - "title": "Description" + "title": "Source Prompt Id" + }, + "id": { + "type": "string", + "title": "Id" } }, "type": "object", "required": [ - "id", - "source_prompt_id", "name", - "description" + "id" ], "title": "Prompt" }, @@ -2302,6 +2381,40 @@ ], "title": "PromptChatTemplate" }, + "PromptData": { + "properties": { + "name": { + "$ref": "#/components/schemas/Identifier" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "source_prompt_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Source Prompt Id" + } + }, + "type": "object", + "required": [ + "name" + ], + "title": "PromptData" + }, "PromptFunctionTool": { "properties": { "type": { @@ -2722,12 +2835,133 @@ }, "PromptVersion": { "properties": { + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "model_provider": { + "$ref": "#/components/schemas/ModelProvider" + }, + "model_name": { + "type": "string", + "title": "Model Name" + }, + "template": { + "oneOf": [ + { + "$ref": "#/components/schemas/PromptChatTemplate" + }, + { + "$ref": "#/components/schemas/PromptStringTemplate" + } + ], + "title": "Template", + "discriminator": { + "propertyName": "type", + "mapping": { + "chat": "#/components/schemas/PromptChatTemplate", + "string": "#/components/schemas/PromptStringTemplate" + } + } + }, + "template_type": { + "$ref": "#/components/schemas/PromptTemplateType" + }, + "template_format": { + "$ref": "#/components/schemas/PromptTemplateFormat" + }, + "invocation_parameters": { + "oneOf": [ + { + "$ref": "#/components/schemas/PromptOpenAIInvocationParameters" + }, + { + "$ref": "#/components/schemas/PromptAzureOpenAIInvocationParameters" + }, + { + "$ref": "#/components/schemas/PromptAnthropicInvocationParameters" + }, + { + "$ref": "#/components/schemas/PromptGeminiInvocationParameters" + } + ], + "title": "Invocation Parameters", + "discriminator": { + "propertyName": "type", + "mapping": { + "anthropic": "#/components/schemas/PromptAnthropicInvocationParameters", + "azure_openai": "#/components/schemas/PromptAzureOpenAIInvocationParameters", + "gemini": "#/components/schemas/PromptGeminiInvocationParameters", + "openai": "#/components/schemas/PromptOpenAIInvocationParameters" + } + } + }, + "tools": { + "anyOf": [ + { + "$ref": "#/components/schemas/PromptTools" + }, + { + "type": "null" + } + ] + }, + "response_format": { + "anyOf": [ + { + "oneOf": [ + { + "$ref": "#/components/schemas/PromptResponseFormatJSONSchema" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "response-format-json-schema": "#/components/schemas/PromptResponseFormatJSONSchema" + } + } + }, + { + "type": "null" + } + ], + "title": "Response Format" + }, "id": { "type": "string", "title": "Id" - }, + } + }, + "type": "object", + "required": [ + "model_provider", + "model_name", + "template", + "template_type", + "template_format", + "invocation_parameters", + "id" + ], + "title": "PromptVersion" + }, + "PromptVersionData": { + "properties": { "description": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], "title": "Description" }, "model_provider": { @@ -2821,8 +3055,6 @@ }, "type": "object", "required": [ - "id", - "description", "model_provider", "model_name", "template", @@ -2830,7 +3062,7 @@ "template_format", "invocation_parameters" ], - "title": "PromptVersion" + "title": "PromptVersionData" }, "SpanAnnotation": { "properties": { diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 972d114046..5d23eb7c51 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -47,7 +47,9 @@ PromptResponseFormat, PromptResponseFormatRootModel, PromptTemplate, + PromptTemplateFormat, PromptTemplateRootModel, + PromptTemplateType, PromptTools, is_prompt_invocation_parameters, is_prompt_template, @@ -212,6 +214,40 @@ def process_result_value( ) +class _PromptTemplateType(TypeDecorator[PromptTemplateType]): + # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html + cache_ok = True + impl = String + + def process_bind_param(self, value: Optional[PromptTemplateType], _: Dialect) -> Optional[str]: + if isinstance(value, str): + return PromptTemplateType(value).value + return None if value is None else value.value + + def process_result_value( + self, value: Optional[str], _: Dialect + ) -> Optional[PromptTemplateType]: + return None if value is None else PromptTemplateType(value) + + +class _PromptTemplateFormat(TypeDecorator[PromptTemplateFormat]): + # See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html + cache_ok = True + impl = String + + def process_bind_param( + self, value: Optional[PromptTemplateFormat], _: Dialect + ) -> Optional[str]: + if isinstance(value, str): + return PromptTemplateFormat(value).value + return None if value is None else value.value + + def process_result_value( + self, value: Optional[str], _: Dialect + ) -> Optional[PromptTemplateFormat]: + return None if value is None else PromptTemplateFormat(value) + + class ExperimentRunOutput(TypedDict, total=False): task_output: Any @@ -1013,13 +1049,13 @@ class PromptVersion(Base): index=True, nullable=True, ) - template_type: Mapped[str] = mapped_column( - String, + template_type: Mapped[PromptTemplateType] = mapped_column( + _PromptTemplateType, CheckConstraint("template_type IN ('CHAT', 'STR')", name="template_type"), nullable=False, ) - template_format: Mapped[str] = mapped_column( - String, + template_format: Mapped[PromptTemplateFormat] = mapped_column( + _PromptTemplateFormat, CheckConstraint( "template_format IN ('FSTRING', 'MUSTACHE', 'NONE')", name="template_format" ), diff --git a/src/phoenix/server/api/openapi/schema.py b/src/phoenix/server/api/openapi/schema.py index df9aa35e85..bf25797a11 100644 --- a/src/phoenix/server/api/openapi/schema.py +++ b/src/phoenix/server/api/openapi/schema.py @@ -13,4 +13,5 @@ def get_openapi_schema() -> dict[str, Any]: openapi_version="3.1.0", description="Schema for Arize-Phoenix REST API", routes=v1_router.routes, + separate_input_output_schemas=False, ) diff --git a/src/phoenix/server/api/routers/v1/models.py b/src/phoenix/server/api/routers/v1/models.py index e430676ea7..eb80802768 100644 --- a/src/phoenix/server/api/routers/v1/models.py +++ b/src/phoenix/server/api/routers/v1/models.py @@ -36,4 +36,7 @@ def datetime_encoder(dt: datetime) -> str: class V1RoutesBaseModel(BaseModel): - model_config = ConfigDict({"json_encoders": {datetime: datetime_encoder}}) + model_config = ConfigDict( + json_encoders={datetime: datetime_encoder}, + validate_assignment=True, + ) diff --git a/src/phoenix/server/api/routers/v1/prompts.py b/src/phoenix/server/api/routers/v1/prompts.py index d5bca9d0dc..bd641a6592 100644 --- a/src/phoenix/server/api/routers/v1/prompts.py +++ b/src/phoenix/server/api/routers/v1/prompts.py @@ -30,16 +30,18 @@ logger = logging.getLogger(__name__) -class Prompt(V1RoutesBaseModel): - id: str - source_prompt_id: Optional[str] - name: str - description: Optional[str] +class PromptData(V1RoutesBaseModel): + name: Identifier + description: Optional[str] = None + source_prompt_id: Optional[str] = None -class PromptVersion(V1RoutesBaseModel): +class Prompt(PromptData): id: str - description: str + + +class PromptVersionData(V1RoutesBaseModel): + description: Optional[str] = None model_provider: ModelProvider model_name: str template: PromptTemplate @@ -50,6 +52,10 @@ class PromptVersion(V1RoutesBaseModel): response_format: Optional[PromptResponseFormat] = None +class PromptVersion(PromptVersionData): + id: str + + class GetPromptResponseBody(ResponseBody[PromptVersion]): pass @@ -62,6 +68,15 @@ class GetPromptVersionsResponseBody(ResponseBody[list[PromptVersion]]): pass +class CreatePromptRequestBody(V1RoutesBaseModel): + prompt: PromptData + version: PromptVersionData + + +class CreatePromptResponseBody(ResponseBody[PromptVersion]): + pass + + router = APIRouter(tags=["prompts"]) @@ -266,6 +281,58 @@ async def get_prompt_version_by_latest( return GetPromptResponseBody(data=data) +@router.post( + "/prompts", + operation_id="postPromptVersion", + summary="Create a prompt version", + responses=add_errors_to_responses( + [ + HTTP_422_UNPROCESSABLE_ENTITY, + ] + ), + response_model_by_alias=True, + response_model_exclude_defaults=True, + response_model_exclude_unset=True, +) +async def create_prompt( + request: Request, + request_body: CreatePromptRequestBody, +) -> CreatePromptResponseBody: + prompt = request_body.prompt + try: + name = Identifier.model_validate(prompt.name) + except ValidationError as e: + raise HTTPException( + HTTP_422_UNPROCESSABLE_ENTITY, + "Invalid name identifier for prompt: " + e.errors()[0]["msg"], + ) + version = request_body.version + async with request.app.state.db() as session: + if not (prompt_id := await session.scalar(select(models.Prompt.id).filter_by(name=name))): + prompt_orm = models.Prompt( + name=name, + description=prompt.description, + ) + session.add(prompt_orm) + await session.flush() + prompt_id = prompt_orm.id + version_orm = models.PromptVersion( + prompt_id=prompt_id, + description=version.description, + model_provider=version.model_provider, + model_name=version.model_name, + template_type=version.template_type, + template_format=version.template_format, + template=version.template, + invocation_parameters=version.invocation_parameters, + tools=version.tools, + response_format=version.response_format, + ) + session.add(version_orm) + data = _prompt_version_from_orm_version(version_orm) + return CreatePromptResponseBody(data=data) + + class _PromptId(int): ... @@ -321,13 +388,13 @@ def _prompt_version_from_orm_version( ) -def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> Prompt: +def _prompt_from_orm_prompt(orm_prompt: models.Prompt) -> PromptData: source_prompt_id = ( str(GlobalID(PromptNodeType.__name__, str(orm_prompt.source_prompt_id))) if orm_prompt.source_prompt_id else None ) - return Prompt( + return PromptData( id=str(GlobalID(PromptNodeType.__name__, str(orm_prompt.id))), source_prompt_id=source_prompt_id, name=orm_prompt.name, diff --git a/tests/integration/prompts/test_prompts.py b/tests/integration/prompts/test_prompts.py index e098f65729..46d652f588 100644 --- a/tests/integration/prompts/test_prompts.py +++ b/tests/integration/prompts/test_prompts.py @@ -24,7 +24,7 @@ ChatCompletionToolParam, ) from openai.types.shared_params import ResponseFormatJSONSchema -from phoenix.client.__generated__.v1 import PromptVersion +from phoenix.client.__generated__ import v1 from phoenix.client.utils import to_chat_messages_and_kwargs from pydantic import BaseModel, create_model @@ -54,6 +54,7 @@ def test_user_message( prompt = _create_chat_prompt(u, template_format="FSTRING") messages, _ = to_chat_messages_and_kwargs(prompt, variables={"x": x}) assert not DeepDiff(expected, messages) + _can_recreate_under_new_identifier(prompt) class _GetWeather(BaseModel): @@ -97,6 +98,7 @@ def test_openai( if t["type"] == "function" and "parameters" in t["function"] } assert not DeepDiff(expected, actual) + _can_recreate_under_new_identifier(prompt) @pytest.mark.parametrize( "types_", @@ -132,6 +134,7 @@ def test_anthropic( assert not DeepDiff(expected, actual) assert "max_tokens" in kwargs assert kwargs["max_tokens"] == 1024 + _can_recreate_under_new_identifier(prompt) class TestToolChoice: @@ -162,6 +165,7 @@ def test_openai( assert "tool_choice" in kwargs actual = kwargs["tool_choice"] assert not DeepDiff(expected, actual) + _can_recreate_under_new_identifier(prompt) @pytest.mark.parametrize( "expected", @@ -198,6 +202,7 @@ def test_anthropic( assert not DeepDiff(expected, actual) assert "max_tokens" in kwargs assert kwargs["max_tokens"] == 1024 + _can_recreate_under_new_identifier(prompt) class _UIType(str, Enum): @@ -246,6 +251,22 @@ def test_openai( assert "response_format" in kwargs actual = kwargs["response_format"] assert not DeepDiff(expected, actual) + _can_recreate_under_new_identifier(prompt) + + +def _can_recreate_under_new_identifier(version: v1.PromptVersion) -> None: + new_name = token_hex(8) + a = px.Client().prompts.create(name=new_name, version=version) + assert version["id"] != a["id"] + expected = {**version, "id": ""} + assert not DeepDiff(expected, {**a, "id": ""}) + b = px.Client().prompts.get(prompt_identifier=new_name) + assert a["id"] == b["id"] + assert not DeepDiff(expected, {**b, "id": ""}) + same_name = new_name + c = px.Client().prompts.create(name=same_name, version=version) + assert a["id"] != c["id"] + assert not DeepDiff(expected, {**c, "id": ""}) def _create_chat_prompt( @@ -259,7 +280,7 @@ def _create_chat_prompt( tools: Sequence[ToolDefinitionInput] = (), invocation_parameters: Mapping[str, Any] = MappingProxyType({}), template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "NONE", -) -> PromptVersion: +) -> v1.PromptVersion: messages = list(messages) or [ PromptMessageInput( role="USER", diff --git a/tox.ini b/tox.ini index 31e148bff4..6f76c4dd1c 100644 --- a/tox.ini +++ b/tox.ini @@ -295,13 +295,15 @@ commands = [testenv:openapi_codegen_for_python_client] description = Generate data models from OpenAPI schema for Python client changedir = packages/phoenix-client/src/phoenix/client/__generated__/ +commands_pre = + uv tool install ruff@0.8.6 commands = uv pip list -v python -c "import pathlib; pathlib.Path('v1/__init__.py').unlink(missing_ok=True)" uv tool run --from datamodel-code-generator datamodel-codegen \ --input {toxinidir}/schemas/openapi.json \ --input-file-type openapi \ - --output v1/__init__.py \ + --output v1/.dataclass.txt \ --output-model-type dataclasses.dataclass \ --collapse-root-models \ --enum-field-as-literal all \ @@ -311,11 +313,31 @@ commands = --use-generic-container-types \ --wrap-string-literal \ --disable-timestamp - python -c "import re; file = 'v1/__init__.py'; lines = [re.sub(r'\\bSequence]', 'Sequence[Any]]', line) for line in open(file).readlines()]; open(file, 'w').writelines(lines)" - python {toxinidir}/packages/phoenix-client/scripts/codegen/transform.py v1/__init__.py + uv tool run --from datamodel-code-generator datamodel-codegen \ + --input {toxinidir}/schemas/openapi.json \ + --input-file-type openapi \ + --output v1/.pydantic.txt \ + --output-model-type pydantic_v2.BaseModel \ + --collapse-root-models \ + --enum-field-as-literal all \ + --target-python-version 3.9 \ + --use-annotated \ + --use-default-kwarg \ + --use-double-quotes \ + --use-generic-container-types \ + --wrap-string-literal \ + --disable-timestamp + python -c "import re; file = 'v1/.pydantic.txt'; lines = [re.sub(r'\\bSequence]', 'Sequence[Any]]', line) for line in open(file).readlines()]; open(file, 'w').writelines(lines)" + python -c "import re; file = 'v1/.dataclass.txt'; lines = [re.sub(r'\\bSequence]', 'Sequence[Any]]', line) for line in open(file).readlines()]; open(file, 'w').writelines(lines)" + python {toxinidir}/packages/phoenix-client/scripts/codegen/transform.py v1 uv pip install --strict --reinstall-package arize-phoenix-client {toxinidir}/packages/phoenix-client + uv pip uninstall pydantic uv pip list -v python -c "import phoenix.client.__generated__.v1" + uv pip install --strict -U pydantic + python -c "import phoenix.client.__generated__.v1.models" + uv tool run ruff format v1 + uv tool run ruff check --fix v1 [testenv:graphql_codegen_for_python_tests] description = Generate data models from GraphQL schema for Python tests diff --git a/tutorials/internal/prompts/hallucination_eval.ipynb b/tutorials/internal/prompts/hallucination_eval.ipynb index 6528cf6c10..4e0945afae 100644 --- a/tutorials/internal/prompts/hallucination_eval.ipynb +++ b/tutorials/internal/prompts/hallucination_eval.ipynb @@ -15,12 +15,14 @@ "import openai\n", "import pandas as pd\n", "from dotenv import load_dotenv\n", + "from openai.types.chat.completion_create_params import CompletionCreateParamsBase\n", "from openinference.instrumentation.groq import GroqInstrumentor\n", "from openinference.instrumentation.openai import OpenAIInstrumentor\n", "from sklearn.metrics import accuracy_score\n", "\n", "import phoenix as px\n", "from phoenix.client import Client\n", + "from phoenix.client.helpers.sdk.openai.chat import create_prompt_version_from_openai_chat\n", "from phoenix.client.utils import to_chat_messages_and_kwargs\n", "from phoenix.experiments import run_experiment\n", "from phoenix.otel import register\n", @@ -74,92 +76,157 @@ }, { "cell_type": "markdown", - "id": "88dc3cc5", + "id": "d0ea3a27", "metadata": {}, "source": [ - "# Get Prompt" + "# Upload Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cdbed73", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_name = \"hallu-eval-\" + token_hex(4) # adding a random suffix for demo purposes\n", + "\n", + "ds = px.Client().upload_dataset(\n", + " dataframe=df,\n", + " dataset_name=dataset_name,\n", + " input_keys=[\"question\", \"knowledge\", \"answer\"],\n", + " output_keys=[\"true_label\"],\n", + ")" ] }, { "cell_type": "markdown", - "id": "14b09dc0", + "id": "eed1a954b3859891", "metadata": {}, "source": [ - "https://github.com/Arize-ai/phoenix/blob/390cfaa42c5b2c28d3f9f83fbf7c694b8c2beeff/packages/phoenix-evals/src/phoenix/evals/default_templates.py#L56" + "# Create Prompt\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "593ec84e", + "id": "e7c810e0b6cdf90", "metadata": {}, "outputs": [], "source": [ - "prompt = Client().prompts.get(prompt_identifier=\"hallu-eval\")" + "prompt_name = f\"hallu-eval-{token_hex(4)}\" # adding a random suffix for demo purposes" ] }, { "cell_type": "markdown", - "id": "cb0e0cba", + "id": "b16fb9c96b819d4e", "metadata": {}, "source": [ - "# GPT 4o Mini" + "Send this [prompt](https://github.com/Arize-ai/phoenix/blob/390cfaa42c5b2c28d3f9f83fbf7c694b8c2beeff/packages/phoenix-evals/src/phoenix/evals/default_templates.py#L56) to Phoenix." ] }, { "cell_type": "code", "execution_count": null, - "id": "d0c38fc6", + "id": "af22e707bedaff64", "metadata": {}, "outputs": [], "source": [ - "def openai_eval(input):\n", - " messages, kwargs = to_chat_messages_and_kwargs(prompt, variables=dict(input))\n", - " response = openai.OpenAI().chat.completions.create(messages=messages, **kwargs)\n", - " return {\"label\": response.choices[0].message.content}" + "content = \"\"\"\\\n", + "In this task, you will be presented with a query, a reference text and an answer. The answer is\n", + "generated to the question based on the reference text. The answer may contain false information. You\n", + "must use the reference text to determine if the answer to the question contains false information,\n", + "if the answer is a hallucination of facts. Your objective is to determine whether the answer text\n", + "contains factual information and is not a hallucination. A 'hallucination' refers to\n", + "an answer that is not based on the reference text or assumes information that is not available in\n", + "the reference text. Your response should be a single word: either \"factual\" or \"hallucinated\", and\n", + "it should not include any other text or characters. \"hallucinated\" indicates that the answer\n", + "provides factually inaccurate information to the query based on the reference text. \"factual\"\n", + "indicates that the answer to the question is correct relative to the reference text, and does not\n", + "contain made up information. Please read the query and reference text carefully before determining\n", + "your response.\n", + "\n", + "[BEGIN DATA]\n", + "************\n", + "[Query]: {{ question }}\n", + "************\n", + "[Reference text]: {{ knowledge }}\n", + "************\n", + "[Answer]: {{ answer }}\n", + "************\n", + "[END DATA]\n", + "\n", + "Is the answer above factual or hallucinated based on the query and reference text?\n", + "\"\"\"\n", + "_ = Client().prompts.create(\n", + " name=prompt_name,\n", + " version=create_prompt_version_from_openai_chat(\n", + " CompletionCreateParamsBase(\n", + " messages=[{\"role\": \"user\", \"content\": content}],\n", + " model=\"gpt-4o-mini\",\n", + " temperature=1,\n", + " )\n", + " ),\n", + ")" ] }, { "cell_type": "markdown", - "id": "4b137880", + "id": "88dc3cc5", "metadata": {}, "source": [ - "### DataFrame Apply" + "# Get Prompt" ] }, { "cell_type": "code", "execution_count": null, - "id": "bd9bdca3", + "id": "593ec84e", "metadata": {}, "outputs": [], "source": [ - "gpt_result = pd.concat([pd.json_normalize(df.apply(openai_eval, axis=1)), df.true_label], axis=1)\n", - "print(f\"Accuracy: {accuracy_score(gpt_result.true_label, gpt_result.label) * 100:.0f}%\")\n", - "gpt_result" + "prompt = Client().prompts.get(prompt_identifier=prompt_name)" ] }, { "cell_type": "markdown", - "id": "d0ea3a27", + "id": "cb0e0cba", "metadata": {}, "source": [ - "# Upload Dataset" + "# GPT 4o Mini" ] }, { "cell_type": "code", "execution_count": null, - "id": "1cdbed73", + "id": "d0c38fc6", "metadata": {}, "outputs": [], "source": [ - "ds = px.Client().upload_dataset(\n", - " dataframe=df,\n", - " dataset_name=\"hallu-eval-\" + token_hex(),\n", - " input_keys=[\"question\", \"knowledge\", \"answer\"],\n", - " output_keys=[\"true_label\"],\n", - ")" + "def openai_eval(input):\n", + " messages, kwargs = to_chat_messages_and_kwargs(prompt, variables=dict(input))\n", + " response = openai.OpenAI().chat.completions.create(messages=messages, **kwargs)\n", + " return {\"label\": response.choices[0].message.content}" + ] + }, + { + "cell_type": "markdown", + "id": "4b137880", + "metadata": {}, + "source": [ + "### DataFrame Apply" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd9bdca3", + "metadata": {}, + "outputs": [], + "source": [ + "gpt_result = pd.concat([pd.json_normalize(df.apply(openai_eval, axis=1)), df.true_label], axis=1)\n", + "print(f\"Accuracy: {accuracy_score(gpt_result.true_label, gpt_result.label) * 100:.0f}%\")\n", + "gpt_result" ] }, {