Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proper type annotation #409

Merged
merged 9 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cuenca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
'get_balance',
]

from typing import cast

from . import http
from .resources import (
Account,
Expand Down Expand Up @@ -96,5 +94,5 @@


def get_balance(session: http.Session = session) -> int:
balance_entry = cast('BalanceEntry', BalanceEntry.first(session=session))
balance_entry = BalanceEntry.first(session=session)
return balance_entry.rolling_balance if balance_entry else 0
9 changes: 4 additions & 5 deletions cuenca/resources/api_keys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import ApiKeyQuery, ApiKeyUpdateRequest

Expand Down Expand Up @@ -36,7 +36,7 @@ def active(self) -> bool:

@classmethod
def create(cls, *, session: Session = global_session) -> 'ApiKey':
return cast('ApiKey', cls._create(session=session))
return cls._create(session=session)

@classmethod
def deactivate(
Expand All @@ -55,7 +55,7 @@ def deactivate(
"""
url = cls._resource + f'/{api_key_id}'
resp = session.delete(url, dict(minutes=minutes))
return cast('ApiKey', cls._from_dict(resp))
return cls(**resp)

@classmethod
def update(
Expand All @@ -74,5 +74,4 @@ def update(
req = ApiKeyUpdateRequest(
metadata=metadata, user_id=user_id, platform_id=platform_id
)
resp = cls._update(api_key_id, **req.dict(), session=session)
return cast('ApiKey', resp)
return cls._update(api_key_id, **req.dict(), session=session)
4 changes: 2 additions & 2 deletions cuenca/resources/arpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types.requests import ARPCRequest

Expand Down Expand Up @@ -52,4 +52,4 @@ def create(
unique_number=unique_number,
track_data_method=track_data_method,
)
return cast('Arpc', cls._create(session=session, **req.dict()))
return cls._create(session=session, **req.dict())
117 changes: 66 additions & 51 deletions cuenca/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime as dt
import json
from io import BytesIO
from typing import ClassVar, Dict, Generator, Optional, Union
from typing import Any, ClassVar, Generator, Optional, Type, TypeVar, cast
from urllib.parse import urlencode

from cuenca_validations.types import (
Expand All @@ -12,34 +12,21 @@
TransactionQuery,
TransactionStatus,
)
from pydantic import BaseModel
from pydantic import BaseModel, Extra

from ..exc import MultipleResultsFound, NoResultFound
from ..http import Session, session as global_session

R_co = TypeVar('R_co', bound='Resource', covariant=True)


class Resource(BaseModel):
_resource: ClassVar[str]

id: str

@classmethod
def _from_dict(cls, obj_dict: Dict[str, Union[str, int]]) -> 'Resource':
cls._filter_excess_fields(obj_dict)
return cls(**obj_dict)

@classmethod
def _filter_excess_fields(cls, obj_dict):
"""
dataclasses don't allow __init__ to be called with excess fields. This
method allows the API to add fields in the response body without
breaking the client
"""
excess = set(obj_dict.keys()) - set(
cls.schema().get("properties").keys()
)
for f in excess:
del obj_dict[f]
class Config:
extra = Extra.ignore

def to_dict(self):
return SantizedDict(self.dict())
Expand All @@ -48,22 +35,30 @@ def to_dict(self):
class Retrievable(Resource):
@classmethod
def retrieve(
cls, id: str, *, session: Session = global_session
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
) -> R_co:
resp = session.get(f'/{cls._resource}/{id}')
return cls._from_dict(resp)
return cls(**resp)

def refresh(self, *, session: Session = global_session):
def refresh(self, *, session: Session = global_session) -> None:
new = self.retrieve(self.id, session=session)
for attr, value in new.__dict__.items():
setattr(self, attr, value)


class Creatable(Resource):
@classmethod
def _create(cls, *, session: Session = global_session, **data) -> Resource:
def _create(
cls: Type[R_co],
*,
session: Session = global_session,
**data: Any,
) -> R_co:
resp = session.post(cls._resource, data)
return cls._from_dict(resp)
return cls(**resp)


class Updateable(Resource):
Expand All @@ -72,31 +67,39 @@ class Updateable(Resource):

@classmethod
def _update(
cls, id: str, *, session: Session = global_session, **data
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
**data: Any,
) -> R_co:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify a more precise type for **data in _update method

**data: Any uses Any, which is disallowed (ANN401). Please provide a more specific type for data to improve type safety and comply with code standards.

resp = session.patch(f'/{cls._resource}/{id}', data)
return cls._from_dict(resp)
return cls(**resp)


class Deactivable(Resource):
deactivated_at: Optional[dt.datetime]

@classmethod
def deactivate(
cls, id: str, *, session: Session = global_session, **data
) -> Resource:
cls: Type[R_co],
id: str,
*,
session: Session = global_session,
**data: Any,
) -> R_co:
resp = session.delete(f'/{cls._resource}/{id}', data)
return cls._from_dict(resp)
return cls(**resp)

@property
def is_active(self):
def is_active(self) -> bool:
return not self.deactivated_at


class Downloadable(Resource):
@classmethod
def download(
cls,
cls: Type[R_co],
id: str,
file_format: FileFormat = FileFormat.any,
*,
Expand All @@ -121,13 +124,13 @@ def xml(self) -> bytes:
class Uploadable(Resource):
@classmethod
def _upload(
cls,
cls: Type[R_co],
file: bytes,
user_id: str,
*,
session: Session = global_session,
**data,
) -> Resource:
**data: Any,
) -> R_co:
encoded_file = base64.b64encode(file)
resp = session.request(
'post',
Expand All @@ -138,7 +141,7 @@ def _upload(
**{k: (None, v) for k, v in data.items()},
),
)
return cls._from_dict(json.loads(resp))
return cls(**json.loads(resp))


class Queryable(Resource):
Expand All @@ -148,50 +151,62 @@ class Queryable(Resource):

@classmethod
def one(
cls, *, session: Session = global_session, **query_params
) -> Resource:
q = cls._query_params(limit=2, **query_params)
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> R_co:
q = cast(Queryable, cls)._query_params(limit=2, **query_params)
resp = session.get(cls._resource, q.dict())
items = resp['items']
len_items = len(items)
if not len_items:
raise NoResultFound
if len_items > 1:
raise MultipleResultsFound
return cls._from_dict(items[0])
return cls(**items[0])

@classmethod
def first(
cls, *, session: Session = global_session, **query_params
) -> Optional[Resource]:
q = cls._query_params(limit=1, **query_params)
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> Optional[R_co]:
q = cast(Queryable, cls)._query_params(limit=1, **query_params)
resp = session.get(cls._resource, q.dict())
try:
item = resp['items'][0]
except IndexError:
rv = None
else:
rv = cls._from_dict(item)
rv = cls(**item)
return rv

@classmethod
def count(
cls, *, session: Session = global_session, **query_params
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> int:
q = cls._query_params(count=True, **query_params)
q = cast(Queryable, cls)._query_params(count=True, **query_params)
resp = session.get(cls._resource, q.dict())
return resp['count']

@classmethod
def all(
cls, *, session: Session = global_session, **query_params
) -> Generator[Resource, None, None]:
cls: Type[R_co],
*,
session: Session = global_session,
**query_params: Any,
) -> Generator[R_co, None, None]:
session = session or global_session
q = cls._query_params(**query_params)
q = cast(Queryable, cls)._query_params(**query_params)
next_page_uri = f'{cls._resource}?{urlencode(q.dict())}'
while next_page_uri:
page = session.get(next_page_uri)
yield from (cls._from_dict(item) for item in page['items'])
yield from (cls(**item) for item in page['items'])
next_page_uri = page['next_page_uri']


Expand Down
4 changes: 1 addition & 3 deletions cuenca/resources/card_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def create(
exp_year=exp_year,
cvv2=cvv2,
)
return cast(
'CardActivation', cls._create(session=session, **req.dict())
)
return cls._create(session=session, **req.dict())

@property
def card(self) -> Optional[Card]:
Expand Down
4 changes: 1 addition & 3 deletions cuenca/resources/card_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def create(
pin_block=pin_block,
pin_attempts_exceeded=pin_attempts_exceeded,
)
return cast(
'CardValidation', cls._create(session=session, **req.dict())
)
return cls._create(session=session, **req.dict())

@property
def card(self) -> Card:
Expand Down
9 changes: 4 additions & 5 deletions cuenca/resources/cards.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import (
CardFundingType,
Expand Down Expand Up @@ -81,7 +81,7 @@ def create(
card_holder_user_id=card_holder_user_id,
is_dynamic_cvv=is_dynamic_cvv,
)
return cast('Card', cls._create(session=session, **req.dict()))
return cls._create(session=session, **req.dict())

@classmethod
def update(
Expand All @@ -106,8 +106,7 @@ def update(
req = CardUpdateRequest(
status=status, pin_block=pin_block, is_dynamic_cvv=is_dynamic_cvv
)
resp = cls._update(card_id, session=session, **req.dict())
return cast('Card', resp)
return cls._update(card_id, session=session, **req.dict())

@classmethod
def deactivate(
Expand All @@ -118,4 +117,4 @@ def deactivate(
"""
url = f'{cls._resource}/{card_id}'
resp = session.delete(url)
return cast('Card', cls._from_dict(resp))
return cls(**resp)
4 changes: 2 additions & 2 deletions cuenca/resources/clabes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, cast
from typing import ClassVar

from ..http import Session, session as global_session
from .base import Creatable, Queryable, Retrievable
Expand All @@ -11,4 +11,4 @@ class Clabe(Creatable, Queryable, Retrievable):

@classmethod
def create(cls, session: Session = global_session):
return cast('Clabe', cls._create(session=session))
return cls._create(session=session)
7 changes: 2 additions & 5 deletions cuenca/resources/curp_validations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime as dt
from typing import ClassVar, Optional, cast
from typing import ClassVar, Optional

from cuenca_validations.types import (
Country,
Expand Down Expand Up @@ -98,7 +98,4 @@ def create(
gender=gender,
manual_curp=manual_curp,
)
return cast(
'CurpValidation',
cls._create(session=session, **req.dict()),
)
return cls._create(session=session, **req.dict())
Loading
Loading