From 7084245e8f416316dde17d48236e1754fad54d57 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 18 Sep 2024 12:00:21 -0500 Subject: [PATCH 1/5] safeguard static openai funciton checks; update async tests to use custom llm --- guardrails/applications/text2sql.py | 4 +- guardrails/llm_providers.py | 27 ++-- guardrails/utils/openai_utils/__init__.py | 16 +-- guardrails/utils/openai_utils/v1.py | 28 ++-- .../test_assets/entity_extraction/__init__.py | 4 + .../non_openai_compiled_prompt.txt | 131 ++++++++++++++++++ .../non_openai_compiled_prompt_reask.txt | 41 ++++++ tests/integration_tests/test_async.py | 57 +++----- tests/integration_tests/test_guard.py | 49 +++---- tests/integration_tests/test_multi_reask.py | 4 +- tests/integration_tests/test_parsing.py | 5 +- tests/integration_tests/test_pydantic.py | 10 +- tests/integration_tests/test_python_rail.py | 9 +- tests/unit_tests/test_validator_base.py | 20 +-- 14 files changed, 279 insertions(+), 126 deletions(-) create mode 100644 tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt.txt create mode 100644 tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt_reask.txt diff --git a/guardrails/applications/text2sql.py b/guardrails/applications/text2sql.py index b2fe86333..8c97b7d12 100644 --- a/guardrails/applications/text2sql.py +++ b/guardrails/applications/text2sql.py @@ -1,6 +1,7 @@ import asyncio import json import os +import openai from string import Template from typing import Callable, Dict, Optional, Type, cast @@ -8,7 +9,6 @@ from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore from guardrails.embedding import EmbeddingBase, OpenAIEmbedding from guardrails.guard import Guard -from guardrails.utils.openai_utils import get_static_openai_create_func from guardrails.utils.sql_utils import create_sql_driver from guardrails.vectordb import Faiss, VectorDBBase @@ -89,7 +89,7 @@ def __init__( reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT. """ if llm_api is None: - llm_api = get_static_openai_create_func() + llm_api = openai.completions.create self.example_formatter = example_formatter self.llm_api = llm_api diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index d601d228a..f72ccd374 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -27,10 +27,10 @@ from guardrails.utils.openai_utils import ( AsyncOpenAIClient, OpenAIClient, - get_static_openai_acreate_func, - get_static_openai_chat_acreate_func, - get_static_openai_chat_create_func, - get_static_openai_create_func, + is_static_openai_acreate_func, + is_static_openai_chat_acreate_func, + is_static_openai_chat_create_func, + is_static_openai_create_func, ) from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn from guardrails.utils.safe_get import safe_get @@ -784,9 +784,9 @@ def get_llm_ask( except ImportError: pass - if llm_api == get_static_openai_create_func(): + if is_static_openai_create_func(llm_api): return OpenAICallable(*args, **kwargs) - if llm_api == get_static_openai_chat_create_func(): + if is_static_openai_chat_create_func(llm_api): return OpenAIChatCallable(*args, **kwargs) try: @@ -1252,9 +1252,12 @@ def get_async_llm_ask( pass # these only work with openai v0 (None otherwise) - if llm_api == get_static_openai_acreate_func(): + # We no longer support OpenAI v0 + # We should drop these checks or update the logic to support + # OpenAI v1 clients instead of just static methods + if is_static_openai_acreate_func(llm_api): return AsyncOpenAICallable(*args, **kwargs) - if llm_api == get_static_openai_chat_acreate_func(): + if is_static_openai_chat_acreate_func(llm_api): return AsyncOpenAIChatCallable(*args, **kwargs) try: @@ -1293,13 +1296,13 @@ def get_llm_api_enum( ) -> Optional[LLMResource]: # TODO: Distinguish between v1 and v2 model = get_llm_ask(llm_api, *args, **kwargs) - if llm_api == get_static_openai_create_func(): + if is_static_openai_create_func(llm_api): return LLMResource.OPENAI_DOT_COMPLETION_DOT_CREATE - elif llm_api == get_static_openai_chat_create_func(): + elif is_static_openai_chat_create_func(llm_api): return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_CREATE - elif llm_api == get_static_openai_acreate_func(): + elif is_static_openai_acreate_func(llm_api): # This is always False return LLMResource.OPENAI_DOT_COMPLETION_DOT_ACREATE - elif llm_api == get_static_openai_chat_acreate_func(): + elif is_static_openai_chat_acreate_func(llm_api): # This is always False return LLMResource.OPENAI_DOT_CHAT_COMPLETION_DOT_ACREATE elif isinstance(model, LiteLLMCallable): return LLMResource.LITELLM_DOT_COMPLETION diff --git a/guardrails/utils/openai_utils/__init__.py b/guardrails/utils/openai_utils/__init__.py index 7f489812a..3785f6448 100644 --- a/guardrails/utils/openai_utils/__init__.py +++ b/guardrails/utils/openai_utils/__init__.py @@ -2,18 +2,18 @@ from .v1 import OpenAIClientV1 as OpenAIClient from .v1 import ( OpenAIServiceUnavailableError, - get_static_openai_acreate_func, - get_static_openai_chat_acreate_func, - get_static_openai_chat_create_func, - get_static_openai_create_func, + is_static_openai_acreate_func, + is_static_openai_chat_acreate_func, + is_static_openai_chat_create_func, + is_static_openai_create_func, ) __all__ = [ "AsyncOpenAIClient", "OpenAIClient", - "get_static_openai_create_func", - "get_static_openai_chat_create_func", - "get_static_openai_acreate_func", - "get_static_openai_chat_acreate_func", + "is_static_openai_create_func", + "is_static_openai_chat_create_func", + "is_static_openai_acreate_func", + "is_static_openai_chat_acreate_func", "OpenAIServiceUnavailableError", ] diff --git a/guardrails/utils/openai_utils/v1.py b/guardrails/utils/openai_utils/v1.py index 1d1bb5bac..93e24042c 100644 --- a/guardrails/utils/openai_utils/v1.py +++ b/guardrails/utils/openai_utils/v1.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncIterable, Dict, Iterable, List, cast +from typing import Any, AsyncIterable, Callable, Dict, Iterable, List, Optional, cast import openai @@ -12,20 +12,30 @@ from guardrails.telemetry import trace_llm_call, trace_operation -def get_static_openai_create_func(): - return openai.completions.create +def is_static_openai_create_func(llm_api: Optional[Callable]) -> bool: + try: + return llm_api == openai.completions.create + except openai.OpenAIError: + return False -def get_static_openai_chat_create_func(): - return openai.chat.completions.create +def is_static_openai_chat_create_func(llm_api: Optional[Callable]) -> bool: + try: + return llm_api == openai.chat.completions.create + except openai.OpenAIError: + return False -def get_static_openai_acreate_func(): - return None +def is_static_openai_acreate_func(llm_api: Optional[Callable]) -> bool: + # Because the static version of this does not exist in OpenAI 1.x + # Can we just drop these checks? + return False -def get_static_openai_chat_acreate_func(): - return None +def is_static_openai_chat_acreate_func(llm_api: Optional[Callable]) -> bool: + # Because the static version of this does not exist in OpenAI 1.x + # Can we just drop these checks? + return False OpenAIServiceUnavailableError = openai.APIError diff --git a/tests/integration_tests/test_assets/entity_extraction/__init__.py b/tests/integration_tests/test_assets/entity_extraction/__init__.py index 61b4dc3b8..e9cc75457 100644 --- a/tests/integration_tests/test_assets/entity_extraction/__init__.py +++ b/tests/integration_tests/test_assets/entity_extraction/__init__.py @@ -33,10 +33,12 @@ # Compiled prompts COMPILED_PROMPT = reader("compiled_prompt.txt") +NON_OPENAI_COMPILED_PROMPT = reader("non_openai_compiled_prompt.txt") COMPILED_PROMPT_WITHOUT_INSTRUCTIONS = reader( "compiled_prompt_without_instructions.txt" ) COMPILED_PROMPT_REASK = reader("compiled_prompt_reask.txt") +NON_OPENAI_COMPILED_PROMPT_REASK = reader("non_openai_compiled_prompt_reask.txt") COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS = reader( "compiled_prompt_reask_without_instructions.txt" ) @@ -82,8 +84,10 @@ __all__ = [ "COMPILED_PROMPT", + "NON_OPENAI_COMPILED_PROMPT", "COMPILED_PROMPT_WITHOUT_INSTRUCTIONS", "COMPILED_PROMPT_REASK", + "NON_OPENAI_COMPILED_PROMPT_REASK", "COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS", "COMPILED_INSTRUCTIONS", "COMPILED_INSTRUCTIONS_REASK", diff --git a/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt.txt b/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt.txt new file mode 100644 index 000000000..1ee8e898d --- /dev/null +++ b/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt.txt @@ -0,0 +1,131 @@ + +Given the following document, answer the following questions. If the answer doesn't exist in the document, enter 'None'. + +2/25/23, 7:59 PM about:blank +about:blank 1/4 +PRICING INFORMATION +INTEREST RATES AND INTEREST CHARGES +Purchase Annual +Percentage Rate (APR) 0% Intro APR for the first 18 months that your Account is open. +After that, 19.49%. This APR will vary with the market based on the Prime +Rate. +a +My Chase Loan +SM APR 19.49%. This APR will vary with the market based on the Prime Rate. +a +Promotional offers with fixed APRs and varying durations may be available from +time to time on some accounts. +Balance Transfer APR 0% Intro APR for the first 18 months that your Account is open. +After that, 19.49%. This APR will vary with the market based on the Prime +Rate. +a +Cash Advance APR 29.49%. This APR will vary with the market based on the Prime Rate. +b +Penalty APR and When +It Applies +Up to 29.99%. This APR will vary with the market based on the Prime Rate. +c +We may apply the Penalty APR to your account if you: +fail to make a Minimum Payment by the date and time that it is due; or +make a payment to us that is returned unpaid. +How Long Will the Penalty APR Apply?: If we apply the Penalty APR for +either of these reasons, the Penalty APR could potentially remain in effect +indefinitely. +How to Avoid Paying +Interest on Purchases +Your due date will be a minimum of 21 days after the close of each billing cycle. +We will not charge you interest on new purchases if you pay your entire balance +or Interest Saving Balance by the due date each month. We will begin charging +interest on balance transfers and cash advances on the transaction date. +Minimum Interest +Charge +None +Credit Card Tips from +the Consumer Financial +Protection Bureau +To learn more about factors to consider when applying for or using a credit card, +visit the website of the Consumer Financial Protection Bureau at +http://www.consumerfinance.gov/learnmore. +FEES +Annual Membership +Fee +None +My Chase Plan +SM Fee +(fixed finance charge) +Monthly fee of 0% of the amount of each eligible purchase transaction or +amount selected to create a My Chase Plan while in the 0% Intro Purchase +APR period. +After that, monthly fee of 1.72% of the amount of each eligible purchase +transaction or amount selected to create a My Chase Plan. The My Chase Plan +Fee will be determined at the time each My Chase Plan is created and will +remain the same until the My Chase Plan is paid in full. +d +Transaction Fees +Balance Transfers Intro fee of either $5 or 3% of the amount of each transfer, whichever is greater, +on transfers made within 60 days of account opening. After that: Either $5 or 5% +of the amount of each transfer, whichever is greater. +Cash Advances Either $10 or 5% of the amount of each transaction, whichever is greater. +2/25/23, 7:59 PM about:blank +about:blank 2/4 +Foreign Transactions 3% of the amount of each transaction in U.S. dollars. +Penalty Fees +Late Payment Up to $40. +Over-the-Credit-Limit None +Return Payment Up to $40. +Return Check None +Note: This account may not be eligible for balance transfers. +Loss of Intro APR: We will end your introductory APR if any required Minimum Payment is 60 days late, and +apply the Penalty APR. +How We Will Calculate Your Balance: We use the daily balance method (including new transactions). +Prime Rate: Variable APRs are based on the 7.75% Prime Rate as of 2/7/2023. +aWe add 11.74% to the Prime Rate to determine the Purchase/My Chase Loan/Balance Transfer APR. +Maximum APR 29.99%. +bWe add 21.74% to the Prime Rate to determine the Cash Advance APR. Maximum APR 29.99%. +cWe add up to 26.99% to the Prime Rate to determine the Penalty APR. Maximum APR 29.99%. +dMy Chase Plan Fee: The My Chase Plan Fee is calculated at the time each plan is created and is based on +the amount of each purchase transaction or amount selected to create the plan, the number of billing periods +you choose to pay the balance in full, and other factors. The monthly and aggregate dollar amount of your My +Chase Plan Fee will be disclosed during the activation of each My Chase Plan. +MILITARY LENDING ACT NOTICE: Federal law provides important protections to members of the Armed +Forces and their dependents relating to extensions of consumer credit. In general, the cost of consumer credit +to a member of the Armed Forces and his or her dependent may not exceed an annual percentage rate of 36 +percent. This rate must include, as applicable to the credit transaction or account: the costs associated with +credit insurance premiums; fees for ancillary products sold in connection with the credit transaction; any +application fee charged (other than certain application fees for specified credit transactions or accounts); and +any participation fee charged (other than certain participation fees for a credit card account). To receive this +information and a description of your payment obligation verbally, please call 1-800-235-9978. +TERMS & CONDITIONS +Authorization: When you respond to this credit card offer from JPMorgan Chase Bank, N.A., Member FDIC, a +subsidiary of JPMorgan Chase & Co. ("Chase", "we", or "us"), you agree to the following: +1. You authorize us to obtain credit bureau reports, employment, and income information about you that we +will use when considering your application for credit. We may obtain and use information about your +accounts with us and others such as Checking, Deposit, Investment, and Utility accounts from credit +bureaus and other entities. You also authorize us to obtain credit bureau reports and any other +information about you in connection with: 1) extensions of credit on your account; 2) the administration, +review or collection of your account; and 3) offering you enhanced or additional products and services. If +you ask, we will tell you the name and address of the credit bureau from which we obtained a report +about you. +2. If an account is opened, you will receive a Cardmember Agreement with your card(s). You agree to the +terms of this agreement by: using the account or any card, authorizing their use, or making any payment +on the account. +3. By providing your mobile ph + + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + + + + + + + + + + +ONLY return a valid JSON object (no other text is necessary). The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. diff --git a/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt_reask.txt b/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt_reask.txt new file mode 100644 index 000000000..7398047fb --- /dev/null +++ b/tests/integration_tests/test_assets/entity_extraction/non_openai_compiled_prompt_reask.txt @@ -0,0 +1,41 @@ + +I was given the following JSON response, which had problems due to incorrect values. + +{ + "fees": [ + { + "name": { + "incorrect_value": "my chase plan", + "error_messages": [ + "must be exactly two words" + ] + } + }, + { + "name": { + "incorrect_value": "over-the-credit-limit", + "error_messages": [ + "must be exactly two words" + ] + } + } + ] +} + +Help me correct the incorrect values based on the given error messages. + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + + + + + + + + +ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`. diff --git a/tests/integration_tests/test_async.py b/tests/integration_tests/test_async.py index ddab76016..01627ed44 100644 --- a/tests/integration_tests/test_async.py +++ b/tests/integration_tests/test_async.py @@ -17,11 +17,10 @@ async def mock_llm(*args, **kwargs): @pytest.mark.asyncio -@pytest.mark.parametrize("multiprocessing_validators", (True, False)) -async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: bool): +async def test_entity_extraction_with_reask(mocker): """Test that the entity extraction works with re-asking.""" mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -37,13 +36,6 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: response_token_count=1234, ), ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) - mocker.patch( - "guardrails.validators.Validator.run_in_separate_process", - new=multiprocessing_validators, - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REASK) @@ -75,9 +67,9 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: # For orginal prompt and output first = call.iterations.first - assert first.inputs.prompt == Prompt(entity_extraction.COMPILED_PROMPT) + assert first.inputs.prompt == Prompt(entity_extraction.NON_OPENAI_COMPILED_PROMPT) # Same as above - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert first.prompt_tokens_consumed == 123 assert first.completion_tokens_consumed == 1234 assert first.raw_output == entity_extraction.LLM_OUTPUT @@ -85,9 +77,11 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: # For re-asked prompt and output final = call.iterations.last - assert final.inputs.prompt == Prompt(entity_extraction.COMPILED_PROMPT_REASK) + assert final.inputs.prompt == Prompt( + entity_extraction.NON_OPENAI_COMPILED_PROMPT_REASK + ) # Same as above - assert call.reask_prompts.last == entity_extraction.COMPILED_PROMPT_REASK + assert call.reask_prompts.last == entity_extraction.NON_OPENAI_COMPILED_PROMPT_REASK # TODO: Re-enable once field level reasking is supported # assert final.raw_output == entity_extraction.LLM_OUTPUT_REASK @@ -98,7 +92,7 @@ async def test_entity_extraction_with_reask(mocker, multiprocessing_validators: @pytest.mark.asyncio async def test_entity_extraction_with_noop(mocker): mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -107,9 +101,6 @@ async def test_entity_extraction_with_noop(mocker): response_token_count=1234, ) ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_NOOP) final_output = await guard( @@ -137,7 +128,7 @@ async def test_entity_extraction_with_noop(mocker): assert call.iterations.length == 1 # For orginal prompt and output - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT assert call.validation_response == entity_extraction.VALIDATED_OUTPUT_NOOP @@ -145,7 +136,7 @@ async def test_entity_extraction_with_noop(mocker): @pytest.mark.asyncio async def test_entity_extraction_with_noop_pydantic(mocker): mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -154,9 +145,6 @@ async def test_entity_extraction_with_noop_pydantic(mocker): response_token_count=1234, ) ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_pydantic( entity_extraction.PYDANTIC_RAIL_WITH_NOOP, @@ -182,7 +170,7 @@ async def test_entity_extraction_with_noop_pydantic(mocker): assert call.iterations.length == 1 # For orginal prompt and output - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT assert call.validation_response == entity_extraction.VALIDATED_OUTPUT_NOOP @@ -191,7 +179,7 @@ async def test_entity_extraction_with_noop_pydantic(mocker): async def test_entity_extraction_with_filter(mocker): """Test that the entity extraction works with re-asking.""" mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -200,9 +188,6 @@ async def test_entity_extraction_with_filter(mocker): response_token_count=1234, ) ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FILTER) @@ -223,7 +208,7 @@ async def test_entity_extraction_with_filter(mocker): assert call.iterations.length == 1 # For orginal prompt and output - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT assert call.validation_response == entity_extraction.VALIDATED_OUTPUT_FILTER assert call.guarded_output is None @@ -234,7 +219,7 @@ async def test_entity_extraction_with_filter(mocker): async def test_entity_extraction_with_fix(mocker): """Test that the entity extraction works with re-asking.""" mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -243,9 +228,6 @@ async def test_entity_extraction_with_fix(mocker): response_token_count=1234, ) ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FIX) @@ -265,7 +247,7 @@ async def test_entity_extraction_with_fix(mocker): assert guard.history.length == 1 # For orginal prompt and output - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT assert call.guarded_output == entity_extraction.VALIDATED_OUTPUT_FIX @@ -274,7 +256,7 @@ async def test_entity_extraction_with_fix(mocker): async def test_entity_extraction_with_refrain(mocker): """Test that the entity extraction works with re-asking.""" mock_invoke_llm = mocker.patch( - "guardrails.llm_providers.AsyncOpenAICallable.invoke_llm", + "guardrails.llm_providers.AsyncArbitraryCallable.invoke_llm", ) mock_invoke_llm.side_effect = [ LLMResponse( @@ -283,9 +265,6 @@ async def test_entity_extraction_with_refrain(mocker): response_token_count=1234, ) ] - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", return_value=mock_llm - ) content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REFRAIN) @@ -305,7 +284,7 @@ async def test_entity_extraction_with_refrain(mocker): assert guard.history.length == 1 # For orginal prompt and output - assert call.compiled_prompt == entity_extraction.COMPILED_PROMPT + assert call.compiled_prompt == entity_extraction.NON_OPENAI_COMPILED_PROMPT assert call.raw_outputs.last == entity_extraction.LLM_OUTPUT assert call.guarded_output == entity_extraction.VALIDATED_OUTPUT_REFRAIN diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 533dc5ebc..61598c7f4 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -2,6 +2,7 @@ import importlib import json import os +import openai from typing import List, Optional, Union import pytest @@ -15,10 +16,6 @@ from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.classes.validation.validation_result import FailResult from guardrails.guard import Guard -from guardrails.utils.openai_utils import ( - get_static_openai_chat_create_func, - get_static_openai_create_func, -) from guardrails.actions.reask import FieldReAsk from tests.integration_tests.test_assets.validators import ( RegexMatch, @@ -182,7 +179,7 @@ def test_entity_extraction_with_reask( guard = guard_initializer(rail, prompt) final_output: ValidationOutcome = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, max_tokens=2000, @@ -267,7 +264,7 @@ def test_entity_extraction_with_noop(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -313,7 +310,7 @@ def test_entity_extraction_with_filter(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -348,7 +345,7 @@ def test_entity_extraction_with_fix(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -384,7 +381,7 @@ def test_entity_extraction_with_refrain(mocker, rail, prompt): content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -430,7 +427,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = guard_initializer(rail, prompt, instructions) final_output = guard( - llm_api=get_static_openai_chat_create_func(), + llm_api=openai.chat.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -460,7 +457,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_LIST) _, final_output, *rest = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, num_reasks=1, ) assert final_output == string.LIST_LLM_OUTPUT @@ -487,7 +484,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio entity_extraction.OPTIONAL_PROMPT_COMPLETION_MODEL, None, None, - get_static_openai_create_func(), + openai.completions.create, entity_extraction.COMPILED_PROMPT, None, entity_extraction.COMPILED_PROMPT_REASK, @@ -504,7 +501,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio entity_extraction.OPTIONAL_PROMPT_CHAT_MODEL, entity_extraction.OPTIONAL_INSTRUCTIONS_CHAT_MODEL, None, - get_static_openai_chat_create_func(), + openai.chat.completions.create, entity_extraction.COMPILED_PROMPT_WITHOUT_INSTRUCTIONS, entity_extraction.COMPILED_INSTRUCTIONS, entity_extraction.COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS, @@ -521,7 +518,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio None, None, entity_extraction.OPTIONAL_MSG_HISTORY, - get_static_openai_chat_create_func(), + openai.chat.completions.create, None, None, entity_extraction.COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS, @@ -558,7 +555,7 @@ def test_entity_extraction_with_reask_with_optional_prompts( for o in llm_outputs ] mock_openai_invoke_llm = None - if llm_api == get_static_openai_create_func(): + if llm_api == openai.completions.create: mock_openai_invoke_llm = mocker.patch( "guardrails.llm_providers.OpenAICallable._invoke_llm" ) @@ -669,7 +666,7 @@ def test_skeleton_reask(mocker): entity_extraction.RAIL_SPEC_WITH_SKELETON_REASK ) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, max_tokens=1000, num_reasks=1, @@ -713,7 +710,7 @@ def test_string_with_message_history_reask(mocker): guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_MSG_HISTORY) final_output = guard( - llm_api=get_static_openai_chat_create_func(), + llm_api=openai.chat.completions.create, msg_history=string.MOVIE_MSG_HISTORY, temperature=0.0, model="gpt-3.5-turbo", @@ -769,7 +766,7 @@ def test_pydantic_with_message_history_reask(mocker): guard = gd.Guard.from_pydantic(output_class=pydantic.WITH_MSG_HISTORY) final_output = guard( - llm_api=get_static_openai_chat_create_func(), + llm_api=openai.chat.completions.create, msg_history=string.MOVIE_MSG_HISTORY, temperature=0.0, model="gpt-3.5-turbo", @@ -818,7 +815,7 @@ def test_sequential_validator_log_is_not_duplicated(mocker): ) guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -861,7 +858,7 @@ def test_in_memory_validator_log_is_not_duplicated(mocker): ) guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -933,7 +930,7 @@ def test_guard_with_top_level_list_return_type(mocker, rail, prompt): guard = guard_initializer(rail, prompt=prompt) - output = guard(llm_api=get_static_openai_create_func()) + output = guard(llm_api=openai.completions.create) # Validate the output assert output.validated_output == [ @@ -989,7 +986,7 @@ def test_string_output(mocker): guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, ) @@ -1071,7 +1068,7 @@ class Tasks(BaseModel): ] final_output = guard( - llm_api=get_static_openai_chat_create_func(), + llm_api=openai.chat.completions.create, msg_history=[ { "role": "user", @@ -1125,7 +1122,7 @@ def test_string_reask(mocker): guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING_REASK) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, max_tokens=100, @@ -1442,3 +1439,7 @@ def test_guard_use_many_same_instance_on_two_guards(self, mocker): guard_2.parse("some-other-name") assert init_spy.call_count == 1 + + +class TestCustomLLMApi: + pass diff --git a/tests/integration_tests/test_multi_reask.py b/tests/integration_tests/test_multi_reask.py index d1a6d3d42..e3037b93d 100644 --- a/tests/integration_tests/test_multi_reask.py +++ b/tests/integration_tests/test_multi_reask.py @@ -1,6 +1,6 @@ +import openai import guardrails as gd from guardrails.classes.llm.llm_response import LLMResponse -from guardrails.utils.openai_utils import get_static_openai_create_func import tests.integration_tests.test_assets.validators # noqa @@ -33,7 +33,7 @@ def test_multi_reask(mocker): guard = gd.Guard.from_rail_string(python_rail.RAIL_SPEC_WITH_VALIDATOR_PARALLELISM) guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, engine="text-davinci-003", num_reasks=5, ) diff --git a/tests/integration_tests/test_parsing.py b/tests/integration_tests/test_parsing.py index 3d63c2e2a..c185d4d67 100644 --- a/tests/integration_tests/test_parsing.py +++ b/tests/integration_tests/test_parsing.py @@ -1,11 +1,10 @@ from typing import Dict - +import openai import pytest import guardrails as gd from guardrails import register_validator from guardrails.classes.llm.llm_response import LLMResponse -from guardrails.utils.openai_utils import get_static_openai_chat_create_func from guardrails.validator_base import OnFailAction from guardrails.classes.validation.validation_result import FailResult, ValidationResult @@ -149,7 +148,7 @@ def always_fail(value: str, metadata: Dict) -> ValidationResult: guard.parse( llm_output="Tomato Cheese Pizza", - llm_api=get_static_openai_chat_create_func(), + llm_api=openai.chat.completions.create, msg_history=[ {"role": "system", "content": "Some content"}, {"role": "user", "content": "Some prompt"}, diff --git a/tests/integration_tests/test_pydantic.py b/tests/integration_tests/test_pydantic.py index df3f1e1b5..9bc60105a 100644 --- a/tests/integration_tests/test_pydantic.py +++ b/tests/integration_tests/test_pydantic.py @@ -1,6 +1,6 @@ import json from typing import Dict, List - +import openai import pytest from pydantic import BaseModel @@ -8,10 +8,6 @@ from guardrails.classes.generic.stack import Stack from guardrails.classes.history.call import Call from guardrails.classes.llm.llm_response import LLMResponse -from guardrails.utils.openai_utils import ( - get_static_openai_chat_create_func, - get_static_openai_create_func, -) from .mock_llm_outputs import pydantic from .test_assets.pydantic import VALIDATED_RESPONSE_REASK_PROMPT, ListOfPeople @@ -42,7 +38,7 @@ def test_pydantic_with_reask(mocker): guard = gd.Guard.from_pydantic(ListOfPeople, prompt=VALIDATED_RESPONSE_REASK_PROMPT) final_output = guard( - get_static_openai_create_func(), + openai.completions.create, engine="text-davinci-003", max_tokens=512, temperature=0.5, @@ -128,7 +124,7 @@ def test_pydantic_with_full_schema_reask(mocker): guard = gd.Guard.from_pydantic(ListOfPeople, prompt=VALIDATED_RESPONSE_REASK_PROMPT) final_output = guard( - get_static_openai_chat_create_func(), + openai.chat.completions.create, model="gpt-3.5-turbo", max_tokens=512, temperature=0.5, diff --git a/tests/integration_tests/test_python_rail.py b/tests/integration_tests/test_python_rail.py index 2dea75602..eecfa4738 100644 --- a/tests/integration_tests/test_python_rail.py +++ b/tests/integration_tests/test_python_rail.py @@ -2,16 +2,13 @@ from datetime import date, time from typing import List, Literal, Union +import openai import pytest from pydantic import BaseModel, Field, field_validator, model_validator import guardrails as gd from guardrails import Validator, register_validator from guardrails.classes.llm.llm_response import LLMResponse -from guardrails.utils.openai_utils import ( - get_static_openai_chat_create_func, - get_static_openai_create_func, -) from guardrails.types import OnFailAction from guardrails.classes.validation.validation_result import ( FailResult, @@ -131,7 +128,7 @@ def test_python_rail(mocker): # Guardrails runs validation and fixes the first failing output through reasking final_output = guard( - get_static_openai_chat_create_func(), + openai.chat.completions.create, prompt_params={"director": "Christopher Nolan"}, num_reasks=2, full_schema_reask=False, @@ -202,7 +199,7 @@ def test_python_string(mocker): instructions=instructions, ) final_output = guard( - llm_api=get_static_openai_create_func(), + llm_api=openai.completions.create, prompt_params={"ingredients": "tomato, cheese, sour cream"}, num_reasks=1, max_tokens=100, diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 78069c349..74feeb325 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -1,16 +1,13 @@ import json import re from typing import Any, Dict, List - +import openai import pytest from pydantic import BaseModel, Field from guardrails import Guard, Validator, register_validator from guardrails.async_guard import AsyncGuard from guardrails.errors import ValidationError -from guardrails.utils.openai_utils import ( - get_static_openai_create_func, -) from guardrails.actions.reask import FieldReAsk from guardrails.actions.refrain import Refrain from guardrails.actions.filter import Filter @@ -771,10 +768,10 @@ def custom_llm(*args, **kwargs): [ ( OnFailAction.REASK, - "Prompt validation failed: incorrect_value='What kind of pet should I get?\\n\\nJson Output:\\n\\n' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='What kind', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa + "Prompt validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='What kind', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa "Instructions validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='What kind', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa "Message history validation failed: incorrect_value='What kind of pet should I get?' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='What kind', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa - "Prompt validation failed: incorrect_value='\\nThis is not two words\\n\\n\\nString Output:\\n\\n' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='This is', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa + "Prompt validation failed: incorrect_value='\\nThis is not two words\\n' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='This is', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa "Instructions validation failed: incorrect_value='\\nThis also is not two words\\n' fail_results=[FailResult(outcome='fail', error_message='must be exactly two words', fix_value='This also', error_spans=None, metadata=None, validated_chunk=None)] additional_properties={} path=None", # noqa ), ( @@ -819,11 +816,6 @@ async def custom_llm(*args, **kwargs): "Input Validation did not raise as expected!" ) - mocker.patch( - "guardrails.llm_providers.get_static_openai_acreate_func", - return_value=custom_llm, - ) - # with_prompt_validation guard = AsyncGuard.from_pydantic(output_class=Pet) guard.use(TwoWords(on_fail=on_fail), on="prompt") @@ -944,7 +936,7 @@ def test_input_validation_mismatch_raise(): with pytest.raises(ValueError): guard( - get_static_openai_create_func(), + openai.completions.create, msg_history=[ { "role": "user", @@ -959,7 +951,7 @@ def test_input_validation_mismatch_raise(): with pytest.raises(ValueError): guard( - get_static_openai_create_func(), + openai.completions.create, msg_history=[ { "role": "user", @@ -974,6 +966,6 @@ def test_input_validation_mismatch_raise(): with pytest.raises(ValueError): guard( - get_static_openai_create_func(), + openai.completions.create, prompt="What kind of pet should I get?", ) From 259d9173a68569c934ca88c1e1c637313e733b9a Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 18 Sep 2024 15:22:26 -0500 Subject: [PATCH 2/5] enforce method signature for custom llm callables --- guardrails/formatters/json_formatter.py | 34 ++- guardrails/llm_providers.py | 50 +++- .../test_assets/custom_llm.py | 21 ++ tests/integration_tests/test_async.py | 21 +- tests/integration_tests/test_guard.py | 216 +++++++++++++++++- tests/integration_tests/test_parsing.py | 11 +- tests/unit_tests/classes/history/test_call.py | 2 +- tests/unit_tests/mocks/mock_custom_llm.py | 24 +- tests/unit_tests/test_async_guard.py | 8 +- tests/unit_tests/test_llm_providers.py | 111 ++++++++- tests/unit_tests/test_validator_base.py | 12 +- 11 files changed, 455 insertions(+), 55 deletions(-) create mode 100644 tests/integration_tests/test_assets/custom_llm.py diff --git a/guardrails/formatters/json_formatter.py b/guardrails/formatters/json_formatter.py index 791b823f1..491999f91 100644 --- a/guardrails/formatters/json_formatter.py +++ b/guardrails/formatters/json_formatter.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Union +from typing import Dict, List, Optional, Union from guardrails.formatters.base_formatter import BaseFormatter from guardrails.llm_providers import ( @@ -99,32 +99,48 @@ def wrap_callable(self, llm_callable) -> ArbitraryCallable: if isinstance(llm_callable, HuggingFacePipelineCallable): model = llm_callable.init_kwargs["pipeline"] - return ArbitraryCallable( - lambda p: json.dumps( + + def fn( + prompt: str, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + return json.dumps( Jsonformer( model=model.model, tokenizer=model.tokenizer, json_schema=self.output_schema, - prompt=p, + prompt=prompt, )() ) - ) + + return ArbitraryCallable(fn) elif isinstance(llm_callable, HuggingFaceModelCallable): # This will not work because 'model_generate' is the .gen method. # model = self.api.init_kwargs["model_generate"] # Use the __self__ to grab the base mode for passing into JF. model = llm_callable.init_kwargs["model_generate"].__self__ tokenizer = llm_callable.init_kwargs["tokenizer"] - return ArbitraryCallable( - lambda p: json.dumps( + + def fn( + prompt: str, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + return json.dumps( Jsonformer( model=model, tokenizer=tokenizer, json_schema=self.output_schema, - prompt=p, + prompt=prompt, )() ) - ) + + return ArbitraryCallable(fn) else: raise ValueError( "JsonFormatter can only be used with HuggingFace*Callable." diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index f72ccd374..b48bf031f 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -1,5 +1,6 @@ import asyncio +import inspect from typing import ( Any, Awaitable, @@ -711,6 +712,26 @@ def _invoke_llm(self, prompt: str, pipeline: Any, *args, **kwargs) -> LLMRespons class ArbitraryCallable(PromptCallableBase): def __init__(self, llm_api: Optional[Callable] = None, *args, **kwargs): + llm_api_args = inspect.getfullargspec(llm_api) + if not llm_api_args.args: + raise ValueError( + "Custom LLM callables must accept" + " at least one positional argument for prompt!" + ) + if not llm_api_args.varkw: + raise ValueError("Custom LLM callables must accept **kwargs!") + if ( + not llm_api_args.kwonlyargs + or "instructions" not in llm_api_args.kwonlyargs + or "msg_history" not in llm_api_args.kwonlyargs + ): + warnings.warn( + "We recommend including 'instructions' and 'msg_history'" + " as keyword-only arguments for custom LLM callables." + " Doing so ensures these arguments are not uninentionally" + " passed through to other calls via **kwargs.", + UserWarning, + ) self.llm_api = llm_api super().__init__(*args, **kwargs) @@ -1190,6 +1211,26 @@ async def invoke_llm( class AsyncArbitraryCallable(AsyncPromptCallableBase): def __init__(self, llm_api: Callable, *args, **kwargs): + llm_api_args = inspect.getfullargspec(llm_api) + if not llm_api_args.args: + raise ValueError( + "Custom LLM callables must accept" + " at least one positional argument for prompt!" + ) + if not llm_api_args.varkw: + raise ValueError("Custom LLM callables must accept **kwargs!") + if ( + not llm_api_args.kwonlyargs + or "instructions" not in llm_api_args.kwonlyargs + or "msg_history" not in llm_api_args.kwonlyargs + ): + warnings.warn( + "We recommend including 'instructions' and 'msg_history'" + " as keyword-only arguments for custom LLM callables." + " Doing so ensures these arguments are not uninentionally" + " passed through to other calls via **kwargs.", + UserWarning, + ) self.llm_api = llm_api super().__init__(*args, **kwargs) @@ -1241,7 +1282,7 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse: def get_async_llm_ask( - llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs + llm_api: Callable[..., Awaitable[Any]], *args, **kwargs ) -> AsyncPromptCallableBase: try: import litellm @@ -1268,11 +1309,12 @@ def get_async_llm_ask( except ImportError: pass - return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs) + if llm_api is not None: + return AsyncArbitraryCallable(*args, llm_api=llm_api, **kwargs) def model_is_supported_server_side( - llm_api: Optional[Union[Callable, Callable[[Any], Awaitable[Any]]]] = None, + llm_api: Optional[Union[Callable, Callable[..., Awaitable[Any]]]] = None, *args, **kwargs, ) -> bool: @@ -1292,7 +1334,7 @@ def model_is_supported_server_side( # CONTINUOUS FIXME: Update with newly supported LLMs def get_llm_api_enum( - llm_api: Callable[[Any], Awaitable[Any]], *args, **kwargs + llm_api: Callable[..., Awaitable[Any]], *args, **kwargs ) -> Optional[LLMResource]: # TODO: Distinguish between v1 and v2 model = get_llm_ask(llm_api, *args, **kwargs) diff --git a/tests/integration_tests/test_assets/custom_llm.py b/tests/integration_tests/test_assets/custom_llm.py new file mode 100644 index 000000000..643d0278f --- /dev/null +++ b/tests/integration_tests/test_assets/custom_llm.py @@ -0,0 +1,21 @@ +from typing import Dict, List, Optional + + +def mock_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, +) -> str: + return "" + + +async def mock_async_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, +) -> str: + return "" diff --git a/tests/integration_tests/test_async.py b/tests/integration_tests/test_async.py index 01627ed44..9766a1322 100644 --- a/tests/integration_tests/test_async.py +++ b/tests/integration_tests/test_async.py @@ -3,6 +3,7 @@ from guardrails import AsyncGuard, Prompt from guardrails.utils import docs_utils from guardrails.classes.llm.llm_response import LLMResponse +from tests.integration_tests.test_assets.custom_llm import mock_async_llm from tests.integration_tests.test_assets.fixtures import ( # noqa fixture_llm_output, fixture_rail_spec, @@ -12,10 +13,6 @@ from .mock_llm_outputs import entity_extraction -async def mock_llm(*args, **kwargs): - return "" - - @pytest.mark.asyncio async def test_entity_extraction_with_reask(mocker): """Test that the entity extraction works with re-asking.""" @@ -45,7 +42,7 @@ async def test_entity_extraction_with_reask(mocker): preprocess_prompt_spy = mocker.spy(async_runner, "preprocess_prompt") final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -104,7 +101,7 @@ async def test_entity_extraction_with_noop(mocker): content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_NOOP) final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -151,7 +148,7 @@ async def test_entity_extraction_with_noop_pydantic(mocker): prompt=entity_extraction.PYDANTIC_PROMPT, ) final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -192,7 +189,7 @@ async def test_entity_extraction_with_filter(mocker): content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FILTER) final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -232,7 +229,7 @@ async def test_entity_extraction_with_fix(mocker): content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_FIX) final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -269,7 +266,7 @@ async def test_entity_extraction_with_refrain(mocker): content = docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf") guard = AsyncGuard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_REFRAIN) final_output = await guard( - llm_api=mock_llm, + llm_api=mock_async_llm, prompt_params={"document": content[:6000]}, num_reasks=1, ) @@ -295,7 +292,7 @@ async def test_rail_spec_output_parse(rail_spec, llm_output, validated_output): guard = AsyncGuard.from_rail_string(rail_spec) output = await guard.parse( llm_output, - llm_api=mock_llm, + llm_api=mock_async_llm, ) assert output.validated_output == validated_output @@ -334,7 +331,7 @@ async def test_string_rail_spec_output_parse( guard: AsyncGuard = AsyncGuard.from_rail_string(string_rail_spec) output = await guard.parse( string_llm_output, - llm_api=mock_llm, + llm_api=mock_async_llm, num_reasks=0, ) assert output.validated_output == validated_string_output diff --git a/tests/integration_tests/test_guard.py b/tests/integration_tests/test_guard.py index 61598c7f4..4038ae18d 100644 --- a/tests/integration_tests/test_guard.py +++ b/tests/integration_tests/test_guard.py @@ -3,7 +3,7 @@ import json import os import openai -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import pytest from pydantic import BaseModel, Field @@ -877,7 +877,7 @@ def test_in_memory_validator_log_is_not_duplicated(mocker): OneLine.run_in_separate_process = separate_proc_bak -def test_enum_datatype(): +def test_enum_datatype(mocker): class TaskStatus(enum.Enum): not_started = "not started" on_hold = "on hold" @@ -886,16 +886,29 @@ class TaskStatus(enum.Enum): class Task(BaseModel): status: TaskStatus + return_value = pydantic.LLM_OUTPUT_ENUM + + def custom_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + nonlocal return_value + return return_value + guard = gd.Guard.from_pydantic(Task) _, dict_o, *rest = guard( - lambda *args, **kwargs: pydantic.LLM_OUTPUT_ENUM, + custom_llm, prompt="What is the status of this task?", ) assert dict_o == {"status": "not started"} + return_value = pydantic.LLM_OUTPUT_ENUM_2 guard = gd.Guard.from_pydantic(Task) result = guard( - lambda *args, **kwargs: pydantic.LLM_OUTPUT_ENUM_2, + custom_llm, prompt="What is the status of this task REALLY?", num_reasks=0, ) @@ -1441,5 +1454,198 @@ def test_guard_use_many_same_instance_on_two_guards(self, mocker): assert init_spy.call_count == 1 +# These tests are descriptive not prescriptive. +# The method signature for custom LLM APIs needs to be updated to make more sense. +# With 0.6.0 we can drop the baggage of +# the prompt and instructions and just pass in the messages. class TestCustomLLMApi: - pass + def test_with_prompt(self, mocker): + mock_llm = mocker.Mock() + + def custom_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + mock_llm( + prompt, + *args, + instructions=instructions, + msg_history=msg_history, + **kwargs, + ) + return "Not really, no. I'm just a static function." + + guard = Guard().use( + ValidLength(1, 100), + ) + output = guard( + llm_api=custom_llm, + prompt="Can you generate a list of 10 things that are not food?", + ) + + assert output.validation_passed is True + assert output.validated_output == "Not really, no. I'm just a static function." + mock_llm.assert_called_once_with( + "Can you generate a list of 10 things that are not food?", + instructions=None, + msg_history=None, + temperature=0, + ) + + def test_with_prompt_and_instructions(self, mocker): + mock_llm = mocker.Mock() + + def custom_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + mock_llm( + prompt, + *args, + instructions=instructions, + msg_history=msg_history, + **kwargs, + ) + return "Not really, no. I'm just a static function." + + guard = Guard().use( + ValidLength(1, 100), + ) + output = guard( + llm_api=custom_llm, + prompt="Can you generate a list of 10 things that are not food?", + instructions="You are a list generator. You can generate a list of things that are not food.", # noqa + ) + + assert output.validation_passed is True + assert output.validated_output == "Not really, no. I'm just a static function." + mock_llm.assert_called_once_with( + "Can you generate a list of 10 things that are not food?", + instructions="You are a list generator. You can generate a list of things that are not food.", # noqa + msg_history=None, + temperature=0, + ) + + def test_with_msg_history(self, mocker): + mock_llm = mocker.Mock() + + def custom_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + mock_llm( + prompt, + *args, + instructions=instructions, + msg_history=msg_history, + **kwargs, + ) + return "Not really, no. I'm just a static function." + + guard = Guard().use( + ValidLength(1, 100), + ) + output = guard( + llm_api=custom_llm, + msg_history=[ + { + "role": "system", + "content": "You are a list generator. You can generate a list of things that are not food.", # noqa + }, + { + "role": "user", + "content": "Can you generate a list of 10 things that are not food?", # noqa + }, + ], + ) + + assert output.validation_passed is True + assert output.validated_output == "Not really, no. I'm just a static function." + mock_llm.assert_called_once_with( + None, + instructions=None, + msg_history=[ + { + "role": "system", + "content": "You are a list generator. You can generate a list of things that are not food.", # noqa + }, + { + "role": "user", + "content": "Can you generate a list of 10 things that are not food?", # noqa + }, + ], + temperature=0, + ) + + def test_with_messages(self, mocker): + mock_llm = mocker.Mock() + + def custom_llm( + prompt: Optional[str] = None, + *args, + instructions: Optional[str] = None, + msg_history: Optional[List[Dict[str, str]]] = None, + **kwargs, + ) -> str: + mock_llm( + prompt, + *args, + instructions=instructions, + msg_history=msg_history, + **kwargs, + ) + return "Not really, no. I'm just a static function." + + guard = Guard().use( + ValidLength(1, 100), + ) + output = guard( + llm_api=custom_llm, + messages=[ + { + "role": "system", + "content": "You are a list generator. You can generate a list of things that are not food.", # noqa + }, + { + "role": "user", + "content": "Can you generate a list of 10 things that are not food?", # noqa + }, + ], + ) + + assert output.validation_passed is True + assert output.validated_output == "Not really, no. I'm just a static function." + mock_llm.assert_called_once_with( + None, + instructions=None, + msg_history=[ + { + "role": "system", + "content": "You are a list generator. You can generate a list of things that are not food.", # noqa + }, + { + "role": "user", + "content": "Can you generate a list of 10 things that are not food?", # noqa + }, + ], + messages=[ + { + "role": "system", + "content": "You are a list generator. You can generate a list of things that are not food.", # noqa + }, + { + "role": "user", + "content": "Can you generate a list of 10 things that are not food?", # noqa + }, + ], + temperature=0, + ) diff --git a/tests/integration_tests/test_parsing.py b/tests/integration_tests/test_parsing.py index c185d4d67..22eb4dd23 100644 --- a/tests/integration_tests/test_parsing.py +++ b/tests/integration_tests/test_parsing.py @@ -7,6 +7,7 @@ from guardrails.classes.llm.llm_response import LLMResponse from guardrails.validator_base import OnFailAction from guardrails.classes.validation.validation_result import FailResult, ValidationResult +from tests.integration_tests.test_assets.custom_llm import mock_async_llm, mock_llm from .test_assets import pydantic, string @@ -33,11 +34,8 @@ def test_parsing_reask(mocker): output_class=pydantic.PersonalDetails, prompt=pydantic.PARSING_INITIAL_PROMPT ) - def mock_callable(prompt: str): - return - final_output = guard( - llm_api=mock_callable, + llm_api=mock_llm, prompt_params={"document": pydantic.PARSING_DOCUMENT}, num_reasks=1, ) @@ -87,11 +85,8 @@ async def test_async_parsing_reask(mocker): output_class=pydantic.PersonalDetails, prompt=pydantic.PARSING_INITIAL_PROMPT ) - async def mock_async_callable(prompt: str): - return - final_output = await guard( - llm_api=mock_async_callable, + llm_api=mock_async_llm, prompt_params={"document": pydantic.PARSING_DOCUMENT}, num_reasks=1, ) diff --git a/tests/unit_tests/classes/history/test_call.py b/tests/unit_tests/classes/history/test_call.py index f1f2eb034..8002ed65b 100644 --- a/tests/unit_tests/classes/history/test_call.py +++ b/tests/unit_tests/classes/history/test_call.py @@ -47,7 +47,7 @@ def test_empty_initialization(): def test_non_empty_initialization(): # Call input - def custom_llm(): + def custom_llm(prompt, *args, instructions=None, msg_history=None, **kwargs): return "Hello there!" llm_api = custom_llm diff --git a/tests/unit_tests/mocks/mock_custom_llm.py b/tests/unit_tests/mocks/mock_custom_llm.py index a35c8cb43..3f62044bf 100644 --- a/tests/unit_tests/mocks/mock_custom_llm.py +++ b/tests/unit_tests/mocks/mock_custom_llm.py @@ -6,16 +6,22 @@ def __init__(self, times_called=0, response="Hello world!"): self.times_called = times_called self.response = response - def fail_retryable(self, prompt: str, *args, **kwargs) -> str: + def fail_retryable( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: if self.times_called == 0: self.times_called = self.times_called + 1 raise OpenAIServiceUnavailableError("ServiceUnavailableError") return self.response - def fail_non_retryable(self, prompt: str, *args, **kwargs) -> str: + def fail_non_retryable( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: raise Exception("Non-Retryable Error!") - def succeed(self, prompt: str, *args, **kwargs) -> str: + def succeed( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: return self.response @@ -24,14 +30,20 @@ def __init__(self, times_called=0, response="Hello world!"): self.times_called = times_called self.response = response - async def fail_retryable(self, prompt: str, *args, **kwargs) -> str: + async def fail_retryable( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: if self.times_called == 0: self.times_called = self.times_called + 1 raise OpenAIServiceUnavailableError("ServiceUnavailableError") return self.response - async def fail_non_retryable(self, prompt: str, *args, **kwargs) -> str: + async def fail_non_retryable( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: raise Exception("Non-Retryable Error!") - async def succeed(self, prompt: str, *args, **kwargs) -> str: + async def succeed( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> str: return self.response diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py index d8b331b8d..0fca512da 100644 --- a/tests/unit_tests/test_async_guard.py +++ b/tests/unit_tests/test_async_guard.py @@ -6,6 +6,7 @@ from guardrails.utils import args, kwargs, on_fail from guardrails.utils.validator_utils import verify_metadata_requirements from guardrails.types import OnFailAction +from tests.integration_tests.test_assets.custom_llm import mock_async_llm from tests.integration_tests.test_assets.validators import ( EndsWith, LowerCase, @@ -96,17 +97,14 @@ async def test_required_metadata(spec, metadata, error_message): not_missing_keys = verify_metadata_requirements(metadata, guard._validators) assert not_missing_keys == [] - async def mock_llm(*args, **kwargs): - return "" - # test async guard with pytest.raises(ValueError) as excinfo: await guard.parse("{}") - await guard.parse("{}", llm_api=mock_llm, num_reasks=0) + await guard.parse("{}", llm_api=mock_async_llm, num_reasks=0) assert str(excinfo.value) == error_message response = await guard.parse( - "{}", metadata=metadata, llm_api=mock_llm, num_reasks=0 + "{}", metadata=metadata, llm_api=mock_async_llm, num_reasks=0 ) assert response.error is None diff --git a/tests/unit_tests/test_llm_providers.py b/tests/unit_tests/test_llm_providers.py index 6bd2e022e..9d3d453ee 100644 --- a/tests/unit_tests/test_llm_providers.py +++ b/tests/unit_tests/test_llm_providers.py @@ -13,6 +13,7 @@ LLMResponse, PromptCallableException, chat_prompt, + get_async_llm_ask, get_llm_ask, ) from guardrails.utils.safe_get import safe_get_with_brackets @@ -401,7 +402,9 @@ def create() -> MockResponse: class ReturnTempCallable(Callable): - def __call__(*args, **kwargs) -> Any: + def __call__( + self, prompt: str, *args, instructions=None, msg_history=None, **kwargs + ) -> Any: return "" @@ -591,6 +594,112 @@ def test_get_llm_ask_litellm(): assert isinstance(prompt_callable, LiteLLMCallable) +def test_get_llm_ask_custom_llm(): + from guardrails.llm_providers import ArbitraryCallable + + def my_llm(prompt: str, *, instructions=None, msg_history=None, **kwargs) -> str: + return f"Hello {prompt}!" + + prompt_callable = get_llm_ask(my_llm) + + assert isinstance(prompt_callable, ArbitraryCallable) + + +def test_get_llm_ask_custom_llm_warning(): + from guardrails.llm_providers import ArbitraryCallable + + def my_llm(prompt: str, **kwargs) -> str: + return f"Hello {prompt}!" + + with pytest.warns( + UserWarning, + match=( + "We recommend including 'instructions' and 'msg_history'" + " as keyword-only arguments for custom LLM callables." + " Doing so ensures these arguments are not uninentionally" + " passed through to other calls via \\*\\*kwargs." + ), + ): + prompt_callable = get_llm_ask(my_llm) + + assert isinstance(prompt_callable, ArbitraryCallable) + + +def test_get_llm_ask_custom_llm_must_accept_prompt(): + def my_llm() -> str: + return "Hello!" + + with pytest.raises( + ValueError, + match="Custom LLM callables must accept at least one positional argument for prompt!", # noqa + ): + get_llm_ask(my_llm) + + +def test_get_llm_ask_custom_llm_must_accept_kwargs(): + def my_llm(prompt: str) -> str: + return f"Hello {prompt}!" + + with pytest.raises( + ValueError, match="Custom LLM callables must accept \\*\\*kwargs!" + ): + get_llm_ask(my_llm) + + +def test_get_async_llm_ask_custom_llm(): + from guardrails.llm_providers import AsyncArbitraryCallable + + async def my_llm( + prompt: str, *, instructions=None, msg_history=None, **kwargs + ) -> str: + return f"Hello {prompt}!" + + prompt_callable = get_async_llm_ask(my_llm) + + assert isinstance(prompt_callable, AsyncArbitraryCallable) + + +def test_get_async_llm_ask_custom_llm_warning(): + from guardrails.llm_providers import AsyncArbitraryCallable + + async def my_llm(prompt: str, **kwargs) -> str: + return f"Hello {prompt}!" + + with pytest.warns( + UserWarning, + match=( + "We recommend including 'instructions' and 'msg_history'" + " as keyword-only arguments for custom LLM callables." + " Doing so ensures these arguments are not uninentionally" + " passed through to other calls via \\*\\*kwargs." + ), + ): + prompt_callable = get_async_llm_ask(my_llm) + + assert isinstance(prompt_callable, AsyncArbitraryCallable) + + +def test_get_async_llm_ask_custom_llm_must_accept_prompt(): + async def my_llm() -> str: + return "Hello!" + + with pytest.raises( + ValueError, + match="Custom LLM callables must accept at least one positional argument for prompt!", # noqa + ): + get_async_llm_ask(my_llm) + + +def test_get_async_llm_ask_custom_llm_must_accept_kwargs(): + def my_llm(prompt: str) -> str: + return f"Hello {prompt}!" + + with pytest.raises( + ValueError, match="Custom LLM callables must accept \\*\\*kwargs!" + ): + get_async_llm_ask(my_llm) + + def test_chat_prompt(): # raises when neither msg_history or prompt are provided with pytest.raises(PromptCallableException): diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 74feeb325..6a8459822 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -421,7 +421,7 @@ class Pet(BaseModel): def test_input_validation_fix(mocker): - def mock_llm_api(*args, **kwargs): + def mock_llm_api(prompt, *args, instructions=None, msg_history=None, **kwargs): return json.dumps({"name": "Fluffy"}) # fix returns an amended value for prompt/instructions validation, @@ -514,7 +514,9 @@ def mock_llm_api(*args, **kwargs): @pytest.mark.asyncio async def test_async_input_validation_fix(mocker): - async def mock_llm_api(*args, **kwargs): + async def mock_llm_api( + prompt, *args, instructions=None, msg_history=None, **kwargs + ) -> str: return json.dumps({"name": "Fluffy"}) # fix returns an amended value for prompt/instructions validation, @@ -660,7 +662,7 @@ def test_input_validation_fail( guard = Guard.from_pydantic(output_class=Pet) guard.use(TwoWords(on_fail=on_fail), on="prompt") - def custom_llm(*args, **kwargs): + def custom_llm(prompt, *args, instructions=None, msg_history=None, **kwargs): raise Exception( "LLM was called when it should not have been!" "Input Validation did not raise as expected!" @@ -810,7 +812,9 @@ async def test_input_validation_fail_async( unstructured_prompt_error, unstructured_instructions_error, ): - async def custom_llm(*args, **kwargs): + async def custom_llm( + prompt, *args, instructions=None, msg_history=None, **kwargs + ) -> str: raise Exception( "LLM was called when it should not have been!" "Input Validation did not raise as expected!" From 78aef23e148e13fe31a01f5dd6651522db576830 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 18 Sep 2024 16:14:07 -0500 Subject: [PATCH 3/5] custom llm wrappers --- docs/how_to_guides/using_llms.md | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/how_to_guides/using_llms.md b/docs/how_to_guides/using_llms.md index 434e48975..18b0eb773 100644 --- a/docs/how_to_guides/using_llms.md +++ b/docs/how_to_guides/using_llms.md @@ -289,3 +289,49 @@ for chunk in stream_chunk_generator ## Other LLMs See LiteLLM’s documentation [here](https://docs.litellm.ai/docs/providers) for details on many other llms. + +## Custom LLM Wrappers +In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that takes accepts a prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string. + +```python +from guardrails import Guard +from guardrails.hub import ProfanityFree + +# Create a Guard class +guard = Guard().use(ProfanityFree()) + +# Function that takes the prompt as a string and returns the LLM output as string +def my_llm_api( + prompt: Optional[str] = None, + *, + instruction: Optional[str] = None, + msg_history: Optional[list[dict]] = None, + **kwargs +) -> str: + """Custom LLM API wrapper. + + At least one of prompt, instruction or msg_history should be provided. + + Args: + prompt (str): The prompt to be passed to the LLM API + instruction (str): The instruction to be passed to the LLM API + msg_history (list[dict]): The message history to be passed to the LLM API + **kwargs: Any additional arguments to be passed to the LLM API + + Returns: + str: The output of the LLM API + """ + + # Call your LLM API here + # What you pass to the llm will depend on what arguments it accepts. + llm_output = some_llm(prompt, instruction, msg_history, **kwargs) + + return llm_output + +# Wrap your LLM API call +validated_response = guard( + my_llm_api, + prompt="Can you generate a list of 10 things that are not food?", + **kwargs, +) +``` \ No newline at end of file From fa2272ad9ab6739b96d7ba29dffbc207a0f4290c Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 18 Sep 2024 16:14:33 -0500 Subject: [PATCH 4/5] instruction -> instructions --- docs/how_to_guides/using_llms.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/how_to_guides/using_llms.md b/docs/how_to_guides/using_llms.md index 18b0eb773..422995ae5 100644 --- a/docs/how_to_guides/using_llms.md +++ b/docs/how_to_guides/using_llms.md @@ -304,7 +304,7 @@ guard = Guard().use(ProfanityFree()) def my_llm_api( prompt: Optional[str] = None, *, - instruction: Optional[str] = None, + instructions: Optional[str] = None, msg_history: Optional[list[dict]] = None, **kwargs ) -> str: @@ -324,7 +324,7 @@ def my_llm_api( # Call your LLM API here # What you pass to the llm will depend on what arguments it accepts. - llm_output = some_llm(prompt, instruction, msg_history, **kwargs) + llm_output = some_llm(prompt, instructions, msg_history, **kwargs) return llm_output From 513b3e3a45f7b225cb69bf460795531bd4342c49 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 18 Sep 2024 16:17:49 -0500 Subject: [PATCH 5/5] clarify prompt parameter --- docs/how_to_guides/using_llms.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/how_to_guides/using_llms.md b/docs/how_to_guides/using_llms.md index 422995ae5..6ec4641da 100644 --- a/docs/how_to_guides/using_llms.md +++ b/docs/how_to_guides/using_llms.md @@ -291,7 +291,7 @@ for chunk in stream_chunk_generator See LiteLLM’s documentation [here](https://docs.litellm.ai/docs/providers) for details on many other llms. ## Custom LLM Wrappers -In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that takes accepts a prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string. +In case you're using an LLM that isn't natively supported by Guardrails and you don't want to use LiteLLM, you can build a custom LLM API wrapper. In order to use a custom LLM, create a function that accepts a positional argument for the prompt as a string and any other arguments that you want to pass to the LLM API as keyword args. The function should return the output of the LLM API as a string. ```python from guardrails import Guard