Skip to content

Commit

Permalink
add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
zrquan committed Oct 12, 2024
1 parent ca61f77 commit da7e840
Show file tree
Hide file tree
Showing 23 changed files with 251 additions and 164 deletions.
9 changes: 6 additions & 3 deletions lib/connection/dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
#
# Author: Mauro Soria

from __future__ import annotations

from socket import getaddrinfo
from typing import Any

_dns_cache = {}
_dns_cache: dict[tuple[str, int], list[Any]] = {}


def cache_dns(domain, port, addr):
def cache_dns(domain: str, port: int, addr: str) -> None:
_dns_cache[domain, port] = getaddrinfo(addr, port)


def cached_getaddrinfo(*args, **kwargs):
def cached_getaddrinfo(*args: Any, **kwargs: int) -> list[Any]:
"""
Replacement for socket.getaddrinfo, they are the same but this function
does cache the answer to improve the performance
Expand Down
42 changes: 22 additions & 20 deletions lib/connection/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#
# Author: Mauro Soria

from __future__ import annotations

import asyncio
import http.client
import random
Expand All @@ -24,7 +26,7 @@
from ssl import SSLError
import threading
import time
from typing import Generator, Optional
from typing import Any, Generator
from urllib.parse import urlparse

import httpx
Expand Down Expand Up @@ -59,12 +61,12 @@


class BaseRequester:
def __init__(self):
self._url = None
self._proxy_cred = None
def __init__(self) -> None:
self._url: str = ""
self._proxy_cred: str = ""
self._rate = 0
self.headers = CaseInsensitiveDict(options["headers"])
self.agents = []
self.agents: list[str] = []
self.session = None

self._cert = None
Expand Down Expand Up @@ -117,27 +119,27 @@ def set_proxy(self, proxy: str) -> None:
def set_proxy_auth(self, credential: str) -> None:
self._proxy_cred = credential

def is_rate_exceeded(self):
def is_rate_exceeded(self) -> bool:
return self._rate >= options["max_rate"] > 0

def decrease_rate(self):
def decrease_rate(self) -> None:
self._rate -= 1

def increase_rate(self):
def increase_rate(self) -> None:
self._rate += 1
threading.Timer(1, self.decrease_rate).start()

@property
@cached(RATE_UPDATE_DELAY)
def rate(self):
def rate(self) -> int:
return self._rate


class HTTPBearerAuth(AuthBase):
def __init__(self, token):
def __init__(self, token: str) -> None:
self.token = token

def __call__(self, request):
def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

Expand All @@ -160,7 +162,7 @@ def __init__(self):
),
)

def set_auth(self, type, credential):
def set_auth(self, type: str, credential: str) -> None:
if type in ("bearer", "jwt"):
self.session.auth = HTTPBearerAuth(credential)
else:
Expand All @@ -178,7 +180,7 @@ def set_auth(self, type, credential):
self.session.auth = HttpNtlmAuth(user, password)

# :path: is expected not to start with "/"
def request(self, path, proxy=None):
def request(self, path: str, proxy: str | None = None) -> Response:
# Pause if the request rate exceeded the maximum
while self.is_rate_exceeded():
time.sleep(0.1)
Expand Down Expand Up @@ -213,13 +215,13 @@ def request(self, path, proxy=None):
prepped = self.session.prepare_request(request)
prepped.url = url

response = self.session.send(
origin_response = self.session.send(
prepped,
allow_redirects=options["follow_redirects"],
timeout=options["timeout"],
stream=True,
)
response = Response(response)
response = Response(origin_response)

log_msg = f'"{options["http_method"]} {response.url}" {response.status} - {response.length}B'

Expand Down Expand Up @@ -270,13 +272,13 @@ class HTTPXBearerAuth(httpx.Auth):
def __init__(self, token: str) -> None:
self.token = token

def auth_flow(self, request: httpx.Request) -> Generator:
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, None, None]:
request.headers["Authorization"] = f"Bearer {self.token}"
yield request


class ProxyRoatingTransport(httpx.AsyncBaseTransport):
def __init__(self, proxies, **kwargs) -> None:
def __init__(self, proxies: list[str], **kwargs: Any) -> None:
self._transports = [
httpx.AsyncHTTPTransport(proxy=proxy, **kwargs) for proxy in proxies
]
Expand All @@ -287,7 +289,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:


class AsyncRequester(BaseRequester):
def __init__(self):
def __init__(self) -> None:
super().__init__()

tpargs = {
Expand Down Expand Up @@ -340,7 +342,7 @@ def set_auth(self, type: str, credential: str) -> None:
else:
self.session.auth = HttpxNtlmAuth(user, password)

async def replay_request(self, path: str, proxy: str):
async def replay_request(self, path: str, proxy: str) -> AsyncResponse:
if self.replay_session is None:
transport = httpx.AsyncHTTPTransport(
verify=False,
Expand All @@ -357,7 +359,7 @@ async def replay_request(self, path: str, proxy: str):

# :path: is expected not to start with "/"
async def request(
self, path: str, session: Optional[httpx.AsyncClient] = None
self, path: str, session: httpx.AsyncClient | None = None
) -> AsyncResponse:
while self.is_rate_exceeded():
await asyncio.sleep(0.1)
Expand Down
31 changes: 18 additions & 13 deletions lib/connection/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
#
# Author: Mauro Soria

from __future__ import annotations

from typing import Any

import httpx
import requests

from lib.core.settings import (
DEFAULT_ENCODING,
Expand All @@ -29,7 +34,7 @@


class BaseResponse:
def __init__(self, response):
def __init__(self, response: requests.Response | httpx.Response) -> None:
self.url = str(response.url)
self.full_path = parse_path(self.url)
self.path = clean_path(self.full_path)
Expand All @@ -41,23 +46,23 @@ def __init__(self, response):
self.body = b""

@property
def type(self):
if "content-type" in self.headers:
return self.headers.get("content-type").split(";")[0]
def type(self) -> str:
if ct := self.headers.get("content-type"):
return ct.split(";")[0]

return UNKNOWN

@property
def length(self):
try:
return int(self.headers.get("content-length"))
except TypeError:
return len(self.body)
def length(self) -> int:
if cl := self.headers.get("content-length"):
return int(cl)

return len(self.body)

def __hash__(self):
def __hash__(self) -> int:
return hash(self.body)

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return (self.status, self.body, self.redirect) == (
other.status,
other.body,
Expand All @@ -66,7 +71,7 @@ def __eq__(self, other):


class Response(BaseResponse):
def __init__(self, response):
def __init__(self, response: requests.Response) -> None:
super().__init__(response)

for chunk in response.iter_content(chunk_size=ITER_CHUNK_SIZE):
Expand All @@ -88,7 +93,7 @@ def __init__(self, response):

class AsyncResponse(BaseResponse):
@classmethod
async def create(cls, response: httpx.Response) -> "AsyncResponse":
async def create(cls, response: httpx.Response) -> AsyncResponse:
self = cls(response)
async for chunk in response.aiter_bytes(chunk_size=ITER_CHUNK_SIZE):
self.body += chunk
Expand Down
Loading

0 comments on commit da7e840

Please sign in to comment.