Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Experiment and dataset improvements #6163

Merged
merged 15 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pg = [
"psycopg[binary,pool]",
]
container = [
"anthropic",
"anthropic>=0.45.2",
"google-generativeai",
"prometheus-client",
"openai>=1.0.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
asyncpg
openai
anthropic
anthropic>=0.45.2
google-generativeai
psycopg[binary,pool]
uvloop; platform_system != 'Windows'
Expand Down
2 changes: 1 addition & 1 deletion requirements/type-check.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r ci.txt
anthropic
anthropic>=0.45.2
asyncpg
grpcio
litellm>=1.0.3
Expand Down
2 changes: 1 addition & 1 deletion requirements/unit-tests.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r ci.txt
anthropic
anthropic>=0.45.2
Faker>=30.1.0
arize
asgi-lifespan
Expand Down
14 changes: 14 additions & 0 deletions src/phoenix/server/api/helpers/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)

Expand All @@ -27,12 +28,18 @@ def get_dataset_example_input(span: Span) -> dict[str, Any]:
input_mime_type = get_attribute_value(attributes, INPUT_MIME_TYPE)
prompt_template_variables = get_attribute_value(attributes, LLM_PROMPT_TEMPLATE_VARIABLES)
input_messages = get_attribute_value(attributes, LLM_INPUT_MESSAGES)
tool_definitions = []
if tools := get_attribute_value(attributes, LLM_TOOLS):
for tool in tools:
if definition := get_attribute_value(tool, TOOL_DEFINITION):
tool_definitions.append(definition)
if span_kind == LLM:
return _get_llm_span_input(
input_messages=input_messages,
input_value=input_value,
input_mime_type=input_mime_type,
prompt_template_variables=prompt_template_variables,
tools=tool_definitions,
)
return _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")

Expand Down Expand Up @@ -71,6 +78,7 @@ def _get_llm_span_input(
input_value: Any,
input_mime_type: Optional[str],
prompt_template_variables: Any,
tools: Any,
) -> dict[str, Any]:
"""
Extracts the input value from an LLM span and returns it as a dictionary.
Expand All @@ -84,6 +92,8 @@ def _get_llm_span_input(
input = _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")
if prompt_template_variables_data := _safely_json_decode(prompt_template_variables):
input["prompt_template_variables"] = prompt_template_variables_data
if tool_definitions_data := [_safely_json_decode(tool_definition) for tool_definition in tools]:
input["tools"] = tool_definitions_data
return input


Expand Down Expand Up @@ -215,3 +225,7 @@ def _safely_json_decode(value: Any) -> Any:
# ToolCallAttributes
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME

# ToolAttributes
LLM_TOOLS = SpanAttributes.LLM_TOOLS
TOOL_DEFINITION = ToolAttributes.TOOL_JSON_SCHEMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can rename to TOOL_JSON_SCHEMA for consistency with the name of the convention

2 changes: 2 additions & 0 deletions src/phoenix/server/api/helpers/playground_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,8 @@ async def chat_completion_create(
pass
elif isinstance(event, anthropic_streaming.InputJsonEvent):
raise NotImplementedError
elif isinstance(event, anthropic_streaming._types.CitationEvent):
raise NotImplementedError
else:
assert_never(event)

Expand Down
8 changes: 1 addition & 7 deletions src/phoenix/server/api/mutations/chat_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
from phoenix.server.api.subscriptions import (
_default_playground_experiment_description,
_default_playground_experiment_metadata,
_default_playground_experiment_name,
)
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
Expand Down Expand Up @@ -183,12 +182,7 @@ async def chat_completion_over_dataset(
description=input.experiment_description
or _default_playground_experiment_description(dataset_name=dataset.name),
repetitions=1,
metadata_=input.experiment_metadata
or _default_playground_experiment_metadata(
dataset_name=dataset.name,
dataset_id=input.dataset_id,
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
),
metadata_=input.experiment_metadata or dict(),
project_name=PLAYGROUND_PROJECT_NAME,
)
session.add(experiment)
Expand Down
21 changes: 21 additions & 0 deletions src/phoenix/server/api/mutations/dataset_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import strawberry
from openinference.semconv.trace import (
MessageAttributes,
MessageContentAttributes,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)
from sqlalchemy import and_, delete, distinct, func, insert, select, update
from strawberry import UNSET
Expand Down Expand Up @@ -181,6 +185,17 @@ async def add_spans_to_dataset(
assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
DatasetExampleRevision = models.DatasetExampleRevision

all_span_attributes = {
**SpanAttributes.__dict__,
**MessageAttributes.__dict__,
**MessageContentAttributes.__dict__,
**ToolCallAttributes.__dict__,
**ToolAttributes.__dict__,
}
nonprivate_span_attributes = {
k: v for k, v in all_span_attributes.items() if not k.startswith("_")
}

await session.execute(
insert(DatasetExampleRevision),
[
Expand All @@ -190,6 +205,12 @@ async def add_spans_to_dataset(
DatasetExampleRevision.input.key: get_dataset_example_input(span),
DatasetExampleRevision.output.key: get_dataset_example_output(span),
DatasetExampleRevision.metadata_.key: {
**(span.attributes.get(SpanAttributes.METADATA) or dict()),
**{
k: v
for k, v in span.attributes.items()
if k in nonprivate_span_attributes
},
"span_kind": span.span_kind,
**(
{"annotations": annotations}
Expand Down
31 changes: 17 additions & 14 deletions src/phoenix/server/api/routers/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,20 +919,23 @@ def _get_content_csv(examples: list[models.DatasetExampleRevision]) -> bytes:
def _get_content_jsonl_openai_ft(examples: list[models.DatasetExampleRevision]) -> bytes:
records = io.BytesIO()
for ex in examples:
records.write(
(
json.dumps(
{
"messages": (
ims if isinstance(ims := ex.input.get("messages"), list) else []
)
+ (oms if isinstance(oms := ex.output.get("messages"), list) else [])
},
ensure_ascii=False,
)
+ "\n"
).encode()
)
input_messages = ex.input.get("messages", [])
if not isinstance(input_messages, list):
input_messages = []
output_messages = ex.output.get("messages", [])
if not isinstance(output_messages, list):
output_messages = []

record_dict = {
"messages": input_messages + output_messages,
}

tools = ex.input.get("tools", [])
if tools:
record_dict["tools"] = tools

records.write((json.dumps(record_dict, ensure_ascii=False) + "\n").encode())

records.seek(0)
return records.read()

Expand Down
17 changes: 1 addition & 16 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,7 @@ async def chat_completion_over_dataset(
description=input.experiment_description
or _default_playground_experiment_description(dataset_name=dataset.name),
repetitions=1,
metadata_=input.experiment_metadata
or _default_playground_experiment_metadata(
dataset_name=dataset.name,
dataset_id=input.dataset_id,
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
),
metadata_=input.experiment_metadata or dict(),
project_name=PLAYGROUND_PROJECT_NAME,
)
session.add(experiment)
Expand Down Expand Up @@ -581,16 +576,6 @@ def _default_playground_experiment_description(dataset_name: str) -> str:
return f'Playground experiment for dataset "{dataset_name}"'


def _default_playground_experiment_metadata(
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
) -> dict[str, Any]:
return {
"dataset_name": dataset_name,
"dataset_id": str(dataset_id),
"dataset_version_id": str(version_id),
}


LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/server/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,12 +1206,8 @@ async def test_emits_expected_payloads_and_records_expected_spans_and_experiment
assert experiment.pop("name") == "playground-experiment"
assert isinstance(experiment_description := experiment.pop("description"), str)
assert "dataset-name" in experiment_description
assert experiment.pop("metadata") == {
"dataset_name": "dataset-name",
"dataset_id": str(dataset_id),
"dataset_version_id": str(version_id),
}
assert experiment.pop("projectName") == "playground"
assert experiment.pop("metadata") == {}
assert isinstance(created_at := experiment.pop("createdAt"), str)
assert isinstance(updated_at := experiment.pop("updatedAt"), str)
assert created_at == updated_at
Expand Down