-
Notifications
You must be signed in to change notification settings - Fork 865
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add sambanova provider * add documentation * update comments in sambanova provider for review * fix client constructor to pass in the entire config * fix linting error
- Loading branch information
1 parent
7af3f48
commit 7eecd6d
Showing
5 changed files
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,6 @@ FIREWORKS_API_KEY= | |
|
||
# Together AI | ||
TOGETHER_API_KEY= | ||
|
||
# Sambanova | ||
SAMBANOVA_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
from aisuite.provider import Provider | ||
from openai import OpenAI | ||
|
||
|
||
class SambanovaProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the SambaNova provider with the given configuration. | ||
Pass the entire configuration dictionary to the OpenAI client constructor. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY")) | ||
if not config["api_key"]: | ||
raise ValueError( | ||
"Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable." | ||
) | ||
|
||
config["base_url"] = "https://api.sambanova.ai/v1/" | ||
# Pass the entire config to the OpenAI client constructor | ||
self.client = OpenAI(**config) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Any exception raised by Sambanova will be returned to the caller. | ||
# Maybe we should catch them and raise a custom LLMError. | ||
return self.client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
**kwargs # Pass any additional arguments to the Sambanova API | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Sambanova | ||
|
||
To use Sambanova with `aisuite`, you’ll need a [Sambanova Cloud](https://cloud.sambanova.ai/) account. After logging in, go to the [API](https://cloud.sambanova.ai/apis) section and generate a new key. Once you have your key, add it to your environment as follows: | ||
|
||
```shell | ||
export SAMBANOVA_API_KEY="your-sambanova-api-key" | ||
``` | ||
|
||
## Create a Chat Completion | ||
|
||
Install the `openai` Python client: | ||
|
||
Example with pip: | ||
```shell | ||
pip install openai | ||
``` | ||
|
||
Example with poetry: | ||
```shell | ||
poetry add openai | ||
``` | ||
|
||
In your code: | ||
```python | ||
import aisuite as ai | ||
client = ai.Client() | ||
|
||
provider = "sambanova" | ||
model_id = "Meta-Llama-3.1-405B-Instruct" | ||
|
||
messages = [ | ||
{"role": "system", "content": "You are a helpful assistant."}, | ||
{"role": "user", "content": "What’s the weather like in San Francisco?"}, | ||
] | ||
|
||
response = client.chat.completions.create( | ||
model=f"{provider}:{model_id}", | ||
messages=messages, | ||
) | ||
|
||
print(response.choices[0].message.content) | ||
``` | ||
|
||
Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from aisuite.providers.sambanova_provider import SambanovaProvider | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def set_api_key_env_var(monkeypatch): | ||
"""Fixture to set environment variables for tests.""" | ||
monkeypatch.setenv("SAMBANOVA_API_KEY", "test-api-key") | ||
|
||
|
||
def test_sambanova_provider(): | ||
"""High-level test that the provider is initialized and chat completions are requested successfully.""" | ||
|
||
user_greeting = "Hello!" | ||
message_history = [{"role": "user", "content": user_greeting}] | ||
selected_model = "our-favorite-model" | ||
chosen_temperature = 0.75 | ||
response_text_content = "mocked-text-response-from-model" | ||
|
||
provider = SambanovaProvider() | ||
mock_response = MagicMock() | ||
mock_response.choices = [MagicMock()] | ||
mock_response.choices[0].message = MagicMock() | ||
mock_response.choices[0].message.content = response_text_content | ||
|
||
with patch.object( | ||
provider.client.chat.completions, | ||
"create", | ||
return_value=mock_response, | ||
) as mock_create: | ||
response = provider.chat_completions_create( | ||
messages=message_history, | ||
model=selected_model, | ||
temperature=chosen_temperature, | ||
) | ||
|
||
mock_create.assert_called_with( | ||
messages=message_history, | ||
model=selected_model, | ||
temperature=chosen_temperature, | ||
) | ||
|
||
assert response.choices[0].message.content == response_text_content |