Skip to content

Commit

Permalink
fix: add and enable OpenAI strict mode (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
undo76 authored Dec 5, 2024
1 parent f6023f5 commit 9e48b20
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
21 changes: 14 additions & 7 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from sqlmodel import Session, func, select
from tqdm.auto import tqdm, trange

Expand All @@ -25,10 +25,11 @@ def insert_evals( # noqa: C901
class QuestionResponse(BaseModel):
"""A specific question about the content of a set of document contexts."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
question: str = Field(
...,
description="A specific question about the content of a set of document contexts.",
min_length=1,
..., description="A specific question about the content of a set of document contexts."
)
system_prompt: ClassVar[str] = """
You are given a set of contexts extracted from a document.
Expand Down Expand Up @@ -85,7 +86,7 @@ def validate_question(cls, value: str) -> str:
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
QuestionResponse, related_chunks, config=config
QuestionResponse, related_chunks, strict=True, config=config
)
except ValueError:
continue
Expand All @@ -101,6 +102,9 @@ def validate_question(cls, value: str) -> str:
class ContextEvalResponse(BaseModel):
"""Indicate whether the provided context can be used to answer a given question."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
hit: bool = Field(
...,
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
Expand All @@ -118,7 +122,7 @@ class ContextEvalResponse(BaseModel):
):
try:
context_eval_response = extract_with_llm(
ContextEvalResponse, str(candidate_chunk), config=config
ContextEvalResponse, str(candidate_chunk), strict=True, config=config
)
except ValueError: # noqa: PERF203
pass
Expand All @@ -132,10 +136,12 @@ class ContextEvalResponse(BaseModel):
class AnswerResponse(BaseModel):
"""Answer a question using the provided context."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
answer: str = Field(
...,
description="A complete answer to the given question using the provided context.",
min_length=1,
)
system_prompt: ClassVar[str] = f"""
You are given a set of contexts extracted from a document.
Expand All @@ -152,6 +158,7 @@ class AnswerResponse(BaseModel):
answer_response = extract_with_llm(
AnswerResponse,
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
strict=True,
config=config,
)
except ValueError:
Expand Down
31 changes: 14 additions & 17 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Any, TypeVar

import litellm
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
from pydantic import BaseModel, ValidationError

Expand All @@ -14,6 +13,7 @@
def extract_with_llm(
return_type: type[T],
user_prompt: str | list[str],
strict: bool = False, # noqa: FBT001,FBT002
config: RAGLiteConfig | None = None,
**kwargs: Any,
) -> T:
Expand All @@ -33,29 +33,31 @@ class MyNameResponse(BaseModel):
"""
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = "\n".join(
(
return_type.system_prompt.strip(), # type: ignore[attr-defined]
"Format your response according to this JSON schema:",
str(return_type.model_json_schema()),
)
# Check if the LLM supports the response format.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
llm_supports_response_format = "response_format" in (
get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
)
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = getattr(return_type, "system_prompt", "").strip()
if not llm_supports_response_format or llm_provider == "llama-cpp-python":
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
# is disabled by default because it only supports a subset of JSON schema features [2].
# [1] https://docs.litellm.ai/docs/completion/json_mode
# [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
response_format: dict[str, Any] | None = (
{
"type": "json_schema",
"json_schema": {
"name": return_type.__name__,
"description": return_type.__doc__ or "",
"schema": return_type.model_json_schema(),
"strict": strict,
},
}
if "response_format"
in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [])
if llm_supports_response_format
else None
)
# Concatenate the user prompt if it is a list of strings.
Expand All @@ -64,9 +66,6 @@ class MyNameResponse(BaseModel):
f'<context index="{i + 1}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
enable_json_schema_validation = litellm.enable_json_schema_validation
litellm.enable_json_schema_validation = True
# Extract structured data from the unstructured input.
for _ in range(config.llm_max_tries):
response = completion(
Expand All @@ -89,6 +88,4 @@ class MyNameResponse(BaseModel):
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message) from last_exception
# Restore the previous JSON schema validation setting.
litellm.enable_json_schema_validation = enable_json_schema_validation
return instance
23 changes: 15 additions & 8 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import ClassVar

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from raglite import RAGLiteConfig
from raglite._extract import extract_with_llm
Expand All @@ -13,29 +13,36 @@
params=[
pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
pytest.param("gpt-4o-mini", id="openai"),
],
]
)
def llm(
request: pytest.FixtureRequest,
) -> str:
def llm(request: pytest.FixtureRequest) -> str:
"""Get an LLM to test RAGLite with."""
llm: str = request.param
return llm


def test_extract(llm: str) -> None:
@pytest.mark.parametrize(
"strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")]
)
def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001
"""Test extracting structured data."""
# Set the LLM.
config = RAGLiteConfig(llm=llm)

# Extract structured data.
# Define the JSON schema of the response.
class LoginResponse(BaseModel):
"""The response to a login request."""

model_config = ConfigDict(extra="forbid" if strict else "allow")
username: str = Field(..., description="The username.")
password: str = Field(..., description="The password.")
system_prompt: ClassVar[str] = "Extract the username and password from the input."

# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
login_response = extract_with_llm(
LoginResponse, f"{username} // {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
assert login_response.username == username
Expand Down

0 comments on commit 9e48b20

Please sign in to comment.