From b347c8b53ba6e91af00d23906a8b13c9956f569a Mon Sep 17 00:00:00 2001 From: Calum Pinder Date: Thu, 31 Oct 2024 08:31:43 +0000 Subject: [PATCH 1/5] Add access tokens to db --- backend/src/database/crud/user.py | 16 ++++++++++++++-- ...-table.py => m_240909_add_tracks_table.py} | 10 +++++----- .../m_241030_add_access_tokens_table.py | 19 +++++++++++++++++++ .../migrations/{init.py => m_init.py} | 10 +++++----- backend/src/database/models.py | 14 +++++++++++--- backend/src/spotify.py | 16 +++++++++++++--- 6 files changed, 67 insertions(+), 18 deletions(-) rename backend/src/database/migrations/{240909-add-tracks-table.py => m_240909_add_tracks_table.py} (66%) create mode 100644 backend/src/database/migrations/m_241030_add_access_tokens_table.py rename backend/src/database/migrations/{init.py => m_init.py} (83%) diff --git a/backend/src/database/crud/user.py b/backend/src/database/crud/user.py index 9ff6e41..05c5ab6 100644 --- a/backend/src/database/crud/user.py +++ b/backend/src/database/crud/user.py @@ -1,9 +1,9 @@ -from src.database.models import DbUser +from src.database.models import DbAccessToken, DbUser from src.dataclasses.user import User def get_user_by_id(id: str): - return DbUser.get( + return DbUser.get_or_none( DbUser.id == id, ) @@ -26,3 +26,15 @@ def get_or_create_user(user: User): "uri": user.uri, }, ) + + +def upsert_user_tokens(user_id: str, access_token: str, refresh_token: str): + DbAccessToken.insert( + user=user_id, access_token=access_token, refresh_token=refresh_token + ).on_conflict( + conflict_target=[DbAccessToken.user], + update={ + DbAccessToken.access_token: access_token, + DbAccessToken.refresh_token: refresh_token, + }, + ).execute() diff --git a/backend/src/database/migrations/240909-add-tracks-table.py b/backend/src/database/migrations/m_240909_add_tracks_table.py similarity index 66% rename from backend/src/database/migrations/240909-add-tracks-table.py rename to backend/src/database/migrations/m_240909_add_tracks_table.py index 1b642bb..9a3d100 100644 --- a/backend/src/database/migrations/240909-add-tracks-table.py +++ b/backend/src/database/migrations/m_240909_add_tracks_table.py @@ -1,13 +1,13 @@ from src.database.models import ( DbTrack, TrackArtistRelationship, - database, + db_wrapper, ) def up(): - with database: - database.create_tables( + with db_wrapper.database: + db_wrapper.database.create_tables( [ DbTrack, TrackArtistRelationship, @@ -16,8 +16,8 @@ def up(): def down(): - with database: - database.drop_tables( + with db_wrapper.database: + db_wrapper.database.drop_tables( [ DbTrack, TrackArtistRelationship, diff --git a/backend/src/database/migrations/m_241030_add_access_tokens_table.py b/backend/src/database/migrations/m_241030_add_access_tokens_table.py new file mode 100644 index 0000000..0a6b2b3 --- /dev/null +++ b/backend/src/database/migrations/m_241030_add_access_tokens_table.py @@ -0,0 +1,19 @@ +from src.database.models import DbAccessToken, db_wrapper + + +def up(): + with db_wrapper.database: + db_wrapper.database.create_tables( + [ + DbAccessToken, + ] + ) + + +def down(): + with db_wrapper.database: + db_wrapper.database.drop_tables( + [ + DbAccessToken, + ] + ) diff --git a/backend/src/database/migrations/init.py b/backend/src/database/migrations/m_init.py similarity index 83% rename from backend/src/database/migrations/init.py rename to backend/src/database/migrations/m_init.py index 9d8437d..3c28e11 100644 --- a/backend/src/database/migrations/init.py +++ b/backend/src/database/migrations/m_init.py @@ -7,13 +7,13 @@ DbPlaylist, PlaylistAlbumRelationship, DbUser, - database, + db_wrapper, ) def up(): - with database: - database.create_tables( + with db_wrapper.database: + db_wrapper.database.create_tables( [ DbUser, DbPlaylist, @@ -28,8 +28,8 @@ def up(): def down(): - with database: - database.drop_tables( + with db_wrapper.database: + db_wrapper.database.drop_tables( [ DbUser, DbPlaylist, diff --git a/backend/src/database/models.py b/backend/src/database/models.py index 8f94a8a..4e7e2d4 100644 --- a/backend/src/database/models.py +++ b/backend/src/database/models.py @@ -1,12 +1,9 @@ from peewee import ( - PostgresqlDatabase, - Model, CharField, IntegerField, DateField, ForeignKeyField, ) -from src.flask_config import Config from playhouse.flask_utils import FlaskDB db_wrapper = FlaskDB() @@ -116,3 +113,14 @@ class TrackArtistRelationship(db_wrapper.Model): class Meta: indexes = ((("track", "artist"), True),) + + +class DbAccessToken(db_wrapper.Model): + user = ForeignKeyField( + DbUser, backref="owner", to_field="id", on_delete="CASCADE", unique=True + ) + access_token = CharField(max_length=400) + refresh_token = CharField(max_length=200) + + class Meta: + db_table = "access_token" diff --git a/backend/src/spotify.py b/backend/src/spotify.py index 4c418d4..4ec1928 100644 --- a/backend/src/spotify.py +++ b/backend/src/spotify.py @@ -6,6 +6,7 @@ from typing import List, Optional from flask import Response, make_response, redirect from src.database.crud.album import get_album_genres +from src.database.crud.user import upsert_user_tokens from src.dataclasses.album import Album from src.dataclasses.playback_info import PlaybackInfo, PlaylistProgression from src.dataclasses.playback_request import ( @@ -102,6 +103,11 @@ def refresh_access_token(self, refresh_token): token_response = TokenResponse.model_validate(api_response) access_token = token_response.access_token user_info = self.get_current_user(access_token) + upsert_user_tokens( + user_info.id, + access_token=token_response.access_token, + refresh_token=token_response.refresh_token, + ) resp = add_cookies_to_response( make_response(), {"spotify_access_token": access_token, "user_id": user_info.id}, @@ -123,12 +129,16 @@ def request_access_token(self, code): ) api_response = self.response_handler(response) token_response = TokenResponse.model_validate(api_response) - access_token = token_response.access_token - user_info = self.get_current_user(access_token) + user_info = self.get_current_user(token_response.access_token) + upsert_user_tokens( + user_info.id, + access_token=token_response.access_token, + refresh_token=token_response.refresh_token, + ) resp = add_cookies_to_response( make_response(redirect(f"{Config().FRONTEND_URL}/")), { - "spotify_access_token": access_token, + "spotify_access_token": token_response.access_token, "spotify_refresh_token": token_response.refresh_token, "user_id": user_info.id, }, From e4ba9e43986aee49573e0bacf9adb030fddf31d5 Mon Sep 17 00:00:00 2001 From: Calum Pinder Date: Tue, 19 Nov 2024 14:16:24 +0000 Subject: [PATCH 2/5] Move all uses of access and refresh tokens to be managed by backend and database. Fix login flow (mostly unnecessary now) --- backend/poetry.lock | 17 +- backend/pyproject.toml | 1 + backend/src/app.py | 8 +- backend/src/controllers/auth.py | 11 +- backend/src/controllers/database.py | 16 +- backend/src/controllers/music_data.py | 30 ++- backend/src/controllers/spotify.py | 58 +++-- backend/src/database/crud/user.py | 4 + backend/src/database/models.py | 4 +- backend/src/spotify.py | 220 +++++++++++------- frontend/src/components/ButtonAsync.tsx | 7 +- .../AlbumList/AlbumActions.tsx | 5 +- 12 files changed, 224 insertions(+), 157 deletions(-) diff --git a/backend/poetry.lock b/backend/poetry.lock index 5402447..bb99451 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -598,6 +598,21 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tomli" version = "2.0.1" @@ -657,4 +672,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "38c06b14a3b5eac97e022c6efe4f9dc4611dc8a7c862a1c97d30121acb89ce27" +content-hash = "4f92276065527cfe2e0ca7f00519dbb952a4bb588e1c8ea1fc973f13c8963a35" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2b50b0e..0cebf20 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,6 +20,7 @@ gunicorn = "^22.0.0" peewee = "^3.17.6" asyncio = "^3.4.3" psycopg2 = "^2.9.9" +tenacity = "^9.0.0" [tool.poetry.group.dev.dependencies] pytest = "^8.1.1" diff --git a/backend/src/app.py b/backend/src/app.py index daf753e..601366d 100644 --- a/backend/src/app.py +++ b/backend/src/app.py @@ -1,4 +1,4 @@ -from flask import Flask, make_response +from flask import Flask, make_response, redirect from flask_cors import CORS from src.controllers.database import database_controller from src.controllers.spotify import spotify_controller @@ -43,11 +43,7 @@ def create_app(): @app.errorhandler(UnauthorizedException) def handle_unauthorized_exception(_): - resp = make_response( - "Spotify access token invalid or missing. Please re-authenticate.", 401 - ) - resp.delete_cookie("spotify_access_token") - resp.delete_cookie("user_id") + resp = redirect("/login", 401) return resp app.register_blueprint(auth_controller(spotify=spotify)) diff --git a/backend/src/controllers/auth.py b/backend/src/controllers/auth.py index 57298ff..7c53134 100644 --- a/backend/src/controllers/auth.py +++ b/backend/src/controllers/auth.py @@ -2,6 +2,7 @@ from flask import Blueprint, make_response, redirect, request, session from src.flask_config import Config from src.spotify import SpotifyClient +from src.utils.response_creator import add_cookies_to_response def auth_controller(spotify: SpotifyClient): @@ -36,7 +37,13 @@ def auth_redirect(): @auth_controller.route("refresh-user-code") def auth_refresh(): - refresh_token = request.cookies.get("spotify_refresh_token") - return spotify.refresh_access_token(refresh_token=refresh_token) + user_id = request.cookies.get("user_id") + (user_id, _, _) = spotify.refresh_access_token(user_id=user_id) + return add_cookies_to_response( + make_response(), + { + "user_id": user_id, + }, + ) return auth_controller diff --git a/backend/src/controllers/database.py b/backend/src/controllers/database.py index 5b98e63..6f75ed5 100644 --- a/backend/src/controllers/database.py +++ b/backend/src/controllers/database.py @@ -31,7 +31,7 @@ def database_controller( @database_controller.route("populate_user", methods=["GET"]) def populate_user(): access_token = request.cookies.get("spotify_access_token") - user = spotify.get_current_user(access_token) + user = spotify.get_user_by_id(access_token) (db_user, _) = get_or_create_user(user) simplified_playlists = spotify.get_all_playlists( user_id=user.id, access_token=access_token @@ -49,7 +49,9 @@ def populate_user(): if db_playlist is not None: delete_playlist(db_playlist.id) playlist = spotify.get_playlist( - access_token=access_token, id=simplified_playlist.id + user_id=user.id, + access_token=access_token, + id=simplified_playlist.id, ) create_playlist(playlist, db_user) @@ -58,12 +60,14 @@ def populate_user(): @database_controller.route("populate_playlist/", methods=["GET"]) def populate_playlist(id): access_token = request.cookies.get("spotify_access_token") - user = spotify.get_current_user(access_token) + user = spotify.get_user_by_id(access_token) (db_user, _) = get_or_create_user(user) db_playlist = get_playlist_by_id_or_none(id) if db_playlist is not None: delete_playlist(db_playlist.id) - playlist = spotify.get_playlist(access_token=access_token, id=id) + playlist = spotify.get_playlist( + user_id=user.id, access_token=access_token, id=id + ) create_playlist(playlist, db_user) albums = get_playlist_albums(playlist.id) batch_albums = split_list(albums, 20) @@ -83,7 +87,7 @@ def populate_playlist(id): @database_controller.route("populate_additional_album_details", methods=["GET"]) def populate_additional_album_details(): access_token = request.cookies.get("spotify_access_token") - user = spotify.get_current_user(access_token) + user = spotify.get_user_by_id(access_token) albums = get_user_albums_with_no_artists(user.id) batch_albums = split_list(albums, 20) for album_chunk in batch_albums: @@ -108,7 +112,7 @@ def populate_universal_genre_list(): @database_controller.route("populate_user_album_genres", methods=["GET"]) def populate_user_album_genres(): access_token = request.cookies.get("spotify_access_token") - user = spotify.get_current_user(access_token) + user = spotify.get_user_by_id(access_token) populate_album_genres_by_user_id(user.id, musicbrainz) return make_response("User album genres populated", 201) diff --git a/backend/src/controllers/music_data.py b/backend/src/controllers/music_data.py index 3e7132e..cd05969 100644 --- a/backend/src/controllers/music_data.py +++ b/backend/src/controllers/music_data.py @@ -79,18 +79,18 @@ def get_playlist(id): if db_playlist is not None: return make_response(jsonify(db_playlist.__data__), 200) else: - access_token = request.cookies.get("spotify_access_token") - playlist = spotify.get_playlist(access_token=access_token, id=id) - return make_response(jsonify(playlist.to_dict()), 200) + user_id = request.cookies.get("user_id") + playlist = spotify.get_playlist(user_id=user_id, id=id) + return make_response(jsonify(playlist.model_dump()), 200) @music_controller.route("playlist/", methods=["POST"]) def post_edit_playlist(id): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") name = request.json.get("name") description = request.json.get("description") update_playlist_info(id=id, name=name, description=description) spotify.update_playlist( - access_token=access_token, + user_id=user_id, id=id, name=name, description=description, @@ -104,12 +104,10 @@ def get_playlist_album_info(id): album_info_list = get_playlist_albums_with_genres(id) return jsonify(album_info_list) else: - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") return [ album.model_dump() - for album in spotify.get_playlist_album_info( - access_token=access_token, id=id - ) + for album in spotify.get_playlist_album_info(user_id=user_id, id=id) ] @music_controller.route("playlist//tracks", methods=["GET"]) @@ -119,12 +117,10 @@ def get_playlist_tracks(id): track_list = get_playlist_track_list(id) return jsonify(track_list) else: - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") return [ album.model_dump() - for album in spotify.get_playlist_album_info( - access_token=access_token, id=id - ) + for album in spotify.get_playlist_album_info(user_id=user_id, id=id) ] @music_controller.route("playlist/search", methods=["POST"]) @@ -135,7 +131,7 @@ def find_associated_playlists(): @music_controller.route("add_album_to_playlist", methods=["POST"]) def add_album_to_playlist(): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") request_body = request.json playlist_id = request_body["playlistId"] album_id = request_body["albumId"] @@ -145,13 +141,13 @@ def add_album_to_playlist(): ) create_playlist_album_relationship(playlist_id=playlist_id, album_id=album_id) return spotify.add_album_to_playlist( - access_token=access_token, playlist_id=playlist_id, album_id=album_id + user_id=user_id, playlist_id=playlist_id, album_id=album_id ) @music_controller.route("playback", methods=["GET"]) def get_playback_info(): - access_token = request.cookies.get("spotify_access_token") - playback_info = spotify.get_my_current_playback(access_token=access_token) + user_id = request.cookies.get("user_id") + playback_info = spotify.get_my_current_playback(user_id=user_id) if playback_info is None: return ("", 204) if playback_info.playlist_id is not None: diff --git a/backend/src/controllers/spotify.py b/backend/src/controllers/spotify.py index f10f137..2dac042 100644 --- a/backend/src/controllers/spotify.py +++ b/backend/src/controllers/spotify.py @@ -11,19 +11,16 @@ def spotify_controller(spotify: SpotifyClient): @spotify_controller.route("current-user") def get_current_user(): - access_token = request.cookies.get("spotify_access_token") - user = spotify.get_current_user(access_token=access_token) + user_id = request.cookies.get("user_id") + user = spotify.get_user_by_id(user_id=user_id) return user.model_dump() @spotify_controller.route("playlists") def index(): user_id = request.cookies.get("user_id") - access_token = request.cookies.get("spotify_access_token") limit = request.args.get("limit") offset = request.args.get("offset") - playlists = spotify.get_playlists( - user_id=user_id, access_token=access_token, limit=limit, offset=offset - ) + playlists = spotify.get_playlists(user_id=user_id, limit=limit, offset=offset) sort_by = request.args.get("sort_by") desc = request.args.get("desc") == "True" if sort_by is not None: @@ -33,12 +30,10 @@ def index(): @spotify_controller.route("create-playlist", methods=["POST"]) def create_playlist(): user_id = request.cookies.get("user_id") - access_token = request.cookies.get("spotify_access_token") name = request.json.get("name") description = request.json.get("description") spotify.create_playlist( user_id=user_id, - access_token=access_token, name=name, description=description, ) @@ -46,23 +41,23 @@ def create_playlist(): @spotify_controller.route("delete-playlist/", methods=["POST"]) def delete_playlist_by_id(id): - access_token = request.cookies.get("spotify_access_token") - spotify.delete_playlist(access_token=access_token, id=id) + user_id = request.cookies.get("user_id") + spotify.delete_playlist(user_id=user_id, id=id) return make_response("playlist deleted", 200) @spotify_controller.route("playlist/", methods=["GET"]) def get_edit_playlist(id): - access_token = request.cookies.get("spotify_access_token") - playlist = spotify.get_playlist(access_token=access_token, id=id) + user_id = request.cookies.get("user_id") + playlist = spotify.get_playlist(user_id=user_id, id=id) return playlist.model_dump() @spotify_controller.route("edit-playlist/", methods=["POST"]) def post_edit_playlist(id): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") name = request.json.get("name") description = request.json.get("description") spotify.update_playlist( - access_token=access_token, + user_id=user_id, id=id, name=name, description=description, @@ -71,28 +66,26 @@ def post_edit_playlist(id): @spotify_controller.route("playlist//albums", methods=["GET"]) def get_playlist_album_info(id): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") return [ album.model_dump() - for album in spotify.get_playlist_album_info( - access_token=access_token, id=id - ) + for album in spotify.get_playlist_album_info(user_id=user_id, id=id) ] @spotify_controller.route("playback", methods=["GET"]) def get_playback_info(): - access_token = request.cookies.get("spotify_access_token") - playback_info = spotify.get_my_current_playback(access_token=access_token) + user_id = request.cookies.get("user_id") + playback_info = spotify.get_my_current_playback(user_id=user_id) if playback_info is None: return ("", 204) return playback_info.model_dump_json() @spotify_controller.route("playlist_progress", methods=["POST"]) def get_playlist_progress(): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") api_playback = PlaybackInfo.model_validate(request.json) playlist_progression = spotify.get_playlist_progression( - access_token=access_token, api_playback=api_playback + user_id=user_id, api_playback=api_playback ) if playlist_progression is None: return ("", 204) @@ -100,10 +93,9 @@ def get_playlist_progress(): @spotify_controller.route("find_associated_playlists/", methods=["GET"]) def find_associated_playlists(playlist_id): - access_token = request.cookies.get("spotify_access_token") user_id = request.cookies.get("user_id") associated_playlists = spotify.find_associated_playlists( - user_id=user_id, access_token=access_token, playlist_id=playlist_id + user_id=user_id, playlist_id=playlist_id ) return [ associated_playlist.model_dump() @@ -112,7 +104,7 @@ def find_associated_playlists(playlist_id): @spotify_controller.route("add_album_to_playlist", methods=["POST"]) def add_album_to_playlist(): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") request_body = request.json playlist_id = request_body["playlistId"] album_id = request_body["albumId"] @@ -121,26 +113,28 @@ def add_album_to_playlist(): "Invalid request payload. Expected playlistId and albumId.", 400 ) return spotify.add_album_to_playlist( - access_token=access_token, playlist_id=playlist_id, album_id=album_id + user_id=user_id, playlist_id=playlist_id, album_id=album_id ) @spotify_controller.route("pause_playback", methods=["PUT"]) def pause_playback(): - access_token = request.cookies.get("spotify_access_token") - return spotify.pause_playback(access_token) + user_id = request.cookies.get("user_id") + return spotify.pause_playback(user_id) @spotify_controller.route("start_playback", methods=["PUT"]) def start_playback(): - access_token = request.cookies.get("spotify_access_token") + user_id = request.cookies.get("user_id") request_body = request.json start_playback_request_body = ( StartPlaybackRequest.model_validate(request_body) if request_body else None ) - return spotify.start_playback(access_token, start_playback_request_body) + return spotify.start_playback( + user_id=user_id, start_playback_request_body=start_playback_request_body + ) @spotify_controller.route("pause_or_start_playback", methods=["PUT"]) def pause_or_start_playback(): - access_token = request.cookies.get("spotify_access_token") - return spotify.pause_or_start_playback(access_token) + user_id = request.cookies.get("user_id") + return spotify.pause_or_start_playback(user_id=user_id) return spotify_controller diff --git a/backend/src/database/crud/user.py b/backend/src/database/crud/user.py index 05c5ab6..57867e9 100644 --- a/backend/src/database/crud/user.py +++ b/backend/src/database/crud/user.py @@ -38,3 +38,7 @@ def upsert_user_tokens(user_id: str, access_token: str, refresh_token: str): DbAccessToken.refresh_token: refresh_token, }, ).execute() + + +def get_user_tokens(user_id: str): + return DbAccessToken.get_or_none(DbAccessToken.user == user_id) diff --git a/backend/src/database/models.py b/backend/src/database/models.py index 4e7e2d4..0b8a3e1 100644 --- a/backend/src/database/models.py +++ b/backend/src/database/models.py @@ -119,8 +119,8 @@ class DbAccessToken(db_wrapper.Model): user = ForeignKeyField( DbUser, backref="owner", to_field="id", on_delete="CASCADE", unique=True ) - access_token = CharField(max_length=400) - refresh_token = CharField(max_length=200) + access_token = CharField(max_length=400, null=True) + refresh_token = CharField(max_length=200, null=True) class Meta: db_table = "access_token" diff --git a/backend/src/spotify.py b/backend/src/spotify.py index 4ec1928..74d897a 100644 --- a/backend/src/spotify.py +++ b/backend/src/spotify.py @@ -3,10 +3,11 @@ import requests import os import urllib.parse +from tenacity import RetryError, Retrying, stop_after_attempt, wait_fixed from typing import List, Optional from flask import Response, make_response, redirect from src.database.crud.album import get_album_genres -from src.database.crud.user import upsert_user_tokens +from src.database.crud.user import get_user_tokens, upsert_user_tokens from src.dataclasses.album import Album from src.dataclasses.playback_info import PlaybackInfo, PlaylistProgression from src.dataclasses.playback_request import ( @@ -60,7 +61,6 @@ def __init__(self): def response_handler(self, response: requests.Response, jsonify=True): if response.status_code == 401: - print(response.reason) raise UnauthorizedException else: if jsonify: @@ -82,7 +82,8 @@ def get_login_query_string(self, state): } ) - def refresh_access_token(self, refresh_token): + def refresh_access_token(self, user_id): + refresh_token = get_user_tokens(user_id).refresh_token if not refresh_token: raise UnauthorizedException response = requests.post( @@ -102,17 +103,17 @@ def refresh_access_token(self, refresh_token): api_response = self.response_handler(response) token_response = TokenResponse.model_validate(api_response) access_token = token_response.access_token - user_info = self.get_current_user(access_token) + refresh_token = ( + refresh_token + if token_response.refresh_token is None + else token_response.refresh_token + ) upsert_user_tokens( - user_info.id, + user_id=user_id, access_token=token_response.access_token, - refresh_token=token_response.refresh_token, - ) - resp = add_cookies_to_response( - make_response(), - {"spotify_access_token": access_token, "user_id": user_info.id}, + refresh_token=refresh_token, ) - return resp + return (user_id, access_token, refresh_token) def request_access_token(self, code): response = requests.post( @@ -138,48 +139,51 @@ def request_access_token(self, code): resp = add_cookies_to_response( make_response(redirect(f"{Config().FRONTEND_URL}/")), { - "spotify_access_token": token_response.access_token, - "spotify_refresh_token": token_response.refresh_token, "user_id": user_info.id, }, ) return resp - def get_playlists( - self, user_id, access_token, limit=10, offset=0 - ) -> CurrentUserPlaylists: - response = requests.get( - url=f"https://api.spotify.com/v1/users/{user_id}/playlists", - params={ - "limit": limit, - "offset": offset, - }, - auth=BearerAuth(access_token), - ) - api_playlists = self.response_handler(response) + def get_playlists(self, user_id, limit=10, offset=0) -> CurrentUserPlaylists: + try: + for attempt in Retrying( + wait=wait_fixed(2), + after=self.refresh_user_access_tokens(user_id=user_id), + ): + with attempt: + access_token = get_user_tokens(user_id=user_id).access_token + response = requests.get( + url=f"https://api.spotify.com/v1/users/{user_id}/playlists", + params={ + "limit": limit, + "offset": offset, + }, + auth=BearerAuth(access_token), + ) + api_playlists = self.response_handler(response) + except RetryError: + pass playlists = CurrentUserPlaylists.model_validate(api_playlists) return playlists - def get_all_playlists(self, user_id, access_token) -> List[SimplifiedPlaylist]: + def get_all_playlists(self, user_id) -> List[SimplifiedPlaylist]: playlists: List[SimplifiedPlaylist] = [] offset = 0 limit = 50 - api_playlists = self.get_playlists( - access_token=access_token, user_id=user_id, limit=limit, offset=offset - ) + api_playlists = self.get_playlists(user_id=user_id, limit=limit, offset=offset) while True: playlists += api_playlists.items if not api_playlists.next: return playlists offset += limit api_playlists = self.get_playlists( - access_token=access_token, user_id=user_id, limit=limit, offset=offset + user_id=user_id, limit=limit, offset=offset ) - def find_associated_playlists(self, user_id, access_token, playlist_id: str): + def find_associated_playlists(self, user_id, playlist_id: str): [playlist_name, user_playlists] = [ - self.get_playlist(access_token=access_token, id=playlist_id).name, - self.get_all_playlists(user_id=user_id, access_token=access_token), + self.get_playlist(id=playlist_id, user_id=user_id).name, + self.get_all_playlists(user_id=user_id), ] associated_playlists = [ matchingPlaylist @@ -189,6 +193,25 @@ def find_associated_playlists(self, user_id, access_token, playlist_id: str): ] return associated_playlists + def get_user_by_id(self, user_id): + try: + for attempt in Retrying( + wait=wait_fixed(1), + after=self.refresh_user_access_tokens(user_id=user_id), + stop=stop_after_attempt(2), + ): + access_token = get_user_tokens(user_id=user_id).access_token + with attempt: + response = requests.get( + url="https://api.spotify.com/v1/me", + auth=BearerAuth(access_token), + ) + api_current_user = self.response_handler(response) + current_user = User.model_validate(api_current_user) + except RetryError: + pass + return current_user + def get_current_user(self, access_token): response = requests.get( url="https://api.spotify.com/v1/me", @@ -198,28 +221,38 @@ def get_current_user(self, access_token): current_user = User.model_validate(api_current_user) return current_user - def get_playlist(self, access_token, id: str, fields=None): - response = requests.get( - url=f"https://api.spotify.com/v1/playlists/{id}", - params={ - "playlist_id": id, - "fields": fields, - }, - headers={ - "content-type": "application/json", - }, - auth=BearerAuth(access_token), - ) - api_playlist = self.response_handler(response) + def get_playlist(self, user_id, id: str, fields=None): + try: + for attempt in Retrying( + wait=wait_fixed(2), + after=self.refresh_user_access_tokens(user_id=user_id), + ): + with attempt: + access_token = get_user_tokens(user_id=user_id).access_token + response = requests.get( + url=f"https://api.spotify.com/v1/playlists/{id}", + params={ + "playlist_id": id, + "fields": fields, + }, + headers={ + "content-type": "application/json", + }, + auth=BearerAuth(access_token), + ) + api_playlist = self.response_handler(response) + except RetryError: + pass playlist = Playlist.model_validate(api_playlist) if playlist.tracks.next: playlist.tracks.items = self.get_playlist_tracks( - access_token=access_token, id=playlist.id + user_id=user_id, id=playlist.id ) return playlist - def create_playlist(self, user_id, access_token, name, description): + def create_playlist(self, user_id, name, description): description = None if description == "" else description + access_token = get_user_tokens(user_id=user_id).access_token response = requests.post( url=f"https://api.spotify.com/v1/users/{user_id}/playlists", data=json.dumps( @@ -236,8 +269,9 @@ def create_playlist(self, user_id, access_token, name, description): return self.response_handler(response, jsonify=False) def update_playlist( - self, access_token, id: str, name, description + self, user_id: str, id: str, name, description ): # ToDo: Figure out how to set description to empty string + access_token = get_user_tokens(user_id=user_id).access_token response = requests.put( url=f"https://api.spotify.com/v1/playlists/{id}", data=json.dumps({"name": name, "description": description, "public": True}), @@ -248,14 +282,16 @@ def update_playlist( ) return self.response_handler(response, jsonify=False) - def delete_playlist(self, access_token, id: str): + def delete_playlist(self, user_id, id: str): + access_token = get_user_tokens(user_id=user_id).access_token response = requests.delete( url=f"https://api.spotify.com/v1/playlists/{id}/followers", auth=BearerAuth(access_token), ) return self.response_handler(response, jsonify=False) - def get_playlist_items(self, access_token, id, limit, offset): + def get_playlist_items(self, user_id, id, limit, offset): + access_token = get_user_tokens(user_id=user_id).access_token response = requests.get( url=f"https://api.spotify.com/v1/playlists/{id}/tracks", params={ @@ -271,8 +307,8 @@ def get_playlist_items(self, access_token, id, limit, offset): playlist_tracks = PlaylistTracks.model_validate(api_playlist_tracks) return playlist_tracks - def get_playlist_album_info(self, access_token, id) -> List[Album]: - playlist_tracks = self.get_playlist_tracks(access_token, id) + def get_playlist_album_info(self, user_id, id) -> List[Album]: + playlist_tracks = self.get_playlist_tracks(user_id=user_id, id=id) playlist_albums: List[Album] = [] for track in playlist_tracks: if track.track.album not in playlist_albums: @@ -281,12 +317,12 @@ def get_playlist_album_info(self, access_token, id) -> List[Album]: album.genres = [genre.name for genre in get_album_genres(album.id)] return playlist_albums - def get_playlist_tracks(self, access_token, id: str): + def get_playlist_tracks(self, user_id, id: str): playlist_tracks: List[PlaylistTrackObject] = [] offset = 0 limit = 100 api_tracks_object = self.get_playlist_items( - access_token=access_token, id=id, limit=limit, offset=offset + user_id=user_id, id=id, limit=limit, offset=offset ) while True: sleep(0.5) @@ -295,10 +331,11 @@ def get_playlist_tracks(self, access_token, id: str): return playlist_tracks offset += limit api_tracks_object = self.get_playlist_items( - access_token, id, limit=limit, offset=offset + user_id=user_id, id=id, limit=limit, offset=offset ) - def get_album(self, access_token, id): + def get_album(self, user_id, id): + access_token = get_user_tokens(user_id=user_id).access_token response = requests.get( f"https://api.spotify.com/v1/albums/{id}", headers={ @@ -310,7 +347,8 @@ def get_album(self, access_token, id): album = Album.model_validate(api_album) return album - def get_multiple_albums(self, access_token, ids: List[str]) -> List[Album]: + def get_multiple_albums(self, user_id, ids: List[str]) -> List[Album]: + access_token = get_user_tokens(user_id=user_id).access_token encoded_ids = urllib.parse.quote_plus(",".join(ids)) response = requests.get( f"https://api.spotify.com/v1/albums?ids={encoded_ids}", @@ -323,9 +361,10 @@ def get_multiple_albums(self, access_token, ids: List[str]) -> List[Album]: albums = [Album.model_validate(api_album) for api_album in api_albums["albums"]] return albums - def get_current_playback(self, access_token) -> PlaybackState | None: + def get_current_playback(self, user_id) -> PlaybackState | None: + access_token = get_user_tokens(user_id=user_id).access_token response = requests.get( - f"https://api.spotify.com/v1/me/player", + "https://api.spotify.com/v1/me/player", auth=BearerAuth(access_token), ) api_current_playback = self.response_handler(response) @@ -336,8 +375,8 @@ def get_current_playback(self, access_token) -> PlaybackState | None: ) return current_playback - def get_my_current_playback(self, access_token) -> PlaybackInfo | None: - api_playback = self.get_current_playback(access_token=access_token) + def get_my_current_playback(self, user_id) -> PlaybackInfo | None: + api_playback = self.get_current_playback(user_id=user_id) if api_playback is None: return None @@ -346,7 +385,7 @@ def get_my_current_playback(self, access_token) -> PlaybackInfo | None: playlist_id = context.uri.replace("spotify:playlist:", "") else: playlist_id = None - album = self.get_album(access_token=access_token, id=api_playback.item.album.id) + album = self.get_album(user_id=user_id, id=api_playback.item.album.id) album_duration = sum([track.duration_ms for track in album.tracks.items]) album_progress = ( sum( @@ -379,13 +418,11 @@ def get_my_current_playback(self, access_token) -> PlaybackInfo | None: } ) - def get_playlist_progression(self, access_token, api_playback: PlaybackInfo): + def get_playlist_progression(self, user_id, api_playback: PlaybackInfo): playlist_tracks = self.get_playlist_tracks( - access_token=access_token, id=api_playback.playlist_id - ) - playlist_info = self.get_playlist( - access_token=access_token, id=api_playback.playlist_id + user_id=user_id, id=api_playback.playlist_id ) + playlist_info = self.get_playlist(user_id=user_id, id=api_playback.playlist_id) playlist_progress = get_playlist_progress(api_playback, playlist_tracks) playlist_duration = get_playlist_duration(playlist_tracks) return PlaylistProgression.model_validate( @@ -397,9 +434,9 @@ def get_playlist_progression(self, access_token, api_playback: PlaybackInfo): } ) - def search_albums( - self, access_token, search=None, offset=0, limit=50 - ) -> List[Album]: + def search_albums(self, user_id, search=None, offset=0, limit=50) -> List[Album]: + access_token = get_user_tokens(user_id=user_id).access_token + if search: response = requests.get( f"https://api.spotify.com/v1/albums/{id}", @@ -440,7 +477,9 @@ def search_albums( if x["album_type"] == "album" ] - def save_albums_to_library(self, access_token, album_ids: List[str]) -> Response: + def save_albums_to_library(self, user_id, album_ids: List[str]) -> Response: + access_token = get_user_tokens(user_id=user_id).access_token + response = requests.put( url="https://api.spotify.com/v1/me/albums", data=json.dumps( @@ -455,12 +494,13 @@ def save_albums_to_library(self, access_token, album_ids: List[str]) -> Response ) return self.response_handler(response, jsonify=False) - def add_album_to_playlist(self, access_token, playlist_id, album_id) -> Response: - album = self.get_album(access_token=access_token, id=album_id) - self.save_albums_to_library(access_token=access_token, album_ids=[album.id]) + def add_album_to_playlist(self, user_id, playlist_id, album_id) -> Response: + access_token = get_user_tokens(user_id=user_id).access_token + album = self.get_album(user_id=user_id, id=album_id) + self.save_albums_to_library(user_id=user_id, album_ids=[album.id]) track_uris = [item.uri for item in album.tracks.items] if self.is_album_in_playlist( - access_token=access_token, album=album, playlist_id=playlist_id + user_id=user_id, album=album, playlist_id=playlist_id ): return make_response("Album already present in playlist", 403) response = requests.post( @@ -481,15 +521,14 @@ def add_album_to_playlist(self, access_token, playlist_id, album_id) -> Response else: return make_response("Failed to add album to playlist", 400) - def is_album_in_playlist(self, access_token, playlist_id, album: Album) -> bool: - playlist_tracks = self.get_playlist_tracks( - access_token=access_token, id=playlist_id - ) + def is_album_in_playlist(self, user_id, playlist_id, album: Album) -> bool: + playlist_tracks = self.get_playlist_tracks(user_id=user_id, id=playlist_id) playlist_track_ids = [track.track.id for track in playlist_tracks] album_track_ids = [track.id for track in album.tracks.items] return all(e in playlist_track_ids for e in album_track_ids) - def pause_playback(self, access_token) -> Response: + def pause_playback(self, user_id) -> Response: + access_token = get_user_tokens(user_id=user_id).access_token response = requests.put( url="https://api.spotify.com/v1/me/player/pause", headers={ @@ -503,9 +542,11 @@ def pause_playback(self, access_token) -> Response: def start_playback( self, - access_token, + user_id, start_playback_request_body: Optional[StartPlaybackRequest] = None, ) -> Response: + access_token = get_user_tokens(user_id=user_id).access_token + if not start_playback_request_body: data = None else: @@ -514,7 +555,7 @@ def start_playback( { "uri": ( self.get_album( - access_token=access_token, + user_id=user_id, id=start_playback_request_body.offset.album_id, ) .tracks.items[0] @@ -538,12 +579,17 @@ def start_playback( make_response("", response.status_code), jsonify=False ) - def pause_or_start_playback(self, access_token) -> Response: - is_playing = self.get_current_playback(access_token).is_playing + def pause_or_start_playback(self, user_id) -> Response: + is_playing = self.get_current_playback(user_id=user_id).is_playing if is_playing: - return self.pause_playback(access_token) + return self.pause_playback(user_id=user_id) else: - return self.start_playback(access_token) + return self.start_playback(user_id=user_id) + + def refresh_user_access_tokens(self, user_id): + if not user_id: + raise UnauthorizedException + self.refresh_access_token(user_id=user_id) def get_playlist_duration(playlist_info: List[PlaylistTrackObject]) -> int: diff --git a/frontend/src/components/ButtonAsync.tsx b/frontend/src/components/ButtonAsync.tsx index 15c9118..1668e98 100644 --- a/frontend/src/components/ButtonAsync.tsx +++ b/frontend/src/components/ButtonAsync.tsx @@ -1,11 +1,12 @@ import React, { FC, useState, MouseEvent } from "react"; +import LoadingSpinner from "./LoadingSpinner"; const ButtonAsync: FC< React.DetailedHTMLProps< React.ButtonHTMLAttributes, HTMLButtonElement > -> = ({className, onClick, ...props}) => { +> = ({className, onClick,children, ...props}) => { const [isLoading, setIsLoading] = useState(false); const handleClick = async (event: MouseEvent): Promise => { @@ -21,7 +22,9 @@ const ButtonAsync: FC< disabled={isLoading} onClick={handleClick} className={`bg-primary rounded p-2 cursor-pointer hover:bg-primary-lighter active:bg-primary-lighter disabled:bg-background-offset ${className}`} - /> + > + {isLoading ? : children} + ); }; diff --git a/frontend/src/playlistExplorer/AlbumList/AlbumActions.tsx b/frontend/src/playlistExplorer/AlbumList/AlbumActions.tsx index acab3ab..4328988 100644 --- a/frontend/src/playlistExplorer/AlbumList/AlbumActions.tsx +++ b/frontend/src/playlistExplorer/AlbumList/AlbumActions.tsx @@ -1,5 +1,6 @@ import React, { FC } from "react"; import { Album } from "../../interfaces/Album"; +import ButtonAsync from "../../components/ButtonAsync"; import Button from "../../components/Button"; import { Playlist } from "../../interfaces/Playlist"; import { addAlbumToPlaylist, startPlayback } from "../../api"; @@ -14,12 +15,12 @@ const AlbumActions: FC = ({album, associatedPlaylists, contex return (
{associatedPlaylists.map((associatedPlaylist) => ( - + ))}