Skip to content

Commit

Permalink
Merge pull request #31 from tb1337/develop
Browse files Browse the repository at this point in the history
Add better request/response handling
  • Loading branch information
tb1337 authored Dec 21, 2023
2 parents 47c0d89 + 35c8399 commit 8217817
Show file tree
Hide file tree
Showing 20 changed files with 111 additions and 70 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr-labels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
uses: ludeeus/[email protected]
with:
labels: >-
breaking change, bug, chore, enhancement, refactor
breaking change, bug, chore, enhancement, refactor, documentation
62 changes: 50 additions & 12 deletions pypaperless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""PyPaperless."""

import logging
from collections.abc import Generator
from contextlib import asynccontextmanager
from typing import Any

import aiohttp
Expand All @@ -21,7 +23,7 @@
TasksEndpoint,
UsersEndpoint,
)
from .errors import BadRequestException
from .errors import BadRequestException, DataNotExpectedException
from .models.shared import ResourceType


Expand Down Expand Up @@ -145,7 +147,7 @@ async def initialize(self):
"""Initialize the connection to the api and fetch the endpoints."""
self.logger.info("Fetching api endpoints.")

res = await self.request("get", "")
res = await self.request_json("get", "")

self._consumption_templates = ConsumptionTemplatesEndpoint(
self, res.pop(ResourceType.CONSUMPTION_TEMPLATES)
Expand All @@ -165,24 +167,30 @@ async def initialize(self):
self._users = UsersEndpoint(self, res.pop(ResourceType.USERS))

self._initialized = True

if len(res) > 0:
self.logger.debug("Unused endpoints: %s", ", ".join(res))
self.logger.info("Initialized.")
self.logger.debug("Unused endpoints: %s", ", ".join(res))

async def close(self):
"""Clean up connection."""
if self._session:
await self._session.close()
self.logger.info("Closed.")

async def request(self, method: str, endpoint: str, **kwargs):
"""Make request on the api and return response data."""
@asynccontextmanager
async def generate_request(
self,
method: str,
endpoint: str,
**kwargs,
) -> Generator[aiohttp.ClientResponse, None, None]:
"""Create a client response object for further use."""
if not isinstance(self._session, aiohttp.ClientSession):
self._session = aiohttp.ClientSession()

url = endpoint if endpoint.startswith("http") else f"http://{self.host}/api/{endpoint}"

# check and add trailing slash
url = url.rstrip("/") + "/"
url = url.rstrip("/") + "/" # check and add trailing slash

kwargs.update(self._request_opts)

Expand All @@ -197,16 +205,46 @@ async def request(self, method: str, endpoint: str, **kwargs):
)

async with self._session.request(method, url, **kwargs) as res:
self.logger.debug("Request %s (%d): %s", method, res.status, res.url)
yield res

async def request_json(
self,
method: str,
endpoint: str,
**kwargs,
) -> dict[str, Any]:
"""Make a request to the api and parse response json to dict."""
async with self.generate_request(method, endpoint, **kwargs) as res:
self.logger.debug("Json-Request %s (%d): %s", method.upper(), res.status, res.url)

# bad request
if res.status == 400:
raise BadRequestException(f"{await res.text()}")
# no content
# no content, in most cases on DELETE method
if res.status == 204:
return {}
res.raise_for_status()
if res.content_type == "application/json":
return await res.json()

if res.content_type != "application/json":
raise DataNotExpectedException(f"Content-type is not json! {res.content_type}")

return await res.json()

async def request_file(
self,
method: str,
endpoint: str,
**kwargs,
) -> bytes:
"""Make a request to the api and return response as bytes."""
async with self.generate_request(method, endpoint, **kwargs) as res:
self.logger.debug("File-Request %s (%d): %s", method.upper(), res.status, res.url)

# bad request
if res.status == 400:
raise BadRequestException(f"{await res.text()}")
res.raise_for_status()

return await res.read()

async def __aenter__(self) -> "Paperless":
Expand Down
12 changes: 6 additions & 6 deletions pypaperless/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def endpoint(self):

async def list(self) -> list[int]:
"""Return a list of all entity ids, if applicable."""
res = await self._paperless.request("get", self.endpoint)
res = await self._paperless.request_json("get", self.endpoint)
if "all" in res:
return res["all"]

Expand All @@ -70,7 +70,7 @@ async def get(
if "page_size" not in kwargs:
kwargs["page_size"] = self.request_page_size

res = await self._paperless.request("get", self.endpoint, params=kwargs)
res = await self._paperless.request_json("get", self.endpoint, params=kwargs)
return PaginatedResult(
kwargs["page"],
kwargs["page"] + 1 if res["next"] else None,
Expand All @@ -91,7 +91,7 @@ async def iterate(self, **kwargs: dict[str, Any]) -> Generator[RT, None, None]:
async def one(self, idx: int) -> RT:
"""Request exactly one entity by id."""
url = f"{self.endpoint}/{idx}"
res = await self._paperless.request("get", url)
res = await self._paperless.request_json("get", url)
return dataclass_from_dict(self.endpoint_cls, res)


Expand All @@ -100,7 +100,7 @@ class BaseEndpointCrudMixin:

async def create(self: BaseEndpoint, obj: PaperlessPost) -> RT:
"""Create a new entity. Raise on failure."""
res = await self._paperless.request(
res = await self._paperless.request_json(
"post",
self.endpoint,
json=dataclass_to_dict(obj),
Expand All @@ -110,15 +110,15 @@ async def create(self: BaseEndpoint, obj: PaperlessPost) -> RT:
async def update(self: BaseEndpoint, obj: RT) -> RT:
"""Update an existing entity. Raise on failure."""
url = f"{self.endpoint}/{obj.id}"
res = await self._paperless.request(
res = await self._paperless.request_json(
"put", url, json=dataclass_to_dict(obj, skip_none=False)
)
return dataclass_from_dict(self.endpoint_cls, res)

async def delete(self: BaseEndpoint, obj: RT) -> bool:
"""Delete an existing entity. Raise on failure."""
url = f"{self.endpoint}/{obj.id}"
await self._paperless.request("delete", url)
await self._paperless.request_json("delete", url)
return True


Expand Down
12 changes: 6 additions & 6 deletions pypaperless/api/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def get(self, obj: DocumentOrIdType) -> list[DocumentNote]:
"""Request document notes of given document."""
idx = _get_document_id_helper(obj)
url = f"{self.endpoint}/{idx}/notes"
res = await self._paperless.request("get", url)
res = await self._paperless.request_json("get", url)

# We have to transform data here slightly.
# There are two major differences in the data depending on which endpoint is requested.
Expand All @@ -54,15 +54,15 @@ async def get(self, obj: DocumentOrIdType) -> list[DocumentNote]:
async def create(self, obj: DocumentNotePost) -> None:
"""Create a new document note. Raises on failure."""
url = f"{self.endpoint}/{obj.document}/notes"
await self._paperless.request("post", url, json=dataclass_to_dict(obj))
await self._paperless.request_json("post", url, json=dataclass_to_dict(obj))

async def delete(self, obj: DocumentNote) -> None:
"""Delete an existing document note. Raises on failure."""
url = f"{self.endpoint}/{obj.document}/notes"
params = {
"id": obj.id,
}
await self._paperless.request("delete", url, params=params)
await self._paperless.request_json("delete", url, params=params)


class DocumentFilesService(BaseService):
Expand All @@ -75,7 +75,7 @@ async def _get_data(
) -> bytes:
"""Request a child endpoint."""
url = f"{self.endpoint}/{idx}/{path}"
return await self._paperless.request("get", url)
return await self._paperless.request_file("get", url)

async def download(self, obj: DocumentOrIdType) -> bytes:
"""Request document endpoint for downloading the actual file."""
Expand Down Expand Up @@ -126,12 +126,12 @@ async def create(self, obj: DocumentPost) -> str:
form.add_field("tags", f"{tag}")

url = f"{self.endpoint}/post_document/"
res = await self._paperless.request("post", url, data=form)
res = await self._paperless.request_json("post", url, data=form)
return str(res)

async def meta(self, obj: DocumentOrIdType) -> RT:
"""Request document metadata of given document."""
idx = _get_document_id_helper(obj)
url = f"{self.endpoint}/{idx}/metadata"
res = await self._paperless.request("get", url)
res = await self._paperless.request_json("get", url)
return dataclass_from_dict(DocumentMetaInformation, res)
4 changes: 2 additions & 2 deletions pypaperless/api/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def get(
**kwargs: dict[str, Any],
) -> list[RT]:
"""Request entities."""
res = await self._paperless.request("get", self.endpoint, params=kwargs)
res = await self._paperless.request_json("get", self.endpoint, params=kwargs)
return [dataclass_from_dict(self.endpoint_cls, item) for item in res]

async def iterate(
Expand All @@ -38,5 +38,5 @@ async def one(self, idx: str) -> RT:
params = {
"task_id": idx,
}
res = await self._paperless.request("get", self.endpoint, params=params)
res = await self._paperless.request_json("get", self.endpoint, params=params)
return dataclass_from_dict(self.endpoint_cls, res.pop())
8 changes: 6 additions & 2 deletions pypaperless/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@


class PaperlessException(Exception):
"""Base exception for paperless."""
"""Base exception for PyPaperless."""


class BadRequestException(PaperlessException):
"""Raised when requesting wrong data."""
"""Raise when requesting wrong data."""


class DataNotExpectedException(PaperlessException):
"""Raise when expecting a type and receiving something else."""
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def endpoints_data():
d = load_fixture_data("data.json")
return d["endpoints"]

with patch.object(api, "request", return_value=endpoints_data()):
with patch.object(api, "request_json", return_value=endpoints_data()):
await api.initialize()

yield api
await api.close()
6 changes: 3 additions & 3 deletions tests/pypaperless/test_consumption_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_endpoint(paperless: Paperless) -> None:

async def test_list_and_get(paperless: Paperless, data):
"""Test list."""
with patch.object(paperless, "request", return_value=data["consumption_templates"]):
with patch.object(paperless, "request_json", return_value=data["consumption_templates"]):
result = await paperless.consumption_templates.list()

assert isinstance(result, list)
Expand All @@ -34,15 +34,15 @@ async def test_list_and_get(paperless: Paperless, data):

async def test_iterate(paperless: Paperless, data):
"""Test iterate."""
with patch.object(paperless, "request", return_value=data["consumption_templates"]):
with patch.object(paperless, "request_json", return_value=data["consumption_templates"]):
async for item in paperless.consumption_templates.iterate():
assert isinstance(item, ConsumptionTemplate)


async def test_one(paperless: Paperless, data):
"""Test one."""
with patch.object(
paperless, "request", return_value=data["consumption_templates"]["results"][0]
paperless, "request_json", return_value=data["consumption_templates"]["results"][0]
):
item = await paperless.consumption_templates.one(72)

Expand Down
6 changes: 3 additions & 3 deletions tests/pypaperless/test_correspondents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_endpoint(paperless: Paperless) -> None:

async def test_list_and_get(paperless: Paperless, data):
"""Test list."""
with patch.object(paperless, "request", return_value=data["correspondents"]):
with patch.object(paperless, "request_json", return_value=data["correspondents"]):
result = await paperless.correspondents.list()

assert isinstance(result, list)
Expand All @@ -34,14 +34,14 @@ async def test_list_and_get(paperless: Paperless, data):

async def test_iterate(paperless: Paperless, data):
"""Test iterate."""
with patch.object(paperless, "request", return_value=data["correspondents"]):
with patch.object(paperless, "request_json", return_value=data["correspondents"]):
async for item in paperless.correspondents.iterate():
assert isinstance(item, Correspondent)


async def test_one(paperless: Paperless, data):
"""Test one."""
with patch.object(paperless, "request", return_value=data["correspondents"]["results"][0]):
with patch.object(paperless, "request_json", return_value=data["correspondents"]["results"][0]):
item = await paperless.correspondents.one(72)

assert isinstance(item, Correspondent)
Expand Down
6 changes: 3 additions & 3 deletions tests/pypaperless/test_custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_endpoint(paperless: Paperless) -> None:

async def test_list_and_get(paperless: Paperless, data):
"""Test list."""
with patch.object(paperless, "request", return_value=data["custom_fields"]):
with patch.object(paperless, "request_json", return_value=data["custom_fields"]):
result = await paperless.custom_fields.list()

assert isinstance(result, list)
Expand All @@ -34,14 +34,14 @@ async def test_list_and_get(paperless: Paperless, data):

async def test_iterate(paperless: Paperless, data):
"""Test iterate."""
with patch.object(paperless, "request", return_value=data["custom_fields"]):
with patch.object(paperless, "request_json", return_value=data["custom_fields"]):
async for item in paperless.custom_fields.iterate():
assert isinstance(item, CustomField)


async def test_one(paperless: Paperless, data):
"""Test one."""
with patch.object(paperless, "request", return_value=data["custom_fields"]["results"][0]):
with patch.object(paperless, "request_json", return_value=data["custom_fields"]["results"][0]):
item = await paperless.custom_fields.one(72)

assert isinstance(item, CustomField)
Expand Down
6 changes: 3 additions & 3 deletions tests/pypaperless/test_document_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_endpoint(paperless: Paperless) -> None:

async def test_list_and_get(paperless: Paperless, data):
"""Test list."""
with patch.object(paperless, "request", return_value=data["document_types"]):
with patch.object(paperless, "request_json", return_value=data["document_types"]):
result = await paperless.document_types.list()

assert isinstance(result, list)
Expand All @@ -34,14 +34,14 @@ async def test_list_and_get(paperless: Paperless, data):

async def test_iterate(paperless: Paperless, data):
"""Test iterate."""
with patch.object(paperless, "request", return_value=data["document_types"]):
with patch.object(paperless, "request_json", return_value=data["document_types"]):
async for item in paperless.document_types.iterate():
assert isinstance(item, DocumentType)


async def test_one(paperless: Paperless, data):
"""Test one."""
with patch.object(paperless, "request", return_value=data["document_types"]["results"][0]):
with patch.object(paperless, "request_json", return_value=data["document_types"]["results"][0]):
item = await paperless.document_types.one(72)

assert isinstance(item, DocumentType)
Expand Down
6 changes: 3 additions & 3 deletions tests/pypaperless/test_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_endpoint(paperless: Paperless) -> None:

async def test_list_and_get(paperless: Paperless, data):
"""Test list."""
with patch.object(paperless, "request", return_value=data["documents"]):
with patch.object(paperless, "request_json", return_value=data["documents"]):
result = await paperless.documents.list()

assert isinstance(result, list)
Expand All @@ -34,14 +34,14 @@ async def test_list_and_get(paperless: Paperless, data):

async def test_iterate(paperless: Paperless, data):
"""Test iterate."""
with patch.object(paperless, "request", return_value=data["documents"]):
with patch.object(paperless, "request_json", return_value=data["documents"]):
async for item in paperless.documents.iterate():
assert isinstance(item, Document)


async def test_one(paperless: Paperless, data):
"""Test one."""
with patch.object(paperless, "request", return_value=data["documents"]["results"][0]):
with patch.object(paperless, "request_json", return_value=data["documents"]["results"][0]):
item = await paperless.documents.one(72)

assert isinstance(item, Document)
Expand Down
Loading

0 comments on commit 8217817

Please sign in to comment.