Skip to content

Commit

Permalink
FEAT: Pass arguments to http client (#554)
Browse files Browse the repository at this point in the history
Co-authored-by: Raja Sekhar Rao Dheekonda <[email protected]>
Co-authored-by: rdheekonda <[email protected]>
Co-authored-by: Roman Lutz <[email protected]>
Co-authored-by: Volkan Kutal <[email protected]>
Co-authored-by: rlundeen2 <[email protected]>
Co-authored-by: rlundeen2 <[email protected]>
  • Loading branch information
7 people authored Nov 21, 2024
1 parent 9bad63b commit 94dd4ec
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
3 changes: 2 additions & 1 deletion doc/code/targets/7_http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@

# For AOAI the response content is located in the path choices[0].message.content - for other responses this should be in the documentation or you can manually test the output to find the right path
parsing_function = get_http_target_json_response_callback_function(key="choices[0].message.content")
http_prompt_target = HTTPTarget(http_request=raw_http_request, callback_function=parsing_function)
# httpx AsyncClient parameters can be passed as kwargs to HTTPTarget, for example the timeout below
http_prompt_target = HTTPTarget(http_request=raw_http_request, callback_function=parsing_function, timeout=20.0)

# Note, a converter is used to format the prompt to be json safe without new lines/carriage returns, etc
with PromptSendingOrchestrator(
Expand Down
7 changes: 5 additions & 2 deletions pyrit/prompt_target/http_target/http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging
import re
from typing import Callable
from typing import Callable, Optional, Any

from pyrit.models import construct_response_from_request, PromptRequestPiece, PromptRequestResponse
from pyrit.prompt_target import PromptTarget
Expand All @@ -26,6 +26,7 @@ class HTTPTarget(PromptTarget):
use_tls: (bool): whether to use TLS or not. Default is True
callback_function (function): function to parse HTTP response.
These are the customizable functions which determine how to parse the output
client_kwargs: (dict): additional keyword arguments to pass to the HTTP client
"""

def __init__(
Expand All @@ -34,12 +35,14 @@ def __init__(
prompt_regex_string: str = "{PROMPT}",
use_tls: bool = True,
callback_function: Callable = None,
**client_kwargs: Optional[Any],
) -> None:

self.http_request = http_request
self.callback_function = callback_function
self.prompt_regex_string = prompt_regex_string
self.use_tls = use_tls
self.client_kwargs = client_kwargs or {}

async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
"""
Expand All @@ -66,7 +69,7 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
if http_version and "HTTP/2" in http_version:
http2_version = True

async with httpx.AsyncClient(http2=http2_version) as client:
async with httpx.AsyncClient(http2=http2_version, **self.client_kwargs) as client:
response = await client.request(
method=http_method,
url=url,
Expand Down
20 changes: 20 additions & 0 deletions tests/target/test_http_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ async def test_send_prompt_async(mock_request, mock_http_target, mock_http_respo
)


@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_send_prompt_async_client_kwargs(mock_async_client):
# Create client_kwargs to test
client_kwargs = {"timeout": 10, "verify": False}
sample_request = "GET /test HTTP/1.1\nHost: example.com\n\n"
# Create instance of HTTPTarget with client_kwargs
# Use **client_kwargs to pass them as keyword arguments
http_target = HTTPTarget(http_request=sample_request, **client_kwargs)
prompt_request = MagicMock()
prompt_request.request_pieces = [MagicMock(converted_value="")]
mock_response = MagicMock()
mock_response.content = b"Response content"
instance = mock_async_client.return_value.__aenter__.return_value
instance.request.return_value = mock_response
await http_target.send_prompt_async(prompt_request=prompt_request)

mock_async_client.assert_called_with(http2=False, timeout=10, verify=False)


@pytest.mark.asyncio
async def test_send_prompt_async_validation(mock_http_target):
# Create an invalid prompt request (missing request_pieces)
Expand Down

0 comments on commit 94dd4ec

Please sign in to comment.