Skip to content

Commit

Permalink
feat: Improve MD_JSON mode (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl authored Mar 6, 2024
1 parent 00bedcf commit 3e44a6b
Show file tree
Hide file tree
Showing 21 changed files with 277 additions and 89 deletions.
4 changes: 2 additions & 2 deletions docs/concepts/caching.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def extract(data) -> UserDetail:
start = time.perf_counter() # (1)
model = extract("Extract jason is 25 years old")
print(f"Time taken: {time.perf_counter() - start}")
#> Time taken: 0.8392175831831992
#> Time taken: 0.7583793329977198

start = time.perf_counter()
model = extract("Extract jason is 25 years old") # (2)
print(f"Time taken: {time.perf_counter() - start}")
#> Time taken: 8.33999365568161e-07
#> Time taken: 4.3330073822289705e-06
```

1. Using `time.perf_counter()` to measure the time taken to run the function is better than using `time.time()` because it's more accurate and less susceptible to system clock changes.
Expand Down
4 changes: 2 additions & 2 deletions docs/concepts/lists.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ async def print_iterable_results():
)
async for m in model:
print(m)
#> name='John Smith' age=30
#> name='Mary Jane' age=28
#> name='John Doe' age=30
#> name='Jane Doe' age=28


import asyncio
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/maybe.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ print(user2.model_dump_json(indent=2))
{
"result": null,
"error": false,
"message": null
"message": "Unknown user"
}
"""
```
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class SearchQuery(BaseModel):

def execute(self):
print(f"Searching for {self.query} of type {self.query_type}")
#> Searching for cat pictures of type image
#> Searching for cat of type image
return "Results for cat"


Expand Down
4 changes: 2 additions & 2 deletions docs/concepts/parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ function_calls = client.chat.completions.create(

for fc in function_calls:
print(fc)
#> location='Toronto' units='imperial'
#> location='Toronto' units='metric'
#> location='Dallas' units='imperial'
#> query='who won the super bowl'
#> query='super bowl winner'
```

1. Set the mode to `PARALLEL_TOOLS` to enable parallel function calling.
Expand Down
4 changes: 2 additions & 2 deletions docs/concepts/partial.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ print(extraction.model_dump_json(indent=2))
"twitter": "@CodeMaster2023"
}
],
"date": "2024-03-15",
"date": "March 15th, 2024",
"location": "Grand Tech Arena located at 4521 Innovation Drive",
"budget": 50000,
"deadline": "2024-02-20"
"deadline": "February 20th"
}
"""
```
Expand Down
8 changes: 4 additions & 4 deletions docs/concepts/raw_response.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ user: UserExtract = client.chat.completions.create(
print(user._raw_response)
"""
ChatCompletion(
id='chatcmpl-8u9bsrmmf5YjZyfCtQymoZV8LK1qg',
id='chatcmpl-8zpltT9vXJdO5OE3AfDsOhAUr911A',
choices=[
Choice(
finish_reason='stop',
Expand All @@ -37,7 +37,7 @@ ChatCompletion(
function_call=None,
tool_calls=[
ChatCompletionMessageToolCall(
id='call_O5rpXf47YgXiYrYWv45yZUeM',
id='call_vXI3foz7jqlzFILU9pwuYJZB',
function=Function(
arguments='{"name":"Jason","age":25}', name='UserExtract'
),
Expand All @@ -47,10 +47,10 @@ ChatCompletion(
),
)
],
created=1708394000,
created=1709747709,
model='gpt-3.5-turbo-0125',
object='chat.completion',
system_fingerprint='fp_69829325d0',
system_fingerprint='fp_2b778c6b35',
usage=CompletionUsage(completion_tokens=9, prompt_tokens=82, total_tokens=91),
)
"""
Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/reask_validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ except ValidationError as e:
"""
1 validation error for QuestionAnswer
answer
Assertion failed, The statement promotes objectionable behavior by encouraging evil and theft. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
Assertion failed, The statement promotes objectionable behavior by encouraging evil and stealing, which goes against the rule of not saying objectionable things. [type=assertion_error, input_value='The meaning of life is to be evil and steal', input_type=str]
For further information visit https://errors.pydantic.dev/2.6/v/assertion_error
"""
```
Expand Down
1 change: 1 addition & 0 deletions docs/hub/batch_classification_langsmith.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ client = instructor.patch(client, mode=instructor.Mode.TOOLS)
# Rate limit the number of requests
sem = asyncio.Semaphore(5)


# Use an Enum to define the types of questions
class QuestionType(Enum):
CONTACT = "CONTACT"
Expand Down
5 changes: 2 additions & 3 deletions docs/hub/groq.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ client = Groq(
)

# By default, the patch function will patch the ChatCompletion.create and ChatCompletion.create methods to support the response_model parameter
client = instructor.patch(
client, mode=instructor.Mode.MD_JSON
)
client = instructor.patch(client, mode=instructor.Mode.MD_JSON)


# Now, we can use the response_model parameter using only a base model
# rather than having to use the OpenAISchema class
Expand Down
6 changes: 2 additions & 4 deletions docs/hub/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ from mistralai.client import MistralClient
# enables `response_model` in chat call
client = MistralClient()

patched_chat = instructor.patch(
create=client.chat,
mode=instructor.Mode.MISTRAL_TOOLS
)
patched_chat = instructor.patch(create=client.chat, mode=instructor.Mode.MISTRAL_TOOLS)

if __name__ == "__main__":

class UserDetails(BaseModel):
name: str
age: int
Expand Down
12 changes: 6 additions & 6 deletions docs/hub/pandas_df.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ if __name__ == "__main__":
assert isinstance(df, pd.DataFrame)
print(df)
"""
Party Years Served
Party Years Served
President
Joe Biden Democratic 2021-
Donald Trump Republican 2017-2021
Barack Obama Democratic 2009-2017
George W. Bush Republican 2001-2009
Bill Clinton Democratic 1993-2001
Joe Biden Democratic 2021-Current
Donald Trump Republican 2017-2021
Barack Obama Democratic 2009-2017
George W. Bush Republican 2001-2009
Bill Clinton Democratic 1993-2001
"""

table = extract_table(
Expand Down
9 changes: 3 additions & 6 deletions docs/hub/youtube_clips.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ instructor hub pull --slug youtube-clips --py > youtube_clips.py
```python
from youtube_transcript_api import YouTubeTranscriptApi
from pydantic import BaseModel, Field
from typing import List, Dict, Generator, Iterable
from typing import List, Generator, Iterable
import instructor
import openai

Expand All @@ -24,12 +24,12 @@ client = instructor.patch(openai.OpenAI())

def extract_video_id(url: str) -> str | None:
import re

match = re.search(r"v=([a-zA-Z0-9_-]+)", url)
if match:
return match.group(1)



class TranscriptSegment(BaseModel):
source_id: int
start: float
Expand All @@ -51,9 +51,7 @@ def get_transcript_with_timing(


class YoutubeClip(BaseModel):
title: str = Field(
description="Specific and informative title for the clip."
)
title: str = Field(description="Specific and informative title for the clip.")
description: str = Field(
description="A detailed description of the clip, including notable quotes or phrases."
)
Expand Down Expand Up @@ -98,7 +96,6 @@ if __name__ == "__main__":
console = Console()
url = Prompt.ask("Enter a YouTube URL")


with console.status("[bold green]Processing YouTube URL...") as status:
video_id = extract_video_id(url)

Expand Down
8 changes: 4 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ print(response.model_dump_json(indent=2))
print(user._raw_response.model_dump_json(indent=2))
"""
{
"id": "chatcmpl-8u9e2TV3ehCgLsRxNLLeAbzpEmBuZ",
"id": "chatcmpl-8zplvRbNM8iKSVa3Ld9NmVICeXZZ9",
"choices": [
{
"finish_reason": "stop",
Expand All @@ -127,7 +127,7 @@ print(response.model_dump_json(indent=2))
"function_call": null,
"tool_calls": [
{
"id": "call_3ZuQhfteTLEy7CUokjwnLBHr",
"id": "call_V5FRMSXrHFFTTqTjpwA76h7t",
"function": {
"arguments": "{\"name\":\"Jason\",\"age\":25}",
"name": "UserDetail"
Expand All @@ -138,10 +138,10 @@ print(response.model_dump_json(indent=2))
}
}
],
"created": 1708394134,
"created": 1709747711,
"model": "gpt-3.5-turbo-0125",
"object": "chat.completion",
"system_fingerprint": "fp_69829325d0",
"system_fingerprint": "fp_2b778c6b35",
"usage": {
"completion_tokens": 9,
"prompt_tokens": 81,
Expand Down
12 changes: 11 additions & 1 deletion instructor/dsl/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, Field, create_model

from instructor.function_calls import OpenAISchema, Mode
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async


class IterableBase:
Expand All @@ -13,13 +14,21 @@ def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[BaseModel, None, None]: # noqa: ARG003
json_chunks = cls.extract_json(completion, mode)

if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream(json_chunks)

yield from cls.tasks_from_chunks(json_chunks, **kwargs)

@classmethod
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
json_chunks = cls.extract_json_async(completion, mode)

if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream_async(json_chunks)

return cls.tasks_from_chunks_async(json_chunks, **kwargs)

@classmethod
Expand Down Expand Up @@ -110,13 +119,14 @@ async def extract_json_async(

@staticmethod
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
start_index = s.find("{")
for i, c in enumerate(s):
if c == "{":
stack += 1
if c == "}":
stack -= 1
if stack == 0:
return s[: i + 1], s[i + 2 :]
return s[start_index : i + 1], s[i + 2 :]
return None, s


Expand Down
9 changes: 9 additions & 0 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from instructor.function_calls import Mode
from instructor.dsl.partialjson import JSONParser
from instructor.utils import extract_json_from_stream, extract_json_from_stream_async

parser = JSONParser()
T_Model = TypeVar("T_Model", bound=BaseModel)
Expand All @@ -35,13 +36,21 @@ def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[T_Model, None, None]:
json_chunks = cls.extract_json(completion, mode)

if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream(json_chunks)

yield from cls.model_from_chunks(json_chunks, **kwargs)

@classmethod
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[T_Model, None]:
json_chunks = cls.extract_json_async(completion, mode)

if mode == Mode.MD_JSON:
json_chunks = extract_json_from_stream_async(json_chunks)

return cls.model_from_chunks_async(json_chunks, **kwargs)

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import logging
from openai.types.chat import ChatCompletion
from instructor.utils import extract_json_from_codeblock

T = TypeVar("T")

Expand Down Expand Up @@ -135,6 +136,9 @@ def from_response(
strict=strict,
)
elif mode in {Mode.JSON, Mode.JSON_SCHEMA, Mode.MD_JSON}:
if mode == Mode.MD_JSON:
message.content = extract_json_from_codeblock(message.content or "")

model_response = cls.model_validate_json(
message.content, # type: ignore
context=validation_context,
Expand Down
Loading

0 comments on commit 3e44a6b

Please sign in to comment.