Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SambaNova] Integrate SambaNova Systems to oumi inference #1415

Merged
merged 9 commits into from
Feb 11, 2025
1 change: 1 addition & 0 deletions docs/user_guides/infer/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ The `engine` parameter specifies which inference engine to use. Available option
- `PARASAIL`: Use Parasail API via {py:obj}`~oumi.inference.ParasailInferenceEngine`
- `REMOTE_VLLM`: Use external vLLM server via {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine`
- `REMOTE`: Use any OpenAI-compatible API via {py:obj}`~oumi.inference.RemoteInferenceEngine`
- `SAMBANOVA`: Use SambaNova API via {py:obj}`~oumi.inference.SambanovaInferenceEngine`
- `SGLANG`: Use SGLang inference engine via {py:obj}`~oumi.inference.SGLangInferenceEngine`
- `TOGETHER`: Use Together API via {py:obj}`~oumi.inference.TogetherInferenceEngine`
- `VLLM`: Use vLLM for optimized local inference via {py:obj}`~oumi.inference.VLLMInferenceEngine`
Expand Down
5 changes: 3 additions & 2 deletions docs/user_guides/infer/infer.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Oumi Infer provides a unified interface for running models, whether you're deplo

Running models in production environments presents several challenges that Oumi helps address:

- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI) through a single, consistent interface
- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI, SambaNova) through a single, consistent interface
- **Production-Ready**: Support for batching, retries, error-handling, structured outputs, and high-performance inference via multi-threading to hit a target throughput.
- **Scalable Architecture**: Deploy anywhere from a single GPU to distributed systems without code changes
- **Unified Configuration**: Control all aspects of model execution through a single config file
Expand Down Expand Up @@ -94,7 +94,7 @@ Our engines are broken into two categories: local inference vs remote inference.

Generally, the answer is simple: if you have sufficient resources to run the model locally without OOMing, then use a local engine like {py:obj}`~oumi.inference.VLLMInferenceEngine`, {py:obj}`~oumi.inference.NativeTextInferenceEngine`, or {py:obj}`~oumi.inference.LlamaCppInferenceEngine`.

If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine` to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi.
If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine`, {py:obj}`~oumi.inference.SambanovaInferenceEngine`..., to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi.

For a comprehensive list of engines, see the [Supported Engines](#supported-engines) section below.

Expand All @@ -112,6 +112,7 @@ See {py:obj}`~oumi.inference.NativeTextInferenceEngine` for an example of a loca

See {py:obj}`~oumi.inference.AnthropicInferenceEngine` for an example of an inference engine that requires a remote API.

See {py:obj}`~oumi.inference.SambanovaInferenceEngine` for an example of an inference engine that requires a remote API.
```python
from oumi.inference import VLLMInferenceEngine
from oumi.core.configs import InferenceConfig, ModelParams
Expand Down
24 changes: 24 additions & 0 deletions docs/user_guides/infer/inference_engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,30 @@ The DeepSeek models available via this API as of late Jan'2025 are listed below.
| DeepSeek-V3 | deepseek-chat |
| DeepSeek-R1 (reasoning with CoT) | deepseek-reasoner |

### SambaNova

[SambaNova](https://www.sambanova.ai/) offers an extreme-speed inference platform on cloud infrastructure with wide variety of models.

This service is particularly useful when you need to run open source models in a managed environment.

**Basic Usage**

```{testcode}
from oumi.inference import SambanovaInferenceEngine
from oumi.core.configs import ModelParams, RemoteParams

engine = SambanovaInferenceEngine(
model_params=ModelParams(
model_name="Meta-Llama-3.1-405B-Instruct"
),
remote_params=RemoteParams(
api_key_env_varname="SAMBANOVA_API_KEY",
)
)
```

** Reference **
- [SambaNova's Documentation](https://docs.sambanova.ai/cloud/docs/get-started/overview)

### Parasail.io

Expand Down
2 changes: 2 additions & 0 deletions src/oumi/builders/inference_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ParasailInferenceEngine,
RemoteInferenceEngine,
RemoteVLLMInferenceEngine,
SambanovaInferenceEngine,
SGLangInferenceEngine,
TogetherInferenceEngine,
VLLMInferenceEngine,
Expand All @@ -51,6 +52,7 @@
InferenceEngineType.PARASAIL: ParasailInferenceEngine,
InferenceEngineType.REMOTE_VLLM: RemoteVLLMInferenceEngine,
InferenceEngineType.REMOTE: RemoteInferenceEngine,
InferenceEngineType.SAMBANOVA: SambanovaInferenceEngine,
InferenceEngineType.SGLANG: SGLangInferenceEngine,
InferenceEngineType.TOGETHER: TogetherInferenceEngine,
InferenceEngineType.VLLM: VLLMInferenceEngine,
Expand Down
3 changes: 3 additions & 0 deletions src/oumi/core/configs/inference_engine_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ class InferenceEngineType(str, Enum):

OPENAI = "OPENAI"
"""The inference engine for OpenAI API."""

SAMBANOVA = "SAMBANOVA"
"""The inference engine for SambaNova API."""
3 changes: 3 additions & 0 deletions src/oumi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from oumi.inference.parasail_inference_engine import ParasailInferenceEngine
from oumi.inference.remote_inference_engine import RemoteInferenceEngine
from oumi.inference.remote_vllm_inference_engine import RemoteVLLMInferenceEngine
from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine
from oumi.inference.sglang_inference_engine import SGLangInferenceEngine
from oumi.inference.together_inference_engine import TogetherInferenceEngine
from oumi.inference.vllm_inference_engine import VLLMInferenceEngine
Expand All @@ -42,7 +43,9 @@
"ParasailInferenceEngine",
"RemoteInferenceEngine",
"RemoteVLLMInferenceEngine",
"SambanovaInferenceEngine",
"SGLangInferenceEngine",
"TogetherInferenceEngine",
"VLLMInferenceEngine",
"SambanovaInferenceEngine",
]
145 changes: 145 additions & 0 deletions src/oumi/inference/sambanova_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2025 - Oumi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

from typing_extensions import override

from oumi.core.configs import GenerationParams, RemoteParams
from oumi.core.types.conversation import Conversation, Message, Role
from oumi.inference.remote_inference_engine import RemoteInferenceEngine


class SambanovaInferenceEngine(RemoteInferenceEngine):
"""Engine for running inference against the SambaNova API.

This class extends RemoteInferenceEngine to provide specific functionality
for interacting with SambaNova's language models via their API. It handles
the conversion of Oumi's Conversation objects to SambaNova's expected input
format, as well as parsing the API responses back into Conversation objects.
"""

@property
@override
def base_url(self) -> Optional[str]:
"""Return the default base URL for the SambaNova API."""
return "https://api.sambanova.ai/v1/chat/completions"

@property
@override
def api_key_env_varname(self) -> Optional[str]:
"""Return the default environment variable name for the SambaNova API key."""
return "SAMBANOVA_API_KEY"

@override
def _convert_conversation_to_api_input(
self, conversation: Conversation, generation_params: GenerationParams
) -> dict[str, Any]:
"""Converts a conversation to a SambaNova API input.

This method transforms an Oumi Conversation object into a format
suitable for the SambaNova API. It handles the conversion of messages
and generation parameters according to the API specification.

Args:
conversation: The Oumi Conversation object to convert.
generation_params: Parameters for text generation.

Returns:
Dict[str, Any]: A dictionary containing the formatted input for the
SambaNova API, including the model, messages, and generation parameters.
"""
# Build request body according to SambaNova API spec
body = {
"model": self._model,
"messages": self._get_list_of_message_json_dicts(
conversation.messages, group_adjacent_same_role_turns=False
),
"max_tokens": generation_params.max_new_tokens,
"temperature": generation_params.temperature,
"top_p": generation_params.top_p,
"stream": False, # We don't support streaming yet
}

if generation_params.stop_strings:
body["stop"] = generation_params.stop_strings

return body

@override
def _convert_api_output_to_conversation(
self, response: dict[str, Any], original_conversation: Conversation
) -> Conversation:
"""Converts a SambaNova API response to a conversation.

Args:
response: The API response to convert.
original_conversation: The original conversation.

Returns:
Conversation: The conversation including the generated response.
"""
choices = response.get("choices", [])
if not choices:
raise RuntimeError("No choices found in API response")
if len(choices) != 1:
raise RuntimeError(
"Sambanova API only supports one choice per response. "
f"Got: {len(choices)}"
)

message = choices[0].get("message", {})
ctseng777 marked this conversation as resolved.
Show resolved Hide resolved
if not message:
raise RuntimeError("No message found in API response")

new_message = Message(
content=message.get("content", ""),
role=Role.ASSISTANT,
)

return Conversation(
messages=[*original_conversation.messages, new_message],
metadata=original_conversation.metadata,
conversation_id=original_conversation.conversation_id,
)

@override
def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]:
"""Get headers for the API request.

Args:
remote_params: Remote server parameters.

Returns:
Dict[str, str]: Headers for the API request.
"""
headers = {
"Content-Type": "application/json",
}

api_key = self._get_api_key(remote_params)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

return headers

@override
def get_supported_params(self) -> set[str]:
"""Returns a set of supported generation parameters for this engine."""
return {
"max_new_tokens",
"stop_strings",
"temperature",
"top_p",
}
10 changes: 10 additions & 0 deletions tests/e2e/sambanova_infer_tutorial.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

model:
model_name: "Meta-Llama-3.1-405B-Instruct"
model_max_length: 512
torch_dtype_str: "bfloat16"
trust_remote_code: True

generation:
max_new_tokens: 128
batch_size: 1
89 changes: 89 additions & 0 deletions tests/e2e/test_sambanova_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from pathlib import Path

import pytest

from oumi.core.configs import (
GenerationParams,
InferenceConfig,
ModelParams,
RemoteParams,
)
from oumi.core.types import Conversation, Message, Role
from oumi.inference import SambanovaInferenceEngine


def load_config(config_path):
"""Load the inference configuration from a YAML file."""
return InferenceConfig.from_yaml(config_path)


def initialize_engine(api_key, model_name):
"""Initialize the SambaNova inference engine."""
return SambanovaInferenceEngine(
model_params=ModelParams(model_name=model_name),
generation_params=GenerationParams(
batch_size=1,
),
remote_params=RemoteParams(
api_key=api_key,
),
)


def create_conversation():
"""Create a conversation for inference."""
return [
Conversation(
messages=[
Message(
role=Role.SYSTEM,
content="Answer the question in a couple sentences.",
),
Message(
role=Role.USER, content="What is the strength of SambaNova Systems?"
),
]
),
]


def perform_inference(engine, conversations, config):
"""Perform inference using the SambaNova engine."""
try:
generations = engine.infer(
input=conversations,
inference_config=config,
)
return generations
except Exception as e:
print("An error occurred during inference:", str(e))
return None


@pytest.mark.e2e
def test_sambanova_inference():
if "SAMBANOVA_API_KEY" not in os.environ:
pytest.skip("SAMBANOVA_API_KEY is not set")

# Set the path to the configuration file using pathlib
config_path = Path(__file__).parent / "sambanova_infer_tutorial.yaml"
print(config_path)

# Load the configuration
config = load_config(config_path)

# Initialize the engine
api_key = os.getenv("SAMBANOVA_API_KEY")
model_name = "Meta-Llama-3.1-405B-Instruct"
engine = initialize_engine(api_key, model_name)

# Create the conversation
conversations = create_conversation()

# Perform inference
generations = perform_inference(engine, conversations, config)

# Print the results
if generations:
print(generations)
2 changes: 2 additions & 0 deletions tests/unit/inference/test_generation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NativeTextInferenceEngine,
RemoteInferenceEngine,
RemoteVLLMInferenceEngine,
SambanovaInferenceEngine,
SGLangInferenceEngine,
VLLMInferenceEngine,
)
Expand All @@ -32,6 +33,7 @@
AnthropicInferenceEngine,
LlamaCppInferenceEngine,
NativeTextInferenceEngine,
SambanovaInferenceEngine,
SGLangInferenceEngine,
VLLMInferenceEngine,
RemoteVLLMInferenceEngine,
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/inference/test_inference_engine_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ParasailInferenceEngine,
RemoteInferenceEngine,
RemoteVLLMInferenceEngine,
SambanovaInferenceEngine,
SGLangInferenceEngine,
TogetherInferenceEngine,
VLLMInferenceEngine,
Expand Down Expand Up @@ -63,6 +64,7 @@
GoogleVertexInferenceEngine,
OpenAIInferenceEngine,
ParasailInferenceEngine,
SambanovaInferenceEngine,
TogetherInferenceEngine,
]

Expand Down
Loading