Skip to content

Commit

Permalink
fix: improve structured output extraction and query adapter updates (#34
Browse files Browse the repository at this point in the history
)
  • Loading branch information
emilradix authored Nov 22, 2024
1 parent 714c68f commit 6b49ced
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 13 deletions.
37 changes: 31 additions & 6 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from typing import Any, TypeVar

from litellm import completion
import litellm
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
from pydantic import BaseModel, ValidationError

from raglite._config import RAGLiteConfig
Expand Down Expand Up @@ -33,17 +34,39 @@ class MyNameResponse(BaseModel):
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = (
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
"Format your response according to this JSON schema:\n",
return_type.model_json_schema(),
system_prompt = "\n".join(
(
return_type.system_prompt.strip(), # type: ignore[attr-defined]
"Format your response according to this JSON schema:",
str(return_type.model_json_schema()),
)
)
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
# [1] https://docs.litellm.ai/docs/completion/json_mode
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
response_format: dict[str, Any] | None = (
{
"type": "json_schema",
"json_schema": {
"name": return_type.__name__,
"description": return_type.__doc__ or "",
"schema": return_type.model_json_schema(),
},
}
if "response_format"
in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [])
else None
)
# Concatenate the user prompt if it is a list of strings.
if isinstance(user_prompt, list):
user_prompt = "\n\n".join(
f'<context index="{i}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
enable_json_schema_validation = litellm.enable_json_schema_validation
litellm.enable_json_schema_validation = True
# Extract structured data from the unstructured input.
for _ in range(config.llm_max_tries):
response = completion(
Expand All @@ -52,7 +75,7 @@ class MyNameResponse(BaseModel):
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object", "schema": return_type.model_json_schema()},
response_format=response_format,
**kwargs,
)
try:
Expand All @@ -66,4 +89,6 @@ class MyNameResponse(BaseModel):
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message) from last_exception
# Restore the previous JSON schema validation setting.
litellm.enable_json_schema_validation = enable_json_schema_validation
return instance
26 changes: 20 additions & 6 deletions src/raglite/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,24 @@ def llm(model: str, **kwargs: Any) -> Llama:
)
return llm

def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, Any]:
# Filter out unsupported OpenAI parameters.
llama_cpp_python_params = {
k: v for k, v in optional_params.items() if k in self.supported_openai_params
}
# Translate OpenAI's response_format [1] to llama-cpp-python's response_format [2].
# [1] https://platform.openai.com/docs/guides/structured-outputs
# [2] https://github.com/abetlen/llama-cpp-python#json-schema-mode
if (
"response_format" in llama_cpp_python_params
and "json_schema" in llama_cpp_python_params["response_format"]
):
llama_cpp_python_params["response_format"] = {
"type": "json_object",
"schema": llama_cpp_python_params["response_format"]["json_schema"]["schema"],
}
return llama_cpp_python_params

def completion( # noqa: PLR0913
self,
model: str,
Expand All @@ -149,9 +167,7 @@ def completion( # noqa: PLR0913
client: HTTPHandler | None = None,
) -> ModelResponse:
llm = self.llm(model)
llama_cpp_python_params = {
k: v for k, v in optional_params.items() if k in self.supported_openai_params
}
llama_cpp_python_params = self._translate_openai_params(optional_params)
response = cast(
CreateChatCompletionResponse,
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
Expand Down Expand Up @@ -184,9 +200,7 @@ def streaming( # noqa: PLR0913
client: HTTPHandler | None = None,
) -> Iterator[GenericStreamingChunk]:
llm = self.llm(model)
llama_cpp_python_params = {
k: v for k, v in optional_params.items() if k in self.supported_openai_params
}
llama_cpp_python_params = self._translate_openai_params(optional_params)
stream = cast(
Iterator[CreateChatCompletionStreamResponse],
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),
Expand Down
4 changes: 3 additions & 1 deletion src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Compute and update an optimal query adapter."""

import numpy as np
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import Session, col, select
from tqdm.auto import tqdm

Expand Down Expand Up @@ -157,6 +158,7 @@ def update_query_adapter( # noqa: PLR0915, C901
raise ValueError(error_message)
# Store the optimal query adapter in the database.
index_metadata = session.get(IndexMetadata, "default") or IndexMetadata(id="default")
index_metadata.metadata_ = {**index_metadata.metadata_, "query_adapter": A_star}
index_metadata.metadata_["query_adapter"] = A_star
flag_modified(index_metadata, "metadata_")
session.add(index_metadata)
session.commit()
42 changes: 42 additions & 0 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Test RAGLite's structured output extraction."""

from typing import ClassVar

import pytest
from pydantic import BaseModel, Field

from raglite import RAGLiteConfig
from raglite._extract import extract_with_llm


@pytest.fixture(
params=[
pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
pytest.param("gpt-4o-mini", id="openai"),
],
)
def llm(
request: pytest.FixtureRequest,
) -> str:
"""Get an LLM to test RAGLite with."""
llm: str = request.param
return llm


def test_extract(llm: str) -> None:
"""Test extracting structured data."""
# Set the LLM.
config = RAGLiteConfig(llm=llm)

# Extract structured data.
class LoginResponse(BaseModel):
username: str = Field(..., description="The username.")
password: str = Field(..., description="The password.")
system_prompt: ClassVar[str] = "Extract the username and password from the input."

username, password = "cypher", "steak"
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
# Validate the response.
assert isinstance(login_response, LoginResponse)
assert login_response.username == username
assert login_response.password == password

0 comments on commit 6b49ced

Please sign in to comment.