Skip to content

Commit

Permalink
Add Fireworks Client (#1073)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Chua <[email protected]>
Co-authored-by: gabriel chua <[email protected]>
Co-authored-by: Jason Liu <[email protected]>
  • Loading branch information
4 people authored Oct 17, 2024
1 parent 747780b commit 4b63d4e
Show file tree
Hide file tree
Showing 18 changed files with 803 additions and 245 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- name: Generate coverage report
if: matrix.python-version == '3.11'
run: |
poetry run coverage run -m pytest tests/ -k "not docs and not anthropic and not gemini and not cohere and not vertexai"
poetry run coverage run -m pytest tests/ -k "not docs and not anthropic and not gemini and not cohere and not vertexai and not fireworks"
poetry run coverage report
poetry run coverage html
env:
Expand Down
36 changes: 36 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,42 @@ print(resp)
#> name='Jason' age=25
```

### Using Fireworks

For those who want to use the Fireworks models, you can use the `from_fireworks` method to patch the client. You can see their list of models [here](https://fireworks.ai/models).

```python
from fireworks.client import Fireworks
import instructor
from pydantic import BaseModel
import os

client = Fireworks(
api_key=os.environ.get("FIREWORKS_API_KEY"),
)
client = instructor.from_fireworks(client)


class User(BaseModel):
name: str
age: int


resp = client.chat.completions.create(
model="accounts/fireworks/models/llama-v3p2-1b-instruct",
response_model=User,
messages=[
{
"role": "user",
"content": "Extract Jason is 25 years old.",
}
],
)

print(resp)
#> name='Jason' age=25
```

## Correct Typing

This was the dream of instructor but due to the patching of openai, it wasnt possible for me to get typing to work well. Now, with the new client, we can get typing to work well! We've also added a few `create_*` methods to make it easier to create iterables and partials, and to access the original completion.
Expand Down
5 changes: 5 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@

__all__ += ["from_gemini"]

if importlib.util.find_spec("fireworks") is not None:
from .client_fireworks import from_fireworks

__all__ += ["from_fireworks"]

if importlib.util.find_spec("cerebras") is not None:
from .client_cerebras import from_cerebras

Expand Down
67 changes: 67 additions & 0 deletions instructor/client_fireworks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import Any, overload

import instructor
from instructor.client import AsyncInstructor, Instructor


from fireworks.client import Fireworks, AsyncFireworks # type:ignore


@overload
def from_fireworks(
client: Fireworks,
mode: instructor.Mode = instructor.Mode.FIREWORKS_JSON,
**kwargs: Any,
) -> Instructor: ...


@overload
def from_fireworks(
client: AsyncFireworks,
mode: instructor.Mode = instructor.Mode.FIREWORKS_JSON,
**kwargs: Any,
) -> AsyncInstructor: ...


def from_fireworks(
client: Fireworks | AsyncFireworks,
mode: instructor.Mode = instructor.Mode.FIREWORKS_JSON,
**kwargs: Any,
) -> Instructor | AsyncInstructor:
assert (
mode
in {
instructor.Mode.FIREWORKS_TOOLS,
instructor.Mode.FIREWORKS_JSON,
}
), "Mode must be one of {instructor.Mode.FIREWORKS_TOOLS, instructor.Mode.FIREWORKS_JSON}"

assert isinstance(
client, (AsyncFireworks, Fireworks)
), "Client must be an instance of Fireworks or AsyncFireworks"

if isinstance(client, AsyncFireworks):

async def async_wrapper(*args: Any, **kwargs: Any): # type:ignore
if "stream" in kwargs and kwargs["stream"] is True:
return client.chat.completions.acreate(*args, **kwargs) # type:ignore
return await client.chat.completions.acreate(*args, **kwargs) # type:ignore

return AsyncInstructor(
client=client,
create=instructor.patch(create=async_wrapper, mode=mode),
provider=instructor.Provider.FIREWORKS,
mode=mode,
**kwargs,
)

if isinstance(client, Fireworks):
return Instructor(
client=client,
create=instructor.patch(create=client.chat.completions.create, mode=mode), # type: ignore
provider=instructor.Provider.FIREWORKS,
mode=mode,
**kwargs,
)
12 changes: 8 additions & 4 deletions instructor/dsl/iterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ def extract_json(
Mode.MD_JSON,
Mode.JSON_SCHEMA,
Mode.CEREBRAS_JSON,
Mode.FIREWORKS_JSON,
}:
if json_chunk := chunk.choices[0].delta.content:
yield json_chunk
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT}:
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT, Mode.FIREWORKS_TOOLS}:
if json_chunk := chunk.choices[0].delta.tool_calls:
yield json_chunk[0].function.arguments
if json_chunk[0].function.arguments is not None:
yield json_chunk[0].function.arguments
else:
raise NotImplementedError(
f"Mode {mode} is not supported for MultiTask streaming"
Expand Down Expand Up @@ -139,12 +141,14 @@ async def extract_json_async(
Mode.MD_JSON,
Mode.JSON_SCHEMA,
Mode.CEREBRAS_JSON,
Mode.FIREWORKS_JSON,
}:
if json_chunk := chunk.choices[0].delta.content:
yield json_chunk
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT}:
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT, Mode.FIREWORKS_TOOLS}:
if json_chunk := chunk.choices[0].delta.tool_calls:
yield json_chunk[0].function.arguments
if json_chunk[0].function.arguments is not None:
yield json_chunk[0].function.arguments
else:
raise NotImplementedError(
f"Mode {mode} is not supported for MultiTask streaming"
Expand Down
12 changes: 8 additions & 4 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,14 @@ def extract_json(
Mode.MD_JSON,
Mode.JSON_SCHEMA,
Mode.CEREBRAS_JSON,
Mode.FIREWORKS_JSON,
}:
if json_chunk := chunk.choices[0].delta.content:
yield json_chunk
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT}:
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT, Mode.FIREWORKS_TOOLS}:
if json_chunk := chunk.choices[0].delta.tool_calls:
yield json_chunk[0].function.arguments
if json_chunk[0].function.arguments:
yield json_chunk[0].function.arguments
else:
raise NotImplementedError(
f"Mode {mode} is not supported for MultiTask streaming"
Expand Down Expand Up @@ -210,12 +212,14 @@ async def extract_json_async(
Mode.MD_JSON,
Mode.JSON_SCHEMA,
Mode.CEREBRAS_JSON,
Mode.FIREWORKS_JSON,
}:
if json_chunk := chunk.choices[0].delta.content:
yield json_chunk
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT}:
elif mode in {Mode.TOOLS, Mode.TOOLS_STRICT, Mode.FIREWORKS_TOOLS}:
if json_chunk := chunk.choices[0].delta.tool_calls:
yield json_chunk[0].function.arguments
if json_chunk[0].function.arguments:
yield json_chunk[0].function.arguments
else:
raise NotImplementedError(
f"Mode {mode} is not supported for MultiTask streaming"
Expand Down
2 changes: 2 additions & 0 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def from_response(
Mode.MISTRAL_TOOLS,
Mode.TOOLS_STRICT,
Mode.CEREBRAS_TOOLS,
Mode.FIREWORKS_TOOLS,
}:
return cls.parse_tools(completion, validation_context, strict)

Expand All @@ -157,6 +158,7 @@ def from_response(
Mode.MD_JSON,
Mode.JSON_O1,
Mode.CEREBRAS_JSON,
Mode.FIREWORKS_JSON,
}:
return cls.parse_json(completion, validation_context, strict)

Expand Down
2 changes: 2 additions & 0 deletions instructor/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Mode(enum.Enum):
TOOLS_STRICT = "tools_strict"
CEREBRAS_TOOLS = "cerebras_tools"
CEREBRAS_JSON = "cerebras_json"
FIREWORKS_TOOLS = "fireworks_tools"
FIREWORKS_JSON = "fireworks_json"

@classmethod
def warn_mode_functions_deprecation(cls):
Expand Down
33 changes: 33 additions & 0 deletions instructor/process_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,37 @@ def handle_cohere_modes(new_kwargs: dict[str, Any]) -> tuple[None, dict[str, Any
return None, new_kwargs


def handle_fireworks_tools(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
if "stream" not in new_kwargs:
new_kwargs["stream"] = False
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
return response_model, new_kwargs


def handle_fireworks_json(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
if "stream" not in new_kwargs:
new_kwargs["stream"] = False

new_kwargs["response_format"] = {
"type": "json_object",
"schema": response_model.model_json_schema(),
}
return response_model, new_kwargs


def handle_gemini_json(
response_model: type[T], new_kwargs: dict[str, Any]
) -> tuple[type[T], dict[str, Any]]:
Expand Down Expand Up @@ -657,6 +688,8 @@ def handle_response_model(
Mode.VERTEXAI_JSON: handle_vertexai_json,
Mode.CEREBRAS_JSON: handle_cerebras_json,
Mode.CEREBRAS_TOOLS: handle_cerebras_tools,
Mode.FIREWORKS_JSON: handle_fireworks_json,
Mode.FIREWORKS_TOOLS: handle_fireworks_tools,
}

if mode in mode_handlers:
Expand Down
37 changes: 37 additions & 0 deletions instructor/reask.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,41 @@ def reask_default(
return kwargs


def reask_fireworks_tools(kwargs: dict[str, Any], response: Any, exception: Exception):
kwargs = kwargs.copy()
reask_msgs = [dump_message(response.choices[0].message)]
for tool_call in response.choices[0].message.tool_calls:
reask_msgs.append(
{
"role": "tool", # type: ignore
"tool_call_id": tool_call.id,
"name": tool_call.function.name,
"content": (
f"Validation Error found:\n{exception}\nRecall the function correctly, fix the errors"
),
}
)
kwargs["messages"].extend(reask_msgs)
return kwargs


def reask_fireworks_json(
kwargs: dict[str, Any],
response: Any,
exception: Exception,
):
kwargs = kwargs.copy()
reask_msgs = [dump_message(response.choices[0].message)]
reask_msgs.append(
{
"role": "user",
"content": f"Correct your JSON ONLY RESPONSE, based on the following errors:\n{exception}",
}
)
kwargs["messages"].extend(reask_msgs)
return kwargs


def handle_reask_kwargs(
kwargs: dict[str, Any],
mode: Mode,
Expand All @@ -284,6 +319,8 @@ def handle_reask_kwargs(
Mode.TOOLS_STRICT: reask_tools,
Mode.CEREBRAS_TOOLS: reask_cerebras_tools,
Mode.MD_JSON: reask_md_json,
Mode.FIREWORKS_TOOLS: reask_fireworks_tools,
Mode.FIREWORKS_JSON: reask_fireworks_json,
}
reask_function = functions.get(mode, reask_default)
return reask_function(kwargs=kwargs, response=response, exception=exception)
3 changes: 3 additions & 0 deletions instructor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Provider(Enum):
GEMINI = "gemini"
DATABRICKS = "databricks"
CEREBRAS = "cerebras"
FIREWORKS = "fireworks"
UNKNOWN = "unknown"


Expand All @@ -61,6 +62,8 @@ def get_provider(base_url: str) -> Provider:
return Provider.ANTHROPIC
elif "cerebras" in str(base_url):
return Provider.CEREBRAS
elif "fireworks" in str(base_url):
return Provider.FIREWORKS
elif "groq" in str(base_url):
return Provider.GROQ
elif "openai" in str(base_url):
Expand Down
Loading

0 comments on commit 4b63d4e

Please sign in to comment.