Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Experiment and dataset improvements #6163

Merged
merged 15 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/phoenix/server/api/helpers/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)

Expand All @@ -27,12 +28,18 @@ def get_dataset_example_input(span: Span) -> dict[str, Any]:
input_mime_type = get_attribute_value(attributes, INPUT_MIME_TYPE)
prompt_template_variables = get_attribute_value(attributes, LLM_PROMPT_TEMPLATE_VARIABLES)
input_messages = get_attribute_value(attributes, LLM_INPUT_MESSAGES)
tool_definitions = []
if tools := get_attribute_value(attributes, LLM_TOOLS):
for tool in tools:
if definition := get_attribute_value(tool, TOOL_DEFINITION):
tool_definitions.append(definition)
if span_kind == LLM:
return _get_llm_span_input(
input_messages=input_messages,
input_value=input_value,
input_mime_type=input_mime_type,
prompt_template_variables=prompt_template_variables,
tool_definitions=tool_definitions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just double check that this works with openai ft format? also it might be just simpler if the key was tools

)
return _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")

Expand Down Expand Up @@ -71,6 +78,7 @@ def _get_llm_span_input(
input_value: Any,
input_mime_type: Optional[str],
prompt_template_variables: Any,
tool_definitions: Any,
) -> dict[str, Any]:
"""
Extracts the input value from an LLM span and returns it as a dictionary.
Expand All @@ -84,6 +92,10 @@ def _get_llm_span_input(
input = _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")
if prompt_template_variables_data := _safely_json_decode(prompt_template_variables):
input["prompt_template_variables"] = prompt_template_variables_data
if tool_definitions_data := [
_safely_json_decode(tool_definition) for tool_definition in tool_definitions
]:
input["tool_definitions"] = tool_definitions_data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input["tool_definitions"] = tool_definitions_data
input["tools"] = tool_definitions_data

Kinda leaning this direction. Does this work with openai ft / evals format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like the tool definitions are a key alongside the input:

https://platform.openai.com/docs/guides/fine-tuning#preparing-your-dataset-for-dpo

return input


Expand Down Expand Up @@ -215,3 +227,7 @@ def _safely_json_decode(value: Any) -> Any:
# ToolCallAttributes
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME

# ToolAttributes
LLM_TOOLS = SpanAttributes.LLM_TOOLS
TOOL_DEFINITION = ToolAttributes.TOOL_JSON_SCHEMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can rename to TOOL_JSON_SCHEMA for consistency with the name of the convention

2 changes: 1 addition & 1 deletion src/phoenix/server/api/helpers/playground_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ async def chat_completion_create(
elif isinstance(event, anthropic_streaming.InputJsonEvent):
raise NotImplementedError
else:
assert_never(event)
assert_never(event) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a fix for this that @RogerHYang put out. It's the anthropic citations event. Can you cherry pick that and remove this ignore?


def _build_anthropic_messages(
self,
Expand Down
8 changes: 1 addition & 7 deletions src/phoenix/server/api/mutations/chat_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from phoenix.server.api.input_types.TemplateOptions import TemplateOptions
from phoenix.server.api.subscriptions import (
_default_playground_experiment_description,
_default_playground_experiment_metadata,
_default_playground_experiment_name,
)
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
Expand Down Expand Up @@ -183,12 +182,7 @@ async def chat_completion_over_dataset(
description=input.experiment_description
or _default_playground_experiment_description(dataset_name=dataset.name),
repetitions=1,
metadata_=input.experiment_metadata
or _default_playground_experiment_metadata(
dataset_name=dataset.name,
dataset_id=input.dataset_id,
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
),
metadata_=input.experiment_metadata or dict(),
project_name=PLAYGROUND_PROJECT_NAME,
)
session.add(experiment)
Expand Down
21 changes: 21 additions & 0 deletions src/phoenix/server/api/mutations/dataset_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import strawberry
from openinference.semconv.trace import (
MessageAttributes,
MessageContentAttributes,
SpanAttributes,
ToolAttributes,
ToolCallAttributes,
)
from sqlalchemy import and_, delete, distinct, func, insert, select, update
from strawberry import UNSET
Expand Down Expand Up @@ -181,6 +185,17 @@ async def add_spans_to_dataset(
assert all(map(lambda id: isinstance(id, int), dataset_example_rowids))
DatasetExampleRevision = models.DatasetExampleRevision

all_span_attributes = {
**SpanAttributes.__dict__,
**MessageAttributes.__dict__,
**MessageContentAttributes.__dict__,
**ToolCallAttributes.__dict__,
**ToolAttributes.__dict__,
}
nonprivate_span_attributes = {
k: v for k, v in all_span_attributes.items() if not k.startswith("_")
}

await session.execute(
insert(DatasetExampleRevision),
[
Expand All @@ -190,6 +205,12 @@ async def add_spans_to_dataset(
DatasetExampleRevision.input.key: get_dataset_example_input(span),
DatasetExampleRevision.output.key: get_dataset_example_output(span),
DatasetExampleRevision.metadata_.key: {
**(span.attributes.get(SpanAttributes.METADATA) or dict()),
**{
k: v
for k, v in span.attributes.items()
if k in nonprivate_span_attributes
},
"span_kind": span.span_kind,
**(
{"annotations": annotations}
Expand Down
17 changes: 1 addition & 16 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,7 @@ async def chat_completion_over_dataset(
description=input.experiment_description
or _default_playground_experiment_description(dataset_name=dataset.name),
repetitions=1,
metadata_=input.experiment_metadata
or _default_playground_experiment_metadata(
dataset_name=dataset.name,
dataset_id=input.dataset_id,
version_id=GlobalID(DatasetVersion.__name__, str(resolved_version_id)),
),
metadata_=input.experiment_metadata or dict(),
project_name=PLAYGROUND_PROJECT_NAME,
)
session.add(experiment)
Expand Down Expand Up @@ -581,16 +576,6 @@ def _default_playground_experiment_description(dataset_name: str) -> str:
return f'Playground experiment for dataset "{dataset_name}"'


def _default_playground_experiment_metadata(
dataset_name: str, dataset_id: GlobalID, version_id: GlobalID
) -> dict[str, Any]:
return {
"dataset_name": dataset_name,
"dataset_id": str(dataset_id),
"dataset_version_id": str(version_id),
}


LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
Expand Down
6 changes: 1 addition & 5 deletions tests/unit/server/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,12 +1206,8 @@ async def test_emits_expected_payloads_and_records_expected_spans_and_experiment
assert experiment.pop("name") == "playground-experiment"
assert isinstance(experiment_description := experiment.pop("description"), str)
assert "dataset-name" in experiment_description
assert experiment.pop("metadata") == {
"dataset_name": "dataset-name",
"dataset_id": str(dataset_id),
"dataset_version_id": str(version_id),
}
assert experiment.pop("projectName") == "playground"
assert experiment.pop("metadata") == {}
assert isinstance(created_at := experiment.pop("createdAt"), str)
assert isinstance(updated_at := experiment.pop("updatedAt"), str)
assert created_at == updated_at
Expand Down
Loading