Skip to content

Commit

Permalink
Merge pull request bohning#228 from bohning/status-filter
Browse files Browse the repository at this point in the history
Filtering and sorting by status
  • Loading branch information
RumovZ authored Feb 13, 2024
2 parents 154fd5b + 2e64724 commit a567dfa
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 70 deletions.
84 changes: 63 additions & 21 deletions src/usdb_syncer/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import threading
import time
from pathlib import Path
from typing import Generator, Iterable, Iterator
from typing import Generator, Iterable, Iterator, assert_never, cast

import attrs

Expand All @@ -23,9 +23,11 @@ class _SqlCache:
_cache: dict[str, str] = {}

@classmethod
def get(cls, name: str) -> str:
def get(cls, name: str, cache: bool = True) -> str:
if (stmt := cls._cache.get(name)) is None:
cls._cache[name] = stmt = AppPaths.sql.joinpath(name).read_text("utf8")
stmt = AppPaths.sql.joinpath(name).read_text("utf8")
if cache:
cls._cache[name] = stmt
return stmt


Expand Down Expand Up @@ -75,7 +77,7 @@ def _validate_schema(connection: sqlite3.Connection) -> None:
"SELECT 1 FROM sqlite_schema WHERE type = 'table' AND name = 'meta'"
).fetchone()
if meta_table is None:
connection.executescript(_SqlCache.get("setup_script.sql"))
connection.executescript(_SqlCache.get("setup_script.sql", cache=False))
connection.execute(
"INSERT INTO meta (id, version, ctime) VALUES (1, ?, ?)",
(SCHEMA_VERSION, int(time.time() * 1_000_000)),
Expand All @@ -84,7 +86,7 @@ def _validate_schema(connection: sqlite3.Connection) -> None:
row = connection.execute("SELECT version FROM meta").fetchone()
if not row or row[0] != SCHEMA_VERSION:
raise errors.UnknownSchemaError
connection.execute("PRAGMA foreign_keys = ON")
connection.executescript(_SqlCache.get("setup_session_script.sql", cache=False))


def connect(db_path: Path | str, trace: bool = False) -> None:
Expand All @@ -95,6 +97,34 @@ def close() -> None:
_DbState.close()


class DownloadStatus(enum.IntEnum):
"""Status of song in download queue."""

NONE = 0
PENDING = enum.auto()
DOWNLOADING = enum.auto()
FAILED = enum.auto()

def __str__(self) -> str:
match self:
case DownloadStatus.NONE:
return ""
case DownloadStatus.PENDING:
return "Pending"
case DownloadStatus.DOWNLOADING:
return "Downloading"
case DownloadStatus.FAILED:
return "Failed"
case _ as unreachable:
assert_never(unreachable)

def can_be_downloaded(self) -> bool:
return self in (DownloadStatus.NONE, DownloadStatus.FAILED)

def can_be_aborted(self) -> bool:
return self in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING)


class SongOrder(enum.Enum):
"""Attributes songs can be sorted by."""

Expand All @@ -117,7 +147,8 @@ class SongOrder(enum.Enum):
VIDEO = "video.sync_meta_id IS NULL"
COVER = "cover.sync_meta_id IS NULL"
BACKGROUND = "background.sync_meta_id IS NULL"
SYNC_TIME = "sync_meta.mtime"
# max integer in SQLite
STATUS = "coalesce(usdb_song_status.status, sync_meta.mtime, 9223372036854775807)"


@attrs.define
Expand All @@ -130,10 +161,11 @@ class SearchBuilder:
artists: list[str] = attrs.field(factory=list)
titles: list[str] = attrs.field(factory=list)
editions: list[str] = attrs.field(factory=list)
languages: list[str] = attrs.field(factory=list)
golden_notes: bool | None = None
ratings: list[int] = attrs.field(factory=list)
statuses: list[DownloadStatus] = attrs.field(factory=list)
languages: list[str] = attrs.field(factory=list)
views: list[tuple[int, int | None]] = attrs.field(factory=list)
golden_notes: bool | None = None
downloaded: bool | None = None

def _filters(self) -> Iterator[str]:
Expand All @@ -142,23 +174,24 @@ def _filters(self) -> Iterator[str]:
"usdb_song.song_id IN (SELECT rowid FROM fts_usdb_song WHERE"
" fts_usdb_song MATCH ?)"
)
if self.artists:
yield _in_values_clause("usdb_song.artist", self.artists)
if self.titles:
yield _in_values_clause("usdb_song.title", self.titles)
if self.editions:
yield _in_values_clause("usdb_song.edition", self.editions)
for vals, col in (
(self.artists, "usdb_song.artist"),
(self.titles, "usdb_song.title"),
(self.editions, "usdb_song.edition"),
(self.ratings, "usdb_song.rating"),
(self.statuses, "usdb_song_status.status"),
):
if vals:
yield _in_values_clause(col, cast(list, vals))
if self.languages:
yield (
"usdb_song.song_id IN (SELECT song_id FROM usdb_song_language WHERE"
f" {_in_values_clause('language', self.languages)})"
)
if self.golden_notes is not None:
yield "usdb_song.golden_notes = ?"
if self.ratings:
yield _in_values_clause("usdb_song.rating", self.ratings)
if self.views:
yield _in_ranges_clause("usdb_song.views", self.views)
if self.golden_notes is not None:
yield "usdb_song.golden_notes = ?"
if self.downloaded is not None:
yield f"sync_meta.sync_meta_id IS {'NOT ' if self.downloaded else ''}NULL"

Expand All @@ -177,14 +210,15 @@ def parameters(self) -> Iterator[str | int | bool]:
yield from self.artists
yield from self.titles
yield from self.editions
yield from self.languages
if self.golden_notes is not None:
yield self.golden_notes
yield from self.ratings
yield from self.statuses
yield from self.languages
for min_views, max_views in self.views:
yield min_views
if max_views is not None:
yield max_views
if self.golden_notes is not None:
yield self.golden_notes

def statement(self) -> str:
select_from = _SqlCache.get("select_song_id.sql")
Expand Down Expand Up @@ -242,11 +276,19 @@ class UsdbSongParams:
genre: str
creator: str
tags: str
status: DownloadStatus


def upsert_usdb_song(params: UsdbSongParams) -> None:
stmt = _SqlCache.get("upsert_usdb_song.sql")
_DbState.connection().execute(stmt, params.__dict__)
if params.status is DownloadStatus.NONE:
_DbState.connection().execute(
"DELETE FROM usdb_song_status WHERE song_id = ?", (params.song_id,)
)
else:
stmt = _SqlCache.get("upsert_usdb_song_status.sql")
_DbState.connection().execute(stmt, params.__dict__)


def upsert_usdb_songs(params: Iterable[UsdbSongParams]) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/usdb_syncer/db/sql/select_song_id.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ SELECT
usdb_song.song_id
FROM
usdb_song
LEFT JOIN usdb_song_status ON usdb_song.song_id = usdb_song_status.song_id
LEFT JOIN active_sync_meta ON usdb_song.song_id = active_sync_meta.song_id
AND active_sync_meta.rank = 1
LEFT JOIN sync_meta ON sync_meta.sync_meta_id = active_sync_meta.sync_meta_id
Expand Down
2 changes: 2 additions & 0 deletions src/usdb_syncer/db/sql/select_usdb_song.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ SELECT
usdb_song.genre,
usdb_song.creator,
usdb_song.tags,
coalesce(usdb_song_status.status, 0),
sync_meta.sync_meta_id,
sync_meta.song_id,
sync_meta.path,
Expand All @@ -34,6 +35,7 @@ SELECT
background.resource
FROM
usdb_song
LEFT JOIN usdb_song_status ON usdb_song.song_id = usdb_song_status.song_id
LEFT JOIN active_sync_meta ON usdb_song.song_id = active_sync_meta.song_id
AND active_sync_meta.rank = 1
LEFT JOIN sync_meta ON sync_meta.sync_meta_id = active_sync_meta.sync_meta_id
Expand Down
12 changes: 12 additions & 0 deletions src/usdb_syncer/db/sql/setup_session_script.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
BEGIN;

CREATE TEMPORARY TABLE usdb_song_status (
song_id INTEGER NOT NULL,
status INTEGER NOT NULL,
PRIMARY KEY (song_id),
FOREIGN KEY (song_id) REFERENCES usdb_song (song_id) ON DELETE CASCADE
);

PRAGMA foreign_keys = ON;

COMMIT;
7 changes: 7 additions & 0 deletions src/usdb_syncer/db/sql/upsert_usdb_song_status.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
INSERT INTO
usdb_song_status
VALUES
(:song_id, :status) ON CONFLICT (song_id) DO
UPDATE
SET
status = :status
5 changes: 3 additions & 2 deletions src/usdb_syncer/gui/mw.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def _setup_statusbar(self) -> None:
self._status_label = QLabel(self)
self.statusbar.addWidget(self._status_label)

def on_count_changed(shown_count: int) -> None:
def on_count_changed(rows: int, selected: int) -> None:
total = db.usdb_song_count()
self._status_label.setText(
f"{shown_count} out of {db.usdb_song_count()} songs shown."
f"{rows} out of {total} songs shown, {selected} selected."
)

self.table.connect_row_count_changed(on_count_changed)
Expand Down
24 changes: 20 additions & 4 deletions src/usdb_syncer/gui/search_tree/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,37 @@ class StatusVariant(SongMatch, enum.Enum):

NONE = enum.auto()
DOWNLOADED = enum.auto()
IN_PROGRESS = enum.auto()
FAILED = enum.auto()

def __str__(self) -> str:
match self:
case StatusVariant.NONE:
return "Not downloaded"
case StatusVariant.DOWNLOADED:
return "Downloaded"
case StatusVariant.IN_PROGRESS:
return "In progress"
case StatusVariant.FAILED:
return "Failed"
case _ as unreachable:
assert_never(unreachable)

def build_search(self, search: db.SearchBuilder) -> None:
if search.downloaded is None:
search.downloaded = self is StatusVariant.DOWNLOADED
else:
search.downloaded = None
match self:
case StatusVariant.IN_PROGRESS:
search.statuses.append(db.DownloadStatus.PENDING)
search.statuses.append(db.DownloadStatus.DOWNLOADING)
case StatusVariant.FAILED:
search.statuses.append(db.DownloadStatus.FAILED)
case StatusVariant.NONE | StatusVariant.DOWNLOADED:
search.downloaded = (
self is StatusVariant.DOWNLOADED
if search.downloaded is None
else None
)
case unreachable:
assert_never(unreachable)


class RatingVariant(SongMatch, enum.Enum):
Expand Down
2 changes: 1 addition & 1 deletion src/usdb_syncer/gui/song_table/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,6 @@ def song_order(self) -> db.SongOrder:
case Column.BACKGROUND:
return db.SongOrder.BACKGROUND
case Column.DOWNLOAD_STATUS:
return db.SongOrder.SYNC_TIME
return db.SongOrder.STATUS
case unreachable:
assert_never(unreachable)
11 changes: 8 additions & 3 deletions src/usdb_syncer/gui/song_table/song_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def _download_inner(self, rows: Iterable[int]) -> None:
continue
if song.status.can_be_downloaded():
song.status = DownloadStatus.PENDING
with db.transaction():
song.upsert()
events.SongChanged(song.song_id).post()
to_download.append(song)
if to_download:
Expand Down Expand Up @@ -209,15 +211,18 @@ def set_selection_to_indices(self, rows: Iterable[QModelIndex]) -> None:

### sorting and filtering

def connect_row_count_changed(self, func: Callable[[int], None]) -> None:
"""Calls `func` with the new row count."""
def connect_row_count_changed(self, func: Callable[[int, int], None]) -> None:
"""Calls `func` with the new table row and selection counts."""

def wrapped(*_: Any) -> None:
func(self._model.rowCount())
func(
self._model.rowCount(), len(self._view.selectionModel().selectedRows())
)

self._model.modelReset.connect(wrapped)
self._model.rowsInserted.connect(wrapped)
self._model.rowsRemoved.connect(wrapped)
self._view.selectionModel().selectionChanged.connect(wrapped)

def _setup_search_timer(self) -> None:
self._search_timer = QTimer(self.mw)
Expand Down
19 changes: 12 additions & 7 deletions src/usdb_syncer/song_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def abort(cls, songs: Iterable[SongId]) -> None:
if cls._threadpool().tryTake(job):
job.logger.info("Download aborted by user request.")
job.song.status = DownloadStatus.NONE
with db.transaction():
job.song.upsert()
events.SongChanged(job.song_id).post()
events.DownloadFinished(job.song_id).post()
else:
Expand Down Expand Up @@ -304,9 +306,8 @@ def __init__(self, song: UsdbSong, options: download_options.Options) -> None:
self.logger = get_logger(__file__, self.song_id)

def run(self) -> None:
change_event: events.SubscriptableEvent = events.SongChanged(self.song_id)
try:
updated_song = self._run_inner()
self.song = self._run_inner()
except errors.AbortError:
self.logger.info("Download aborted by user request.")
self.song.status = DownloadStatus.NONE
Expand All @@ -317,7 +318,9 @@ def run(self) -> None:
self.logger.error("Song has been deleted from USDB.")
with db.transaction():
self.song.delete()
change_event = events.SongDeleted(self.song_id)
events.SongDeleted(self.song_id).post()
events.DownloadFinished(self.song_id).post()
return
except Exception: # pylint: disable=broad-except
self.logger.debug(traceback.format_exc())
self.logger.error(
Expand All @@ -326,16 +329,18 @@ def run(self) -> None:
)
self.song.status = DownloadStatus.FAILED
else:
updated_song.status = DownloadStatus.NONE
with db.transaction():
updated_song.upsert()
self.song.status = DownloadStatus.NONE
self.logger.info("All done!")
change_event.post()
with db.transaction():
self.song.upsert()
events.SongChanged(self.song_id).post()
events.DownloadFinished(self.song_id).post()

def _run_inner(self) -> UsdbSong:
self._check_flags()
self.song.status = DownloadStatus.DOWNLOADING
with db.transaction():
self.song.upsert()
events.SongChanged(self.song_id).post()
with tempfile.TemporaryDirectory() as tempdir:
ctx = _Context.new(self.song, self.options, Path(tempdir), self.logger)
Expand Down
Loading

0 comments on commit a567dfa

Please sign in to comment.