From 21b941c46cb7c13258777b5dd677465824f556f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jarek=20G=C5=82owacki?= Date: Mon, 25 Nov 2024 13:44:03 +1100 Subject: [PATCH] Add type annotations (#107) --- pyproject.toml | 12 ++++++++--- src/myob/api.py | 30 ++++++++++++++------------ src/myob/credentials.py | 45 +++++++++++++++++++------------------- src/myob/endpoints.py | 13 +++++------ src/myob/exceptions.py | 5 ++++- src/myob/managers.py | 48 +++++++++++++++++++++++++---------------- src/myob/types.py | 9 ++++++++ src/myob/utils.py | 2 +- 8 files changed, 98 insertions(+), 66 deletions(-) create mode 100644 src/myob/types.py diff --git a/pyproject.toml b/pyproject.toml index a3e7bb9..0a76c95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,11 @@ packages = ["src/myob"] [tool.ruff] target-version = "py312" line-length = 100 +exclude = [ + "build", + "dist", + "tests" +] [tool.ruff.lint] select = [ @@ -49,17 +54,18 @@ select = [ "I", # isort "N", # pep8-naming "UP", # pyupgrade - # "ANN", # flake8-annotations + "ANN", # flake8-annotations "B", # flake8-bugbear "S", # flake8-bandit "T10", # debugger "TID", # flake8-tidy-imports ] ignore = [ - "E501" + "E501", + "ANN401", ] [tool.ruff.lint.isort] extra-standard-library = [ "requests", -] \ No newline at end of file +] diff --git a/src/myob/api.py b/src/myob/api.py index d8cc9a6..6acf520 100755 --- a/src/myob/api.py +++ b/src/myob/api.py @@ -1,3 +1,5 @@ +from typing import Any + from .credentials import PartnerCredentials from .endpoints import ALL, ENDPOINTS, GET from .managers import Manager @@ -6,7 +8,7 @@ class Myob: """An ORM-like interface to the MYOB API.""" - def __init__(self, credentials): + def __init__(self, credentials: PartnerCredentials) -> None: if not isinstance(credentials, PartnerCredentials): raise TypeError(f"Expected a Credentials instance, got {type(credentials).__name__}.") self.credentials = credentials @@ -23,16 +25,16 @@ def __init__(self, credentials): ], ) - def info(self): - return self._manager.info() + def info(self) -> str: + return self._manager.info() # type: ignore[attr-defined] - def __repr__(self): + def __repr__(self) -> str: options = "\n ".join(["companyfiles", "info"]) return f"Myob:\n {options}" class CompanyFiles: - def __init__(self, credentials): + def __init__(self, credentials: PartnerCredentials) -> None: self.credentials = credentials self._manager = Manager( "", @@ -44,13 +46,13 @@ def __init__(self, credentials): ) self._manager.name = "CompanyFile" - def all(self): - raw_companyfiles = self._manager.all() + def all(self) -> list["CompanyFile"]: + raw_companyfiles = self._manager.all() # type: ignore[attr-defined] return [ CompanyFile(raw_companyfile, self.credentials) for raw_companyfile in raw_companyfiles ] - def get(self, id, call=True): + def get(self, id: str, call: bool = True) -> "CompanyFile": if call: # raw_companyfile = self._manager.get(id=id)['CompanyFile'] # NOTE: Annoyingly, we need to pass company_id to the manager, else we won't have permission @@ -58,17 +60,17 @@ def get(self, id, call=True): # and we can't do that on init, as this is a manager for company files plural.. # Reluctant to change manager code, as it would add confusion if the inner method let you override the company_id. manager = Manager("", self.credentials, raw_endpoints=[(GET, "", "")], company_id=id) - raw_companyfile = manager.get()["CompanyFile"] + raw_companyfile = manager.get()["CompanyFile"] # type: ignore[attr-defined] else: raw_companyfile = {"Id": id} return CompanyFile(raw_companyfile, self.credentials) - def __repr__(self): + def __repr__(self) -> str: return self._manager.__repr__() class CompanyFile: - def __init__(self, raw, credentials): + def __init__(self, raw: dict[str, Any], credentials: PartnerCredentials) -> None: self.id = raw["Id"] self.name = raw.get("Name") self.data = raw # Dump remaining raw data here. @@ -76,10 +78,10 @@ def __init__(self, raw, credentials): for k, v in ENDPOINTS.items(): setattr( self, - v["name"], + v["name"], # type: ignore[arg-type] Manager(k, credentials, endpoints=v["methods"], company_id=self.id), ) - def __repr__(self): - options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values())) + def __repr__(self) -> str: + options = "\n ".join(sorted(v["name"] for v in ENDPOINTS.values())) # type: ignore[misc] return f"CompanyFile:\n {options}" diff --git a/src/myob/credentials.py b/src/myob/credentials.py index 8f96534..02cfbe5 100755 --- a/src/myob/credentials.py +++ b/src/myob/credentials.py @@ -1,5 +1,6 @@ import base64 -import datetime +from datetime import datetime, timedelta +from typing import Any from requests_oauthlib import OAuth2Session @@ -11,17 +12,17 @@ class PartnerCredentials: def __init__( self, - consumer_key, - consumer_secret, - callback_uri, - verified=False, - companyfile_credentials={}, # noqa: B006 - oauth_token=None, - refresh_token=None, - oauth_expires_at=None, - scope=None, - state=None, - ): + consumer_key: str, + consumer_secret: str, + callback_uri: str, + verified: bool = False, + companyfile_credentials: dict[str, str] = {}, # noqa: B006 + oauth_token: str | None = None, + refresh_token: str | None = None, + oauth_expires_at: datetime | None = None, + scope: None = None, # TODO: Review if used. + state: str | None = None, + ) -> None: self.consumer_key = consumer_key self.consumer_secret = consumer_secret self.callback_uri = callback_uri @@ -32,7 +33,7 @@ def __init__( self.refresh_token = refresh_token if oauth_expires_at is not None: - if not isinstance(oauth_expires_at, datetime.datetime): + if not isinstance(oauth_expires_at, datetime): raise ValueError("'oauth_expires_at' must be a datetime instance.") self.oauth_expires_at = oauth_expires_at @@ -42,13 +43,13 @@ def __init__( # TODO: Add `verify` kwarg here, which will quickly throw the provided credentials at a # protected endpoint to ensure they are valid. If not, raise appropriate error. - def authenticate_companyfile(self, company_id, username, password): + def authenticate_companyfile(self, company_id: str, username: str, password: str) -> None: """Store hashed username-password for logging into company file.""" userpass = base64.b64encode(bytes(f"{username}:{password}", "utf-8")).decode("utf-8") self.companyfile_credentials[company_id] = userpass @property - def state(self): + def state(self) -> dict[str, Any]: """Get a representation of this credentials object from which it can be reconstructed.""" return { attr: getattr(self, attr) @@ -65,7 +66,7 @@ def state(self): if getattr(self, attr) is not None } - def expired(self, now=None): + def expired(self, now: datetime | None = None) -> bool: """Determine whether the current access token has expired.""" # Expiry might be unset if the user hasn't finished authenticating. if self.oauth_expires_at is None: @@ -76,10 +77,10 @@ def expired(self, now=None): # they can use self.oauth_expires_at CONSERVATIVE_SECONDS = 30 # noqa: N806 - now = now or datetime.datetime.now() - return self.oauth_expires_at <= (now + datetime.timedelta(seconds=CONSERVATIVE_SECONDS)) + now = now or datetime.now() + return self.oauth_expires_at <= (now + timedelta(seconds=CONSERVATIVE_SECONDS)) - def verify(self, code): + def verify(self, code: str) -> None: """Verify an OAuth session, retrieving an access token.""" token = self._oauth.fetch_token( MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL, @@ -89,7 +90,7 @@ def verify(self, code): ) self.save_token(token) - def refresh(self): + def refresh(self) -> None: """Refresh an expired token.""" token = self._oauth.refresh_token( MYOB_PARTNER_BASE_URL + ACCESS_TOKEN_URL, @@ -99,9 +100,9 @@ def refresh(self): ) self.save_token(token) - def save_token(self, token): + def save_token(self, token: dict) -> None: self.oauth_token = token.get("access_token") self.refresh_token = token.get("refresh_token") - self.oauth_expires_at = datetime.datetime.fromtimestamp(token.get("expires_at")) + self.oauth_expires_at = datetime.fromtimestamp(token.get("expires_at")) # type: ignore[arg-type] self.verified = True diff --git a/src/myob/endpoints.py b/src/myob/endpoints.py index d8e4c0b..19f8321 100755 --- a/src/myob/endpoints.py +++ b/src/myob/endpoints.py @@ -1,13 +1,14 @@ +from .types import Method from .utils import pluralise -ALL = "ALL" -GET = "GET" # this method expects a UID as a keyword -POST = "POST" -PUT = "PUT" -DELETE = "DELETE" +ALL: Method = "ALL" +GET: Method = "GET" # this method expects a UID as a keyword +POST: Method = "POST" +PUT: Method = "PUT" +DELETE: Method = "DELETE" CRUD = "CRUD" # shorthand for creating the ALL|GET|POST|PUT|DELETE endpoints in one swoop -METHOD_ORDER = [ALL, GET, POST, PUT, DELETE] +METHOD_ORDER: list[Method] = [ALL, GET, POST, PUT, DELETE] ENDPOINTS = { "Banking/": { diff --git a/src/myob/exceptions.py b/src/myob/exceptions.py index 889021c..c348be2 100644 --- a/src/myob/exceptions.py +++ b/src/myob/exceptions.py @@ -1,5 +1,8 @@ +from requests import Response + + class MyobException(Exception): # noqa: N818 - def __init__(self, response, msg=None): + def __init__(self, response: Response, msg: str | None = None) -> None: self.response = response try: self.errors = response.json()["Errors"] diff --git a/src/myob/managers.py b/src/myob/managers.py index fc7338d..1c5d41d 100755 --- a/src/myob/managers.py +++ b/src/myob/managers.py @@ -1,9 +1,11 @@ import re import requests from datetime import date +from typing import Any from .constants import DEFAULT_PAGE_SIZE, MYOB_BASE_URL -from .endpoints import CRUD, METHOD_MAPPING, METHOD_ORDER +from .credentials import PartnerCredentials +from .endpoints import ALL, CRUD, GET, METHOD_MAPPING, METHOD_ORDER, POST, PUT, Method from .exceptions import ( MyobBadRequest, MyobConflict, @@ -15,10 +17,18 @@ MyobRateLimitExceeded, MyobUnauthorized, ) +from .types import MethodDetails class Manager: - def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoints=[]): # noqa: B006 + def __init__( + self, + name: str, + credentials: PartnerCredentials, + company_id: str | None = None, + endpoints: list = [], # noqa: B006 + raw_endpoints: list = [], # noqa: B006 + ) -> None: self.credentials = credentials self.name = "_".join(p for p in name.rstrip("/").split("/") if "[" not in p) self.base_url = MYOB_BASE_URL @@ -26,7 +36,7 @@ def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoin self.base_url += company_id + "/" if name: self.base_url += name - self.method_details = {} + self.method_details: dict[str, MethodDetails] = {} self.company_id = company_id # Build ORM methods from given url endpoints. @@ -48,16 +58,16 @@ def __init__(self, name, credentials, company_id=None, endpoints=[], raw_endpoin for method, endpoint, hint in raw_endpoints: self.build_method(method, endpoint, hint) - def build_method(self, method, endpoint, hint): + def build_method(self, method: Method, endpoint: str, hint: str) -> None: full_endpoint = self.base_url + endpoint url_keys = re.findall(r"\[([^\]]*)\]", full_endpoint) template = full_endpoint.replace("[", "{").replace("]", "}") required_kwargs = url_keys.copy() - if method in ("PUT", "POST"): + if method in (PUT, POST): required_kwargs.append("data") - def inner(*args, timeout=None, **kwargs): + def inner(*args: Any, timeout: int | None = None, **kwargs: Any) -> str | dict: if args: raise AttributeError("Unnamed args provided. Only keyword args accepted.") @@ -78,7 +88,7 @@ def inner(*args, timeout=None, **kwargs): request_kwargs_raw[k] = v # Determine request method. - request_method = "GET" if method == "ALL" else method + request_method = GET if method == ALL else method # Build url. url = template.format(**url_kwargs) @@ -130,13 +140,13 @@ def inner(*args, timeout=None, **kwargs): # If it already exists, prepend with method to disambiguate. elif hasattr(self, method_name): method_name = f"{method.lower()}_{method_name}" - self.method_details[method_name] = { - "kwargs": required_kwargs, - "hint": hint, - } + self.method_details[method_name] = MethodDetails( + kwargs=required_kwargs, + hint=hint, + ) setattr(self, method_name, inner) - def build_request_kwargs(self, method, data=None, **kwargs): + def build_request_kwargs(self, method: Method, data: dict | None = None, **kwargs: Any) -> dict: request_kwargs = {} # Build headers. @@ -166,7 +176,7 @@ def build_request_kwargs(self, method, data=None, **kwargs): request_kwargs["params"] = {} filters = [] - def build_value(value): + def build_value(value: Any) -> str: if issubclass(type(value), date): return f"datetime'{value}'" if isinstance(value, bool): @@ -205,10 +215,10 @@ def build_value(value): page_size = DEFAULT_PAGE_SIZE if "limit" in kwargs: page_size = int(kwargs["limit"]) - request_kwargs["params"]["$top"] = page_size + request_kwargs["params"]["$top"] = page_size # type: ignore[assignment] if "page" in kwargs: - request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size + request_kwargs["params"]["$skip"] = (int(kwargs["page"]) - 1) * page_size # type: ignore[assignment] if "format" in kwargs: request_kwargs["params"]["format"] = kwargs["format"] @@ -225,16 +235,16 @@ def build_value(value): return request_kwargs - def __repr__(self): - def _get_signature(name, kwargs): + def __repr__(self) -> str: + def _get_signature(name: str, kwargs: list[str]) -> str: return f"{name}({', '.join(kwargs)})" - def print_method(name, kwargs, hint, offset): + def _print_method(name: str, kwargs: list[str], hint: str, offset: int) -> str: return f"{_get_signature(name, kwargs):>{offset}} - {hint}" offset = max(len(_get_signature(k, v["kwargs"])) for k, v in self.method_details.items()) options = "\n ".join( - print_method(k, v["kwargs"], v["hint"], offset) + _print_method(k, v["kwargs"], v["hint"], offset) for k, v in sorted(self.method_details.items()) ) return f"{self.name}{self.__class__.__name__}:\n {options}" diff --git a/src/myob/types.py b/src/myob/types.py new file mode 100644 index 0000000..7d13005 --- /dev/null +++ b/src/myob/types.py @@ -0,0 +1,9 @@ +from typing import Literal, TypedDict + +# TODO: This could probs do better as an enum.. +Method = Literal["ALL", "GET", "POST", "PUT", "DELETE"] + + +class MethodDetails(TypedDict): + kwargs: list[str] + hint: str diff --git a/src/myob/utils.py b/src/myob/utils.py index 1d98f50..b6f1842 100644 --- a/src/myob/utils.py +++ b/src/myob/utils.py @@ -1,4 +1,4 @@ -def pluralise(s): +def pluralise(s: str) -> str: if s.endswith("y"): return s[:-1] + "ies" elif s.endswith("rix"):