From c848e50b3bdb8d585880cb83abb163924d5f3a07 Mon Sep 17 00:00:00 2001 From: sigma67 Date: Sun, 31 Dec 2023 11:18:39 +0100 Subject: [PATCH] refactor authtypes --- tests/test.py | 9 +++++---- ytmusicapi/auth/headers.py | 12 ------------ ytmusicapi/auth/types.py | 25 ++++++++++++++++++++++++ ytmusicapi/mixins/uploads.py | 3 ++- ytmusicapi/ytmusic.py | 37 ++++++++++++++++++------------------ 5 files changed, 50 insertions(+), 36 deletions(-) delete mode 100644 ytmusicapi/auth/headers.py create mode 100644 ytmusicapi/auth/types.py diff --git a/tests/test.py b/tests/test.py index 99f6e79..d43e100 100644 --- a/tests/test.py +++ b/tests/test.py @@ -8,6 +8,7 @@ from requests import Response +from ytmusicapi.auth.types import AuthType from ytmusicapi.setup import main, setup # noqa: E402 from ytmusicapi.ytmusic import YTMusic, OAuthCredentials # noqa: E402 from ytmusicapi.constants import OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET @@ -82,7 +83,7 @@ def test_setup_oauth(self, session_mock, json_mock): # OAUTH ############### # 000 so test is run first and fresh token is available to others - def test_000_oauth_tokens(self): + def test_oauth_tokens(self): # ensure instance initialized token self.assertIsNotNone(self.yt_oauth._token) @@ -115,14 +116,14 @@ def test_000_oauth_tokens(self): # ensure token is updating local file self.assertNotEqual(first_json, second_json) - def test_alt_oauth(self): + def test_oauth_custom_client(self): # ensure client works/ignores alt if browser credentials passed as auth - self.assertFalse(self.yt_alt_oauth.is_alt_oauth) + self.assertNotEqual(self.yt_alt_oauth.auth_type, AuthType.OAUTH_CUSTOM_CLIENT) with open(oauth_filepath, 'r') as f: token_dict = json.load(f) # oauth token dict entry and alt self.yt_alt_oauth = YTMusic(token_dict, oauth_credentials=alt_oauth_creds) - self.assertTrue(self.yt_alt_oauth.is_alt_oauth) + self.assertEqual(self.yt_alt_oauth.auth_type, AuthType.OAUTH_CUSTOM_CLIENT) ############### # BROWSING diff --git a/ytmusicapi/auth/headers.py b/ytmusicapi/auth/headers.py deleted file mode 100644 index c3d8165..0000000 --- a/ytmusicapi/auth/headers.py +++ /dev/null @@ -1,12 +0,0 @@ -import json -import os -from typing import Dict - - -def load_headers_file(auth: str) -> Dict: - if os.path.isfile(auth): - with open(auth) as json_file: - input_json = json.load(json_file) - else: - input_json = json.loads(auth) - return input_json diff --git a/ytmusicapi/auth/types.py b/ytmusicapi/auth/types.py new file mode 100644 index 0000000..106c051 --- /dev/null +++ b/ytmusicapi/auth/types.py @@ -0,0 +1,25 @@ +"""enum representing types of authentication supported by this library""" + +from enum import Enum, auto +from typing import List + + +class AuthType(int, Enum): + """enum representing types of authentication supported by this library""" + + UNAUTHORIZED = auto() + + BROWSER = auto() + + #: client auth via OAuth token refreshing + OAUTH_DEFAULT = auto() + + #: YTM instance is using a non-default OAuth client (id & secret) + OAUTH_CUSTOM_CLIENT = auto() + + #: allows fully formed OAuth headers to ignore browser auth refresh flow + OAUTH_CUSTOM_FULL = auto() + + @classmethod + def oauth_types(cls) -> List["AuthType"]: + return [cls.OAUTH_DEFAULT, cls.OAUTH_CUSTOM_CLIENT, cls.OAUTH_CUSTOM_FULL] diff --git a/ytmusicapi/mixins/uploads.py b/ytmusicapi/mixins/uploads.py index e63bb19..6465a9b 100644 --- a/ytmusicapi/mixins/uploads.py +++ b/ytmusicapi/mixins/uploads.py @@ -10,6 +10,7 @@ from ytmusicapi.parsers.library import parse_library_albums, parse_library_artists, get_library_contents from ytmusicapi.parsers.albums import parse_album_header from ytmusicapi.parsers.uploads import parse_uploaded_items +from ..auth.types import AuthType class UploadsMixin: @@ -194,7 +195,7 @@ def upload_song(self, filepath: str) -> Union[str, requests.Response]: :return: Status String or full response """ self._check_auth() - if not self.is_browser_auth: + if not self.auth_type == AuthType.BROWSER: raise Exception("Please provide authentication before using this function") if not os.path.isfile(filepath): raise Exception("The provided file does not exist.") diff --git a/ytmusicapi/ytmusic.py b/ytmusicapi/ytmusic.py index 922cb2a..75bf940 100644 --- a/ytmusicapi/ytmusic.py +++ b/ytmusicapi/ytmusic.py @@ -17,9 +17,9 @@ from ytmusicapi.mixins.playlists import PlaylistsMixin from ytmusicapi.mixins.uploads import UploadsMixin -from .auth.headers import load_headers_file from .auth.oauth import OAuthCredentials, RefreshingToken, OAuthToken from .auth.oauth.base import Token +from .auth.types import AuthType class YTMusic(BrowsingMixin, SearchMixin, WatchMixin, ExploreMixin, LibraryMixin, PlaylistsMixin, @@ -82,11 +82,7 @@ def __init__(self, self.auth = auth #: raw auth self._input_dict = {} #: parsed auth arg value in dictionary format - # (?) may be better implemented as an auth_type attribute with a literal/enum value (?) - self.is_alt_oauth = False #: YTM instance is using a non-default OAuth client (id & secret) - self.is_oauth_auth = False #: client auth via OAuth token refreshing - self.is_browser_auth = False #: authorization via extracted browser headers, enables uploading capabilities - self.is_custom_oauth = False #: allows fully formed OAuth headers to ignore browser auth refresh flow + self.auth_type: AuthType = AuthType.UNAUTHORIZED self._token: Token #: OAuth credential handler self.oauth_credentials: OAuthCredentials #: Client used for OAuth refreshing @@ -103,14 +99,19 @@ def __init__(self, else: # Use the Requests API module as a "session". self._session = requests.api - self.oauth_credentials = oauth_credentials if oauth_credentials is not None else OAuthCredentials() - # see google cookie docs: https://policies.google.com/technologies/cookies # value from https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/extractor/youtube.py#L502 self.cookies = {'SOCS': 'CAI'} if self.auth is not None: + self.oauth_credentials = oauth_credentials if oauth_credentials is not None else OAuthCredentials() + auth_filepath = None if isinstance(self.auth, str): - input_json = load_headers_file(self.auth) + if os.path.isfile(auth): + with open(auth) as json_file: + auth_filepath = auth + input_json = json.load(json_file) + else: + input_json = json.loads(auth) self._input_dict = CaseInsensitiveDict(input_json) else: @@ -118,10 +119,8 @@ def __init__(self, if OAuthToken.is_oauth(self._input_dict): base_token = OAuthToken(**self._input_dict) - self._token = RefreshingToken(base_token, self.oauth_credentials, - self._input_dict.get('filepath')) - self.is_oauth_auth = True - self.is_alt_oauth = oauth_credentials is not None + self._token = RefreshingToken(base_token, self.oauth_credentials, auth_filepath) + self.auth_type = AuthType.OAUTH_CUSTOM_CLIENT if oauth_credentials else AuthType.OAUTH_DEFAULT # prepare context self.context = initialize_context() @@ -152,13 +151,13 @@ def __init__(self, auth_headers = self._input_dict.get("authorization") if auth_headers: if "SAPISIDHASH" in auth_headers: - self.is_browser_auth = True + self.auth_type = AuthType.BROWSER elif auth_headers.startswith('Bearer'): - self.is_custom_oauth = True + self.auth_type = AuthType.OAUTH_CUSTOM_FULL # sapsid, origin, and params all set once during init self.params = YTM_PARAMS - if self.is_browser_auth: + if self.auth_type == AuthType.BROWSER: self.params += YTM_PARAMS_KEY try: cookie = self.base_headers.get('cookie') @@ -170,7 +169,7 @@ def __init__(self, @property def base_headers(self): if not self._base_headers: - if self.is_browser_auth or self.is_custom_oauth: + if self.auth_type == AuthType.BROWSER or self.auth_type == AuthType.OAUTH_CUSTOM_FULL: self._base_headers = self._input_dict else: self._base_headers = { @@ -191,10 +190,10 @@ def headers(self): self._headers = self.base_headers # keys updated each use, custom oauth implementations left untouched - if self.is_browser_auth: + if self.auth_type == AuthType.BROWSER: self._headers["authorization"] = get_authorization(self.sapisid + ' ' + self.origin) - elif self.is_oauth_auth: + elif self.auth_type in AuthType.oauth_types(): self._headers['authorization'] = self._token.as_auth() self._headers['X-Goog-Request-Time'] = str(int(time.time()))