Skip to content

Commit

Permalink
Merge pull request #1079 from guardrails-ai/custom-llms
Browse files Browse the repository at this point in the history
Fix Custom LLM Support
  • Loading branch information
zsimjee authored Sep 19, 2024
2 parents fd0f83b + 513b3e3 commit eb212ba
Show file tree
Hide file tree
Showing 21 changed files with 779 additions and 180 deletions.
46 changes: 46 additions & 0 deletions docs/how_to_guides/using_llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,49 @@ for chunk in stream_chunk_generator
## Other LLMs

See LiteLLM’s documentation [here](https://docs.litellm.ai/docs/providers) for details on many other llms.

## Custom LLM Wrappers
In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that accepts a positional argument for the prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string.

```python
from guardrails import Guard
from guardrails.hub import ProfanityFree

# Create a Guard class
guard = Guard().use(ProfanityFree())

# Function that takes the prompt as a string and returns the LLM output as string
def my_llm_api(
prompt: Optional[str] = None,
*,
instructions: Optional[str] = None,
msg_history: Optional[list[dict]] = None,
**kwargs
) -> str:
"""Custom LLM API wrapper.
At least one of prompt, instruction or msg_history should be provided.
Args:
prompt (str): The prompt to be passed to the LLM API
instruction (str): The instruction to be passed to the LLM API
msg_history (list[dict]): The message history to be passed to the LLM API
**kwargs: Any additional arguments to be passed to the LLM API
Returns:
str: The output of the LLM API
"""

# Call your LLM API here
# What you pass to the llm will depend on what arguments it accepts.
llm_output = some_llm(prompt, instructions, msg_history, **kwargs)

return llm_output

# Wrap your LLM API call
validated_response = guard(
my_llm_api,
prompt="Can you generate a list of 10 things that are not food?",
**kwargs,
)
```
4 changes: 2 additions & 2 deletions guardrails/applications/text2sql.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import json
import os
import openai
from string import Template
from typing import Callable, Dict, Optional, Type, cast

from guardrails.classes import ValidationOutcome
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
from guardrails.guard import Guard
from guardrails.utils.openai_utils import get_static_openai_create_func
from guardrails.utils.sql_utils import create_sql_driver
from guardrails.vectordb import Faiss, VectorDBBase

Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
"""
if llm_api is None:
llm_api = get_static_openai_create_func()
llm_api = openai.completions.create

self.example_formatter = example_formatter
self.llm_api = llm_api
Expand Down
34 changes: 25 additions & 9 deletions guardrails/formatters/json_formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Dict, List, Optional, Union

from guardrails.formatters.base_formatter import BaseFormatter
from guardrails.llm_providers import (
Expand Down Expand Up @@ -99,32 +99,48 @@ def wrap_callable(self, llm_callable) -> ArbitraryCallable:

if isinstance(llm_callable, HuggingFacePipelineCallable):
model = llm_callable.init_kwargs["pipeline"]
return ArbitraryCallable(
lambda p: json.dumps(

def fn(
prompt: str,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return json.dumps(
Jsonformer(
model=model.model,
tokenizer=model.tokenizer,
json_schema=self.output_schema,
prompt=p,
prompt=prompt,
)()
)
)

return ArbitraryCallable(fn)
elif isinstance(llm_callable, HuggingFaceModelCallable):
# This will not work because 'model_generate' is the .gen method.
# model = self.api.init_kwargs["model_generate"]
# Use the __self__ to grab the base mode for passing into JF.
model = llm_callable.init_kwargs["model_generate"].__self__
tokenizer = llm_callable.init_kwargs["tokenizer"]
return ArbitraryCallable(
lambda p: json.dumps(

def fn(
prompt: str,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return json.dumps(
Jsonformer(
model=model,
tokenizer=tokenizer,
json_schema=self.output_schema,
prompt=p,
prompt=prompt,
)()
)
)

return ArbitraryCallable(fn)
else:
raise ValueError(
"JsonFormatter can only be used with HuggingFace*Callable."
Expand Down
77 changes: 61 additions & 16 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

import inspect
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -27,10 +28,10 @@
from guardrails.utils.openai_utils import (
AsyncOpenAIClient,
OpenAIClient,
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
get_static_openai_chat_create_func,
get_static_openai_create_func,
is_static_openai_acreate_func,
is_static_openai_chat_acreate_func,
is_static_openai_chat_create_func,
is_static_openai_create_func,
)
from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn
from guardrails.utils.safe_get import safe_get
Expand Down Expand Up @@ -711,6 +712,26 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons

class ArbitraryCallable(PromptCallableBase):
def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs):
llm_api_args = inspect.getfullargspec(llm_api)
if not llm_api_args.args:
raise ValueError(
"Custom LLM callables must accept"
" at least one positional argument for prompt!"
)
if not llm_api_args.varkw:
raise ValueError("Custom LLM callables must accept **kwargs!")
if (
not llm_api_args.kwonlyargs
or "instructions" not in llm_api_args.kwonlyargs
or "msg_history" not in llm_api_args.kwonlyargs
):
warnings.warn(
"We recommend including 'instructions' and 'msg_history'"
" as keyword-only arguments for custom LLM callables."
" Doing so ensures these arguments are not uninentionally"
" passed through to other calls via **kwargs.",
UserWarning,
)
self.llm_api = llm_api
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -784,9 +805,9 @@ def get_llm_ask(
except ImportError:
pass

if llm_api == get_static_openai_create_func():
if is_static_openai_create_func(llm_api):
return OpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_create_func():
if is_static_openai_chat_create_func(llm_api):
return OpenAIChatCallable(*args, **kwargs)

try:
Expand Down Expand Up @@ -1190,6 +1211,26 @@ async def invoke_llm(

class AsyncArbitraryCallable(AsyncPromptCallableBase):
def __init__(self, llm_api: Callable, *args, **kwargs):
llm_api_args = inspect.getfullargspec(llm_api)
if not llm_api_args.args:
raise ValueError(
"Custom LLM callables must accept"
" at least one positional argument for prompt!"
)
if not llm_api_args.varkw:
raise ValueError("Custom LLM callables must accept **kwargs!")
if (
not llm_api_args.kwonlyargs
or "instructions" not in llm_api_args.kwonlyargs
or "msg_history" not in llm_api_args.kwonlyargs
):
warnings.warn(
"We recommend including 'instructions' and 'msg_history'"
" as keyword-only arguments for custom LLM callables."
" Doing so ensures these arguments are not uninentionally"
" passed through to other calls via **kwargs.",
UserWarning,
)
self.llm_api = llm_api
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -1241,7 +1282,7 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:


def get_async_llm_ask(
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
) -> AsyncPromptCallableBase:
try:
import litellm
Expand All @@ -1252,9 +1293,12 @@ def get_async_llm_ask(
pass

# these only work with openai v0 (None otherwise)
if llm_api == get_static_openai_acreate_func():
# We no longer support OpenAI v0
# We should drop these checks or update the logic to support
# OpenAI v1 clients instead of just static methods
if is_static_openai_acreate_func(llm_api):
return AsyncOpenAICallable(*args, **kwargs)
if llm_api == get_static_openai_chat_acreate_func():
if is_static_openai_chat_acreate_func(llm_api):
return AsyncOpenAIChatCallable(*args, **kwargs)

try:
Expand All @@ -1265,11 +1309,12 @@ def get_async_llm_ask(
except ImportError:
pass

return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)
if llm_api is not None:
return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs)


def model_is_supported_server_side(
llm_api: Optional[Union[Callable, Callable[[Any], Awaitable[Any]]]] = None,
llm_api: Optional[Union[Callable, Callable[..., Awaitable[Any]]]] = None,
*args,
**kwargs,
) -> bool:
Expand All @@ -1289,17 +1334,17 @@ def model_is_supported_server_side(

# CONTINUOUS FIXME: Update with newly supported LLMs
def get_llm_api_enum(
llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs
llm_api: Callable[..., Awaitable[Any]], *args, **kwargs
) -> Optional[LLMResource]:
# TODO: Distinguish between v1 and v2
model = get_llm_ask(llm_api, *args, **kwargs)
if llm_api == get_static_openai_create_func():
if is_static_openai_create_func(llm_api):
return LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE
elif llm_api == get_static_openai_chat_create_func():
elif is_static_openai_chat_create_func(llm_api):
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE
elif llm_api == get_static_openai_acreate_func():
elif is_static_openai_acreate_func(llm_api): # This is always False
return LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE
elif llm_api == get_static_openai_chat_acreate_func():
elif is_static_openai_chat_acreate_func(llm_api): # This is always False
return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE
elif isinstance(model, LiteLLMCallable):
return LLMResource.LITELLM_DOT_COMPLETION
Expand Down
16 changes: 8 additions & 8 deletions guardrails/utils/openai_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
from .v1 import OpenAIClientV1 as OpenAIClient
from .v1 import (
OpenAIServiceUnavailableError,
get_static_openai_acreate_func,
get_static_openai_chat_acreate_func,
get_static_openai_chat_create_func,
get_static_openai_create_func,
is_static_openai_acreate_func,
is_static_openai_chat_acreate_func,
is_static_openai_chat_create_func,
is_static_openai_create_func,
)

__all__ = [
"AsyncOpenAIClient",
"OpenAIClient",
"get_static_openai_create_func",
"get_static_openai_chat_create_func",
"get_static_openai_acreate_func",
"get_static_openai_chat_acreate_func",
"is_static_openai_create_func",
"is_static_openai_chat_create_func",
"is_static_openai_acreate_func",
"is_static_openai_chat_acreate_func",
"OpenAIServiceUnavailableError",
]
28 changes: 19 additions & 9 deletions guardrails/utils/openai_utils/v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterable, Dict, Iterable, List, cast
from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, cast

import openai

Expand All @@ -12,20 +12,30 @@
from guardrails.telemetry import trace_llm_call, trace_operation


def get_static_openai_create_func():
return openai.completions.create
def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool:
try:
return llm_api == openai.completions.create
except openai.OpenAIError:
return False


def get_static_openai_chat_create_func():
return openai.chat.completions.create
def is_static_openai_chat_create_func(llm_api: Optional[Callable]) -> bool:
try:
return llm_api == openai.chat.completions.create
except openai.OpenAIError:
return False


def get_static_openai_acreate_func():
return None
def is_static_openai_acreate_func(llm_api: Optional[Callable]) -> bool:
# Because the static version of this does not exist in OpenAI 1.x
# Can we just drop these checks?
return False


def get_static_openai_chat_acreate_func():
return None
def is_static_openai_chat_acreate_func(llm_api: Optional[Callable]) -> bool:
# Because the static version of this does not exist in OpenAI 1.x
# Can we just drop these checks?
return False


OpenAIServiceUnavailableError = openai.APIError
Expand Down
21 changes: 21 additions & 0 deletions tests/integration_tests/test_assets/custom_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Dict, List, Optional


def mock_llm(
prompt: Optional[str] = None,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return ""


async def mock_async_llm(
prompt: Optional[str] = None,
*args,
instructions: Optional[str] = None,
msg_history: Optional[List[Dict[str, str]]] = None,
**kwargs,
) -> str:
return ""
Loading

0 comments on commit eb212ba

Please sign in to comment.