diff --git a/docs/user_guides/infer/configuration.md b/docs/user_guides/infer/configuration.md index 2a25a84c7..d3e37eaf7 100644 --- a/docs/user_guides/infer/configuration.md +++ b/docs/user_guides/infer/configuration.md @@ -143,6 +143,7 @@ The `engine` parameter specifies which inference engine to use. Available option - `PARASAIL`: Use Parasail API via {py:obj}`~oumi.inference.ParasailInferenceEngine` - `REMOTE_VLLM`: Use external vLLM server via {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` - `REMOTE`: Use any OpenAI-compatible API via {py:obj}`~oumi.inference.RemoteInferenceEngine` +- `SAMBANOVA`: Use SambaNova API via {py:obj}`~oumi.inference.SambanovaInferenceEngine` - `SGLANG`: Use SGLang inference engine via {py:obj}`~oumi.inference.SGLangInferenceEngine` - `TOGETHER`: Use Together API via {py:obj}`~oumi.inference.TogetherInferenceEngine` - `VLLM`: Use vLLM for optimized local inference via {py:obj}`~oumi.inference.VLLMInferenceEngine` diff --git a/docs/user_guides/infer/infer.md b/docs/user_guides/infer/infer.md index 4afc00591..7488a9a31 100644 --- a/docs/user_guides/infer/infer.md +++ b/docs/user_guides/infer/infer.md @@ -17,7 +17,7 @@ Oumi Infer provides a unified interface for running models, whether you're deplo Running models in production environments presents several challenges that Oumi helps address: -- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI) through a single, consistent interface +- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI, SambaNova) through a single, consistent interface - **Production-Ready**: Support for batching, retries, error-handling, structured outputs, and high-performance inference via multi-threading to hit a target throughput. - **Scalable Architecture**: Deploy anywhere from a single GPU to distributed systems without code changes - **Unified Configuration**: Control all aspects of model execution through a single config file @@ -94,7 +94,7 @@ Our engines are broken into two categories: local inference vs remote inference. Generally, the answer is simple: if you have sufficient resources to run the model locally without OOMing, then use a local engine like {py:obj}`~oumi.inference.VLLMInferenceEngine`, {py:obj}`~oumi.inference.NativeTextInferenceEngine`, or {py:obj}`~oumi.inference.LlamaCppInferenceEngine`. -If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine` to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. +If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine`, {py:obj}`~oumi.inference.SambanovaInferenceEngine`..., to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. For a comprehensive list of engines, see the [Supported Engines](#supported-engines) section below. @@ -112,6 +112,7 @@ See {py:obj}`~oumi.inference.NativeTextInferenceEngine` for an example of a loca See {py:obj}`~oumi.inference.AnthropicInferenceEngine` for an example of an inference engine that requires a remote API. +See {py:obj}`~oumi.inference.SambanovaInferenceEngine` for an example of an inference engine that requires a remote API. ```python from oumi.inference import VLLMInferenceEngine from oumi.core.configs import InferenceConfig, ModelParams diff --git a/docs/user_guides/infer/inference_engines.md b/docs/user_guides/infer/inference_engines.md index 4a9d104f1..87052572f 100644 --- a/docs/user_guides/infer/inference_engines.md +++ b/docs/user_guides/infer/inference_engines.md @@ -508,6 +508,30 @@ The DeepSeek models available via this API as of late Jan'2025 are listed below. | DeepSeek-V3 | deepseek-chat | | DeepSeek-R1 (reasoning with CoT) | deepseek-reasoner | +### SambaNova + +[SambaNova](https://www.sambanova.ai/) offers an extreme-speed inference platform on cloud infrastructure with wide variety of models. + +This service is particularly useful when you need to run open source models in a managed environment. + +**Basic Usage** + +```{testcode} +from oumi.inference import SambanovaInferenceEngine +from oumi.core.configs import ModelParams, RemoteParams + +engine = SambanovaInferenceEngine( + model_params=ModelParams( + model_name="Meta-Llama-3.1-405B-Instruct" + ), + remote_params=RemoteParams( + api_key_env_varname="SAMBANOVA_API_KEY", + ) +) +``` + +** Reference ** +- [SambaNova's Documentation](https://docs.sambanova.ai/cloud/docs/get-started/overview) ### Parasail.io diff --git a/src/oumi/builders/inference_engines.py b/src/oumi/builders/inference_engines.py index 155d35a17..1bc0b74f9 100644 --- a/src/oumi/builders/inference_engines.py +++ b/src/oumi/builders/inference_engines.py @@ -33,6 +33,7 @@ ParasailInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, TogetherInferenceEngine, VLLMInferenceEngine, @@ -51,6 +52,7 @@ InferenceEngineType.PARASAIL: ParasailInferenceEngine, InferenceEngineType.REMOTE_VLLM: RemoteVLLMInferenceEngine, InferenceEngineType.REMOTE: RemoteInferenceEngine, + InferenceEngineType.SAMBANOVA: SambanovaInferenceEngine, InferenceEngineType.SGLANG: SGLangInferenceEngine, InferenceEngineType.TOGETHER: TogetherInferenceEngine, InferenceEngineType.VLLM: VLLMInferenceEngine, diff --git a/src/oumi/core/configs/inference_engine_type.py b/src/oumi/core/configs/inference_engine_type.py index 3e3211960..7fa678b49 100644 --- a/src/oumi/core/configs/inference_engine_type.py +++ b/src/oumi/core/configs/inference_engine_type.py @@ -56,3 +56,6 @@ class InferenceEngineType(str, Enum): OPENAI = "OPENAI" """The inference engine for OpenAI API.""" + + SAMBANOVA = "SAMBANOVA" + """The inference engine for SambaNova API.""" diff --git a/src/oumi/inference/__init__.py b/src/oumi/inference/__init__.py index 97c04e772..3963e0f76 100644 --- a/src/oumi/inference/__init__.py +++ b/src/oumi/inference/__init__.py @@ -27,6 +27,7 @@ from oumi.inference.parasail_inference_engine import ParasailInferenceEngine from oumi.inference.remote_inference_engine import RemoteInferenceEngine from oumi.inference.remote_vllm_inference_engine import RemoteVLLMInferenceEngine +from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine from oumi.inference.sglang_inference_engine import SGLangInferenceEngine from oumi.inference.together_inference_engine import TogetherInferenceEngine from oumi.inference.vllm_inference_engine import VLLMInferenceEngine @@ -42,7 +43,9 @@ "ParasailInferenceEngine", "RemoteInferenceEngine", "RemoteVLLMInferenceEngine", + "SambanovaInferenceEngine", "SGLangInferenceEngine", "TogetherInferenceEngine", "VLLMInferenceEngine", + "SambanovaInferenceEngine", ] diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py new file mode 100644 index 000000000..145aacfa1 --- /dev/null +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -0,0 +1,145 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +from typing_extensions import override + +from oumi.core.configs import GenerationParams, RemoteParams +from oumi.core.types.conversation import Conversation, Message, Role +from oumi.inference.remote_inference_engine import RemoteInferenceEngine + + +class SambanovaInferenceEngine(RemoteInferenceEngine): + """Engine for running inference against the SambaNova API. + + This class extends RemoteInferenceEngine to provide specific functionality + for interacting with SambaNova's language models via their API. It handles + the conversion of Oumi's Conversation objects to SambaNova's expected input + format, as well as parsing the API responses back into Conversation objects. + """ + + @property + @override + def base_url(self) -> Optional[str]: + """Return the default base URL for the SambaNova API.""" + return "https://api.sambanova.ai/v1/chat/completions" + + @property + @override + def api_key_env_varname(self) -> Optional[str]: + """Return the default environment variable name for the SambaNova API key.""" + return "SAMBANOVA_API_KEY" + + @override + def _convert_conversation_to_api_input( + self, conversation: Conversation, generation_params: GenerationParams + ) -> dict[str, Any]: + """Converts a conversation to a SambaNova API input. + + This method transforms an Oumi Conversation object into a format + suitable for the SambaNova API. It handles the conversion of messages + and generation parameters according to the API specification. + + Args: + conversation: The Oumi Conversation object to convert. + generation_params: Parameters for text generation. + + Returns: + Dict[str, Any]: A dictionary containing the formatted input for the + SambaNova API, including the model, messages, and generation parameters. + """ + # Build request body according to SambaNova API spec + body = { + "model": self._model, + "messages": self._get_list_of_message_json_dicts( + conversation.messages, group_adjacent_same_role_turns=False + ), + "max_tokens": generation_params.max_new_tokens, + "temperature": generation_params.temperature, + "top_p": generation_params.top_p, + "stream": False, # We don't support streaming yet + } + + if generation_params.stop_strings: + body["stop"] = generation_params.stop_strings + + return body + + @override + def _convert_api_output_to_conversation( + self, response: dict[str, Any], original_conversation: Conversation + ) -> Conversation: + """Converts a SambaNova API response to a conversation. + + Args: + response: The API response to convert. + original_conversation: The original conversation. + + Returns: + Conversation: The conversation including the generated response. + """ + choices = response.get("choices", []) + if not choices: + raise RuntimeError("No choices found in API response") + if len(choices) != 1: + raise RuntimeError( + "Sambanova API only supports one choice per response. " + f"Got: {len(choices)}" + ) + + message = choices[0].get("message", {}) + if not message: + raise RuntimeError("No message found in API response") + + new_message = Message( + content=message.get("content", ""), + role=Role.ASSISTANT, + ) + + return Conversation( + messages=[*original_conversation.messages, new_message], + metadata=original_conversation.metadata, + conversation_id=original_conversation.conversation_id, + ) + + @override + def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: + """Get headers for the API request. + + Args: + remote_params: Remote server parameters. + + Returns: + Dict[str, str]: Headers for the API request. + """ + headers = { + "Content-Type": "application/json", + } + + api_key = self._get_api_key(remote_params) + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + return headers + + @override + def get_supported_params(self) -> set[str]: + """Returns a set of supported generation parameters for this engine.""" + return { + "max_new_tokens", + "stop_strings", + "temperature", + "top_p", + } diff --git a/tests/e2e/sambanova_infer_tutorial.yaml b/tests/e2e/sambanova_infer_tutorial.yaml new file mode 100644 index 000000000..4b3add400 --- /dev/null +++ b/tests/e2e/sambanova_infer_tutorial.yaml @@ -0,0 +1,10 @@ + +model: + model_name: "Meta-Llama-3.1-405B-Instruct" + model_max_length: 512 + torch_dtype_str: "bfloat16" + trust_remote_code: True + +generation: + max_new_tokens: 128 + batch_size: 1 diff --git a/tests/e2e/test_sambanova_inference.py b/tests/e2e/test_sambanova_inference.py new file mode 100644 index 000000000..cc60fa3e6 --- /dev/null +++ b/tests/e2e/test_sambanova_inference.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path + +import pytest + +from oumi.core.configs import ( + GenerationParams, + InferenceConfig, + ModelParams, + RemoteParams, +) +from oumi.core.types import Conversation, Message, Role +from oumi.inference import SambanovaInferenceEngine + + +def load_config(config_path): + """Load the inference configuration from a YAML file.""" + return InferenceConfig.from_yaml(config_path) + + +def initialize_engine(api_key, model_name): + """Initialize the SambaNova inference engine.""" + return SambanovaInferenceEngine( + model_params=ModelParams(model_name=model_name), + generation_params=GenerationParams( + batch_size=1, + ), + remote_params=RemoteParams( + api_key=api_key, + ), + ) + + +def create_conversation(): + """Create a conversation for inference.""" + return [ + Conversation( + messages=[ + Message( + role=Role.SYSTEM, + content="Answer the question in a couple sentences.", + ), + Message( + role=Role.USER, content="What is the strength of SambaNova Systems?" + ), + ] + ), + ] + + +def perform_inference(engine, conversations, config): + """Perform inference using the SambaNova engine.""" + try: + generations = engine.infer( + input=conversations, + inference_config=config, + ) + return generations + except Exception as e: + print("An error occurred during inference:", str(e)) + return None + + +@pytest.mark.e2e +def test_sambanova_inference(): + if "SAMBANOVA_API_KEY" not in os.environ: + pytest.skip("SAMBANOVA_API_KEY is not set") + + # Set the path to the configuration file using pathlib + config_path = Path(__file__).parent / "sambanova_infer_tutorial.yaml" + print(config_path) + + # Load the configuration + config = load_config(config_path) + + # Initialize the engine + api_key = os.getenv("SAMBANOVA_API_KEY") + model_name = "Meta-Llama-3.1-405B-Instruct" + engine = initialize_engine(api_key, model_name) + + # Create the conversation + conversations = create_conversation() + + # Perform inference + generations = perform_inference(engine, conversations, config) + + # Print the results + if generations: + print(generations) diff --git a/tests/unit/inference/test_generation_params.py b/tests/unit/inference/test_generation_params.py index b4b95c8c0..e54d819c7 100644 --- a/tests/unit/inference/test_generation_params.py +++ b/tests/unit/inference/test_generation_params.py @@ -20,6 +20,7 @@ NativeTextInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, VLLMInferenceEngine, ) @@ -32,6 +33,7 @@ AnthropicInferenceEngine, LlamaCppInferenceEngine, NativeTextInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, VLLMInferenceEngine, RemoteVLLMInferenceEngine, diff --git a/tests/unit/inference/test_inference_engine_init.py b/tests/unit/inference/test_inference_engine_init.py index 8d0d2a810..cd7861c3d 100644 --- a/tests/unit/inference/test_inference_engine_init.py +++ b/tests/unit/inference/test_inference_engine_init.py @@ -23,6 +23,7 @@ ParasailInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, TogetherInferenceEngine, VLLMInferenceEngine, @@ -63,6 +64,7 @@ GoogleVertexInferenceEngine, OpenAIInferenceEngine, ParasailInferenceEngine, + SambanovaInferenceEngine, TogetherInferenceEngine, ] diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py new file mode 100644 index 000000000..80b3d87d2 --- /dev/null +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -0,0 +1,122 @@ +from unittest.mock import patch + +import pytest + +from oumi.core.configs import GenerationParams, ModelParams, RemoteParams +from oumi.core.types.conversation import Conversation, Message, Role +from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine + + +@pytest.fixture +def sambanova_engine(): + return SambanovaInferenceEngine( + model_params=ModelParams(model_name="Meta-Llama-3.1-8B-Instruct"), + remote_params=RemoteParams(api_key="test_api_key", api_url=""), + ) + + +def test_convert_conversation_to_api_input(sambanova_engine): + """Test conversion of conversation to SambaNova API input format.""" + conversation = Conversation( + messages=[ + Message(content="System message", role=Role.SYSTEM), + Message(content="User message", role=Role.USER), + Message(content="Assistant message", role=Role.ASSISTANT), + ] + ) + generation_params = GenerationParams( + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + stop_strings=["stop"], + ) + + result = sambanova_engine._convert_conversation_to_api_input( + conversation, generation_params + ) + + # Verify the API input format + assert result["model"] == "Meta-Llama-3.1-8B-Instruct" + assert len(result["messages"]) == 3 + assert result["messages"][0]["content"] == "System message" + assert result["messages"][0]["role"] == "system" + assert result["messages"][1]["content"] == "User message" + assert result["messages"][1]["role"] == "user" + assert result["messages"][2]["content"] == "Assistant message" + assert result["messages"][2]["role"] == "assistant" + assert result["max_tokens"] == 100 + assert result["temperature"] == 0.7 + assert result["top_p"] == 0.9 + assert result["stop"] == ["stop"] + assert result["stream"] is False + + +def test_convert_api_output_to_conversation(sambanova_engine): + """Test conversion of SambaNova API output to conversation.""" + original_conversation = Conversation( + messages=[ + Message(content="User message", role=Role.USER), + ], + metadata={"key": "value"}, + conversation_id="test_id", + ) + api_response = { + "choices": [{"message": {"content": "Assistant response", "role": "assistant"}}] + } + + result = sambanova_engine._convert_api_output_to_conversation( + api_response, original_conversation + ) + + assert len(result.messages) == 2 + assert result.messages[0].content == "User message" + assert result.messages[1].content == "Assistant response" + assert result.messages[1].role == Role.ASSISTANT + assert result.metadata == {"key": "value"} + assert result.conversation_id == "test_id" + + +def test_convert_api_output_to_conversation_error_handling(sambanova_engine): + """Test error handling in API output conversion.""" + original_conversation = Conversation( + messages=[Message(content="User message", role=Role.USER)] + ) + + # Test empty choices + with pytest.raises(RuntimeError, match="No choices found in API response"): + sambanova_engine._convert_api_output_to_conversation( + {"choices": []}, original_conversation + ) + + # Test missing message + with pytest.raises(RuntimeError, match="No message found in API response"): + sambanova_engine._convert_api_output_to_conversation( + {"choices": [{}]}, original_conversation + ) + + +def test_get_request_headers(sambanova_engine): + """Test generation of request headers.""" + remote_params = RemoteParams(api_key="test_api_key", api_url="") + + with patch.object( + SambanovaInferenceEngine, + "_get_api_key", + return_value="test_api_key", + ): + result = sambanova_engine._get_request_headers(remote_params) + + assert result["Content-Type"] == "application/json" + assert result["Authorization"] == "Bearer test_api_key" + + +def test_get_supported_params(sambanova_engine): + """Test supported generation parameters.""" + supported_params = sambanova_engine.get_supported_params() + + assert supported_params == { + "max_new_tokens", + "stop_strings", + "temperature", + "top_p", + }