Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor CozePKCEAuthError to use CozePKCEAuthErrorType Enum #58

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cozepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)
from .conversations import Conversation
from .coze import AsyncCoze, Coze
from .exception import CozeAPIError, CozeError, CozeEventError, CozePKCEAuthError
from .exception import CozeAPIError, CozeError, CozeEventError, CozePKCEAuthError, CozePKCEAuthErrorType
from .files import File
from .knowledge.documents import (
Document,
Expand Down Expand Up @@ -158,6 +158,7 @@
"CozeAPIError",
"CozeEventError",
"CozePKCEAuthError",
"CozePKCEAuthErrorType",
# model
"AsyncStream",
"LastIDPaged",
Expand Down
10 changes: 5 additions & 5 deletions cozepy/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import Literal

from cozepy.config import COZE_CN_BASE_URL, COZE_COM_BASE_URL
from cozepy.exception import CozePKCEAuthError
from cozepy.exception import CozePKCEAuthError, CozePKCEAuthErrorType
from cozepy.model import CozeModel
from cozepy.request import Requester
from cozepy.util import gen_s256_code_challenge, random_hex
Expand Down Expand Up @@ -483,10 +483,10 @@
try:
return self._get_access_token(device_code)
except CozePKCEAuthError as e:
if e.error == "authorization_pending":
if e.error == CozePKCEAuthErrorType.AUTHORIZATION_PENDING:

Check warning on line 486 in cozepy/auth/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/auth/__init__.py#L486

Added line #L486 was not covered by tests
time.sleep(interval)
continue
elif e.error == "slow_down":
elif e.error == CozePKCEAuthErrorType.SLOW_DOWN:

Check warning on line 489 in cozepy/auth/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/auth/__init__.py#L489

Added line #L489 was not covered by tests
if interval < 30:
interval += 5
time.sleep(interval)
Expand Down Expand Up @@ -565,10 +565,10 @@
try:
return await self._get_access_token(device_code)
except CozePKCEAuthError as e:
if e.error == "authorization_pending":
if e.error == CozePKCEAuthErrorType.AUTHORIZATION_PENDING:

Check warning on line 568 in cozepy/auth/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/auth/__init__.py#L568

Added line #L568 was not covered by tests
time.sleep(interval)
continue
elif e.error == "slow_down":
elif e.error == CozePKCEAuthErrorType.SLOW_DOWN:

Check warning on line 571 in cozepy/auth/__init__.py

View check run for this annotation

Codecov / codecov/patch

cozepy/auth/__init__.py#L571

Added line #L571 was not covered by tests
if interval < 30:
interval += 5
time.sleep(interval)
Expand Down
18 changes: 13 additions & 5 deletions cozepy/exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing_extensions import Literal
from enum import Enum


class CozeError(Exception):
Expand All @@ -24,15 +24,23 @@ def __init__(self, code: int = None, msg: str = "", logid: str = None):
super().__init__(f"msg: {msg}, logid: {logid}")


class CozePKCEAuthErrorType(str, Enum):
AUTHORIZATION_PENDING = "authorization_pending"
SLOW_DOWN = "slow_down"
ACCESS_DENIED = "access_denied"
EXPIRED_TOKEN = "expired_token"


COZE_PKCE_AUTH_ERROR_TYPE_ENUMS = set(e.value for e in CozePKCEAuthErrorType)


class CozePKCEAuthError(CozeError):
"""
base class for all pkce auth errors
"""

def __init__(
self, error: Literal["authorization_pending", "slow_down", "access_denied", "expired_token"], logid: str = None
):
super().__init__(f"pkce auth error: {error}")
def __init__(self, error: CozePKCEAuthErrorType, logid: str = None):
super().__init__(f"pkce auth error: {error.value}")
self.error = error
self.logid = logid

Expand Down
13 changes: 4 additions & 9 deletions cozepy/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing_extensions import get_args, get_origin

from cozepy.config import DEFAULT_CONNECTION_LIMITS, DEFAULT_TIMEOUT
from cozepy.exception import CozeAPIError, CozePKCEAuthError
from cozepy.exception import COZE_PKCE_AUTH_ERROR_TYPE_ENUMS, CozeAPIError, CozePKCEAuthError, CozePKCEAuthErrorType
from cozepy.log import log_debug, log_warning
from cozepy.version import user_agent

Expand Down Expand Up @@ -174,8 +174,8 @@ def _parse_response(
raise CozeAPIError(code, msg, logid)
elif code is None and msg != "":
log_warning("request %s#%s failed, logid=%s, msg=%s", method, url, logid, msg)
if msg in ["authorization_pending", "slow_down", "access_denied", "expired_token"]:
raise CozePKCEAuthError(msg, logid)
if msg in COZE_PKCE_AUTH_ERROR_TYPE_ENUMS:
raise CozePKCEAuthError(CozePKCEAuthErrorType(msg), logid)
raise CozeAPIError(code, msg, logid)
if get_origin(model) is list:
item_model = get_args(model)[0]
Expand All @@ -201,12 +201,7 @@ def _parse_requests_code_msg(

if "code" in body and "msg" in body and int(body["code"]) > 0:
return int(body["code"]), body["msg"], body.get(data_field)
if "error_code" in body and body["error_code"] in [
"authorization_pending",
"slow_down",
"access_denied",
"expired_token",
]:
if "error_code" in body and body["error_code"] in COZE_PKCE_AUTH_ERROR_TYPE_ENUMS:
return None, body["error_code"], None
if "error_message" in body and body["error_message"] != "":
return None, body["error_message"], None
Expand Down
5 changes: 3 additions & 2 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AsyncJWTOAuthApp,
AsyncPKCEOAuthApp,
AsyncWebOAuthApp,
CozePKCEAuthErrorType,
DeviceAuthCode,
DeviceOAuthApp,
JWTAuth,
Expand Down Expand Up @@ -377,7 +378,7 @@ def test_get_access_token_poll(self, respx_mock):
mock_token = random_hex(20)

respx_mock.post("/api/permission/oauth2/token").mock(
httpx.Response(200, json={"error_code": "authorization_pending"})
httpx.Response(200, json={"error_code": CozePKCEAuthErrorType.AUTHORIZATION_PENDING})
).mock(
httpx.Response(
200, content=OAuthToken(access_token=mock_token, expires_in=int(time.time())).model_dump_json()
Expand Down Expand Up @@ -462,7 +463,7 @@ async def test_get_access_token_poll(self, respx_mock):
mock_token = random_hex(20)

respx_mock.post("/api/permission/oauth2/token").mock(
httpx.Response(200, json={"error_code": "authorization_pending"})
httpx.Response(200, json={"error_code": CozePKCEAuthErrorType.AUTHORIZATION_PENDING})
).mock(
httpx.Response(
200, content=OAuthToken(access_token=mock_token, expires_in=int(time.time())).model_dump_json()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cozepy import CozeAPIError, CozeEventError, CozePKCEAuthError
from cozepy import CozeAPIError, CozeEventError, CozePKCEAuthError, CozePKCEAuthErrorType


def test_coze_error():
Expand All @@ -14,7 +14,7 @@ def test_coze_error():
assert err.logid == "logid"
assert str(err) == "msg: msg, logid: logid"

err = CozePKCEAuthError("authorization_pending")
err = CozePKCEAuthError(CozePKCEAuthErrorType.AUTHORIZATION_PENDING)
assert err.error == "authorization_pending"

err = CozeEventError("event", "xxx", "logid")
Expand Down
Loading