diff --git a/cozepy/__init__.py b/cozepy/__init__.py index c05395a..2481abf 100644 --- a/cozepy/__init__.py +++ b/cozepy/__init__.py @@ -45,7 +45,7 @@ ) from .conversations import Conversation from .coze import AsyncCoze, Coze -from .exception import CozeAPIError, CozeError, CozeEventError, CozePKCEAuthError +from .exception import CozeAPIError, CozeError, CozeEventError, CozePKCEAuthError, CozePKCEAuthErrorType from .files import File from .knowledge.documents import ( Document, @@ -158,6 +158,7 @@ "CozeAPIError", "CozeEventError", "CozePKCEAuthError", + "CozePKCEAuthErrorType", # model "AsyncStream", "LastIDPaged", diff --git a/cozepy/auth/__init__.py b/cozepy/auth/__init__.py index e7e8380..2e85861 100644 --- a/cozepy/auth/__init__.py +++ b/cozepy/auth/__init__.py @@ -7,7 +7,7 @@ from typing_extensions import Literal from cozepy.config import COZE_CN_BASE_URL, COZE_COM_BASE_URL -from cozepy.exception import CozePKCEAuthError +from cozepy.exception import CozePKCEAuthError, CozePKCEAuthErrorType from cozepy.model import CozeModel from cozepy.request import Requester from cozepy.util import gen_s256_code_challenge, random_hex @@ -483,10 +483,10 @@ def get_access_token(self, device_code: str, poll: bool = False) -> OAuthToken: try: return self._get_access_token(device_code) except CozePKCEAuthError as e: - if e.error == "authorization_pending": + if e.error == CozePKCEAuthErrorType.AUTHORIZATION_PENDING: time.sleep(interval) continue - elif e.error == "slow_down": + elif e.error == CozePKCEAuthErrorType.SLOW_DOWN: if interval < 30: interval += 5 time.sleep(interval) @@ -565,10 +565,10 @@ async def get_access_token(self, device_code: str, poll: bool = False) -> OAuthT try: return await self._get_access_token(device_code) except CozePKCEAuthError as e: - if e.error == "authorization_pending": + if e.error == CozePKCEAuthErrorType.AUTHORIZATION_PENDING: time.sleep(interval) continue - elif e.error == "slow_down": + elif e.error == CozePKCEAuthErrorType.SLOW_DOWN: if interval < 30: interval += 5 time.sleep(interval) diff --git a/cozepy/exception.py b/cozepy/exception.py index 8324438..6ffe3bc 100644 --- a/cozepy/exception.py +++ b/cozepy/exception.py @@ -1,4 +1,4 @@ -from typing_extensions import Literal +from enum import Enum class CozeError(Exception): @@ -24,15 +24,23 @@ def __init__(self, code: int = None, msg: str = "", logid: str = None): super().__init__(f"msg: {msg}, logid: {logid}") +class CozePKCEAuthErrorType(str, Enum): + AUTHORIZATION_PENDING = "authorization_pending" + SLOW_DOWN = "slow_down" + ACCESS_DENIED = "access_denied" + EXPIRED_TOKEN = "expired_token" + + +COZE_PKCE_AUTH_ERROR_TYPE_ENUMS = set(e.value for e in CozePKCEAuthErrorType) + + class CozePKCEAuthError(CozeError): """ base class for all pkce auth errors """ - def __init__( - self, error: Literal["authorization_pending", "slow_down", "access_denied", "expired_token"], logid: str = None - ): - super().__init__(f"pkce auth error: {error}") + def __init__(self, error: CozePKCEAuthErrorType, logid: str = None): + super().__init__(f"pkce auth error: {error.value}") self.error = error self.logid = logid diff --git a/cozepy/request.py b/cozepy/request.py index 1b5b0dd..58c172a 100644 --- a/cozepy/request.py +++ b/cozepy/request.py @@ -18,7 +18,7 @@ from typing_extensions import get_args, get_origin from cozepy.config import DEFAULT_CONNECTION_LIMITS, DEFAULT_TIMEOUT -from cozepy.exception import CozeAPIError, CozePKCEAuthError +from cozepy.exception import COZE_PKCE_AUTH_ERROR_TYPE_ENUMS, CozeAPIError, CozePKCEAuthError, CozePKCEAuthErrorType from cozepy.log import log_debug, log_warning from cozepy.version import user_agent @@ -174,8 +174,8 @@ def _parse_response( raise CozeAPIError(code, msg, logid) elif code is None and msg != "": log_warning("request %s#%s failed, logid=%s, msg=%s", method, url, logid, msg) - if msg in ["authorization_pending", "slow_down", "access_denied", "expired_token"]: - raise CozePKCEAuthError(msg, logid) + if msg in COZE_PKCE_AUTH_ERROR_TYPE_ENUMS: + raise CozePKCEAuthError(CozePKCEAuthErrorType(msg), logid) raise CozeAPIError(code, msg, logid) if get_origin(model) is list: item_model = get_args(model)[0] @@ -201,12 +201,7 @@ def _parse_requests_code_msg( if "code" in body and "msg" in body and int(body["code"]) > 0: return int(body["code"]), body["msg"], body.get(data_field) - if "error_code" in body and body["error_code"] in [ - "authorization_pending", - "slow_down", - "access_denied", - "expired_token", - ]: + if "error_code" in body and body["error_code"] in COZE_PKCE_AUTH_ERROR_TYPE_ENUMS: return None, body["error_code"], None if "error_message" in body and body["error_message"] != "": return None, body["error_message"], None diff --git a/tests/test_auth.py b/tests/test_auth.py index a40644f..15366d5 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -9,6 +9,7 @@ AsyncJWTOAuthApp, AsyncPKCEOAuthApp, AsyncWebOAuthApp, + CozePKCEAuthErrorType, DeviceAuthCode, DeviceOAuthApp, JWTAuth, @@ -377,7 +378,7 @@ def test_get_access_token_poll(self, respx_mock): mock_token = random_hex(20) respx_mock.post("/api/permission/oauth2/token").mock( - httpx.Response(200, json={"error_code": "authorization_pending"}) + httpx.Response(200, json={"error_code": CozePKCEAuthErrorType.AUTHORIZATION_PENDING}) ).mock( httpx.Response( 200, content=OAuthToken(access_token=mock_token, expires_in=int(time.time())).model_dump_json() @@ -462,7 +463,7 @@ async def test_get_access_token_poll(self, respx_mock): mock_token = random_hex(20) respx_mock.post("/api/permission/oauth2/token").mock( - httpx.Response(200, json={"error_code": "authorization_pending"}) + httpx.Response(200, json={"error_code": CozePKCEAuthErrorType.AUTHORIZATION_PENDING}) ).mock( httpx.Response( 200, content=OAuthToken(access_token=mock_token, expires_in=int(time.time())).model_dump_json() diff --git a/tests/test_exception.py b/tests/test_exception.py index 1017c74..156ed99 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -1,4 +1,4 @@ -from cozepy import CozeAPIError, CozeEventError, CozePKCEAuthError +from cozepy import CozeAPIError, CozeEventError, CozePKCEAuthError, CozePKCEAuthErrorType def test_coze_error(): @@ -14,7 +14,7 @@ def test_coze_error(): assert err.logid == "logid" assert str(err) == "msg: msg, logid: logid" - err = CozePKCEAuthError("authorization_pending") + err = CozePKCEAuthError(CozePKCEAuthErrorType.AUTHORIZATION_PENDING) assert err.error == "authorization_pending" err = CozeEventError("event", "xxx", "logid")