Skip to content

Commit

Permalink
Showing 6 changed files with 38 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -350,6 +350,7 @@ The following environment variables can be set.
| MAGENTIC_LITELLM_MAX_TOKENS | LiteLLM max number of generated tokens | 1024 |
| MAGENTIC_LITELLM_TEMPERATURE | LiteLLM temperature | 0.5 |
| MAGENTIC_OPENAI_MODEL | OpenAI model | gpt-4 |
| MAGENTIC_OPENAI_API_KEY | OpenAI API key to be used by magentic | sk-... |
| MAGENTIC_OPENAI_API_TYPE | Allowed options: "openai", "azure" | azure |
| MAGENTIC_OPENAI_BASE_URL | Base URL for an OpenAI-compatible API | http://localhost:8080 |
| MAGENTIC_OPENAI_MAX_TOKENS | OpenAI max number of generated tokens | 1024 |
1 change: 1 addition & 0 deletions src/magentic/backend.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ def get_chat_model() -> ChatModel:

return OpenaiChatModel(
model=settings.openai_model,
api_key=settings.openai_api_key,
api_type=settings.openai_api_type,
base_url=settings.openai_base_url,
max_tokens=settings.openai_max_tokens,
18 changes: 16 additions & 2 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -85,6 +85,7 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar


def openai_chatcompletion_create(
api_key: str | None,
api_type: Literal["openai", "azure"],
base_url: str | None,
model: str,
@@ -95,7 +96,9 @@ def openai_chatcompletion_create(
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> Iterator[ChatCompletionChunk]:
client_kwargs: dict[str, Any] = {}
client_kwargs: dict[str, Any] = {
"api_key": api_key,
}
if api_type == "openai" and base_url:
client_kwargs["base_url"] = base_url

@@ -126,6 +129,7 @@ def openai_chatcompletion_create(


async def openai_chatcompletion_acreate(
api_key: str | None,
api_type: Literal["openai", "azure"],
base_url: str | None,
model: str,
@@ -136,7 +140,9 @@ async def openai_chatcompletion_acreate(
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> AsyncIterator[ChatCompletionChunk]:
client_kwargs: dict[str, Any] = {}
client_kwargs: dict[str, Any] = {
"api_key": api_key,
}
if api_type == "openai" and base_url:
client_kwargs["base_url"] = base_url

@@ -176,13 +182,15 @@ def __init__(
self,
model: str,
*,
api_key: str | None = None,
api_type: Literal["openai", "azure"] = "openai",
base_url: str | None = None,
max_tokens: int | None = None,
seed: int | None = None,
temperature: float | None = None,
):
self._model = model
self._api_key = api_key
self._api_type = api_type
self._base_url = base_url
self._max_tokens = max_tokens
@@ -193,6 +201,10 @@ def __init__(
def model(self) -> str:
return self._model

@property
def api_key(self) -> str | None:
return self._api_key

@property
def api_type(self) -> Literal["openai", "azure"]:
return self._api_type
@@ -277,6 +289,7 @@ def complete(

openai_functions = [schema.dict() for schema in function_schemas]
response = openai_chatcompletion_create(
api_key=self.api_key,
api_type=self.api_type,
base_url=self.base_url,
model=self.model,
@@ -405,6 +418,7 @@ async def acomplete(

openai_functions = [schema.dict() for schema in function_schemas]
response = await openai_chatcompletion_acreate(
api_key=self.api_key,
api_type=self.api_type,
base_url=self.base_url,
model=self.model,
1 change: 1 addition & 0 deletions src/magentic/settings.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ class Settings(BaseSettings):
litellm_max_tokens: int | None = None
litellm_temperature: float | None = None
openai_model: str = "gpt-3.5-turbo"
openai_api_key: str | None = None
openai_api_type: Literal["openai", "azure"] = "openai"
openai_base_url: str | None = None
openai_max_tokens: int | None = None
17 changes: 17 additions & 0 deletions tests/chat_model/test_openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
import os

import openai
import pytest

from magentic.chat_model.message import UserMessage
from magentic.chat_model.openai_chat_model import OpenaiChatModel


@pytest.mark.openai
def test_openai_chat_model_api_key(monkeypatch):
openai_api_key = os.environ["OPENAI_API_KEY"]
monkeypatch.delenv("OPENAI_API_KEY")

chat_model = OpenaiChatModel("gpt-3.5-turbo")
with pytest.raises(openai.OpenAIError):
chat_model.complete(messages=[UserMessage("Say hello!")])

chat_model = OpenaiChatModel("gpt-3.5-turbo", api_key=openai_api_key)
message = chat_model.complete(messages=[UserMessage("Say hello!")])
assert isinstance(message.content, str)


@pytest.mark.openai
def test_openai_chat_model_complete_base_url():
chat_model = OpenaiChatModel("gpt-3.5-turbo", base_url="https://api.openai.com/v1")
2 changes: 2 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
def test_backend_openai_chat_model(monkeypatch):
monkeypatch.setenv("MAGENTIC_BACKEND", "openai")
monkeypatch.setenv("MAGENTIC_OPENAI_MODEL", "gpt-4")
monkeypatch.setenv("MAGENTIC_OPENAI_API_KEY", "sk-1234567890")
monkeypatch.setenv("MAGENTIC_OPENAI_API_TYPE", "azure")
monkeypatch.setenv("MAGENTIC_OPENAI_BASE_URL", "http://localhost:8080")
monkeypatch.setenv("MAGENTIC_OPENAI_MAX_TOKENS", "1024")
@@ -19,6 +20,7 @@ def test_backend_openai_chat_model(monkeypatch):
chat_model = get_chat_model()
assert isinstance(chat_model, OpenaiChatModel)
assert chat_model.model == "gpt-4"
assert chat_model.api_key == "sk-1234567890"
assert chat_model.api_type == "azure"
assert chat_model.base_url == "http://localhost:8080"
assert chat_model.max_tokens == 1024

0 comments on commit a9801f8

Please sign in to comment.