From e4133c3dbf6f3d6afb9799768a475a55eff90d08 Mon Sep 17 00:00:00 2001 From: William Price <82848178+william-price01@users.noreply.github.com> Date: Wed, 2 Oct 2024 19:28:57 -0600 Subject: [PATCH] Fixed issue with integ tests (#1219) --- griptape/drivers/web_search/exa_web_search_driver.py | 4 ++-- .../drivers/web_search/test_exa_web_search_driver.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/griptape/drivers/web_search/exa_web_search_driver.py b/griptape/drivers/web_search/exa_web_search_driver.py index ca4219eac..c5ef3abe7 100644 --- a/griptape/drivers/web_search/exa_web_search_driver.py +++ b/griptape/drivers/web_search/exa_web_search_driver.py @@ -17,7 +17,7 @@ class ExaWebSearchDriver(BaseWebSearchDriver): api_key: str = field(kw_only=True, default=None) highlights: bool = field(default=False, kw_only=True) - use_auto_prompt: bool = field(default=False, kw_only=True) + use_autoprompt: bool = field(default=False, kw_only=True) params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) _client: Exa = field(default=None, kw_only=True, alias="client") @@ -28,7 +28,7 @@ def client(self) -> Exa: def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]: response = self.client.search_and_contents( highlights=self.highlights, - use_auto_prompt=self.use_auto_prompt, + use_autoprompt=self.use_autoprompt, query=query, num_results=self.results_count, text=True, diff --git a/tests/unit/drivers/web_search/test_exa_web_search_driver.py b/tests/unit/drivers/web_search/test_exa_web_search_driver.py index 66fae8a01..7456fd35e 100644 --- a/tests/unit/drivers/web_search/test_exa_web_search_driver.py +++ b/tests/unit/drivers/web_search/test_exa_web_search_driver.py @@ -17,7 +17,7 @@ def driver(self, mock_exa_client, mocker): mock_response = mocker.Mock() mock_response.results = [self.mock_data(mocker), self.mock_data(mocker)] # Make sure results is iterable mock_exa_client.return_value.search_and_contents.return_value = mock_response - return ExaWebSearchDriver(api_key="test", highlights=True, use_auto_prompt=True) + return ExaWebSearchDriver(api_key="test", highlights=True, use_autoprompt=True) def test_search_returns_results(self, driver, mock_exa_client): results = driver.search("test") @@ -29,16 +29,16 @@ def test_search_returns_results(self, driver, mock_exa_client): assert output[0]["highlights"] == "baz" assert output[0]["text"] == "qux" mock_exa_client.return_value.search_and_contents.assert_called_once_with( - query="test", num_results=5, text=True, highlights=True, use_auto_prompt=True + query="test", num_results=5, text=True, highlights=True, use_autoprompt=True ) def test_search_raises_error(self, driver, mock_exa_client): mock_exa_client.return_value.search_and_contents.side_effect = Exception("test_error") - driver = ExaWebSearchDriver(api_key="test", highlights=True, use_auto_prompt=True) + driver = ExaWebSearchDriver(api_key="test", highlights=True, use_autoprompt=True) with pytest.raises(Exception, match="test_error"): driver.search("test") mock_exa_client.return_value.search_and_contents.assert_called_once_with( - query="test", num_results=5, text=True, highlights=True, use_auto_prompt=True + query="test", num_results=5, text=True, highlights=True, use_autoprompt=True ) def test_search_with_params(self, driver, mock_exa_client): @@ -50,7 +50,7 @@ def test_search_with_params(self, driver, mock_exa_client): num_results=5, text=True, highlights=True, - use_auto_prompt=True, + use_autoprompt=True, custom_param="value", additional_param="extra", )