diff --git a/docs/components/embedders/models/azure_openai.mdx b/docs/components/embedders/models/azure_openai.mdx index 42e22854d0..fda2d32e8a 100644 --- a/docs/components/embedders/models/azure_openai.mdx +++ b/docs/components/embedders/models/azure_openai.mdx @@ -21,11 +21,14 @@ config = { "provider": "azure_openai", "config": { "model": "text-embedding-3-large" - "azure_kwargs" : { - "api_version" : "", - "azure_deployment" : "", - "azure_endpoint" : "", - "api_key": "" + "azure_kwargs": { + "api_version": "", + "azure_deployment": "", + "azure_endpoint": "", + "api_key": "", + "default_headers": { + "CustomHeader": "your-custom-header", + } } } } diff --git a/docs/components/llms/models/azure_openai.mdx b/docs/components/llms/models/azure_openai.mdx index ddb99322b8..ea32b67304 100644 --- a/docs/components/llms/models/azure_openai.mdx +++ b/docs/components/llms/models/azure_openai.mdx @@ -22,11 +22,14 @@ config = { "model": "your-deployment-name", "temperature": 0.1, "max_tokens": 2000, - "azure_kwargs" : { - "azure_deployment" : "", - "api_version" : "", - "azure_endpoint" : "", - "api_key" : "" + "azure_kwargs": { + "azure_deployment": "", + "api_version": "", + "azure_endpoint": "", + "api_key": "", + "default_headers": { + "CustomHeader": "your-custom-header", + } } } } @@ -54,11 +57,14 @@ config = { "model": "your-deployment-name", "temperature": 0.1, "max_tokens": 2000, - "azure_kwargs" : { - "azure_deployment" : "", - "api_version" : "", - "azure_endpoint" : "", - "api_key" : "" + "azure_kwargs": { + "azure_deployment": "", + "api_version": "", + "azure_endpoint": "", + "api_key": "", + "default_headers": { + "CustomHeader": "your-custom-header", + } } } } diff --git a/mem0/configs/base.py b/mem0/configs/base.py index 55e09f2757..c9293c2503 100644 --- a/mem0/configs/base.py +++ b/mem0/configs/base.py @@ -63,6 +63,7 @@ class AzureConfig(BaseModel): azure_deployment (str): The name of the Azure deployment. azure_endpoint (str): The endpoint URL for the Azure service. api_version (str): The version of the Azure API being used. + default_headers (Dict[str, str]): Headers to include in requests to the Azure API. """ api_key: str = Field( @@ -72,3 +73,4 @@ class AzureConfig(BaseModel): azure_deployment: str = Field(description="The name of the Azure deployment.", default=None) azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None) api_version: str = Field(description="The version of the Azure API being used.", default=None) + default_headers: Optional[Dict[str, str]] = Field(description="Headers to include in requests to the Azure API.", default=None) diff --git a/mem0/embeddings/azure_openai.py b/mem0/embeddings/azure_openai.py index d25cc00e45..4b33b1a2f5 100644 --- a/mem0/embeddings/azure_openai.py +++ b/mem0/embeddings/azure_openai.py @@ -15,6 +15,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("EMBEDDING_AZURE_DEPLOYMENT") azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("EMBEDDING_AZURE_ENDPOINT") api_version = self.config.azure_kwargs.api_version or os.getenv("EMBEDDING_AZURE_API_VERSION") + default_headers = self.config.azure_kwargs.default_headers self.client = AzureOpenAI( azure_deployment=azure_deployment, @@ -22,6 +23,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None): api_version=api_version, api_key=api_key, http_client=self.config.http_client, + default_headers=default_headers, ) def embed(self, text): diff --git a/mem0/llms/azure_openai.py b/mem0/llms/azure_openai.py index f1fe6863a7..3400b38236 100644 --- a/mem0/llms/azure_openai.py +++ b/mem0/llms/azure_openai.py @@ -20,6 +20,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): azure_deployment = self.config.azure_kwargs.azure_deployment or os.getenv("LLM_AZURE_DEPLOYMENT") azure_endpoint = self.config.azure_kwargs.azure_endpoint or os.getenv("LLM_AZURE_ENDPOINT") api_version = self.config.azure_kwargs.api_version or os.getenv("LLM_AZURE_API_VERSION") + default_headers = self.config.azure_kwargs.default_headers self.client = AzureOpenAI( azure_deployment=azure_deployment, @@ -27,6 +28,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): api_version=api_version, api_key=api_key, http_client=self.config.http_client, + default_headers=default_headers, ) def _parse_response(self, response, tools): diff --git a/mem0/llms/azure_openai_structured.py b/mem0/llms/azure_openai_structured.py index 729523d85d..2d7103f614 100644 --- a/mem0/llms/azure_openai_structured.py +++ b/mem0/llms/azure_openai_structured.py @@ -20,14 +20,16 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): azure_deployment = os.getenv("LLM_AZURE_DEPLOYMENT") or self.config.azure_kwargs.azure_deployment azure_endpoint = os.getenv("LLM_AZURE_ENDPOINT") or self.config.azure_kwargs.azure_endpoint api_version = os.getenv("LLM_AZURE_API_VERSION") or self.config.azure_kwargs.api_version - # Can display a warning if API version is of model and api-version + default_headers = self.config.azure_kwargs.default_headers + # Can display a warning if API version is of model and api-version self.client = AzureOpenAI( azure_deployment=azure_deployment, azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, http_client=self.config.http_client, + default_headers=default_headers, ) def _parse_response(self, response, tools): diff --git a/tests/embeddings/test_azure_openai_embeddings.py b/tests/embeddings/test_azure_openai_embeddings.py index 3425ea48a2..f674dc2a45 100644 --- a/tests/embeddings/test_azure_openai_embeddings.py +++ b/tests/embeddings/test_azure_openai_embeddings.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import Mock, patch -from mem0.embeddings.azure_openai import AzureOpenAIEmbedding + +import pytest + from mem0.configs.embeddings.base import BaseEmbedderConfig +from mem0.embeddings.azure_openai import AzureOpenAIEmbedding @pytest.fixture @@ -29,18 +31,26 @@ def test_embed_text(mock_openai_client): assert embedding == [0.1, 0.2, 0.3] -def test_embed_text_with_newlines(mock_openai_client): - config = BaseEmbedderConfig(model="text-embedding-ada-002") - embedder = AzureOpenAIEmbedding(config) - - mock_embedding_response = Mock() - mock_embedding_response.data = [Mock(embedding=[0.4, 0.5, 0.6])] - mock_openai_client.embeddings.create.return_value = mock_embedding_response - - text = "Hello,\nthis is a test\nwith newlines." - embedding = embedder.embed(text) - - mock_openai_client.embeddings.create.assert_called_once_with( - input=["Hello, this is a test with newlines."], model="text-embedding-ada-002" +@pytest.mark.parametrize( + "default_headers, expected_header", + [ + (None, None), + ({"Test": "test_value"}, "test_value"), + ({}, None) + ], +) +def test_embed_text_with_default_headers(default_headers, expected_header): + config = BaseEmbedderConfig( + model="text-embedding-ada-002", + azure_kwargs={ + "api_key": "test", + "api_version": "test_version", + "azure_endpoint": "test_endpoint", + "azuer_deployment": "test_deployment", + "default_headers": default_headers + } ) - assert embedding == [0.4, 0.5, 0.6] + embedder = AzureOpenAIEmbedding(config) + assert embedder.client.api_key == "test" + assert embedder.client._api_version == "test_version" + assert embedder.client.default_headers.get("Test") == expected_header diff --git a/tests/llms/test_azure_openai.py b/tests/llms/test_azure_openai.py index e54d244fbd..42b3ba37c4 100644 --- a/tests/llms/test_azure_openai.py +++ b/tests/llms/test_azure_openai.py @@ -92,10 +92,17 @@ def test_generate_response_with_tools(mock_openai_client): assert response["tool_calls"][0]["arguments"] == {"data": "Today is a sunny day."} -def test_generate_with_http_proxies(): +@pytest.mark.parametrize( + "default_headers", + [None, {"Firstkey": "FirstVal", "SecondKey": "SecondVal"}], +) +def test_generate_with_http_proxies(default_headers): mock_http_client = Mock(spec=httpx.Client) mock_http_client_instance = Mock(spec=httpx.Client) mock_http_client.return_value = mock_http_client_instance + azure_kwargs = {"api_key": "test"} + if default_headers: + azure_kwargs["default_headers"] = default_headers with ( patch("mem0.llms.azure_openai.AzureOpenAI") as mock_azure_openai, @@ -108,7 +115,7 @@ def test_generate_with_http_proxies(): top_p=TOP_P, api_key="test", http_client_proxies="http://testproxy.mem0.net:8000", - azure_kwargs={"api_key": "test"}, + azure_kwargs=azure_kwargs, ) _ = AzureOpenAILLM(config) @@ -119,5 +126,6 @@ def test_generate_with_http_proxies(): azure_deployment=None, azure_endpoint=None, api_version=None, + default_headers=default_headers, ) mock_http_client.assert_called_once_with(proxies="http://testproxy.mem0.net:8000")