From 33c1fa90e46a20c70aacc355969daf527954dfc2 Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Sat, 8 Feb 2025 14:10:03 -0800 Subject: [PATCH 1/8] [SambaNova] Integrate SambaNova Systems to oumi inference --- src/oumi/builders/inference_engines.py | 2 + src/oumi/inference/__init__.py | 4 + .../inference/sambanova_inference_engine.py | 141 ++++++++++++++++++ .../unit/inference/test_generation_params.py | 2 + .../inference/test_inference_engine_init.py | 2 + .../test_sambanova_inference_engine.py | 129 ++++++++++++++++ 6 files changed, 280 insertions(+) create mode 100644 src/oumi/inference/sambanova_inference_engine.py create mode 100644 tests/unit/inference/test_sambanova_inference_engine.py diff --git a/src/oumi/builders/inference_engines.py b/src/oumi/builders/inference_engines.py index 155d35a17..1bc0b74f9 100644 --- a/src/oumi/builders/inference_engines.py +++ b/src/oumi/builders/inference_engines.py @@ -33,6 +33,7 @@ ParasailInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, TogetherInferenceEngine, VLLMInferenceEngine, @@ -51,6 +52,7 @@ InferenceEngineType.PARASAIL: ParasailInferenceEngine, InferenceEngineType.REMOTE_VLLM: RemoteVLLMInferenceEngine, InferenceEngineType.REMOTE: RemoteInferenceEngine, + InferenceEngineType.SAMBANOVA: SambanovaInferenceEngine, InferenceEngineType.SGLANG: SGLangInferenceEngine, InferenceEngineType.TOGETHER: TogetherInferenceEngine, InferenceEngineType.VLLM: VLLMInferenceEngine, diff --git a/src/oumi/inference/__init__.py b/src/oumi/inference/__init__.py index 97c04e772..129d41db2 100644 --- a/src/oumi/inference/__init__.py +++ b/src/oumi/inference/__init__.py @@ -27,9 +27,11 @@ from oumi.inference.parasail_inference_engine import ParasailInferenceEngine from oumi.inference.remote_inference_engine import RemoteInferenceEngine from oumi.inference.remote_vllm_inference_engine import RemoteVLLMInferenceEngine +from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine from oumi.inference.sglang_inference_engine import SGLangInferenceEngine from oumi.inference.together_inference_engine import TogetherInferenceEngine from oumi.inference.vllm_inference_engine import VLLMInferenceEngine +from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine __all__ = [ "AnthropicInferenceEngine", @@ -42,7 +44,9 @@ "ParasailInferenceEngine", "RemoteInferenceEngine", "RemoteVLLMInferenceEngine", + "SambanovaInferenceEngine", "SGLangInferenceEngine", "TogetherInferenceEngine", "VLLMInferenceEngine", + "SambanovaInferenceEngine", ] diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py new file mode 100644 index 000000000..3477e2b50 --- /dev/null +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -0,0 +1,141 @@ +# Copyright 2025 - Oumi +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +from typing_extensions import override + +from oumi.core.configs import GenerationParams, RemoteParams +from oumi.core.types.conversation import Conversation, Message, Role +from oumi.inference.remote_inference_engine import RemoteInferenceEngine +from oumi.utils.logging import logger + +class SambanovaInferenceEngine(RemoteInferenceEngine): + """Engine for running inference against the SambaNova API. + + This class extends RemoteInferenceEngine to provide specific functionality + for interacting with SambaNova's language models via their API. It handles + the conversion of Oumi's Conversation objects to SambaNova's expected input + format, as well as parsing the API responses back into Conversation objects. + """ + + @property + @override + def base_url(self) -> Optional[str]: + """Return the default base URL for the SambaNova API.""" + return "https://api.sambanova.ai/v1/chat/completions" + + @property + @override + def api_key_env_varname(self) -> Optional[str]: + """Return the default environment variable name for the SambaNova API key.""" + return "SAMBANOVA_API_KEY" + + @override + def _convert_conversation_to_api_input( + self, conversation: Conversation, generation_params: GenerationParams + ) -> dict[str, Any]: + """Converts a conversation to a SambaNova API input. + + This method transforms an Oumi Conversation object into a format + suitable for the SambaNova API. It handles the conversion of messages + and generation parameters according to the API specification. + + Args: + conversation: The Oumi Conversation object to convert. + generation_params: Parameters for text generation. + + Returns: + Dict[str, Any]: A dictionary containing the formatted input for the + SambaNova API, including the model, messages, and generation parameters. + """ + # Build request body according to SambaNova API spec + body = { + "model": self._model, + "messages": self._get_list_of_message_json_dicts( + conversation.messages, + group_adjacent_same_role_turns=False + ), + "max_tokens": generation_params.max_new_tokens, + "temperature": generation_params.temperature, + "top_p": generation_params.top_p, + "stream": False, # We don't support streaming yet + } + + if generation_params.stop_strings: + body["stop"] = generation_params.stop_strings + + return body + + @override + def _convert_api_output_to_conversation( + self, response: dict[str, Any], original_conversation: Conversation + ) -> Conversation: + """Converts a SambaNova API response to a conversation. + + Args: + response: The API response to convert. + original_conversation: The original conversation. + + Returns: + Conversation: The conversation including the generated response. + """ + choices = response.get("choices", []) + if not choices: + raise RuntimeError("No choices found in API response") + + message = choices[0].get("message", {}) + if not message: + raise RuntimeError("No message found in API response") + + new_message = Message( + content=message.get("content", ""), + role=Role.ASSISTANT, + ) + + return Conversation( + messages=[*original_conversation.messages, new_message], + metadata=original_conversation.metadata, + conversation_id=original_conversation.conversation_id, + ) + + @override + def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: + """Get headers for the API request. + + Args: + remote_params: Remote server parameters. + + Returns: + Dict[str, str]: Headers for the API request. + """ + headers = { + "Content-Type": "application/json", + } + + api_key = self._get_api_key(remote_params) + if api_key: + headers["X-API-Key"] = api_key + + return headers + + @override + def get_supported_params(self) -> set[str]: + """Returns a set of supported generation parameters for this engine.""" + return { + "max_new_tokens", + "stop_strings", + "temperature", + "top_p", + } \ No newline at end of file diff --git a/tests/unit/inference/test_generation_params.py b/tests/unit/inference/test_generation_params.py index b4b95c8c0..e54d819c7 100644 --- a/tests/unit/inference/test_generation_params.py +++ b/tests/unit/inference/test_generation_params.py @@ -20,6 +20,7 @@ NativeTextInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, VLLMInferenceEngine, ) @@ -32,6 +33,7 @@ AnthropicInferenceEngine, LlamaCppInferenceEngine, NativeTextInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, VLLMInferenceEngine, RemoteVLLMInferenceEngine, diff --git a/tests/unit/inference/test_inference_engine_init.py b/tests/unit/inference/test_inference_engine_init.py index 8d0d2a810..cd7861c3d 100644 --- a/tests/unit/inference/test_inference_engine_init.py +++ b/tests/unit/inference/test_inference_engine_init.py @@ -23,6 +23,7 @@ ParasailInferenceEngine, RemoteInferenceEngine, RemoteVLLMInferenceEngine, + SambanovaInferenceEngine, SGLangInferenceEngine, TogetherInferenceEngine, VLLMInferenceEngine, @@ -63,6 +64,7 @@ GoogleVertexInferenceEngine, OpenAIInferenceEngine, ParasailInferenceEngine, + SambanovaInferenceEngine, TogetherInferenceEngine, ] diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py new file mode 100644 index 000000000..b9ae5f0fe --- /dev/null +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -0,0 +1,129 @@ +from unittest.mock import patch + +import pytest + +from oumi.core.configs import GenerationParams, ModelParams, RemoteParams +from oumi.core.types.conversation import Conversation, Message, Role +from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine + + +@pytest.fixture +def sambanova_engine(): + return SambanovaInferenceEngine( + model_params=ModelParams(model_name="Meta-Llama-3.1-8B-Instruct"), + remote_params=RemoteParams(api_key="test_api_key", api_url=""), + ) + + +def test_convert_conversation_to_api_input(sambanova_engine): + """Test conversion of conversation to SambaNova API input format.""" + conversation = Conversation( + messages=[ + Message(content="System message", role=Role.SYSTEM), + Message(content="User message", role=Role.USER), + Message(content="Assistant message", role=Role.ASSISTANT), + ] + ) + generation_params = GenerationParams( + max_new_tokens=100, + temperature=0.7, + top_p=0.9, + stop_strings=["stop"], + ) + + result = sambanova_engine._convert_conversation_to_api_input( + conversation, generation_params + ) + + # Verify the API input format + assert result["model"] == "Meta-Llama-3.1-8B-Instruct" + assert len(result["messages"]) == 3 + assert result["messages"][0]["content"] == "System message" + assert result["messages"][0]["role"] == "system" + assert result["messages"][1]["content"] == "User message" + assert result["messages"][1]["role"] == "user" + assert result["messages"][2]["content"] == "Assistant message" + assert result["messages"][2]["role"] == "assistant" + assert result["max_tokens"] == 100 + assert result["temperature"] == 0.7 + assert result["top_p"] == 0.9 + assert result["stop"] == ["stop"] + assert result["stream"] is False + + +def test_convert_api_output_to_conversation(sambanova_engine): + """Test conversion of SambaNova API output to conversation.""" + original_conversation = Conversation( + messages=[ + Message(content="User message", role=Role.USER), + ], + metadata={"key": "value"}, + conversation_id="test_id", + ) + api_response = { + "choices": [ + { + "message": { + "content": "Assistant response", + "role": "assistant" + } + } + ] + } + + result = sambanova_engine._convert_api_output_to_conversation( + api_response, original_conversation + ) + + assert len(result.messages) == 2 + assert result.messages[0].content == "User message" + assert result.messages[1].content == "Assistant response" + assert result.messages[1].role == Role.ASSISTANT + assert result.metadata == {"key": "value"} + assert result.conversation_id == "test_id" + + +def test_convert_api_output_to_conversation_error_handling(sambanova_engine): + """Test error handling in API output conversion.""" + original_conversation = Conversation( + messages=[Message(content="User message", role=Role.USER)] + ) + + # Test empty choices + with pytest.raises(RuntimeError, match="No choices found in API response"): + sambanova_engine._convert_api_output_to_conversation( + {"choices": []}, original_conversation + ) + + # Test missing message + with pytest.raises(RuntimeError, match="No message found in API response"): + sambanova_engine._convert_api_output_to_conversation( + {"choices": [{}]}, original_conversation + ) + + +def test_get_request_headers(sambanova_engine): + """Test generation of request headers.""" + remote_params = RemoteParams(api_key="test_api_key", api_url="") + + with patch.object( + SambanovaInferenceEngine, + "_get_api_key", + return_value="test_api_key", + ): + result = sambanova_engine._get_request_headers(remote_params) + + assert result["Content-Type"] == "application/json" + assert result["X-API-Key"] == "test_api_key" + + +def test_get_supported_params(sambanova_engine): + """Test supported generation parameters.""" + supported_params = sambanova_engine.get_supported_params() + + assert supported_params == { + "max_new_tokens", + "stop_strings", + "temperature", + "top_p", + } \ No newline at end of file From f03dbc4786638fa2b85385858495e9ce2e24eb61 Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Sat, 8 Feb 2025 19:16:08 -0800 Subject: [PATCH 2/8] Add documentation. Unit test passed --- docs/user_guides/infer/configuration.md | 1 + docs/user_guides/infer/infer.md | 5 +++-- docs/user_guides/infer/inference_engines.md | 24 +++++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/docs/user_guides/infer/configuration.md b/docs/user_guides/infer/configuration.md index 2a25a84c7..d3e37eaf7 100644 --- a/docs/user_guides/infer/configuration.md +++ b/docs/user_guides/infer/configuration.md @@ -143,6 +143,7 @@ The `engine` parameter specifies which inference engine to use. Available option - `PARASAIL`: Use Parasail API via {py:obj}`~oumi.inference.ParasailInferenceEngine` - `REMOTE_VLLM`: Use external vLLM server via {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` - `REMOTE`: Use any OpenAI-compatible API via {py:obj}`~oumi.inference.RemoteInferenceEngine` +- `SAMBANOVA`: Use SambaNova API via {py:obj}`~oumi.inference.SambanovaInferenceEngine` - `SGLANG`: Use SGLang inference engine via {py:obj}`~oumi.inference.SGLangInferenceEngine` - `TOGETHER`: Use Together API via {py:obj}`~oumi.inference.TogetherInferenceEngine` - `VLLM`: Use vLLM for optimized local inference via {py:obj}`~oumi.inference.VLLMInferenceEngine` diff --git a/docs/user_guides/infer/infer.md b/docs/user_guides/infer/infer.md index 4afc00591..b59d5116a 100644 --- a/docs/user_guides/infer/infer.md +++ b/docs/user_guides/infer/infer.md @@ -17,7 +17,7 @@ Oumi Infer provides a unified interface for running models, whether you're deplo Running models in production environments presents several challenges that Oumi helps address: -- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI) through a single, consistent interface +- **Universal Model Support**: Run models locally (vLLM, LlamaCPP, Transformers) or connect to hosted APIs (Anthropic, Gemini, OpenAI, Together, Parasail, Vertex AI, SambaNova) through a single, consistent interface - **Production-Ready**: Support for batching, retries, error-handling, structured outputs, and high-performance inference via multi-threading to hit a target throughput. - **Scalable Architecture**: Deploy anywhere from a single GPU to distributed systems without code changes - **Unified Configuration**: Control all aspects of model execution through a single config file @@ -94,7 +94,7 @@ Our engines are broken into two categories: local inference vs remote inference. Generally, the answer is simple: if you have sufficient resources to run the model locally without OOMing, then use a local engine like {py:obj}`~oumi.inference.VLLMInferenceEngine`, {py:obj}`~oumi.inference.NativeTextInferenceEngine`, or {py:obj}`~oumi.inference.LlamaCppInferenceEngine`. -If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine` to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. +If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine`, {py:obj}`~oumi.inference.SambanovaInferenceEngine` to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. For a comprehensive list of engines, see the [Supported Engines](#supported-engines) section below. @@ -112,6 +112,7 @@ See {py:obj}`~oumi.inference.NativeTextInferenceEngine` for an example of a loca See {py:obj}`~oumi.inference.AnthropicInferenceEngine` for an example of an inference engine that requires a remote API. +See {py:obj}`~oumi.inference.SambanovaInferenceEngine` for an example of an inference engine that requires a remote API. ```python from oumi.inference import VLLMInferenceEngine from oumi.core.configs import InferenceConfig, ModelParams diff --git a/docs/user_guides/infer/inference_engines.md b/docs/user_guides/infer/inference_engines.md index 4a9d104f1..87052572f 100644 --- a/docs/user_guides/infer/inference_engines.md +++ b/docs/user_guides/infer/inference_engines.md @@ -508,6 +508,30 @@ The DeepSeek models available via this API as of late Jan'2025 are listed below. | DeepSeek-V3 | deepseek-chat | | DeepSeek-R1 (reasoning with CoT) | deepseek-reasoner | +### SambaNova + +[SambaNova](https://www.sambanova.ai/) offers an extreme-speed inference platform on cloud infrastructure with wide variety of models. + +This service is particularly useful when you need to run open source models in a managed environment. + +**Basic Usage** + +```{testcode} +from oumi.inference import SambanovaInferenceEngine +from oumi.core.configs import ModelParams, RemoteParams + +engine = SambanovaInferenceEngine( + model_params=ModelParams( + model_name="Meta-Llama-3.1-405B-Instruct" + ), + remote_params=RemoteParams( + api_key_env_varname="SAMBANOVA_API_KEY", + ) +) +``` + +** Reference ** +- [SambaNova's Documentation](https://docs.sambanova.ai/cloud/docs/get-started/overview) ### Parasail.io From fce0c25d9d0716867e5a0721a8311f4dbef76afe Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Sat, 8 Feb 2025 22:43:27 -0800 Subject: [PATCH 3/8] Corrected authroization header. Added e2e test. Able to run pass the e2e test. Still need to revise notebook tutorial Revise the location of configuration file Precommit auto fix pre-commit auto fixes Add SAMBANOVA literal --- .../core/configs/inference_engine_type.py | 3 + src/oumi/inference/__init__.py | 1 - .../inference/sambanova_inference_engine.py | 12 ++- .../infer/test_sambanova_inference.py | 87 +++++++++++++++++++ .../test_sambanova_inference_engine.py | 13 +-- 5 files changed, 98 insertions(+), 18 deletions(-) create mode 100644 tests/integration/infer/test_sambanova_inference.py diff --git a/src/oumi/core/configs/inference_engine_type.py b/src/oumi/core/configs/inference_engine_type.py index 3e3211960..7fa678b49 100644 --- a/src/oumi/core/configs/inference_engine_type.py +++ b/src/oumi/core/configs/inference_engine_type.py @@ -56,3 +56,6 @@ class InferenceEngineType(str, Enum): OPENAI = "OPENAI" """The inference engine for OpenAI API.""" + + SAMBANOVA = "SAMBANOVA" + """The inference engine for SambaNova API.""" diff --git a/src/oumi/inference/__init__.py b/src/oumi/inference/__init__.py index 129d41db2..3963e0f76 100644 --- a/src/oumi/inference/__init__.py +++ b/src/oumi/inference/__init__.py @@ -31,7 +31,6 @@ from oumi.inference.sglang_inference_engine import SGLangInferenceEngine from oumi.inference.together_inference_engine import TogetherInferenceEngine from oumi.inference.vllm_inference_engine import VLLMInferenceEngine -from oumi.inference.sambanova_inference_engine import SambanovaInferenceEngine __all__ = [ "AnthropicInferenceEngine", diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py index 3477e2b50..b255a4a29 100644 --- a/src/oumi/inference/sambanova_inference_engine.py +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -19,7 +19,7 @@ from oumi.core.configs import GenerationParams, RemoteParams from oumi.core.types.conversation import Conversation, Message, Role from oumi.inference.remote_inference_engine import RemoteInferenceEngine -from oumi.utils.logging import logger + class SambanovaInferenceEngine(RemoteInferenceEngine): """Engine for running inference against the SambaNova API. @@ -64,8 +64,7 @@ def _convert_conversation_to_api_input( body = { "model": self._model, "messages": self._get_list_of_message_json_dicts( - conversation.messages, - group_adjacent_same_role_turns=False + conversation.messages, group_adjacent_same_role_turns=False ), "max_tokens": generation_params.max_new_tokens, "temperature": generation_params.temperature, @@ -126,7 +125,7 @@ def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: api_key = self._get_api_key(remote_params) if api_key: - headers["X-API-Key"] = api_key + headers["Authorization"] = f"Bearer {api_key}" return headers @@ -134,8 +133,7 @@ def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { - "max_new_tokens", - "stop_strings", + "max_tokens", "temperature", "top_p", - } \ No newline at end of file + } diff --git a/tests/integration/infer/test_sambanova_inference.py b/tests/integration/infer/test_sambanova_inference.py new file mode 100644 index 000000000..286914d1c --- /dev/null +++ b/tests/integration/infer/test_sambanova_inference.py @@ -0,0 +1,87 @@ +import os + +from oumi.core.configs import ( + GenerationParams, + InferenceConfig, + ModelParams, + RemoteParams, +) +from oumi.core.types import Conversation, Message, Role +from oumi.inference import SambanovaInferenceEngine + + +def load_config(config_path): + """Load the inference configuration from a YAML file.""" + return InferenceConfig.from_yaml(config_path) + + +def initialize_engine(api_key, model_name): + """Initialize the SambaNova inference engine.""" + return SambanovaInferenceEngine( + model_params=ModelParams(model_name=model_name), + generation_params=GenerationParams( + batch_size=1, + ), + remote_params=RemoteParams( + api_key=api_key, + ), + ) + + +def create_conversation(): + """Create a conversation for inference.""" + return [ + Conversation( + messages=[ + Message( + role=Role.SYSTEM, + content="Answer the question in a couple sentences.", + ), + Message( + role=Role.USER, content="What is the strength of SambaNova Systems?" + ), + ] + ), + ] + + +def perform_inference(engine, conversations, config): + """Perform inference using the SambaNova engine.""" + try: + generations = engine.infer( + input=conversations, + inference_config=config, + ) + return generations + except Exception as e: + print("An error occurred during inference:", str(e)) + return None + + +def main(): + # Set the path to the configuration file + config_path = os.path.join( + os.path.dirname(__file__), "sambanova_infer_tutorial.yaml" + ) + + # Load the configuration + config = load_config(config_path) + + # Initialize the engine + api_key = os.getenv("SAMBANOVA_API_KEY") + model_name = "Meta-Llama-3.1-405B-Instruct" + engine = initialize_engine(api_key, model_name) + + # Create the conversation + conversations = create_conversation() + + # Perform inference + generations = perform_inference(engine, conversations, config) + + # Print the results + if generations: + print(generations) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py index b9ae5f0fe..7081199bb 100644 --- a/tests/unit/inference/test_sambanova_inference_engine.py +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -61,14 +61,7 @@ def test_convert_api_output_to_conversation(sambanova_engine): conversation_id="test_id", ) api_response = { - "choices": [ - { - "message": { - "content": "Assistant response", - "role": "assistant" - } - } - ] + "choices": [{"message": {"content": "Assistant response", "role": "assistant"}}] } result = sambanova_engine._convert_api_output_to_conversation( @@ -120,10 +113,10 @@ def test_get_request_headers(sambanova_engine): def test_get_supported_params(sambanova_engine): """Test supported generation parameters.""" supported_params = sambanova_engine.get_supported_params() - + assert supported_params == { "max_new_tokens", "stop_strings", "temperature", "top_p", - } \ No newline at end of file + } From 5d28927397b6476cfd9511fc45f09f584d68a588 Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Sun, 9 Feb 2025 18:54:46 -0800 Subject: [PATCH 4/8] Add test yaml file. Move the sambanova_infer_test to e2e. Add pytest skip if the API key is not present Move the sambanova_infer_test to e2e. Add pytest skip if the API key is not present precommit auto fix --- tests/e2e/sambanova_infer_tutorial.yaml | 10 ++++++++++ .../infer => e2e}/test_sambanova_inference.py | 20 ++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) create mode 100644 tests/e2e/sambanova_infer_tutorial.yaml rename tests/{integration/infer => e2e}/test_sambanova_inference.py (85%) diff --git a/tests/e2e/sambanova_infer_tutorial.yaml b/tests/e2e/sambanova_infer_tutorial.yaml new file mode 100644 index 000000000..4b3add400 --- /dev/null +++ b/tests/e2e/sambanova_infer_tutorial.yaml @@ -0,0 +1,10 @@ + +model: + model_name: "Meta-Llama-3.1-405B-Instruct" + model_max_length: 512 + torch_dtype_str: "bfloat16" + trust_remote_code: True + +generation: + max_new_tokens: 128 + batch_size: 1 diff --git a/tests/integration/infer/test_sambanova_inference.py b/tests/e2e/test_sambanova_inference.py similarity index 85% rename from tests/integration/infer/test_sambanova_inference.py rename to tests/e2e/test_sambanova_inference.py index 286914d1c..cc60fa3e6 100644 --- a/tests/integration/infer/test_sambanova_inference.py +++ b/tests/e2e/test_sambanova_inference.py @@ -1,4 +1,7 @@ import os +from pathlib import Path + +import pytest from oumi.core.configs import ( GenerationParams, @@ -58,11 +61,14 @@ def perform_inference(engine, conversations, config): return None -def main(): - # Set the path to the configuration file - config_path = os.path.join( - os.path.dirname(__file__), "sambanova_infer_tutorial.yaml" - ) +@pytest.mark.e2e +def test_sambanova_inference(): + if "SAMBANOVA_API_KEY" not in os.environ: + pytest.skip("SAMBANOVA_API_KEY is not set") + + # Set the path to the configuration file using pathlib + config_path = Path(__file__).parent / "sambanova_infer_tutorial.yaml" + print(config_path) # Load the configuration config = load_config(config_path) @@ -81,7 +87,3 @@ def main(): # Print the results if generations: print(generations) - - -if __name__ == "__main__": - main() From 023b8eddd36bddfbc14a00db15852e02a3fd5bbc Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Mon, 10 Feb 2025 22:11:48 -0800 Subject: [PATCH 5/8] Address @xrdaukar's comments --- docs/user_guides/infer/infer.md | 2 +- src/oumi/inference/sambanova_inference_engine.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/user_guides/infer/infer.md b/docs/user_guides/infer/infer.md index b59d5116a..7488a9a31 100644 --- a/docs/user_guides/infer/infer.md +++ b/docs/user_guides/infer/infer.md @@ -94,7 +94,7 @@ Our engines are broken into two categories: local inference vs remote inference. Generally, the answer is simple: if you have sufficient resources to run the model locally without OOMing, then use a local engine like {py:obj}`~oumi.inference.VLLMInferenceEngine`, {py:obj}`~oumi.inference.NativeTextInferenceEngine`, or {py:obj}`~oumi.inference.LlamaCppInferenceEngine`. -If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine`, {py:obj}`~oumi.inference.SambanovaInferenceEngine` to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. +If you don't have enough local compute resources, then the model must be hosted elsewhere. Our remote inference engines assume that your model is hosted behind a remote API. You can use {py:obj}`~oumi.inference.AnthropicInferenceEngine`, {py:obj}`~oumi.inference.GoogleGeminiInferenceEngine`, or {py:obj}`~oumi.inference.GoogleVertexInferenceEngine`, {py:obj}`~oumi.inference.SambanovaInferenceEngine`..., to call their respective APIs. You can also use {py:obj}`~oumi.inference.RemoteInferenceEngine` to call any API implementing the OpenAI Chat API format (including OpenAI's native API), or use {py:obj}`~oumi.inference.SGLangInferenceEngine` or {py:obj}`~oumi.inference.RemoteVLLMInferenceEngine` to call external SGLang or vLLM servers started remotely or locally outside of Oumi. For a comprehensive list of engines, see the [Supported Engines](#supported-engines) section below. diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py index b255a4a29..ec6f64301 100644 --- a/src/oumi/inference/sambanova_inference_engine.py +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -93,6 +93,7 @@ def _convert_api_output_to_conversation( choices = response.get("choices", []) if not choices: raise RuntimeError("No choices found in API response") + assert len(choices) == 1, "Sambanova API only supports one choice per response" message = choices[0].get("message", {}) if not message: From 1d60647995d43b13c06bfaf89aa53ae7355e7e07 Mon Sep 17 00:00:00 2001 From: Chiung-Yi Date: Mon, 10 Feb 2025 22:23:03 -0800 Subject: [PATCH 6/8] Address comments. Fix unit test --- tests/unit/inference/test_sambanova_inference_engine.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py index 7081199bb..6772f5854 100644 --- a/tests/unit/inference/test_sambanova_inference_engine.py +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -107,7 +107,7 @@ def test_get_request_headers(sambanova_engine): result = sambanova_engine._get_request_headers(remote_params) assert result["Content-Type"] == "application/json" - assert result["X-API-Key"] == "test_api_key" + assert result["Authorization"] == "Bearer test_api_key" def test_get_supported_params(sambanova_engine): @@ -115,8 +115,7 @@ def test_get_supported_params(sambanova_engine): supported_params = sambanova_engine.get_supported_params() assert supported_params == { - "max_new_tokens", - "stop_strings", + "max_tokens", "temperature", "top_p", } From 14c4df6fe97c2d3b9b4384e9eba75540e878a07a Mon Sep 17 00:00:00 2001 From: xrdaukar Date: Tue, 11 Feb 2025 09:37:17 -0800 Subject: [PATCH 7/8] update unit tests - param names --- src/oumi/inference/sambanova_inference_engine.py | 3 ++- tests/unit/inference/test_sambanova_inference_engine.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py index ec6f64301..6616b1d77 100644 --- a/src/oumi/inference/sambanova_inference_engine.py +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -134,7 +134,8 @@ def _get_request_headers(self, remote_params: RemoteParams) -> dict[str, str]: def get_supported_params(self) -> set[str]: """Returns a set of supported generation parameters for this engine.""" return { - "max_tokens", + "max_new_tokens", + "stop_strings", "temperature", "top_p", } diff --git a/tests/unit/inference/test_sambanova_inference_engine.py b/tests/unit/inference/test_sambanova_inference_engine.py index 6772f5854..80b3d87d2 100644 --- a/tests/unit/inference/test_sambanova_inference_engine.py +++ b/tests/unit/inference/test_sambanova_inference_engine.py @@ -115,7 +115,8 @@ def test_get_supported_params(sambanova_engine): supported_params = sambanova_engine.get_supported_params() assert supported_params == { - "max_tokens", + "max_new_tokens", + "stop_strings", "temperature", "top_p", } From 15aff47f34b01aba6a0345bc094d54b87a36cbf1 Mon Sep 17 00:00:00 2001 From: xrdaukar Date: Tue, 11 Feb 2025 09:40:42 -0800 Subject: [PATCH 8/8] Replace assert with an exception --- src/oumi/inference/sambanova_inference_engine.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/oumi/inference/sambanova_inference_engine.py b/src/oumi/inference/sambanova_inference_engine.py index 6616b1d77..145aacfa1 100644 --- a/src/oumi/inference/sambanova_inference_engine.py +++ b/src/oumi/inference/sambanova_inference_engine.py @@ -93,7 +93,11 @@ def _convert_api_output_to_conversation( choices = response.get("choices", []) if not choices: raise RuntimeError("No choices found in API response") - assert len(choices) == 1, "Sambanova API only supports one choice per response" + if len(choices) != 1: + raise RuntimeError( + "Sambanova API only supports one choice per response. " + f"Got: {len(choices)}" + ) message = choices[0].get("message", {}) if not message: