-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated llmclient with the newest ldp implementations #21
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,14 @@ | |
Awaitable, | ||
Callable, | ||
Iterable, | ||
Mapping, | ||
) | ||
from inspect import isasyncgenfunction, signature | ||
from typing import ( | ||
Any, | ||
ClassVar, | ||
Self, | ||
TypeAlias, | ||
TypeVar, | ||
cast, | ||
) | ||
|
@@ -59,6 +61,10 @@ | |
config=ConfigDict(arbitrary_types_allowed=True), | ||
) | ||
|
||
# Yes, this is a hack, it mostly matches | ||
# https://github.com/python-jsonschema/referencing/blob/v0.35.1/referencing/jsonschema.py#L20-L21 | ||
JSONSchema: TypeAlias = Mapping[str, Any] | ||
|
||
|
||
def sum_logprobs(choice: litellm.utils.Choices) -> float | None: | ||
"""Calculate the sum of the log probabilities of an LLM completion (a Choices object). | ||
|
@@ -84,13 +90,13 @@ def sum_logprobs(choice: litellm.utils.Choices) -> float | None: | |
|
||
|
||
def validate_json_completion( | ||
completion: litellm.ModelResponse, output_type: type[BaseModel] | ||
completion: litellm.ModelResponse, output_type: type[BaseModel] | JSONSchema | ||
) -> None: | ||
"""Validate a completion against a JSON schema. | ||
|
||
Args: | ||
completion: The completion to validate. | ||
output_type: The Pydantic model to validate the completion against. | ||
output_type: A JSON schema or a Pydantic model to validate the completion. | ||
""" | ||
try: | ||
for choice in completion.choices: | ||
|
@@ -102,7 +108,12 @@ def validate_json_completion( | |
choice.message.content = ( | ||
choice.message.content.split("```json")[-1].split("```")[0] or "" | ||
) | ||
output_type.model_validate_json(choice.message.content) | ||
if isinstance(output_type, Mapping): # JSON schema | ||
litellm.litellm_core_utils.json_validation_rule.validate_schema( | ||
schema=dict(output_type), response=choice.message.content | ||
) | ||
else: | ||
output_type.model_validate_json(choice.message.content) | ||
except ValidationError as err: | ||
raise JSONSchemaValidationError( | ||
"The completion does not match the specified schema." | ||
|
@@ -655,14 +666,20 @@ async def achat_iter(self, messages: Iterable[Message], **kwargs) -> AsyncGenera | |
) | ||
|
||
# SEE: https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice | ||
# > `none` means the model will not call any tool and instead generates a message. | ||
# > `auto` means the model can pick between generating a message or calling one or more tools. | ||
# > `required` means the model must call one or more tools. | ||
NO_TOOL_CHOICE: ClassVar[str] = "none" | ||
MODEL_CHOOSES_TOOL: ClassVar[str] = "auto" | ||
TOOL_CHOICE_REQUIRED: ClassVar[str] = "required" | ||
# None means we won't provide a tool_choice to the LLM API | ||
UNSPECIFIED_TOOL_CHOICE: ClassVar[None] = None | ||
|
||
async def call( # noqa: C901, PLR0915 | ||
self, | ||
messages: list[Message], | ||
callbacks: list[Callable] | None = None, | ||
output_type: type[BaseModel] | None = None, | ||
output_type: type[BaseModel] | JSONSchema | None = None, | ||
tools: list[Tool] | None = None, | ||
tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED, | ||
**chat_kwargs, | ||
|
@@ -684,6 +701,9 @@ async def call( # noqa: C901, PLR0915 | |
Raises: | ||
ValueError: If the number of completions (n) is invalid. | ||
""" | ||
# add static configuration to kQwargs | ||
chat_kwargs = self.config | chat_kwargs | ||
|
||
start_clock = asyncio.get_running_loop().time() | ||
|
||
# Deal with tools. Note OpenAI throws a 400 response if tools is empty: | ||
|
@@ -705,7 +725,21 @@ async def call( # noqa: C901, PLR0915 | |
) | ||
|
||
# deal with specifying output type | ||
if output_type is not None: | ||
if isinstance(output_type, Mapping): # Use structured outputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mskarlin had a question in the original PR here: Future-House/ldp#165 (comment), which was if we can have a validation of the model name here (more or less). Since making that change, I learned of
model_name: str = chat_kwargs.get("model", "")
if not litellm.supports_response_schema(model_name, None):
raise ValueError(f"Model {model_name} does not support JSON schema.") Or, do you think we should just pass through to
I am sort of in favor of the latter, which is our code here not doing validations, instead sort of "trusting the process" and letting callers get blown up with a trace terminating below What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that allowing it to pass and blowing on the LLM API will give us more issues in the future. Because we don't know how litellm treats this, the error can be uninformative and the user will open issues here. Since litellm already maintains a list of models that supports this schema and we can use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess there could be a situation where However, a validation also does better define the system, and add confidence in it. I can see arguments for both sides, defer to you here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are using litellm as the main framework, I think it might be ok to trust it? We can also remove it later if it causes any problems |
||
model_name: str = chat_kwargs.get("model", "") | ||
if not litellm.supports_response_schema(model_name, None): | ||
raise ValueError(f"Model {model_name} does not support JSON schema.") | ||
|
||
chat_kwargs["response_format"] = { | ||
"type": "json_schema", | ||
"json_schema": { | ||
"strict": True, | ||
# SEE: https://platform.openai.com/docs/guides/structured-outputs#additionalproperties-false-must-always-be-set-in-objects | ||
"schema": dict(output_type) | {"additionalProperties": False}, | ||
"name": output_type["title"], # Required by OpenAI as of 12/3/2024 | ||
}, | ||
} | ||
elif output_type is not None: # Use JSON mode | ||
schema = json.dumps(output_type.model_json_schema(mode="serialization")) | ||
schema_msg = f"Respond following this JSON schema:\n\n{schema}" | ||
# Get the system prompt and its index, or the index to add it | ||
|
@@ -724,8 +758,6 @@ async def call( # noqa: C901, PLR0915 | |
] | ||
chat_kwargs["response_format"] = {"type": "json_object"} | ||
|
||
# add static configuration to kwargs | ||
chat_kwargs = self.config | chat_kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved above to validate if the model supports the schema |
||
n = chat_kwargs.get("n", 1) # number of completions | ||
if n < 1: | ||
raise ValueError("Number of completions (n) must be >= 1.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To share, I used this import to work around us having to manage
jsonschema
in our own deps, instead we letlitellm
internals manage the implementation