Skip to content

Commit

Permalink
support replay_proxy in async mode
Browse files Browse the repository at this point in the history
  • Loading branch information
zrquan committed Sep 22, 2024
1 parent bbdb60d commit e179332
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
27 changes: 23 additions & 4 deletions lib/connection/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ssl import SSLError
import threading
import time
from typing import Generator
from typing import Generator, Optional
from urllib.parse import urlparse

import httpx
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(self):
mounts={"all://": transport},
timeout=httpx.Timeout(options["timeout"]),
)
self.replay_session = None

def parse_proxy(self, proxy: str) -> str:
if not proxy:
Expand Down Expand Up @@ -339,8 +340,25 @@ def set_auth(self, type: str, credential: str) -> None:
else:
self.session.auth = HttpxNtlmAuth(user, password)

async def replay_request(self, path: str, proxy: str):
if self.replay_session is None:
transport = httpx.AsyncHTTPTransport(
verify=False,
cert=self._cert,
limits=httpx.Limits(max_connections=options["thread_count"]),
proxy=self.parse_proxy(proxy),
socket_options=self._socket_options,
)
self.replay_session = httpx.AsyncClient(
mounts={"all://": transport},
timeout=httpx.Timeout(options["timeout"]),
)
return await self.request(path, self.replay_session)

# :path: is expected not to start with "/"
async def request(self, path: str) -> AsyncResponse:
async def request(
self, path: str, session: Optional[httpx.AsyncClient] = None
) -> AsyncResponse:
while self.is_rate_exceeded():
await asyncio.sleep(0.1)

Expand All @@ -352,13 +370,14 @@ async def request(self, path: str) -> AsyncResponse:
url = safequote(self._url + path if self._url else path)
parsed_url = urlparse(url)

session = session or self.session
for _ in range(options["max_retries"] + 1):
try:
if self.agents:
self.set_header("user-agent", random.choice(self.agents))

# Use "target" extension to avoid the URL path from being normalized
request = self.session.build_request(
request = session.build_request(
options["http_method"],
url,
headers=self.headers,
Expand All @@ -367,7 +386,7 @@ async def request(self, path: str) -> AsyncResponse:
if p := parsed_url.path:
request.extensions = {"target": p.encode()}

xresponse = await self.session.send(
xresponse = await session.send(
request,
stream=True,
follow_redirects=options["follow_redirects"],
Expand Down
7 changes: 5 additions & 2 deletions lib/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,12 @@ def match_callback(self, response):
if added_to_queue:
interface.new_directories(added_to_queue)

if options["replay_proxy"] and not options["async_mode"]:
if options["replay_proxy"]:
# Replay the request with new proxy
self.requester.request(response.full_path, proxy=options["replay_proxy"])
if options["async_mode"]:
self.loop.create_task(self.requester.replay_request(response.full_path, proxy=options["replay_proxy"]))
else:
self.requester.request(response.full_path, proxy=options["replay_proxy"])

if self.report:
self.results.append(response)
Expand Down
3 changes: 0 additions & 3 deletions lib/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ def parse_options():
)
exit(1)

if opt.async_mode and opt.replay_proxy:
print("WARNING: --replay-proxy doesn't work in asynchronous mode")

return vars(opt)


Expand Down

0 comments on commit e179332

Please sign in to comment.