Skip to content

Commit

Permalink
Fixed issue with integ tests (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
william-price01 authored Oct 3, 2024
1 parent 1736d53 commit e4133c3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions griptape/drivers/web_search/exa_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class ExaWebSearchDriver(BaseWebSearchDriver):
api_key: str = field(kw_only=True, default=None)
highlights: bool = field(default=False, kw_only=True)
use_auto_prompt: bool = field(default=False, kw_only=True)
use_autoprompt: bool = field(default=False, kw_only=True)
params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_client: Exa = field(default=None, kw_only=True, alias="client")

Expand All @@ -28,7 +28,7 @@ def client(self) -> Exa:
def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]:
response = self.client.search_and_contents(
highlights=self.highlights,
use_auto_prompt=self.use_auto_prompt,
use_autoprompt=self.use_autoprompt,
query=query,
num_results=self.results_count,
text=True,
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/drivers/web_search/test_exa_web_search_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def driver(self, mock_exa_client, mocker):
mock_response = mocker.Mock()
mock_response.results = [self.mock_data(mocker), self.mock_data(mocker)] # Make sure results is iterable
mock_exa_client.return_value.search_and_contents.return_value = mock_response
return ExaWebSearchDriver(api_key="test", highlights=True, use_auto_prompt=True)
return ExaWebSearchDriver(api_key="test", highlights=True, use_autoprompt=True)

def test_search_returns_results(self, driver, mock_exa_client):
results = driver.search("test")
Expand All @@ -29,16 +29,16 @@ def test_search_returns_results(self, driver, mock_exa_client):
assert output[0]["highlights"] == "baz"
assert output[0]["text"] == "qux"
mock_exa_client.return_value.search_and_contents.assert_called_once_with(
query="test", num_results=5, text=True, highlights=True, use_auto_prompt=True
query="test", num_results=5, text=True, highlights=True, use_autoprompt=True
)

def test_search_raises_error(self, driver, mock_exa_client):
mock_exa_client.return_value.search_and_contents.side_effect = Exception("test_error")
driver = ExaWebSearchDriver(api_key="test", highlights=True, use_auto_prompt=True)
driver = ExaWebSearchDriver(api_key="test", highlights=True, use_autoprompt=True)
with pytest.raises(Exception, match="test_error"):
driver.search("test")
mock_exa_client.return_value.search_and_contents.assert_called_once_with(
query="test", num_results=5, text=True, highlights=True, use_auto_prompt=True
query="test", num_results=5, text=True, highlights=True, use_autoprompt=True
)

def test_search_with_params(self, driver, mock_exa_client):
Expand All @@ -50,7 +50,7 @@ def test_search_with_params(self, driver, mock_exa_client):
num_results=5,
text=True,
highlights=True,
use_auto_prompt=True,
use_autoprompt=True,
custom_param="value",
additional_param="extra",
)

0 comments on commit e4133c3

Please sign in to comment.