diff --git a/asyncord/client/http/client.py b/asyncord/client/http/client.py index 050d283..11219ad 100644 --- a/asyncord/client/http/client.py +++ b/asyncord/client/http/client.py @@ -4,6 +4,7 @@ import asyncio import json +import logging from collections.abc import Mapping, Sequence from http import HTTPStatus from types import MappingProxyType, TracebackType @@ -22,23 +23,39 @@ from typing import Self -AttachedFile = tuple[str, str, BinaryIO | bytes] -"""Type alias for a file to be attached to a request. +MAX_NEXT_RETRY_SEC = 10 +"""Maximum number of seconds to wait before retrying a request.""" -The tuple contains the filename, the content type, and the file object. -""" +logger = logging.getLogger(__name__) -MAX_NEXT_RETRY_SEC = 10 -"""Maximum number of seconds to wait before retrying a request.""" +class AttachedFile(NamedTuple): + """Type alias for a file to be attached to a request. + + The tuple contains the filename, the content type, and the file object. + """ + + filename: str + """Name of the file.""" + + content_type: str + """Content type of the file.""" + + file: BinaryIO + """File object.""" class Response(NamedTuple): """Response structure for the HTTP client.""" status: int + """Response status code.""" + headers: Mapping[str, str] + """Response headers.""" + body: Any + """Response body.""" class RateLimitBody(BaseModel): @@ -50,7 +67,7 @@ class RateLimitBody(BaseModel): retry_after: float """Number of seconds to wait before submitting another request.""" - global_: bool = Field(alias='global') + is_global: bool = Field(alias='global') """Whether this is a global rate limit.""" @@ -219,76 +236,93 @@ async def _request( # noqa: PLR0913 ServerError: If the response status code is in the 500 range. RateLimitError: If the response status code is 429 and the retry_after is greater than 10. """ - if headers is None: - headers = self._headers - else: - headers = {**self._headers, **headers} + headers = {**self._headers, **(headers or {})} async with self._make_raw_request(method, url, payload, files, headers) as resp: - body, message = await self._extract_body_and_message(resp) - - match resp.status: - case status if status < HTTPStatus.BAD_REQUEST: - return Response( - status=resp.status, - headers=MappingProxyType(dict(resp.headers.items())), - body=body, - ) - - case HTTPStatus.TOO_MANY_REQUESTS: - # FIXME: It's a simple hack for now. Potentially 'endless' recursion - ratelimit = RateLimitBody(**body) - if ratelimit.retry_after > MAX_NEXT_RETRY_SEC: - raise errors.RateLimitError( - message=message or 'Unknown error', - resp=resp, - retry_after=ratelimit.retry_after or None, - ) - # FIXME: Move to decorator - await asyncio.sleep(ratelimit.retry_after + 0.1) - return await self._request( - method=method, - url=url, + body = await self._extract_body(resp) + status = resp.status + + if resp.status < HTTPStatus.BAD_REQUEST: + return Response( + status=resp.status, + headers=MappingProxyType(dict(resp.headers.items())), + body=body, + ) + + if not isinstance(body, dict): + raise errors.ServerError( + message='Expected JSON body', + payload=payload, + headers=headers, + resp=resp, + body=body, + ) + + if status == HTTPStatus.TOO_MANY_REQUESTS: + # FIXME: It's a simple hack for now. Potentially 'endless' recursion + ratelimit = RateLimitBody.model_validate(body) + logger.warning(f'Rate limited: {ratelimit.message} (retry after {ratelimit.retry_after})') + + if ratelimit.retry_after > MAX_NEXT_RETRY_SEC: + raise errors.RateLimitError( + message=ratelimit.message, payload=payload, - files=files, headers=headers, - ) - - case status if HTTPStatus.BAD_REQUEST <= status < HTTPStatus.INTERNAL_SERVER_ERROR: - # TODO: #8 Add more specific errors for 400 range - raise errors.ClientError( - message=message or 'Unknown error', - resp=resp, - code=body.get('code'), - ) - - case _: - raise errors.ServerError( - message=message or 'Unknown error', resp=resp, - status_code=resp.status, + retry_after=ratelimit.retry_after, ) - async def _extract_body_and_message(self, resp: ClientResponse) -> tuple[Any, str | None]: - """Extract the body and message from the response. + # FIXME: Move to decorator + await asyncio.sleep(ratelimit.retry_after + 0.1) + return await self._request( + method=method, + url=url, + payload=payload, + files=files, + headers=headers, + ) + + error_body = errors.RequestErrorBody.model_validate(body) + if HTTPStatus.BAD_REQUEST <= status < HTTPStatus.INTERNAL_SERVER_ERROR: + raise errors.ClientError( + message=error_body.message, + payload=payload, + headers=headers, + resp=resp, + body=error_body, + ) + + raise errors.ServerError( + message=error_body.message, + payload=payload, + headers=headers, + resp=resp, + body=error_body, + ) + + async def _extract_body(self, resp: ClientResponse) -> dict[str, Any] | str: + """Extract the body. Args: resp: Request response. Returns: - Body and message from the response. + Body of the response. """ if resp.status == HTTPStatus.NO_CONTENT: - body = {} - message = None - elif resp.headers.get('Content-Type') == JSON_CONTENT_TYPE: - body = await resp.json() - message = body.get('message') if isinstance(body, Mapping) else None - else: - body = {} - message = await resp.text() - - return body, message + return {} + + if resp.headers.get('Content-Type') == JSON_CONTENT_TYPE: + try: + return await resp.json() + except json.JSONDecodeError: + body = await resp.text() + logger.warning(f'Failed to decode JSON body: {body}') + if body: + return body + return {} + + return {} def _make_raw_request( # noqa: PLR0913 self, diff --git a/asyncord/client/http/errors.py b/asyncord/client/http/errors.py index 1e99dff..d94c943 100644 --- a/asyncord/client/http/errors.py +++ b/asyncord/client/http/errors.py @@ -1,51 +1,214 @@ +from __future__ import annotations + +from typing import Any, Mapping + from aiohttp import ClientResponse +from pydantic import BaseModel class BaseDiscordError(Exception): """Base class for all Discord errors.""" - def __init__(self, message: str, resp: ClientResponse) -> None: + def __init__(self, message: str) -> None: self.message = message + + def __str__(self) -> str: + return self.message + + +class DiscordHTTPError(BaseDiscordError): + """Base class for all Discord HTTP errors.""" + + def __init__( + self, + message: str, + payload: Any, + headers: Mapping[str, str], + resp: ClientResponse, + ) -> None: + super().__init__(message) + self.payload = payload + self.headers = headers self.resp = resp + self.status = resp.status def __str__(self) -> str: - return f'{self.message}' + if self.resp.reason: + return "HTTP {0.status} ({0.reason}): {1}".format(self.resp, self.message) + return "HTTP {0.status}: {1}".format(self.resp, self.message) -class ClientError(BaseDiscordError): +class ClientError(DiscordHTTPError): """Error raised when the client encounters an error. Usually this is due to a bad request (4xx). """ - def __init__(self, message: str, resp: ClientResponse, code: int | None = None) -> None: - super().__init__(message, resp) - self.code = code + def __init__( + self, + message: str, + payload: Any, + headers: Mapping[str, str], + resp: ClientResponse, + body: RequestErrorBody, + ) -> None: + """Initialize the ClientError. + + Args: + message (str): The error message. + payload (Any): The payload of the request. + headers (Mapping[str, str]): The headers of the request. + resp (ClientResponse): The response of the request. + body (RequestErrorBody): The body of the request. + """ + super().__init__(message, payload, headers, resp) + self.body = body def __str__(self) -> str: - return f'({self.code}) {self.message}' + """Format the error.""" + if self.resp.reason: + exc_str = 'HTTP {0.status} ({0.reason})'.format(self.resp) + else: + exc_str = 'HTTP {0.status}'.format(self.resp) + + if not self.body: + return exc_str + + if isinstance(self.body, str): + return f'{exc_str}\n{self.body}' + + exc_str = f'{exc_str}\nCode {self.body.code}: {self.body.message}' + if not self.body.errors: + return exc_str + + return f'{exc_str}\n' + self._format_errors('', self.body.errors) + + def _format_errors(self, path: str, errors: ErrorBlock | ObjectErrorType | ArrayErrorType) -> str: + """Get all errors from an error block. + Args: + path (str): Path to the error block. + errors (ErrorBlock | ObjectErrorType | ArrayErrorType): The error block to get the errors from. -class RateLimitError(BaseDiscordError): + Returns: + list[ErrorItem]: The errors. + """ + if isinstance(errors, ErrorBlock): + return path + '\n'.join(f'\t-> {error.code}: {error.message}' for error in errors._errors) + + return '\n'.join(self._format_errors(f'{path}.{key}', value) for key, value in errors.items()) + + +class RateLimitError(DiscordHTTPError): """Error raised when the client encounters a rate limit. This is usually due to too many requests (429). """ - def __init__(self, message: str, resp: ClientResponse, retry_after: float | None = None) -> None: - super().__init__(message, resp) + def __init__( + self, + message: str, + payload: Any, + headers: Mapping[str, str], + resp: ClientResponse, + retry_after: float, + ) -> None: + super().__init__(message, payload, headers, resp) self.retry_after = retry_after def __str__(self) -> str: return f'{self.message} (retry after {self.retry_after})' -class ServerError(BaseDiscordError): +class ServerError(DiscordHTTPError): """Error raised when the server return status code >= 500.""" - def __init__(self, message: str, resp: ClientResponse, status_code: int) -> None: - super().__init__(message, resp) - self.status_code = status_code + def __init__( + self, + message: str, + payload: Any, + headers: Mapping[str, str], + resp: ClientResponse, + body: RequestErrorBody | str | None = None, + ) -> None: + """Initialize the ServerError. + + Args: + message (str): The error message. + payload (Any): The payload of the request. + headers (Mapping[str, str]): The headers of the request. + resp (ClientResponse): The response of the request. + body (RequestErrorBody): The body of the request. + """ + super().__init__(message, payload, headers, resp) + self.body = body def __str__(self) -> str: - return f'{self.status_code}: {self.message}' + """Format the error.""" + if self.resp.reason: + exc_str = 'HTTP {0.status} ({0.reason})'.format(self.resp) + else: + exc_str = 'HTTP {0.status}'.format(self.resp) + + if not self.body: + return exc_str + + if isinstance(self.body, str): + return f'{exc_str}\n{self.body}' + + exc_str = f'{exc_str} - {self.body.code}: {self.body.message}' + if not self.body.errors: + return exc_str + + return f'{exc_str}\n' + self._format_errors('', self.body.errors) + + def _format_errors(self, path: str, errors: ErrorBlock | ObjectErrorType | ArrayErrorType) -> str: + """Get all errors from an error block. + + Args: + path (str): Path to the error block. + errors (ErrorBlock | ObjectErrorType | ArrayErrorType): The error block to get the errors from. + + Returns: + list[ErrorItem]: The errors. + """ + if isinstance(errors, ErrorBlock): + return path + '\n'.join(f'\t-> {error.code}: {error.message}' for error in errors._errors) + + return '\n'.join(self._format_errors(f'{path}.{key}', value) for key, value in errors.items()) + + +class ErrorItem(BaseModel): + """Represents an error item.""" + + code: int + """Error code.""" + + message: str + """Error message.""" + + +class ErrorBlock(BaseModel): + """Represents an object error.""" + + _errors: list[ErrorItem] + """List of errors.""" + + +type ObjectErrorType = dict[str, ErrorBlock | ObjectErrorType | ArrayErrorType] +"""Type hint for an object error.""" + +type ArrayErrorType = dict[int, ObjectErrorType] +"""Type hint for an array error.""" + + +class RequestErrorBody(BaseModel): + """Represents a body of a request error.""" + + code: int + """Error code.""" + + message: str + """Error message.""" + + errors: ErrorBlock | ObjectErrorType | None = None