From 9e48b205f13bb0edf62252a85fe553346365a895 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 21:48:36 +0100 Subject: [PATCH] fix: add and enable OpenAI strict mode (#55) --- src/raglite/_eval.py | 21 ++++++++++++++------- src/raglite/_extract.py | 31 ++++++++++++++----------------- tests/test_extract.py | 23 +++++++++++++++-------- 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index cec5a66..6af9723 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from sqlmodel import Session, func, select from tqdm.auto import tqdm, trange @@ -25,10 +25,11 @@ def insert_evals( # noqa: C901 class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) question: str = Field( - ..., - description="A specific question about the content of a set of document contexts.", - min_length=1, + ..., description="A specific question about the content of a set of document contexts." ) system_prompt: ClassVar[str] = """ You are given a set of contexts extracted from a document. @@ -85,7 +86,7 @@ def validate_question(cls, value: str) -> str: # Extract a question from the seed chunk's related chunks. try: question_response = extract_with_llm( - QuestionResponse, related_chunks, config=config + QuestionResponse, related_chunks, strict=True, config=config ) except ValueError: continue @@ -101,6 +102,9 @@ def validate_question(cls, value: str) -> str: class ContextEvalResponse(BaseModel): """Indicate whether the provided context can be used to answer a given question.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) hit: bool = Field( ..., description="True if the provided context contains (a part of) the answer to the given question, false otherwise.", @@ -118,7 +122,7 @@ class ContextEvalResponse(BaseModel): ): try: context_eval_response = extract_with_llm( - ContextEvalResponse, str(candidate_chunk), config=config + ContextEvalResponse, str(candidate_chunk), strict=True, config=config ) except ValueError: # noqa: PERF203 pass @@ -132,10 +136,12 @@ class ContextEvalResponse(BaseModel): class AnswerResponse(BaseModel): """Answer a question using the provided context.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) answer: str = Field( ..., description="A complete answer to the given question using the provided context.", - min_length=1, ) system_prompt: ClassVar[str] = f""" You are given a set of contexts extracted from a document. @@ -152,6 +158,7 @@ class AnswerResponse(BaseModel): answer_response = extract_with_llm( AnswerResponse, [str(relevant_chunk) for relevant_chunk in relevant_chunks], + strict=True, config=config, ) except ValueError: diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index f3d73ff..bd85d47 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -2,7 +2,6 @@ from typing import Any, TypeVar -import litellm from litellm import completion, get_supported_openai_params # type: ignore[attr-defined] from pydantic import BaseModel, ValidationError @@ -14,6 +13,7 @@ def extract_with_llm( return_type: type[T], user_prompt: str | list[str], + strict: bool = False, # noqa: FBT001,FBT002 config: RAGLiteConfig | None = None, **kwargs: Any, ) -> T: @@ -33,18 +33,20 @@ class MyNameResponse(BaseModel): """ # Load the default config if not provided. config = config or RAGLiteConfig() - # Update the system prompt with the JSON schema of the return type to help the LLM. - system_prompt = "\n".join( - ( - return_type.system_prompt.strip(), # type: ignore[attr-defined] - "Format your response according to this JSON schema:", - str(return_type.model_json_schema()), - ) + # Check if the LLM supports the response format. + llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None + llm_supports_response_format = "response_format" in ( + get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [] ) - # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. + # Update the system prompt with the JSON schema of the return type to help the LLM. + system_prompt = getattr(return_type, "system_prompt", "").strip() + if not llm_supports_response_format or llm_provider == "llama-cpp-python": + system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}" + # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode + # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode + # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. - llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None response_format: dict[str, Any] | None = ( { "type": "json_schema", @@ -52,10 +54,10 @@ class MyNameResponse(BaseModel): "name": return_type.__name__, "description": return_type.__doc__ or "", "schema": return_type.model_json_schema(), + "strict": strict, }, } - if "response_format" - in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []) + if llm_supports_response_format else None ) # Concatenate the user prompt if it is a list of strings. @@ -64,9 +66,6 @@ class MyNameResponse(BaseModel): f'\n{chunk.strip()}\n' for i, chunk in enumerate(user_prompt) ) - # Enable JSON schema validation. - enable_json_schema_validation = litellm.enable_json_schema_validation - litellm.enable_json_schema_validation = True # Extract structured data from the unstructured input. for _ in range(config.llm_max_tries): response = completion( @@ -89,6 +88,4 @@ class MyNameResponse(BaseModel): else: error_message = f"Failed to extract {return_type} from input {user_prompt}." raise ValueError(error_message) from last_exception - # Restore the previous JSON schema validation setting. - litellm.enable_json_schema_validation = enable_json_schema_validation return instance diff --git a/tests/test_extract.py b/tests/test_extract.py index 3ff2a85..33ef6e0 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -3,7 +3,7 @@ from typing import ClassVar import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from raglite import RAGLiteConfig from raglite._extract import extract_with_llm @@ -13,29 +13,36 @@ params=[ pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"), pytest.param("gpt-4o-mini", id="openai"), - ], + ] ) -def llm( - request: pytest.FixtureRequest, -) -> str: +def llm(request: pytest.FixtureRequest) -> str: """Get an LLM to test RAGLite with.""" llm: str = request.param return llm -def test_extract(llm: str) -> None: +@pytest.mark.parametrize( + "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] +) +def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 """Test extracting structured data.""" # Set the LLM. config = RAGLiteConfig(llm=llm) - # Extract structured data. + # Define the JSON schema of the response. class LoginResponse(BaseModel): + """The response to a login request.""" + + model_config = ConfigDict(extra="forbid" if strict else "allow") username: str = Field(..., description="The username.") password: str = Field(..., description="The password.") system_prompt: ClassVar[str] = "Extract the username and password from the input." + # Extract structured data. username, password = "cypher", "steak" - login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config) + login_response = extract_with_llm( + LoginResponse, f"{username} // {password}", strict=strict, config=config + ) # Validate the response. assert isinstance(login_response, LoginResponse) assert login_response.username == username