Skip to content

Commit

Permalink
feat: get oauth token by jwt flow (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
chyroc authored Sep 23, 2024
1 parent 779ffca commit a155601
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ jobs:
env:
COZE_TOKEN: ${{ secrets.COZE_TOKEN }}
SPACE_ID_1: ${{ secrets.SPACE_ID_1 }}
COZE_JWT_AUTH_CLIENT_ID: ${{ secrets.COZE_JWT_AUTH_CLIENT_ID }}
COZE_JWT_AUTH_PRIVATE_KEY: ${{ secrets.COZE_JWT_AUTH_PRIVATE_KEY }}
COZE_JWT_AUTH_KEY_ID: ${{ secrets.COZE_JWT_AUTH_KEY_ID }}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

Expand Down
9 changes: 6 additions & 3 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from .auth import Auth, TokenAuth

from .auth import ApplicationOAuth, Auth, TokenAuth
from .config import COZE_COM_BASE_URL, COZE_CN_BASE_URL
from .coze import Coze

from .model import TokenPaged, NumberPaged

__all__ = [
'ApplicationOAuth',
'Auth',
'TokenAuth',

'COZE_COM_BASE_URL',
'COZE_CN_BASE_URL',

'Coze',

'TokenPaged',
Expand Down
84 changes: 84 additions & 0 deletions cozepy/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,88 @@
import abc
import random
import time
from urllib.parse import urlparse

from authlib.jose import jwt

from cozepy.model import CozeModel
from cozepy.request import Requester
from .config import COZE_COM_BASE_URL


def _random_hex(length):
hex_characters = '0123456789abcdef'
return ''.join(random.choice(hex_characters) for _ in range(length))


class OAuthToken(CozeModel):
# The requested access token. The app can use this token to authenticate to the Coze resource.
access_token: str
# How long the access token is valid, in seconds (UNIX timestamp)
expires_in: int
# An OAuth 2.0 refresh token. The app can use this token to acquire other access tokens after the current access token expires. Refresh tokens are long-lived.
refresh_token: str = ''
# fixed: Bearer
token_type: str = ''


class DeviceAuthCode(CozeModel):
# device code
device_code: str
# The user code
user_code: str
# The verification uri
verification_uri: str
# The interval of the polling request
interval: int = 5
# The expiration time of the device code
expires_in: int

@property
def verification_url(self):
return f'{self.verification_uri}?user_code={self.user_code}'


class ApplicationOAuth(object):
"""
App OAuth process to support obtaining token and refreshing token.
"""

def __init__(self, client_id: str, client_secret: str = '', base_url: str = COZE_COM_BASE_URL):
self._client_id = client_id
self._client_secret = client_secret
self._base_url = base_url
self._api_endpoint = urlparse(base_url).netloc
self._token = ''
self._requester = Requester()

def jwt_auth(self, private_key: str, kid: str, ttl: int):
"""
Get the token by jwt with jwt auth flow.
"""
jwt_token = self._gen_jwt(self._api_endpoint, private_key, self._client_id, kid, 3600)
url = f'{self._base_url}/api/permission/oauth2/token'
headers = {
'Authorization': f'Bearer {jwt_token}'
}
body = {
'duration_seconds': ttl,
'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
}
return self._requester.request('post', url, OAuthToken, headers=headers, body=body)

def _gen_jwt(self, api_endpoint: str, private_key: str, client_id: str, kid: str, ttl: int):
now = int(time.time())
header = {'alg': 'RS256', 'typ': 'JWT', 'kid': kid}
payload = {
"iss": client_id,
'aud': api_endpoint,
"iat": now,
"exp": now + ttl,
'jti': _random_hex(16),
}
s = jwt.encode(header, payload, private_key)
return s.decode('utf-8')


class Auth(abc.ABC):
Expand Down
2 changes: 2 additions & 0 deletions cozepy/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
COZE_COM_BASE_URL = 'https://api.coze.com'
COZE_CN_BASE_URL = 'https://api.coze.cn'
3 changes: 2 additions & 1 deletion cozepy/coze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING

from cozepy.auth import Auth
from cozepy.config import COZE_COM_BASE_URL
from cozepy.request import Requester

if TYPE_CHECKING:
Expand All @@ -10,7 +11,7 @@
class Coze(object):
def __init__(self,
auth: Auth,
base_url: str = 'https://api.coze.com',
base_url: str = COZE_COM_BASE_URL,
):
self._auth = auth
self._base_url = base_url
Expand Down
47 changes: 29 additions & 18 deletions cozepy/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Tuple, Optional

import requests
from requests import Response

if TYPE_CHECKING:
from cozepy.auth import Auth
Expand All @@ -25,39 +26,49 @@ class Requester(object):
"""

def __init__(self,
auth: Optional["Auth"]
auth: 'Auth' = None
):
self._auth = auth

def request(self, method: str, url: str, model: Type[T], params: dict = None, headers: dict = None) -> T:
def request(self, method: str, url: str, model: Type[T], params: dict = None, headers: dict = None,
body: dict = None, ) -> T:
"""
Send a request to the server.
"""
if headers is None:
headers = {}
self._auth.authentication(headers)
r = requests.request(method, url, params=params, headers=headers)
logid = r.headers.get('x-tt-logid')
if self._auth:
self._auth.authentication(headers)
r = requests.request(method, url, params=params, headers=headers, json=body)

try:
json = r.json()
code = json.get('code') or 0
msg = json.get('msg') or ''
data = json.get('data')
except:
r.raise_for_status()

code = 0
msg = ''
data = {}
code, msg, data = self.__parse_requests_code_msg(r)

if code > 0:
if code is not None and code > 0:
# TODO: Exception 自定义类型
logid = r.headers.get('x-tt-logid')
raise Exception(f'{code}: {msg}, logid:{logid}')
elif code is None and msg != "":
logid = r.headers.get('x-tt-logid')
raise Exception(f'{msg}, logid:{logid}')
return model.model_validate(data)

async def arequest(self, method: str, path: str, **kwargs) -> dict:
"""
Send a request to the server with asyncio.
"""
pass

def __parse_requests_code_msg(self, r: Response) -> Tuple[Optional[int], str, Optional[T]]:
try:
json = r.json()
except:
r.raise_for_status()
return

if 'code' in json and 'msg' in json and int(json['code']) > 0:
return int(json['code']), json['msg'], json['data']
if 'error_message' in json and json['error_message'] != '':
return None, json['error_message'], None
if 'data' in json:
return 0, '', json['data']
return 0, '', json
Loading

0 comments on commit a155601

Please sign in to comment.