Skip to content

Commit

Permalink
Add SambaNova Provider (#54)
Browse files Browse the repository at this point in the history
* add sambanova provider

* add documentation

* update comments in sambanova provider for review

* fix client constructor to pass in the entire config

* fix linting error
  • Loading branch information
snova-zoltanc authored Dec 5, 2024
1 parent 7af3f48 commit 7eecd6d
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ FIREWORKS_API_KEY=

# Together AI
TOGETHER_API_KEY=

# Sambanova
SAMBANOVA_API_KEY=
30 changes: 30 additions & 0 deletions aisuite/providers/sambanova_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from aisuite.provider import Provider
from openai import OpenAI


class SambanovaProvider(Provider):
def __init__(self, **config):
"""
Initialize the SambaNova provider with the given configuration.
Pass the entire configuration dictionary to the OpenAI client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("SAMBANOVA_API_KEY"))
if not config["api_key"]:
raise ValueError(
"Sambanova API key is missing. Please provide it in the config or set the SAMBANOVA_API_KEY environment variable."
)

config["base_url"] = "https://api.sambanova.ai/v1/"
# Pass the entire config to the OpenAI client constructor
self.client = OpenAI(**config)

def chat_completions_create(self, model, messages, **kwargs):
# Any exception raised by Sambanova will be returned to the caller.
# Maybe we should catch them and raise a custom LLMError.
return self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the Sambanova API
)
1 change: 1 addition & 0 deletions guides/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Here're the instructions for:
- [Google](google.md)
- [Hugging Face](huggingface.md)
- [OpenAI](openai.md)
- [SambaNova](sambanova.md)

Unless otherwise stated, these guides have not been endorsed by the providers.

Expand Down
44 changes: 44 additions & 0 deletions guides/sambanova.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Sambanova

To use Sambanova with `aisuite`, you’ll need a [Sambanova Cloud](https://cloud.sambanova.ai/) account. After logging in, go to the [API](https://cloud.sambanova.ai/apis) section and generate a new key. Once you have your key, add it to your environment as follows:

```shell
export SAMBANOVA_API_KEY="your-sambanova-api-key"
```

## Create a Chat Completion

Install the `openai` Python client:

Example with pip:
```shell
pip install openai
```

Example with poetry:
```shell
poetry add openai
```

In your code:
```python
import aisuite as ai
client = ai.Client()

provider = "sambanova"
model_id = "Meta-Llama-3.1-405B-Instruct"

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What’s the weather like in San Francisco?"},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you’d like to contribute, please read our [Contributing Guide](CONTRIBUTING.md).
46 changes: 46 additions & 0 deletions tests/providers/test_sambanova_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.providers.sambanova_provider import SambanovaProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("SAMBANOVA_API_KEY", "test-api-key")


def test_sambanova_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = SambanovaProvider()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content

0 comments on commit 7eecd6d

Please sign in to comment.