Skip to content

Commit

Permalink
Merge pull request bohning#279 from bohning/sqlite-error
Browse files Browse the repository at this point in the history
Attempt to fix sqlite3.InterfaceError
  • Loading branch information
RumovZ authored Aug 25, 2024
2 parents 80991aa + 0deb594 commit e426ad5
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 81 deletions.
3 changes: 1 addition & 2 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
Pipfile text eol=lf
Pipfile.lock text eol=lf
text eol=lf
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ load-plugins = "pylint.extensions.mccabe"
good-names = ["mw", "p1", "p2", "closeEvent", "customEvent"]

[tool.pylint.messages_control]
extension-pkg-whitelist = ["PySide6"]
extension-pkg-whitelist = ["PySide6", "shiboken6"]
disable = [
"too-few-public-methods",
"logging-fstring-interpolation",
Expand Down
64 changes: 39 additions & 25 deletions src/usdb_syncer/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import contextlib
import enum
import itertools
import json
import os
import sqlite3
import threading
import time
Expand Down Expand Up @@ -40,45 +40,50 @@ def get(cls, name: str, cache: bool = True) -> str:
return stmt


class _LocalConnection(threading.local):
"""A thread-local database connection."""

connection: sqlite3.Connection | None = None


class _DbState:
"""Singleton for managing the global database connection."""

lock = threading.Lock()
_connection: sqlite3.Connection | None = None
_local: _LocalConnection = _LocalConnection()

@classmethod
def connect(cls, db_path: Path | str, trace: bool = False) -> None:
with cls.lock:
cls._connection = sqlite3.connect(
db_path, check_same_thread=False, isolation_level=None
)
if trace:
cls._connection.set_trace_callback(_logger.debug)
_validate_schema(cls._connection)
if cls._local.connection:
raise errors.DatabaseError("Already connected to database!")
cls._local.connection = sqlite3.connect(
db_path, check_same_thread=False, isolation_level=None
)
if trace:
cls._local.connection.set_trace_callback(_logger.debug)
_validate_schema(cls._local.connection)

@classmethod
def connection(cls) -> sqlite3.Connection:
if cls._connection is None:
if cls._local.connection is None:
raise errors.DatabaseError("Not connected to database!")
return cls._connection
return cls._local.connection

@classmethod
def close(cls) -> None:
if _DbState._connection is not None:
_DbState._connection.close()
_DbState._connection = None
if _DbState._local.connection is not None:
_DbState._local.connection.close()
_DbState._local.connection = None


@contextlib.contextmanager
def transaction() -> Generator[None, None, None]:
with _DbState.lock:
try:
_DbState.connection().execute("BEGIN")
yield None
except Exception: # pylint: disable=broad-except
_DbState.connection().rollback()
raise
_DbState.connection().commit()
try:
_DbState.connection().execute("BEGIN IMMEDIATE")
yield None
except Exception: # pylint: disable=broad-except
_DbState.connection().rollback()
raise
_DbState.connection().commit()


def _validate_schema(connection: sqlite3.Connection) -> None:
Expand All @@ -104,14 +109,23 @@ def _validate_schema(connection: sqlite3.Connection) -> None:
connection.executescript(_SqlCache.get("setup_session_script.sql", cache=False))


def connect(db_path: Path | str, trace: bool = False) -> None:
_DbState.connect(db_path, trace=trace)
def connect(db_path: Path | str) -> None:
_DbState.connect(db_path, trace=bool(os.environ.get("TRACESQL")))


def close() -> None:
_DbState.close()


@contextlib.contextmanager
def managed_connection(db_path: Path | str) -> Generator[None, None, None]:
try:
_DbState.connect(db_path)
yield None
finally:
_DbState.close()


class DownloadStatus(enum.IntEnum):
"""Status of song in download queue."""

Expand Down
4 changes: 3 additions & 1 deletion src/usdb_syncer/db/sql/setup_session_script.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
PRAGMA journal_mode = WAL;

BEGIN;

CREATE TEMPORARY TABLE session_usdb_song (
CREATE TEMPORARY TABLE IF NOT EXISTS session_usdb_song (
song_id INTEGER NOT NULL,
status INTEGER NOT NULL DEFAULT 0,
is_playing BOOLEAN NOT NULL DEFAULT false,
Expand Down
4 changes: 3 additions & 1 deletion src/usdb_syncer/gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
settings,
song_routines,
sync_meta,
usdb_song,
utils,
)

Expand Down Expand Up @@ -87,11 +88,12 @@ def _load_main_window(mw: MainWindow) -> None:
QtWidgets.QApplication.processEvents()
splash.showMessage("Loading song database ...", color=Qt.GlobalColor.gray)
folder = settings.get_song_dir()
db.connect(utils.AppPaths.db, trace=bool(os.environ.get("TRACESQL")))
db.connect(utils.AppPaths.db)
with db.transaction():
song_routines.load_available_songs(force_reload=False)
song_routines.synchronize_sync_meta_folder(folder)
sync_meta.SyncMeta.reset_active(folder)
usdb_song.UsdbSong.clear_cache()
default_search = db.SavedSearch.get_default()
mw.tree.populate()
if default_search:
Expand Down
7 changes: 4 additions & 3 deletions src/usdb_syncer/gui/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import attrs
from PySide6 import QtCore, QtGui, QtWidgets

from usdb_syncer import logger
from usdb_syncer import db, logger, utils

_logger = logger.get_logger(__file__)
T = TypeVar("T")
Expand Down Expand Up @@ -74,14 +74,15 @@ def run_with_progress(
def wrapped_task() -> None:
nonlocal result
try:
result = Result(task())
with db.managed_connection(utils.AppPaths.db):
result = Result(task())
except Exception as exc: # pylint: disable=broad-exception-caught
result = Result(_Error(exc))
signal.result.emit()

def wrapped_on_done() -> None:
assert result
dialog.close()
dialog.deleteLater()
on_done(result)

signal.result.connect(wrapped_on_done)
Expand Down
56 changes: 29 additions & 27 deletions src/usdb_syncer/song_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mutagen.oggopus
import mutagen.oggvorbis
import send2trash
import shiboken6
from mutagen import id3
from mutagen.flac import Picture
from PIL import Image
Expand Down Expand Up @@ -63,7 +64,7 @@ def download(cls, songs: Iterable[UsdbSong]) -> None:
@classmethod
def abort(cls, songs: Iterable[SongId]) -> None:
for song in songs:
if job := cls._jobs.get(song):
if (job := cls._jobs.get(song)) and shiboken6.isValid(job):
if cls._threadpool().tryTake(job):
job.logger.info("Download aborted by user request.")
job.song.status = DownloadStatus.NONE
Expand Down Expand Up @@ -335,33 +336,34 @@ def __init__(self, song: UsdbSong, options: download_options.Options) -> None:
self.logger = get_logger(__file__, self.song_id)

def run(self) -> None:
try:
self.song = self._run_inner()
except errors.AbortError:
self.logger.info("Download aborted by user request.")
self.song.status = DownloadStatus.NONE
except errors.UsdbLoginError:
self.logger.error("Aborted; download requires login.")
self.song.status = DownloadStatus.FAILED
except errors.UsdbNotFoundError:
self.logger.error("Song has been deleted from USDB.")
with db.managed_connection(utils.AppPaths.db):
try:
self.song = self._run_inner()
except errors.AbortError:
self.logger.info("Download aborted by user request.")
self.song.status = DownloadStatus.NONE
except errors.UsdbLoginError:
self.logger.error("Aborted; download requires login.")
self.song.status = DownloadStatus.FAILED
except errors.UsdbNotFoundError:
self.logger.error("Song has been deleted from USDB.")
with db.transaction():
self.song.delete()
events.SongDeleted(self.song_id).post()
events.DownloadFinished(self.song_id).post()
return
except Exception: # pylint: disable=broad-except
self.logger.debug(traceback.format_exc())
self.logger.error(
"Failed to finish download due to an unexpected error. "
"See debug log for more information."
)
self.song.status = DownloadStatus.FAILED
else:
self.song.status = DownloadStatus.NONE
self.logger.info("All done!")
with db.transaction():
self.song.delete()
events.SongDeleted(self.song_id).post()
events.DownloadFinished(self.song_id).post()
return
except Exception: # pylint: disable=broad-except
self.logger.debug(traceback.format_exc())
self.logger.error(
"Failed to finish download due to an unexpected error. "
"See debug log for more information."
)
self.song.status = DownloadStatus.FAILED
else:
self.song.status = DownloadStatus.NONE
self.logger.info("All done!")
with db.transaction():
self.song.upsert()
self.song.upsert()
events.SongChanged(self.song_id).post()
events.DownloadFinished(self.song_id).post()

Expand Down
6 changes: 3 additions & 3 deletions src/usdb_syncer/usdb_song.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def remove_sync_meta(self) -> None:
if self.sync_meta:
self.sync_meta.delete()
self.sync_meta = None
_UsdbSongCache.remove(self.song_id)
_UsdbSongCache.update(self)

@classmethod
def delete_all(cls) -> None:
Expand All @@ -128,7 +128,7 @@ def upsert(self) -> None:
db.upsert_usdb_songs_creators([(self.song_id, self.creators())])
if self.sync_meta:
self.sync_meta.upsert()
_UsdbSongCache.remove(self.song_id)
_UsdbSongCache.update(self)

@classmethod
def upsert_many(cls, songs: list[UsdbSong]) -> None:
Expand All @@ -138,7 +138,7 @@ def upsert_many(cls, songs: list[UsdbSong]) -> None:
db.upsert_usdb_songs_creators([(s.song_id, s.creators()) for s in songs])
SyncMeta.upsert_many([song.sync_meta for song in songs if song.sync_meta])
for song in songs:
_UsdbSongCache.remove(song.song_id)
_UsdbSongCache.update(song)

def db_params(self) -> db.UsdbSongParams:
return db.UsdbSongParams(
Expand Down
36 changes: 18 additions & 18 deletions tests/unit/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


def test_persisting_usdb_song(song: UsdbSong) -> None:
db.connect(":memory:")
song.upsert()
db.reset_active_sync_metas(Path("C:"))
db_song = UsdbSong.get(song.song_id)
with db.managed_connection(":memory:"):
song.upsert()
db.reset_active_sync_metas(Path("C:"))
db_song = UsdbSong.get(song.song_id)

assert db_song
assert attrs.asdict(song) == attrs.asdict(db_song)
Expand All @@ -29,17 +29,17 @@ def test_persisting_saved_search() -> None:
years=[1990, 2000, 2010],
),
)
db.connect(":memory:")
search.insert()
saved = list(db.SavedSearch.load_saved_searches())
assert len(saved) == 1
assert search.name == "name"
assert saved[0] == search

search.insert()
assert search.name == "name (1)"
assert len(list(db.SavedSearch.load_saved_searches())) == 2

search.update(new_name="name")
assert search.name == "name (1)"
assert len(list(db.SavedSearch.load_saved_searches())) == 2
with db.managed_connection(":memory:"):
search.insert()
saved = list(db.SavedSearch.load_saved_searches())
assert len(saved) == 1
assert search.name == "name"
assert saved[0] == search

search.insert()
assert search.name == "name (1)"
assert len(list(db.SavedSearch.load_saved_searches())) == 2

search.update(new_name="name")
assert search.name == "name (1)"
assert len(list(db.SavedSearch.load_saved_searches())) == 2

0 comments on commit e426ad5

Please sign in to comment.