diff --git a/docs/concepts/caching.md b/docs/concepts/caching.md index 25b519cc0..af34c7f1f 100644 --- a/docs/concepts/caching.md +++ b/docs/concepts/caching.md @@ -33,12 +33,12 @@ def extract(data) -> UserDetail: start = time.perf_counter() # (1) model = extract("Extract jason is 25 years old") print(f"Time taken: {time.perf_counter() - start}") -#> Time taken: 0.8392175831831992 +#> Time taken: 0.7583793329977198 start = time.perf_counter() model = extract("Extract jason is 25 years old") # (2) print(f"Time taken: {time.perf_counter() - start}") -#> Time taken: 8.33999365568161e-07 +#> Time taken: 4.3330073822289705e-06 ``` 1. Using `time.perf_counter()` to measure the time taken to run the function is better than using `time.time()` because it's more accurate and less susceptible to system clock changes. diff --git a/docs/concepts/lists.md b/docs/concepts/lists.md index ad98ce7f8..658a5100d 100644 --- a/docs/concepts/lists.md +++ b/docs/concepts/lists.md @@ -157,8 +157,8 @@ async def print_iterable_results(): ) async for m in model: print(m) - #> name='John Smith' age=30 - #> name='Mary Jane' age=28 + #> name='John Doe' age=30 + #> name='Jane Doe' age=28 import asyncio diff --git a/docs/concepts/maybe.md b/docs/concepts/maybe.md index 183340b1d..f25f48f48 100644 --- a/docs/concepts/maybe.md +++ b/docs/concepts/maybe.md @@ -89,7 +89,7 @@ print(user2.model_dump_json(indent=2)) { "result": null, "error": false, - "message": null + "message": "Unknown user" } """ ``` diff --git a/docs/concepts/models.md b/docs/concepts/models.md index 7a6d1e461..293eabe2e 100644 --- a/docs/concepts/models.md +++ b/docs/concepts/models.md @@ -150,7 +150,7 @@ class SearchQuery(BaseModel): def execute(self): print(f"Searching for {self.query} of type {self.query_type}") - #> Searching for cat pictures of type image + #> Searching for cat of type image return "Results for cat" diff --git a/docs/concepts/parallel.md b/docs/concepts/parallel.md index 2152533e0..4ff493d3a 100644 --- a/docs/concepts/parallel.md +++ b/docs/concepts/parallel.md @@ -44,9 +44,9 @@ function_calls = client.chat.completions.create( for fc in function_calls: print(fc) - #> location='Toronto' units='imperial' + #> location='Toronto' units='metric' #> location='Dallas' units='imperial' - #> query='who won the super bowl' + #> query='super bowl winner' ``` 1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling. diff --git a/docs/concepts/partial.md b/docs/concepts/partial.md index 7a543e750..c6dd4bbd3 100644 --- a/docs/concepts/partial.md +++ b/docs/concepts/partial.md @@ -119,10 +119,10 @@ print(extraction.model_dump_json(indent=2)) "twitter": "@CodeMaster2023" } ], - "date": "2024-03-15", + "date": "March 15th, 2024", "location": "Grand Tech Arena located at 4521 Innovation Drive", "budget": 50000, - "deadline": "2024-02-20" + "deadline": "February 20th" } """ ``` diff --git a/docs/concepts/raw_response.md b/docs/concepts/raw_response.md index 7248282c0..b6c9d1c4e 100644 --- a/docs/concepts/raw_response.md +++ b/docs/concepts/raw_response.md @@ -25,7 +25,7 @@ user: UserExtract = client.chat.completions.create( print(user._raw_response) """ ChatCompletion( - id='chatcmpl-8u9bsrmmf5YjZyfCtQymoZV8LK1qg', + id='chatcmpl-8zpltT9vXJdO5OE3AfDsOhAUr911A', choices=[ Choice( finish_reason='stop', @@ -37,7 +37,7 @@ ChatCompletion( function_call=None, tool_calls=[ ChatCompletionMessageToolCall( - id='call_O5rpXf47YgXiYrYWv45yZUeM', + id='call_vXI3foz7jqlzFILU9pwuYJZB', function=Function( arguments='{"name":"Jason","age":25}', name='UserExtract' ), @@ -47,10 +47,10 @@ ChatCompletion( ), ) ], - created=1708394000, + created=1709747709, model='gpt-3.5-turbo-0125', object='chat.completion', - system_fingerprint='fp_69829325d0', + system_fingerprint='fp_2b778c6b35', usage=CompletionUsage(completion_tokens=9, prompt_tokens=82, total_tokens=91), ) """ diff --git a/docs/concepts/reask_validation.md b/docs/concepts/reask_validation.md index dc21efe49..891d9e2a3 100644 --- a/docs/concepts/reask_validation.md +++ b/docs/concepts/reask_validation.md @@ -91,7 +91,7 @@ except ValidationError as e: """ 1 validation error for QuestionAnswer answer - Assertion failed, The statement promotes objectionable behavior by encouraging evil and theft. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str] + Assertion failed, The statement promotes objectionable behavior by encouraging evil and stealing, which goes against the rule of not saying objectionable things. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str] For further information visit https://errors.pydantic.dev/2.6/v/assertion_error """ ``` diff --git a/docs/hub/batch_classification_langsmith.md b/docs/hub/batch_classification_langsmith.md index f26e017c1..225e643ab 100644 --- a/docs/hub/batch_classification_langsmith.md +++ b/docs/hub/batch_classification_langsmith.md @@ -47,6 +47,7 @@ client = instructor.patch(client, mode=instructor.Mode.TOOLS) # Rate limit the number of requests sem = asyncio.Semaphore(5) + # Use an Enum to define the types of questions class QuestionType(Enum): CONTACT = "CONTACT" diff --git a/docs/hub/groq.md b/docs/hub/groq.md index 7e2c31d9f..c4217047e 100644 --- a/docs/hub/groq.md +++ b/docs/hub/groq.md @@ -51,9 +51,8 @@ client = Groq( ) # By default, the patch function will patch the ChatCompletion.create and ChatCompletion.create methods to support the response_model parameter -client = instructor.patch( - client, mode=instructor.Mode.MD_JSON -) +client = instructor.patch(client, mode=instructor.Mode.MD_JSON) + # Now, we can use the response_model parameter using only a base model # rather than having to use the OpenAISchema class diff --git a/docs/hub/mistral.md b/docs/hub/mistral.md index 74b7ad97d..995c2b916 100644 --- a/docs/hub/mistral.md +++ b/docs/hub/mistral.md @@ -50,12 +50,10 @@ from mistralai.client import MistralClient # enables `response_model` in chat call client = MistralClient() -patched_chat = instructor.patch( - create=client.chat, - mode=instructor.Mode.MISTRAL_TOOLS -) +patched_chat = instructor.patch(create=client.chat, mode=instructor.Mode.MISTRAL_TOOLS) if __name__ == "__main__": + class UserDetails(BaseModel): name: str age: int diff --git a/docs/hub/pandas_df.md b/docs/hub/pandas_df.md index 3148f4967..aa826e950 100644 --- a/docs/hub/pandas_df.md +++ b/docs/hub/pandas_df.md @@ -108,13 +108,13 @@ if __name__ == "__main__": assert isinstance(df, pd.DataFrame) print(df) """ - Party Years Served + Party Years Served President - Joe Biden Democratic 2021- - Donald Trump Republican 2017-2021 - Barack Obama Democratic 2009-2017 - George W. Bush Republican 2001-2009 - Bill Clinton Democratic 1993-2001 + Joe Biden Democratic 2021-Current + Donald Trump Republican 2017-2021 + Barack Obama Democratic 2009-2017 + George W. Bush Republican 2001-2009 + Bill Clinton Democratic 1993-2001 """ table = extract_table( diff --git a/docs/hub/youtube_clips.md b/docs/hub/youtube_clips.md index 7470524ae..2bcbaa121 100644 --- a/docs/hub/youtube_clips.md +++ b/docs/hub/youtube_clips.md @@ -15,7 +15,7 @@ instructor hub pull --slug youtube-clips --py > youtube_clips.py ```python from youtube_transcript_api import YouTubeTranscriptApi from pydantic import BaseModel, Field -from typing import List, Dict, Generator, Iterable +from typing import List, Generator, Iterable import instructor import openai @@ -24,12 +24,12 @@ client = instructor.patch(openai.OpenAI()) def extract_video_id(url: str) -> str | None: import re + match = re.search(r"v=([a-zA-Z0-9_-]+)", url) if match: return match.group(1) - class TranscriptSegment(BaseModel): source_id: int start: float @@ -51,9 +51,7 @@ def get_transcript_with_timing( class YoutubeClip(BaseModel): - title: str = Field( - description="Specific and informative title for the clip." - ) + title: str = Field(description="Specific and informative title for the clip.") description: str = Field( description="A detailed description of the clip, including notable quotes or phrases." ) @@ -98,7 +96,6 @@ if __name__ == "__main__": console = Console() url = Prompt.ask("Enter a YouTube URL") - with console.status("[bold green]Processing YouTube URL...") as status: video_id = extract_video_id(url) diff --git a/docs/index.md b/docs/index.md index 9dd0e0f55..54c92dffb 100644 --- a/docs/index.md +++ b/docs/index.md @@ -115,7 +115,7 @@ print(response.model_dump_json(indent=2)) print(user._raw_response.model_dump_json(indent=2)) """ { - "id": "chatcmpl-8u9e2TV3ehCgLsRxNLLeAbzpEmBuZ", + "id": "chatcmpl-8zplvRbNM8iKSVa3Ld9NmVICeXZZ9", "choices": [ { "finish_reason": "stop", @@ -127,7 +127,7 @@ print(response.model_dump_json(indent=2)) "function_call": null, "tool_calls": [ { - "id": "call_3ZuQhfteTLEy7CUokjwnLBHr", + "id": "call_V5FRMSXrHFFTTqTjpwA76h7t", "function": { "arguments": "{\"name\":\"Jason\",\"age\":25}", "name": "UserDetail" @@ -138,10 +138,10 @@ print(response.model_dump_json(indent=2)) } } ], - "created": 1708394134, + "created": 1709747711, "model": "gpt-3.5-turbo-0125", "object": "chat.completion", - "system_fingerprint": "fp_69829325d0", + "system_fingerprint": "fp_2b778c6b35", "usage": { "completion_tokens": 9, "prompt_tokens": 81, diff --git a/instructor/dsl/iterable.py b/instructor/dsl/iterable.py index 3027bb2e4..19c14da15 100644 --- a/instructor/dsl/iterable.py +++ b/instructor/dsl/iterable.py @@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, create_model from instructor.function_calls import OpenAISchema, Mode +from instructor.utils import extract_json_from_stream, extract_json_from_stream_async class IterableBase: @@ -13,6 +14,10 @@ def from_streaming_response( cls, completion: Iterable[Any], mode: Mode, **kwargs: Any ) -> Generator[BaseModel, None, None]: # noqa: ARG003 json_chunks = cls.extract_json(completion, mode) + + if mode == Mode.MD_JSON: + json_chunks = extract_json_from_stream(json_chunks) + yield from cls.tasks_from_chunks(json_chunks, **kwargs) @classmethod @@ -20,6 +25,10 @@ async def from_streaming_response_async( cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any ) -> AsyncGenerator[BaseModel, None]: json_chunks = cls.extract_json_async(completion, mode) + + if mode == Mode.MD_JSON: + json_chunks = extract_json_from_stream_async(json_chunks) + return cls.tasks_from_chunks_async(json_chunks, **kwargs) @classmethod @@ -110,13 +119,14 @@ async def extract_json_async( @staticmethod def get_object(s: str, stack: int) -> Tuple[Optional[str], str]: + start_index = s.find("{") for i, c in enumerate(s): if c == "{": stack += 1 if c == "}": stack -= 1 if stack == 0: - return s[: i + 1], s[i + 2 :] + return s[start_index : i + 1], s[i + 2 :] return None, s diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 9b4fb1ba7..81ae5d6b6 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -24,6 +24,7 @@ from instructor.function_calls import Mode from instructor.dsl.partialjson import JSONParser +from instructor.utils import extract_json_from_stream, extract_json_from_stream_async parser = JSONParser() T_Model = TypeVar("T_Model", bound=BaseModel) @@ -35,6 +36,10 @@ def from_streaming_response( cls, completion: Iterable[Any], mode: Mode, **kwargs: Any ) -> Generator[T_Model, None, None]: json_chunks = cls.extract_json(completion, mode) + + if mode == Mode.MD_JSON: + json_chunks = extract_json_from_stream(json_chunks) + yield from cls.model_from_chunks(json_chunks, **kwargs) @classmethod @@ -42,6 +47,10 @@ async def from_streaming_response_async( cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any ) -> AsyncGenerator[T_Model, None]: json_chunks = cls.extract_json_async(completion, mode) + + if mode == Mode.MD_JSON: + json_chunks = extract_json_from_stream_async(json_chunks) + return cls.model_from_chunks_async(json_chunks, **kwargs) @classmethod diff --git a/instructor/function_calls.py b/instructor/function_calls.py index eeebf0125..bc9cbd795 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -7,6 +7,7 @@ import warnings import logging from openai.types.chat import ChatCompletion +from instructor.utils import extract_json_from_codeblock T = TypeVar("T") @@ -135,6 +136,9 @@ def from_response( strict=strict, ) elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}: + if mode == Mode.MD_JSON: + message.content = extract_json_from_codeblock(message.content or "") + model_response = cls.model_validate_json( message.content, # type: ignore context=validation_context, diff --git a/instructor/patch.py b/instructor/patch.py index 737b54d75..0c1572fd2 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -1,6 +1,5 @@ # type: ignore[all] import inspect -import json import logging from textwrap import dedent from collections.abc import Iterable @@ -24,8 +23,6 @@ from openai import AsyncOpenAI, OpenAI from openai.types.chat import ( ChatCompletion, - ChatCompletionMessage, - ChatCompletionMessageParam, ) from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, ValidationError @@ -34,6 +31,7 @@ from instructor.dsl.parallel import ParallelBase, ParallelModel, handle_parallel_model from instructor.dsl.partial import PartialBase from instructor.dsl.simple_type import ModelAdapter, AdapterBase, is_simple_type +from instructor.utils import dump_message, update_total_usage from .function_calls import Mode, OpenAISchema, openai_schema @@ -47,35 +45,6 @@ T = TypeVar("T") -def update_total_usage(response, total_usage): - if isinstance(response, ChatCompletion) and response.usage is not None: - total_usage.completion_tokens += response.usage.completion_tokens or 0 - total_usage.prompt_tokens += response.usage.prompt_tokens or 0 - total_usage.total_tokens += response.usage.total_tokens or 0 - response.usage = total_usage # Replace each response usage with the total usage - return response - - -def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: - """Dumps a message to a dict, to be returned to the OpenAI API. - Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests - if it isn't used. - """ - ret: ChatCompletionMessageParam = { - "role": message.role, - "content": message.content or "", - } - if hasattr(message, "tool_calls") and message.tool_calls is not None: - ret["tool_calls"] = message.model_dump()["tool_calls"] - if ( - hasattr(message, "function_call") - and message.function_call is not None - and ret["content"] - ): - ret["content"] += json.dumps(message.model_dump()["function_call"]) - return ret - - def handle_response_model( response_model: T, mode: Mode = Mode.TOOLS, **kwargs ) -> Union[Type[OpenAISchema], dict]: @@ -153,12 +122,12 @@ def handle_response_model( f""" As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema:\n - {response_model.model_json_schema()['properties']} + + {response_model.model_json_schema()} + + Make sure to return an instance of the JSON, not the schema itself """ ) - # Check for nested models - if "$defs" in response_model.model_json_schema(): - message += f"\nHere are some more definitions to adhere too:\n{response_model.model_json_schema()['$defs']}" if mode == Mode.JSON: new_kwargs["response_format"] = {"type": "json_object"} @@ -172,11 +141,10 @@ def handle_response_model( elif mode == Mode.MD_JSON: new_kwargs["messages"].append( { - "role": "assistant", - "content": "Here is the perfectly correctly formatted JSON\n```json", + "role": "user", + "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", }, ) - new_kwargs["stop"] = "```" # check that the first message is a system message # if it is not, add a system message to the beginning if new_kwargs["messages"][0]["role"] != "system": @@ -402,8 +370,8 @@ async def retry_async( if mode == Mode.MD_JSON: kwargs["messages"].append( { - "role": "assistant", - "content": "```json", + "role": "user", + "content": "Return the correct JSON response within a ```json codeblock. not the JSON_SCHEMA", }, ) raise e @@ -473,13 +441,6 @@ def retry_sync( "content": f"Recall the function correctly, fix the errors and exceptions found\n{e}", } ) - if mode == Mode.MD_JSON: - kwargs["messages"].append( - { - "role": "assistant", - "content": "```json", - }, - ) raise e except RetryError as e: logger.exception(f"Failed after retries: {e.last_attempt.exception}") diff --git a/instructor/utils.py b/instructor/utils.py new file mode 100644 index 000000000..2ed3b3322 --- /dev/null +++ b/instructor/utils.py @@ -0,0 +1,83 @@ +import json +from typing import Generator, Iterable, AsyncGenerator + +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageParam, +) + + +def extract_json_from_codeblock(content: str) -> str: + first_paren = content.find("{") + last_paren = content.rfind("}") + return content[first_paren : last_paren + 1] + + +def extract_json_from_stream(chunks: Iterable[str]) -> Generator[str, None, None]: + capturing = False + brace_count = 0 + for chunk in chunks: + for char in chunk: + if char == "{": + capturing = True + brace_count += 1 + yield char + elif char == "}" and capturing: + brace_count -= 1 + yield char + if brace_count == 0: + capturing = False + break # Cease yielding upon closing the current JSON object + elif capturing: + yield char + + +async def extract_json_from_stream_async( + chunks: AsyncGenerator[str, None], +) -> AsyncGenerator[str, None]: + capturing = False + brace_count = 0 + async for chunk in chunks: + for char in chunk: + if char == "{": + capturing = True + brace_count += 1 + yield char + elif char == "}" and capturing: + brace_count -= 1 + yield char + if brace_count == 0: + capturing = False + break # Cease yielding upon closing the current JSON object + elif capturing: + yield char + + +def update_total_usage(response, total_usage): + if isinstance(response, ChatCompletion) and response.usage is not None: + total_usage.completion_tokens += response.usage.completion_tokens or 0 + total_usage.prompt_tokens += response.usage.prompt_tokens or 0 + total_usage.total_tokens += response.usage.total_tokens or 0 + response.usage = total_usage # Replace each response usage with the total usage + return response + + +def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: + """Dumps a message to a dict, to be returned to the OpenAI API. + Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests + if it isn't used. + """ + ret: ChatCompletionMessageParam = { + "role": message.role, + "content": message.content or "", + } + if hasattr(message, "tool_calls") and message.tool_calls is not None: + ret["tool_calls"] = message.model_dump()["tool_calls"] + if ( + hasattr(message, "function_call") + and message.function_call is not None + and ret["content"] + ): + ret["content"] += json.dumps(message.model_dump()["function_call"]) + return ret diff --git a/tests/openai/util.py b/tests/openai/util.py index b118e3e46..8bc658db2 100644 --- a/tests/openai/util.py +++ b/tests/openai/util.py @@ -2,6 +2,5 @@ models = ["gpt-4-turbo-preview"] modes = [ - instructor.Mode.JSON, - instructor.Mode.TOOLS, + instructor.Mode.MD_JSON, ] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 000000000..7536167f3 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,127 @@ +import json +import pytest +from instructor.utils import ( + extract_json_from_codeblock, + extract_json_from_stream, + extract_json_from_stream_async, +) + + +def test_extract_json_from_codeblock(): + example = """ + Here is a response + + ```json + { + "key": "value" + } + ``` + """ + result = extract_json_from_codeblock(example) + assert json.loads(result) == {"key": "value"} + + +def test_extract_json_from_codeblock_no_end(): + example = """ + Here is a response + + ```json + { + "key": "value", + "another_key": [{"key": {"key": "value"}}] + } + """ + result = extract_json_from_codeblock(example) + assert json.loads(result) == { + "key": "value", + "another_key": [{"key": {"key": "value"}}], + } + + +def test_extract_json_from_codeblock_no_start(): + example = """ + Here is a response + + { + "key": "value", + "another_key": [{"key": {"key": "value"}}, {"key": "value"}] + } + """ + result = extract_json_from_codeblock(example) + assert json.loads(result) == { + "key": "value", + "another_key": [{"key": {"key": "value"}}, {"key": "value"}], + } + + +def test_stream_json(): + text = """here is the json for you! + + ```json + , here + { + "key": "value", + "another_key": [{"key": {"key": "value"}}] + } + ``` + + What do you think? + """ + + def batch_strings(chunks, n=2): + batch = "" + for chunk in chunks: + for char in chunk: + batch += char + if len(batch) == n: + yield batch + batch = "" + if batch: # Yield any remaining characters in the last batch + yield batch + + result = json.loads( + "".join(list(extract_json_from_stream(batch_strings(text, n=3)))) + ) + assert result == {"key": "value", "another_key": [{"key": {"key": "value"}}]} + + +@pytest.mark.asyncio +async def test_stream_json_async(): + text = """here is the json for you! + + ```json + , here + { + "key": "value", + "another_key": [{"key": {"key": "value"}}, {"key": "value"}] + } + ``` + + What do you think? + """ + + async def batch_strings_async(chunks, n=2): + batch = "" + for chunk in chunks: + for char in chunk: + batch += char + if len(batch) == n: + yield batch + batch = "" + if batch: # Yield any remaining characters in the last batch + yield batch + + result = json.loads( + "".join( + [ + chunk + async for chunk in extract_json_from_stream_async( + batch_strings_async(text, n=3) + ) + ] + ) + ) + assert result == { + "key": "value", + "another_key": [{"key": {"key": "value"}}, {"key": "value"}], + }