Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Handle CouchDB HTTP 403 on all routes #56

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions aiocouch/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,16 @@ class ExpectationFailedError(ValueError):
pass


class ClientResponseError(aiohttp.ClientResponseError):
def __init__(self, reason, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.reason = reason


def raise_for_endpoint(
endpoint: Endpoint,
message: str,
exception: aiohttp.ClientResponseError,
exception: ClientResponseError,
exception_type: Optional[Type[Exception]] = None,
) -> NoReturn:
if exception_type is None:
Expand Down Expand Up @@ -143,6 +149,9 @@ def raise_for_endpoint(

message_input = {}

with suppress(AttributeError):
message_input["reason"] = exception.reason
message_input["reason"] = message_input.get("reason", exception.message)
with suppress(AttributeError):
message_input["id"] = endpoint.id
message_input["endpoint"] = endpoint.endpoint
Expand All @@ -165,7 +174,7 @@ def decorator_raises(func: FuncT) -> FuncT:
async def wrapper(endpoint: Endpoint, *args: Any, **kwargs: Any) -> Any:
try:
return await func(endpoint, *args, **kwargs)
except aiohttp.ClientResponseError as exception:
except ClientResponseError as exception:
if status == exception.status:
raise_for_endpoint(endpoint, message, exception, exception_type)
raise exception
Expand All @@ -186,7 +195,7 @@ async def wrapper(
try:
async for data in func(endpoint, *args, **kwargs):
yield data
except aiohttp.ClientResponseError as exception:
except ClientResponseError as exception:
if status == exception.status:
raise_for_endpoint(endpoint, message, exception, exception_type)
raise exception
Expand Down
84 changes: 60 additions & 24 deletions aiocouch/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import aiohttp

from . import database, document
from .exception import NotFoundError, generator_raises, raises
from .exception import NotFoundError, generator_raises, raises, ClientResponseError
from .typing import JsonDict


Expand Down Expand Up @@ -160,7 +160,21 @@ async def _request(
async with self._http_session.request(
method, url=f"{self._server}{path}", **kwargs
) as resp:
resp.raise_for_status()
if not resp.ok:
reason = None
with suppress(Exception):
reason = (await resp.json())["reason"]
# Copied from aiohttp v3.9.5 raise_for_status():
assert resp.reason is not None
resp.release()
raise ClientResponseError(
reason,
resp.request_info,
resp.history,
status=resp.status,
message=resp.reason,
headers=resp.headers,
)
return (
HTTPResponse(resp),
await resp.json() if return_json else await resp.read(),
Expand All @@ -179,14 +193,29 @@ async def _streamed_request(
async with self._http_session.request(
method, url=f"{self._server}{path}", **kwargs
) as resp:
resp.raise_for_status()
if not resp.ok:
reason = None
with suppress(Exception):
reason = (await resp.json())["reason"]
# Copied from aiohttp v3.9.5 raise_for_status():
assert resp.reason is not None
resp.release()
raise ClientResponseError(
reason,
resp.request_info,
resp.history,
status=resp.status,
message=resp.reason,
headers=resp.headers,
)

async for line in resp.content:
# this should only happen for empty lines
with suppress(json.JSONDecodeError):
yield json.loads(line)

@raises(401, "Invalid credentials")
@raises(403, "Access forbidden: {reason}")
async def _all_dbs(self, **params: Any) -> List[str]:
_, json = await self._get("/_all_dbs", params)
assert not isinstance(json, bytes)
Expand All @@ -203,12 +232,14 @@ async def close(self) -> None:
await asyncio.sleep(0.250 if has_ssl_conn else 0)

@raises(401, "Invalid credentials")
@raises(403, "Access forbidden: {reason}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only manually tested this one with a real CouchDB, not the other routes.

async def _info(self) -> JsonDict:
_, json = await self._get("/")
assert not isinstance(json, bytes)
return json

@raises(401, "Authentication failed, check provided credentials.")
@raises(403, "Access forbidden: {reason}")
async def _check_session(self) -> RequestResult:
return await self._get("/_session")

Expand All @@ -223,19 +254,19 @@ def endpoint(self) -> str:
return f"/{_quote_id(self.id)}"

@raises(401, "Invalid credentials")
@raises(403, "Read permission required")
@raises(403, "Access forbidden: {reason}")
async def _exists(self) -> bool:
try:
await self._remote._head(self.endpoint)
return True
except aiohttp.ClientResponseError as e:
except ClientResponseError as e:
if e.status == 404:
return False
else:
raise e

@raises(401, "Invalid credentials")
@raises(403, "Read permission required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Requested database not found ({id})")
async def _get(self) -> JsonDict:
_, json = await self._remote._get(self.endpoint)
Expand All @@ -244,6 +275,7 @@ async def _get(self) -> JsonDict:

@raises(400, "Invalid database name")
@raises(401, "CouchDB Server Administrator privileges required")
@raises(403, "Access forbidden: {reason}")
@raises(412, "Database already exists")
async def _put(self, **params: Any) -> JsonDict:
_, json = await self._remote._put(self.endpoint, params=params)
Expand All @@ -252,13 +284,14 @@ async def _put(self, **params: Any) -> JsonDict:

@raises(400, "Invalid database name or forgotten document id by accident")
@raises(401, "CouchDB Server Administrator privileges required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Database doesn't exist or invalid database name ({id})")
async def _delete(self) -> None:
await self._remote._delete(self.endpoint)

@raises(400, "The request provided invalid JSON data or invalid query parameter")
@raises(401, "Read permission required")
@raises(403, "Read permission required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Invalid database name")
@raises(415, "Bad Content-Type value")
async def _bulk_get(self, docs: List[str], **params: Any) -> JsonDict:
Expand All @@ -270,7 +303,7 @@ async def _bulk_get(self, docs: List[str], **params: Any) -> JsonDict:

@raises(400, "The request provided invalid JSON data")
@raises(401, "Invalid credentials")
@raises(403, "Write permission required")
@raises(403, "Access forbidden: {reason}")
@raises(417, "At least one document was rejected by the validation function")
async def _bulk_docs(self, docs: List[JsonDict], **data: Any) -> JsonDict:
data["docs"] = docs
Expand All @@ -280,7 +313,7 @@ async def _bulk_docs(self, docs: List[JsonDict], **data: Any) -> JsonDict:

@raises(400, "Invalid request")
@raises(401, "Read privilege required for document '{id}'")
@raises(403, "Read permission required")
@raises(403, "Access forbidden: {reason}")
@raises(500, "Query execution failed", RuntimeError)
async def _find(self, selector: Any, **data: Any) -> JsonDict:
data["selector"] = selector
Expand All @@ -290,6 +323,7 @@ async def _find(self, selector: Any, **data: Any) -> JsonDict:

@raises(400, "Invalid request")
@raises(401, "Admin permission required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Database not found")
@raises(500, "Execution error")
async def _index(self, index: JsonDict, **data: Any) -> JsonDict:
Expand All @@ -299,20 +333,21 @@ async def _index(self, index: JsonDict, **data: Any) -> JsonDict:
return json

@raises(401, "Invalid credentials")
@raises(403, "Permission required")
@raises(403, "Access forbidden: {reason}")
async def _get_security(self) -> JsonDict:
_, json = await self._remote._get(f"{self.endpoint}/_security")
assert not isinstance(json, bytes)
return json

@raises(401, "Invalid credentials")
@raises(403, "Permission required")
@raises(403, "Access forbidden: {reason}")
async def _put_security(self, doc: JsonDict) -> JsonDict:
_, json = await self._remote._put(f"{self.endpoint}/_security", doc)
assert not isinstance(json, bytes)
return json

@generator_raises(400, "Invalid request")
@generator_raises(403, "Access forbidden: {reason}")
async def _changes(self, **params: Any) -> AsyncGenerator[JsonDict, None]:
if "feed" in params and params["feed"] == "continuous":
params.setdefault("heartbeat", True)
Expand All @@ -329,6 +364,7 @@ async def _changes(self, **params: Any) -> AsyncGenerator[JsonDict, None]:
yield result

@raises(400, "Invalid database or JSON payload")
@raises(403, "Access forbidden: {reason}")
@raises(415, "Bad Content-Type header value")
@raises(500, "Internal server error or timeout")
async def _purge(self, docs: JsonDict, **params: Any) -> JsonDict:
Expand All @@ -350,13 +386,13 @@ def endpoint(self) -> str:
return f"{self._database.endpoint}/{_quote_id(self.id)}"

@raises(401, "Read privilege required for document '{id}'")
@raises(403, "Read privilege required for document '{id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Document {id} was not found")
async def _head(self) -> None:
await self._database._remote._head(self.endpoint)

@raises(401, "Read privilege required for document '{id}'")
@raises(403, "Read privilege required for document '{id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Document {id} was not found")
async def _info(self) -> JsonDict:
response, _ = await self._database._remote._head(self.endpoint)
Expand All @@ -376,7 +412,7 @@ async def _exists(self) -> bool:

@raises(400, "The format of the request or revision was invalid")
@raises(401, "Read privilege required for document '{id}'")
@raises(403, "Read privilege required for document '{id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Document {id} was not found")
async def _get(self, **params: Any) -> JsonDict:
_, json = await self._database._remote._get(self.endpoint, params)
Expand All @@ -385,7 +421,7 @@ async def _get(self, **params: Any) -> JsonDict:

@raises(400, "The format of the request or revision was invalid")
@raises(401, "Write privilege required for document '{id}'")
@raises(403, "Write privilege required for document '{id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Specified database or document ID doesn't exists ({endpoint})")
@raises(
409,
Expand All @@ -401,7 +437,7 @@ async def _put(

@raises(400, "Invalid request body or parameters")
@raises(401, "Write privilege required for document '{id}'")
@raises(403, "Write privilege required for document '{id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Specified database or document ID doesn't exists ({endpoint})")
@raises(
409, "Specified revision ({rev}) is not the latest for target document '{id}'"
Expand All @@ -414,7 +450,7 @@ async def _delete(self, rev: str, **params: Any) -> Tuple[HTTPResponse, JsonDict

@raises(400, "Invalid request body or parameters")
@raises(401, "Read or write privileges required")
@raises(403, "Read or write privileges required")
@raises(403, "Access forbidden: {reason}")
@raises(
404, "Specified database, document ID or revision doesn't exists ({endpoint})"
)
Expand Down Expand Up @@ -444,23 +480,23 @@ def endpoint(self) -> str:
return f"{self._document.endpoint}/{_quote_id(self.id)}"

@raises(401, "Read privilege required for document '{document_id}'")
@raises(403, "Read privilege required for document '{document_id}'")
@raises(403, "Access forbidden: {reason}")
async def _exists(self) -> bool:
try:
response, _ = await self._document._database._remote._head(
self.endpoint, return_json=False
)
self.content_type = response.headers["Content-Type"]
return True
except aiohttp.ClientResponseError as e:
except ClientResponseError as e:
if e.status == 404:
return False
else:
raise e

@raises(400, "Invalid request parameters")
@raises(401, "Read privilege required for document '{document_id}'")
@raises(403, "Read privilege required for document '{document_id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Document '{document_id}' or attachment '{id}' doesn't exists")
async def _get(self, **params: Any) -> bytes:
response, data = await self._document._database._remote._get_bytes(
Expand All @@ -472,7 +508,7 @@ async def _get(self, **params: Any) -> bytes:

@raises(400, "Invalid request body or parameters")
@raises(401, "Write privilege required for document '{document_id}'")
@raises(403, "Write privilege required for document '{document_id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Document '{document_id}' doesn't exists")
@raises(
409, "Specified revision {document_rev} is not the latest for target document"
Expand All @@ -490,7 +526,7 @@ async def _put(

@raises(400, "Invalid request body or parameters")
@raises(401, "Write privilege required for document '{document_id}'")
@raises(403, "Write privilege required for document '{document_id}'")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Specified database or document ID doesn't exists ({endpoint})")
@raises(
409, "Specified revision {document_rev} is not the latest for target document"
Expand Down Expand Up @@ -519,7 +555,7 @@ def endpoint(self) -> str:

@raises(400, "Invalid request")
@raises(401, "Read privileges required")
@raises(403, "Read privileges required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Specified database, design document or view is missing")
async def _get(self, **params: Any) -> JsonDict:
_, json = await self._database._remote._get(self.endpoint, params)
Expand All @@ -528,7 +564,7 @@ async def _get(self, **params: Any) -> JsonDict:

@raises(400, "Invalid request")
@raises(401, "Write privileges required")
@raises(403, "Write privileges required")
@raises(403, "Access forbidden: {reason}")
@raises(404, "Specified database, design document or view is missing")
async def _post(self, keys: List[str], **params: Any) -> JsonDict:
_, json = await self._database._remote._post(
Expand Down
Loading
Loading