Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated llmclient with the newest ldp implementations #21

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llmclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
LiteLLMModel,
LLMModel,
MultipleCompletionLLMModel,
sum_logprobs,
validate_json_completion,
)
from .types import (
Chunk,
Expand All @@ -44,4 +46,6 @@
"SentenceTransformerEmbeddingModel",
"SparseEmbeddingModel",
"embedding_model_factory",
"sum_logprobs",
"validate_json_completion",
]
46 changes: 39 additions & 7 deletions llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
Awaitable,
Callable,
Iterable,
Mapping,
)
from inspect import isasyncgenfunction, signature
from typing import (
Any,
ClassVar,
Self,
TypeAlias,
TypeVar,
cast,
)
Expand Down Expand Up @@ -59,6 +61,10 @@
config=ConfigDict(arbitrary_types_allowed=True),
)

# Yes, this is a hack, it mostly matches
# https://github.com/python-jsonschema/referencing/blob/v0.35.1/referencing/jsonschema.py#L20-L21
JSONSchema: TypeAlias = Mapping[str, Any]


def sum_logprobs(choice: litellm.utils.Choices) -> float | None:
"""Calculate the sum of the log probabilities of an LLM completion (a Choices object).
Expand All @@ -84,13 +90,13 @@ def sum_logprobs(choice: litellm.utils.Choices) -> float | None:


def validate_json_completion(
completion: litellm.ModelResponse, output_type: type[BaseModel]
completion: litellm.ModelResponse, output_type: type[BaseModel] | JSONSchema
) -> None:
"""Validate a completion against a JSON schema.

Args:
completion: The completion to validate.
output_type: The Pydantic model to validate the completion against.
output_type: A JSON schema or a Pydantic model to validate the completion.
"""
try:
for choice in completion.choices:
Expand All @@ -102,7 +108,12 @@ def validate_json_completion(
choice.message.content = (
choice.message.content.split("```json")[-1].split("```")[0] or ""
)
output_type.model_validate_json(choice.message.content)
if isinstance(output_type, Mapping): # JSON schema
litellm.litellm_core_utils.json_validation_rule.validate_schema(
Comment on lines +111 to +112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To share, I used this import to work around us having to manage jsonschema in our own deps, instead we let litellm internals manage the implementation

schema=dict(output_type), response=choice.message.content
)
else:
output_type.model_validate_json(choice.message.content)
except ValidationError as err:
raise JSONSchemaValidationError(
"The completion does not match the specified schema."
Expand Down Expand Up @@ -655,14 +666,20 @@ async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenera
)

# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# > `none` means the model will not call any tool and instead generates a message.
# > `auto` means the model can pick between generating a message or calling one or more tools.
# > `required` means the model must call one or more tools.
NO_TOOL_CHOICE: ClassVar[str] = "none"
MODEL_CHOOSES_TOOL: ClassVar[str] = "auto"
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required"
# None means we won't provide a tool_choice to the LLM API
UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None

async def call( # noqa: C901, PLR0915
self,
messages: list[Message],
callbacks: list[Callable] | None = None,
output_type: type[BaseModel] | None = None,
output_type: type[BaseModel] | JSONSchema | None = None,
tools: list[Tool] | None = None,
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED,
**chat_kwargs,
Expand All @@ -684,6 +701,9 @@ async def call( # noqa: C901, PLR0915
Raises:
ValueError: If the number of completions (n) is invalid.
"""
# add static configuration to kQwargs
chat_kwargs = self.config | chat_kwargs

start_clock = asyncio.get_running_loop().time()

# Deal with tools. Note OpenAI throws a 400 response if tools is empty:
Expand All @@ -705,7 +725,21 @@ async def call( # noqa: C901, PLR0915
)

# deal with specifying output type
if output_type is not None:
if isinstance(output_type, Mapping): # Use structured outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mskarlin had a question in the original PR here: Future-House/ldp#165 (comment), which was if we can have a validation of the model name here (more or less). Since making that change, I learned of litellm.supports_response_schema("gpt-4o-2024-11-20", None) which can enable that check. What do you think of:

  1. Moving chat_kwargs = self.config | chat_kwargs above this part of the code
  2. Doing a brief validation here such as:
model_name: str = chat_kwargs.get("model", "")
if not litellm.supports_response_schema(model_name, None):
    raise ValueError(f"Model {model_name} does not support JSON schema.")

Or, do you think we should just pass through to litellm and either:

  • If litellm has validations, let them check
  • Otherwise let the LLM API blow up with something like a 400

I am sort of in favor of the latter, which is our code here not doing validations, instead sort of "trusting the process" and letting callers get blown up with a trace terminating below llmclient.

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that allowing it to pass and blowing on the LLM API will give us more issues in the future. Because we don't know how litellm treats this, the error can be uninformative and the user will open issues here.

Since litellm already maintains a list of models that supports this schema and we can use the supports_response_schema, wouldn't it be useful to add a brief check so we can return an informative error message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there could be a situation where litellm is wrong and/or out-of-date, and a validation may block us.

However, a validation also does better define the system, and add confidence in it.

I can see arguments for both sides, defer to you here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are using litellm as the main framework, I think it might be ok to trust it?
I added the check. Let me know what you think.

We can also remove it later if it causes any problems

model_name: str = chat_kwargs.get("model", "")
if not litellm.supports_response_schema(model_name, None):
raise ValueError(f"Model {model_name} does not support JSON schema.")

chat_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {
"strict": True,
# SEE: https://platform.openai.com/docs/guides/structured-outputs#additionalproperties-false-must-always-be-set-in-objects
"schema": dict(output_type) | {"additionalProperties": False},
"name": output_type["title"], # Required by OpenAI as of 12/3/2024
},
}
elif output_type is not None: # Use JSON mode
schema = json.dumps(output_type.model_json_schema(mode="serialization"))
schema_msg = f"Respond following this JSON schema:\n\n{schema}"
# Get the system prompt and its index, or the index to add it
Expand All @@ -724,8 +758,6 @@ async def call( # noqa: C901, PLR0915
]
chat_kwargs["response_format"] = {"type": "json_object"}

# add static configuration to kwargs
chat_kwargs = self.config | chat_kwargs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved above to validate if the model supports the schema

n = chat_kwargs.get("n", 1) # number of completions
if n < 1:
raise ValueError("Number of completions (n) must be >= 1.")
Expand Down
17 changes: 14 additions & 3 deletions tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,27 @@ def play(move: int | None) -> None:

@pytest.mark.asyncio
@pytest.mark.vcr
async def test_output_schema(self) -> None:
model = self.MODEL_CLS(name="gpt-3.5-turbo", config=self.DEFAULT_CONFIG)
@pytest.mark.parametrize(
("model_name", "output_type"),
[
pytest.param("gpt-3.5-turbo", DummyOutputSchema, id="json-mode"),
pytest.param(
"gpt-4o", DummyOutputSchema.model_json_schema(), id="structured-outputs"
),
],
)
async def test_output_schema(
self, model_name: str, output_type: type[BaseModel] | dict[str, Any]
) -> None:
model = self.MODEL_CLS(name=model_name, config=self.DEFAULT_CONFIG)
messages = [
Message(
content=(
"My name is Claude and I am 1 year old. What is my name and age?"
)
),
]
results = await self.call_model(model, messages, output_type=DummyOutputSchema)
results = await self.call_model(model, messages, output_type=output_type)
assert len(results) == self.NUM_COMPLETIONS
for result in results:
assert result.messages
Expand Down
Loading