Skip to content

Commit

Permalink
Add type annotations (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
jarekwg authored Nov 25, 2024
1 parent 16bb749 commit 21b941c
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 66 deletions.
12 changes: 9 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ packages = ["src/myob"]
[tool.ruff]
target-version = "py312"
line-length = 100
exclude = [
"build",
"dist",
"tests"
]

[tool.ruff.lint]
select = [
Expand All @@ -49,17 +54,18 @@ select = [
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
# "ANN", # flake8-annotations
"ANN", # flake8-annotations
"B", # flake8-bugbear
"S", # flake8-bandit
"T10", # debugger
"TID", # flake8-tidy-imports
]
ignore = [
"E501"
"E501",
"ANN401",
]

[tool.ruff.lint.isort]
extra-standard-library = [
"requests",
]
]
30 changes: 16 additions & 14 deletions src/myob/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from .credentials import PartnerCredentials
from .endpoints import ALL, ENDPOINTS, GET
from .managers import Manager
Expand All @@ -6,7 +8,7 @@
class Myob:
"""An ORM-like interface to the MYOB API."""

def __init__(self, credentials):
def __init__(self, credentials: PartnerCredentials) -> None:
if not isinstance(credentials, PartnerCredentials):
raise TypeError(f"Expected a Credentials instance, got {type(credentials).__name__}.")
self.credentials = credentials
Expand All @@ -23,16 +25,16 @@ def __init__(self, credentials):
],
)

def info(self):
return self._manager.info()
def info(self) -> str:
return self._manager.info() # type: ignore[attr-defined]

def __repr__(self):
def __repr__(self) -> str:
options = "\n ".join(["companyfiles", "info"])
return f"Myob:\n {options}"


class CompanyFiles:
def __init__(self, credentials):
def __init__(self, credentials: PartnerCredentials) -> None:
self.credentials = credentials
self._manager = Manager(
"",
Expand All @@ -44,42 +46,42 @@ def __init__(self, credentials):
)
self._manager.name = "CompanyFile"

def all(self):
raw_companyfiles = self._manager.all()
def all(self) -> list["CompanyFile"]:
raw_companyfiles = self._manager.all() # type: ignore[attr-defined]
return [
CompanyFile(raw_companyfile, self.credentials) for raw_companyfile in raw_companyfiles
]

def get(self, id, call=True):
def get(self, id: str, call: bool = True) -> "CompanyFile":
if call:
# raw_companyfile = self._manager.get(id=id)['CompanyFile']
# NOTE: Annoyingly, we need to pass company_id to the manager, else we won't have permission
# on the GET endpoint. The only way we currently allow passing company_id is by setting it on the manager,
# and we can't do that on init, as this is a manager for company files plural..
# Reluctant to change manager code, as it would add confusion if the inner method let you override the company_id.
manager = Manager("", self.credentials, raw_endpoints=[(GET, "", "")], company_id=id)
raw_companyfile = manager.get()["CompanyFile"]
raw_companyfile = manager.get()["CompanyFile"] # type: ignore[attr-defined]
else:
raw_companyfile = {"Id": id}
return CompanyFile(raw_companyfile, self.credentials)

def __repr__(self):
def __repr__(self) -> str:
return self._manager.__repr__()


class CompanyFile:
def __init__(self, raw, credentials):
def __init__(self, raw: dict[str, Any], credentials: PartnerCredentials) -> None:
self.id = raw["Id"]
self.name = raw.get("Name")
self.data = raw # Dump remaining raw data here.
self.credentials = credentials
for k, v in ENDPOINTS.items():
setattr(
self,
v["name"],
v["name"], # type: ignore[arg-type]
Manager(k, credentials, endpoints=v["methods"], company_id=self.id),
)

def __repr__(self):
options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values()))
def __repr__(self) -> str:
options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values())) # type: ignore[misc]
return f"CompanyFile:\n {options}"
45 changes: 23 additions & 22 deletions src/myob/credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import datetime
from datetime import datetime, timedelta
from typing import Any

from requests_oauthlib import OAuth2Session

Expand All @@ -11,17 +12,17 @@ class PartnerCredentials:

def __init__(
self,
consumer_key,
consumer_secret,
callback_uri,
verified=False,
companyfile_credentials={}, # noqa: B006
oauth_token=None,
refresh_token=None,
oauth_expires_at=None,
scope=None,
state=None,
):
consumer_key: str,
consumer_secret: str,
callback_uri: str,
verified: bool = False,
companyfile_credentials: dict[str, str] = {}, # noqa: B006
oauth_token: str | None = None,
refresh_token: str | None = None,
oauth_expires_at: datetime | None = None,
scope: None = None, # TODO: Review if used.
state: str | None = None,
) -> None:
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.callback_uri = callback_uri
Expand All @@ -32,7 +33,7 @@ def __init__(
self.refresh_token = refresh_token

if oauth_expires_at is not None:
if not isinstance(oauth_expires_at, datetime.datetime):
if not isinstance(oauth_expires_at, datetime):
raise ValueError("'oauth_expires_at' must be a datetime instance.")
self.oauth_expires_at = oauth_expires_at

Expand All @@ -42,13 +43,13 @@ def __init__(

# TODO: Add `verify` kwarg here, which will quickly throw the provided credentials at a
# protected endpoint to ensure they are valid. If not, raise appropriate error.
def authenticate_companyfile(self, company_id, username, password):
def authenticate_companyfile(self, company_id: str, username: str, password: str) -> None:
"""Store hashed username-password for logging into company file."""
userpass = base64.b64encode(bytes(f"{username}:{password}", "utf-8")).decode("utf-8")
self.companyfile_credentials[company_id] = userpass

@property
def state(self):
def state(self) -> dict[str, Any]:
"""Get a representation of this credentials object from which it can be reconstructed."""
return {
attr: getattr(self, attr)
Expand All @@ -65,7 +66,7 @@ def state(self):
if getattr(self, attr) is not None
}

def expired(self, now=None):
def expired(self, now: datetime | None = None) -> bool:
"""Determine whether the current access token has expired."""
# Expiry might be unset if the user hasn't finished authenticating.
if self.oauth_expires_at is None:
Expand All @@ -76,10 +77,10 @@ def expired(self, now=None):
# they can use self.oauth_expires_at
CONSERVATIVE_SECONDS = 30 # noqa: N806

now = now or datetime.datetime.now()
return self.oauth_expires_at <= (now + datetime.timedelta(seconds=CONSERVATIVE_SECONDS))
now = now or datetime.now()
return self.oauth_expires_at <= (now + timedelta(seconds=CONSERVATIVE_SECONDS))

def verify(self, code):
def verify(self, code: str) -> None:
"""Verify an OAuth session, retrieving an access token."""
token = self._oauth.fetch_token(
MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL,
Expand All @@ -89,7 +90,7 @@ def verify(self, code):
)
self.save_token(token)

def refresh(self):
def refresh(self) -> None:
"""Refresh an expired token."""
token = self._oauth.refresh_token(
MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL,
Expand All @@ -99,9 +100,9 @@ def refresh(self):
)
self.save_token(token)

def save_token(self, token):
def save_token(self, token: dict) -> None:
self.oauth_token = token.get("access_token")
self.refresh_token = token.get("refresh_token")

self.oauth_expires_at = datetime.datetime.fromtimestamp(token.get("expires_at"))
self.oauth_expires_at = datetime.fromtimestamp(token.get("expires_at")) # type: ignore[arg-type]
self.verified = True
13 changes: 7 additions & 6 deletions src/myob/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from .types import Method
from .utils import pluralise

ALL = "ALL"
GET = "GET" # this method expects a UID as a keyword
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
ALL: Method = "ALL"
GET: Method = "GET" # this method expects a UID as a keyword
POST: Method = "POST"
PUT: Method = "PUT"
DELETE: Method = "DELETE"
CRUD = "CRUD" # shorthand for creating the ALL|GET|POST|PUT|DELETE endpoints in one swoop

METHOD_ORDER = [ALL, GET, POST, PUT, DELETE]
METHOD_ORDER: list[Method] = [ALL, GET, POST, PUT, DELETE]

ENDPOINTS = {
"Banking/": {
Expand Down
5 changes: 4 additions & 1 deletion src/myob/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from requests import Response


class MyobException(Exception): # noqa: N818
def __init__(self, response, msg=None):
def __init__(self, response: Response, msg: str | None = None) -> None:
self.response = response
try:
self.errors = response.json()["Errors"]
Expand Down
48 changes: 29 additions & 19 deletions src/myob/managers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import re
import requests
from datetime import date
from typing import Any

from .constants import DEFAULT_PAGE_SIZE, MYOB_BASE_URL
from .endpoints import CRUD, METHOD_MAPPING, METHOD_ORDER
from .credentials import PartnerCredentials
from .endpoints import ALL, CRUD, GET, METHOD_MAPPING, METHOD_ORDER, POST, PUT, Method
from .exceptions import (
MyobBadRequest,
MyobConflict,
Expand All @@ -15,18 +17,26 @@
MyobRateLimitExceeded,
MyobUnauthorized,
)
from .types import MethodDetails


class Manager:
def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoints=[]): # noqa: B006
def __init__(
self,
name: str,
credentials: PartnerCredentials,
company_id: str | None = None,
endpoints: list = [], # noqa: B006
raw_endpoints: list = [], # noqa: B006
) -> None:
self.credentials = credentials
self.name = "_".join(p for p in name.rstrip("/").split("/") if "[" not in p)
self.base_url = MYOB_BASE_URL
if company_id is not None:
self.base_url += company_id + "/"
if name:
self.base_url += name
self.method_details = {}
self.method_details: dict[str, MethodDetails] = {}
self.company_id = company_id

# Build ORM methods from given url endpoints.
Expand All @@ -48,16 +58,16 @@ def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoin
for method, endpoint, hint in raw_endpoints:
self.build_method(method, endpoint, hint)

def build_method(self, method, endpoint, hint):
def build_method(self, method: Method, endpoint: str, hint: str) -> None:
full_endpoint = self.base_url + endpoint
url_keys = re.findall(r"\[([^\]]*)\]", full_endpoint)
template = full_endpoint.replace("[", "{").replace("]", "}")

required_kwargs = url_keys.copy()
if method in ("PUT", "POST"):
if method in (PUT, POST):
required_kwargs.append("data")

def inner(*args, timeout=None, **kwargs):
def inner(*args: Any, timeout: int | None = None, **kwargs: Any) -> str | dict:
if args:
raise AttributeError("Unnamed args provided. Only keyword args accepted.")

Expand All @@ -78,7 +88,7 @@ def inner(*args, timeout=None, **kwargs):
request_kwargs_raw[k] = v

# Determine request method.
request_method = "GET" if method == "ALL" else method
request_method = GET if method == ALL else method

# Build url.
url = template.format(**url_kwargs)
Expand Down Expand Up @@ -130,13 +140,13 @@ def inner(*args, timeout=None, **kwargs):
# If it already exists, prepend with method to disambiguate.
elif hasattr(self, method_name):
method_name = f"{method.lower()}_{method_name}"
self.method_details[method_name] = {
"kwargs": required_kwargs,
"hint": hint,
}
self.method_details[method_name] = MethodDetails(
kwargs=required_kwargs,
hint=hint,
)
setattr(self, method_name, inner)

def build_request_kwargs(self, method, data=None, **kwargs):
def build_request_kwargs(self, method: Method, data: dict | None = None, **kwargs: Any) -> dict:
request_kwargs = {}

# Build headers.
Expand Down Expand Up @@ -166,7 +176,7 @@ def build_request_kwargs(self, method, data=None, **kwargs):
request_kwargs["params"] = {}
filters = []

def build_value(value):
def build_value(value: Any) -> str:
if issubclass(type(value), date):
return f"datetime'{value}'"
if isinstance(value, bool):
Expand Down Expand Up @@ -205,10 +215,10 @@ def build_value(value):
page_size = DEFAULT_PAGE_SIZE
if "limit" in kwargs:
page_size = int(kwargs["limit"])
request_kwargs["params"]["$top"] = page_size
request_kwargs["params"]["$top"] = page_size # type: ignore[assignment]

if "page" in kwargs:
request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size
request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size # type: ignore[assignment]

if "format" in kwargs:
request_kwargs["params"]["format"] = kwargs["format"]
Expand All @@ -225,16 +235,16 @@ def build_value(value):

return request_kwargs

def __repr__(self):
def _get_signature(name, kwargs):
def __repr__(self) -> str:
def _get_signature(name: str, kwargs: list[str]) -> str:
return f"{name}({', '.join(kwargs)})"

def print_method(name, kwargs, hint, offset):
def _print_method(name: str, kwargs: list[str], hint: str, offset: int) -> str:
return f"{_get_signature(name, kwargs):>{offset}} - {hint}"

offset = max(len(_get_signature(k, v["kwargs"])) for k, v in self.method_details.items())
options = "\n ".join(
print_method(k, v["kwargs"], v["hint"], offset)
_print_method(k, v["kwargs"], v["hint"], offset)
for k, v in sorted(self.method_details.items())
)
return f"{self.name}{self.__class__.__name__}:\n {options}"
Loading

0 comments on commit 21b941c

Please sign in to comment.