Skip to content

Commit

Permalink
feat: gemini tool calling support (#726)
Browse files Browse the repository at this point in the history
Adding Gemini Tool Calling support
  • Loading branch information
ssonal authored Aug 31, 2024
1 parent c1bbfa5 commit b96e9a3
Show file tree
Hide file tree
Showing 15 changed files with 714 additions and 599 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ assert resp.age == 25
### Using Gemini Models

Make sure you [install](https://ai.google.dev/api/python/google/generativeai#setup) the Google AI Python SDK. You should set a `GOOGLE_API_KEY` environment variable with your API key.
Gemini tool calling also requires `jsonref` to be installed.

```
pip install google-generativeai
pip install google-generativeai jsonref
```

```python
Expand Down
21 changes: 21 additions & 0 deletions docs/concepts/patching.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ client = instructor.from_openai(OpenAI(), mode=instructor.Mode.TOOLS)

### Gemini Tool Calling

Gemini supports tool calling for stuctured data extraction. Gemini tool calling requires `jsonref` to be installed.

!!! warning "Limitations"
Gemini tool calling comes with some known limitations:

- `strict` Pydantic validation can fail for integer/float and enum validations
- Gemini tool calling is incompatible with Pydantic schema customizations such as examples due to API limitations and may result in errors
- Gemini can sometimes call the wrong function name, resulting in malformed or invalid json
- Gemini tool calling could fail with enum and literal field types

```python
import instructor
import google.generativeai as genai

client = instructor.from_gemini(
genai.GenerativeModel(), mode=instructor.Mode.GEMINI_TOOLS
)
```

### Gemini Vertex AI Tool Callin

This method allows us to get structured output from Gemini via tool calling with the Vertex AI SDK.

**Note:** Gemini Tool Calling is in preview and there are some limitations, you can learn more in the [Vertex AI examples notebook](../hub/vertexai.md).
Expand Down
7 changes: 4 additions & 3 deletions instructor/client_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ def from_gemini(
use_async: bool = False,
**kwargs: Any,
) -> instructor.Instructor | instructor.AsyncInstructor:
assert (
mode == instructor.Mode.GEMINI_JSON
), "Mode must be instructor.Mode.GEMINI_JSON"
assert mode in {
instructor.Mode.GEMINI_JSON,
instructor.Mode.GEMINI_TOOLS,
}, "Mode must be one of {instructor.Mode.GEMINI_JSON, instructor.Mode.GEMINI_TOOLS}"

assert isinstance(
client,
Expand Down
9 changes: 8 additions & 1 deletion instructor/dsl/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def from_streaming_response(
) -> Generator[BaseModel, None, None]: # noqa: ARG003
json_chunks = cls.extract_json(completion, mode)

if mode == Mode.MD_JSON:
if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
json_chunks = extract_json_from_stream(json_chunks)

yield from cls.tasks_from_chunks(json_chunks, **kwargs)
Expand Down Expand Up @@ -86,6 +86,13 @@ def extract_json(
yield chunk.delta.partial_json
if mode == Mode.GEMINI_JSON:
yield chunk.text
if mode == Mode.GEMINI_TOOLS:
# Gemini seems to return the entire function_call and not a chunk?
import json

resp = chunk.candidates[0].content.parts[0].function_call

yield json.dumps(type(resp).to_dict(resp)["args"]) # type:ignore
elif chunk.choices:
if mode == Mode.FUNCTIONS:
Mode.warn_mode_functions_deprecation()
Expand Down
16 changes: 9 additions & 7 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def from_streaming_response(
) -> Generator[T_Model, None, None]:
json_chunks = cls.extract_json(completion, mode)

if mode == Mode.MD_JSON:
if mode in {Mode.MD_JSON, Mode.GEMINI_TOOLS}:
json_chunks = extract_json_from_stream(json_chunks)

yield from cls.model_from_chunks(json_chunks, **kwargs)
Expand All @@ -129,9 +129,7 @@ def model_from_chunks(
partial_model = cls.get_partial_model()
for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object or "{}").encode(), partial_mode="on"
)
obj = from_json((potential_object or "{}").encode(), partial_mode="on")
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -143,9 +141,7 @@ async def model_from_chunks_async(
partial_model = cls.get_partial_model()
async for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object or "{}").encode(), partial_mode="on"
)
obj = from_json((potential_object or "{}").encode(), partial_mode="on")
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -162,6 +158,12 @@ def extract_json(
yield chunk.delta.partial_json
if mode == Mode.GEMINI_JSON:
yield chunk.text
if mode == Mode.GEMINI_TOOLS:
# Gemini seems to return the entire function_call and not a chunk?
import json

resp = chunk.candidates[0].content.parts[0].function_call
yield json.dumps(type(resp).to_dict(resp)["args"]) # type:ignore
elif chunk.choices:
if mode == Mode.FUNCTIONS:
Mode.warn_mode_functions_deprecation()
Expand Down
45 changes: 21 additions & 24 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

from instructor.exceptions import IncompleteOutputException
from instructor.mode import Mode
from instructor.utils import classproperty, extract_json_from_codeblock
from instructor.utils import (
classproperty,
extract_json_from_codeblock,
map_to_gemini_function_schema,
)


T = TypeVar("T")
Expand Down Expand Up @@ -77,25 +81,16 @@ def anthropic_schema(cls) -> dict[str, Any]:
"input_schema": cls.model_json_schema(),
}

def has_async_validators(self):
has_validators = (
len(self.__class__.get_async_validators()) > 0
or len(self.get_async_model_validators()) > 0
)

for _, attribute_value in self.__dict__.items():
if isinstance(attribute_value, OpenAISchema):
has_validators = (
has_validators or attribute_value.has_async_validators()
)

# List of items too
if isinstance(attribute_value, (list, set, tuple)):
for item in attribute_value:
if isinstance(item, OpenAISchema):
has_validators = has_validators or item.has_async_validators()
@classproperty
def gemini_schema(cls) -> Any:
import google.generativeai.types as genai_types

return has_validators
function = genai_types.FunctionDeclaration(
name=cls.openai_schema["name"],
description=cls.openai_schema["description"],
parameters=map_to_gemini_function_schema(cls.openai_schema["parameters"]),
)
return function

@classmethod
def from_response(
Expand Down Expand Up @@ -123,8 +118,8 @@ def from_response(
if mode == Mode.ANTHROPIC_JSON:
return cls.parse_anthropic_json(completion, validation_context, strict)

if mode == Mode.VERTEXAI_TOOLS:
return cls.parse_vertexai_tools(completion, validation_context, strict)
if mode in {Mode.VERTEXAI_TOOLS, Mode.GEMINI_TOOLS}:
return cls.parse_vertexai_tools(completion, validation_context)

if mode == Mode.VERTEXAI_JSON:
return cls.parse_vertexai_json(completion, validation_context, strict)
Expand All @@ -135,6 +130,9 @@ def from_response(
if mode == Mode.GEMINI_JSON:
return cls.parse_gemini_json(completion, validation_context, strict)

if mode == Mode.GEMINI_TOOLS:
return cls.parse_gemini_tools(completion, validation_context, strict)

if mode == Mode.COHERE_JSON_SCHEMA:
return cls.parse_cohere_json_schema(completion, validation_context, strict)

Expand Down Expand Up @@ -254,14 +252,13 @@ def parse_vertexai_tools(
cls: type[BaseModel],
completion: ChatCompletion,
validation_context: Optional[dict[str, Any]] = None,
strict: Optional[bool] = None,
) -> BaseModel:
strict = False
tool_call = completion.candidates[0].content.parts[0].function_call.args # type: ignore
model = {}
for field in tool_call: # type: ignore
model[field] = tool_call[field]
return cls.model_validate(model, context=validation_context, strict=strict)
# We enable strict=False because the conversion from protobuf -> dict often results in types like ints being cast to floats, as a result in order for model.validate to work we need to disable strict mode.
return cls.model_validate(model, context=validation_context, strict=False)

@classmethod
def parse_vertexai_json(
Expand Down
1 change: 1 addition & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Mode(enum.Enum):
VERTEXAI_TOOLS = "vertexai_tools"
VERTEXAI_JSON = "vertexai_json"
GEMINI_JSON = "gemini_json"
GEMINI_TOOLS = "gemini_tools"
COHERE_JSON_SCHEMA = "json_object"
TOOLS_STRICT = "tools_strict"

Expand Down
45 changes: 18 additions & 27 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from instructor.mode import Mode

from .utils import transform_to_gemini_prompt

logger = logging.getLogger("instructor")

Expand Down Expand Up @@ -432,6 +431,9 @@ def handle_response_model(
assert (
"model" not in new_kwargs
), "Gemini `model` must be set while patching the client, not passed as a parameter to the create method"

from .utils import update_gemini_kwargs

message = dedent(
f"""
As a genius expert, your task is to understand the content and provide
Expand Down Expand Up @@ -461,34 +463,23 @@ def handle_response_model(
"generation_config", {}
) | {"response_mime_type": "application/json"}

map_openai_args_to_gemini = {
"max_tokens": "max_output_tokens",
"temperature": "temperature",
"n": "candidate_count",
"top_p": "top_p",
"stop": "stop_sequences",
}

# update gemini config if any params are set
for k, v in map_openai_args_to_gemini.items():
val = new_kwargs.pop(k, None)
if val == None:
continue
new_kwargs["generation_config"][v] = val

# gemini has a different prompt format and params from other providers
new_kwargs["contents"] = transform_to_gemini_prompt(
new_kwargs.pop("messages")
)
new_kwargs = update_gemini_kwargs(new_kwargs)

# minimize gemini safety related errors - model is highly prone to false alarms
from google.generativeai.types import HarmCategory, HarmBlockThreshold

new_kwargs["safety_settings"] = new_kwargs.get("safety_settings", {}) | {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
elif mode == Mode.GEMINI_TOOLS:
assert (
"model" not in new_kwargs
), "Gemini `model` must be set while patching the client, not passed as a parameter to the create method"
from .utils import update_gemini_kwargs

new_kwargs["tools"] = [response_model.gemini_schema]
new_kwargs["tool_config"] = {
"function_calling_config": {
"mode": "ANY",
"allowed_function_names": [response_model.__name__],
},
}

new_kwargs = update_gemini_kwargs(new_kwargs)
elif mode == Mode.VERTEXAI_TOOLS:
from instructor.client_vertexai import vertexai_process_response

Expand Down
20 changes: 20 additions & 0 deletions instructor/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ def reask_messages(response: ChatCompletion, mode: Mode, exception: Exception):
if mode == Mode.COHERE_TOOLS or mode == Mode.COHERE_JSON_SCHEMA:
yield f"Correct the following JSON response, based on the errors given below:\n\nJSON:\n{response.text}\n\nExceptions:\n{exception}"
return
if mode == Mode.GEMINI_TOOLS:
from google.ai import generativelanguage as glm

yield {
"role": "function",
"parts": [
glm.Part(
function_response=glm.FunctionResponse(
name=response.parts[0].function_call.name,
response={"error": f"Validation Error(s) found:\n{exception}"},
)
),
],
}
yield {
"role": "user",
"parts": [f"Recall the function arguments correctly and fix the errors"],
}
return
if mode == Mode.GEMINI_JSON:
yield {
"role": "user",
Expand Down Expand Up @@ -173,6 +192,7 @@ def retry_sync(
logger.debug(f"Error response: {response}")
if mode in {
Mode.GEMINI_JSON,
Mode.GEMINI_TOOLS,
Mode.VERTEXAI_TOOLS,
Mode.VERTEXAI_JSON,
}:
Expand Down
Loading

0 comments on commit b96e9a3

Please sign in to comment.