diff --git a/src/db/mod.rs b/src/db/mod.rs index 6364a6c1..8250166f 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -30,6 +30,7 @@ pub enum PixivDownloaderDbError { AnyHow(anyhow::Error), #[cfg(feature = "db_sqlite")] Sqlite(SqliteError), + Str(String), } impl PixivDownloaderDbError { diff --git a/src/db/sqlite/db.rs b/src/db/sqlite/db.rs index 793b9a12..205666ca 100644 --- a/src/db/sqlite/db.rs +++ b/src/db/sqlite/db.rs @@ -9,6 +9,8 @@ use super::super::{PushConfig, PushTask, PushTaskConfig}; #[cfg(feature = "server")] use super::super::{Token, User}; use super::SqliteError; +#[cfg(feature = "server")] +use crate::tmp_cache::TmpCacheEntry; use bytes::BytesMut; use chrono::{DateTime, Utc}; use flagset::FlagSet; @@ -76,6 +78,12 @@ id INT, lang TEXT, translated TEXT );"; +const TMP_CACHE_TABLE: &'static str = "CREATE TABLE tmp_cache ( +url TEXT, +path TEXT, +last_used DATETIME, +PRIMARY KEY (url) +);"; const TOKEN_TABLE: &'static str = "CREATE TABLE token ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INT, @@ -98,7 +106,7 @@ v3 INT, v4 INT, PRIMARY KEY (id) );"; -const VERSION: [u8; 4] = [1, 0, 0, 8]; +const VERSION: [u8; 4] = [1, 0, 0, 9]; pub struct PixivDownloaderSqlite { db: Mutex, @@ -260,6 +268,9 @@ impl PixivDownloaderSqlite { if db_version < [1, 0, 0, 8] { tx.execute(PUSH_TASK_DATA_TABLE, [])?; } + if db_version < [1, 0, 0, 9] { + tx.execute(TMP_CACHE_TABLE, [])?; + } self._write_version(&tx)?; tx.commit()?; } @@ -313,10 +324,19 @@ impl PixivDownloaderSqlite { if !tables.contains_key("push_task_data") { t.execute(PUSH_TASK_DATA_TABLE, [])?; } + if !tables.contains_key("tmp_cache") { + t.execute("TMP_CACHE_TABLE", [])?; + } t.commit()?; Ok(()) } + #[cfg(feature = "server")] + fn _delete_tmp_cache(tx: &Transaction, url: &str) -> Result<(), SqliteError> { + tx.execute("DELETE FROM tmp_cache WHERE url = ?;", [url])?; + Ok(()) + } + #[cfg(feature = "server")] fn _delete_token(tx: &Transaction, id: u64) -> Result<(), SqliteError> { tx.execute("DELETE FROM token WHERE id = ?;", [id])?; @@ -450,6 +470,44 @@ impl PixivDownloaderSqlite { .optional()?) } + #[cfg(feature = "server")] + async fn get_tmp_cache( + &self, + url: &str, + ) -> Result, PixivDownloaderDbError> { + let con = self.db.lock().await; + Ok(con + .query_row("SELECT * FROM tmp_cache WHERE url = ?;", [url], |row| { + Ok(TmpCacheEntry { + url: row.get(0)?, + path: row.get(1)?, + last_used: row.get(2)?, + }) + }) + .optional()?) + } + + #[cfg(feature = "server")] + async fn get_tmp_caches(&self, ttl: i64) -> Result, PixivDownloaderDbError> { + let t = Utc::now() + .checked_sub_signed(chrono::TimeDelta::seconds(ttl)) + .ok_or(PixivDownloaderDbError::Str(String::from( + "Failed to calculate expired time by ttl.", + )))?; + let con = self.db.lock().await; + let mut stmt = con.prepare("SELECT * FROM tmp_cache WHERE last_used < ?;")?; + let mut rows = stmt.query([t])?; + let mut entries = Vec::new(); + while let Some(row) = rows.next()? { + entries.push(TmpCacheEntry { + url: row.get(0)?, + path: row.get(1)?, + last_used: row.get(2)?, + }); + } + Ok(entries) + } + #[cfg(feature = "server")] async fn get_token(&self, id: u64) -> Result, SqliteError> { let con = self.db.lock().await; @@ -573,6 +631,16 @@ impl PixivDownloaderSqlite { } } + #[cfg(feature = "server")] + fn _put_tmp_cache(ts: &Transaction, url: &str, path: &str) -> Result<(), SqliteError> { + let t = Utc::now(); + ts.execute( + "INSERT INTO tmp_cache (url, path, last_used) VALUES (?, ?, ?);", + (url, path, t), + )?; + Ok(()) + } + #[cfg(feature = "server")] fn _revoke_expired_tokens(ts: &Transaction) -> Result { let now = Utc::now(); @@ -725,6 +793,16 @@ impl PixivDownloaderSqlite { Ok(()) } + #[cfg(feature = "server")] + fn _update_tmp_cache(tx: &Transaction, url: &str) -> Result<(), PixivDownloaderDbError> { + let now = Utc::now(); + tx.execute( + "UPDATE tmp_cache SET last_used = ? WHERE url = ?;", + (now, url), + )?; + Ok(()) + } + #[cfg(feature = "server")] async fn _update_user( &self, @@ -923,6 +1001,15 @@ impl PixivDownloaderDb for PixivDownloaderSqlite { .expect("User not found:")) } + #[cfg(feature = "server")] + async fn delete_tmp_cache(&self, url: &str) -> Result<(), PixivDownloaderDbError> { + let mut db = self.db.lock().await; + let mut tx = db.transaction()?; + Self::_delete_tmp_cache(&mut tx, url)?; + tx.commit()?; + Ok(()) + } + #[cfg(feature = "server")] async fn delete_token(&self, id: u64) -> Result<(), PixivDownloaderDbError> { let mut db = self.db.lock().await; @@ -1006,6 +1093,19 @@ impl PixivDownloaderDb for PixivDownloaderSqlite { Ok(self.get_push_task_data(id).await?) } + #[cfg(feature = "server")] + async fn get_tmp_cache( + &self, + url: &str, + ) -> Result, PixivDownloaderDbError> { + Ok(self.get_tmp_cache(url).await?) + } + + #[cfg(feature = "server")] + async fn get_tmp_caches(&self, ttl: i64) -> Result, PixivDownloaderDbError> { + Ok(self.get_tmp_caches(ttl).await?) + } + #[cfg(feature = "server")] async fn get_token(&self, id: u64) -> Result, PixivDownloaderDbError> { Ok(self.get_token(id).await?) @@ -1049,6 +1149,15 @@ impl PixivDownloaderDb for PixivDownloaderSqlite { Ok(self._list_users_id(offset, count).await?) } + #[cfg(feature = "server")] + async fn put_tmp_cache(&self, url: &str, path: &str) -> Result<(), PixivDownloaderDbError> { + let mut db = self.db.lock().await; + let mut tx = db.transaction()?; + let size = Self::_put_tmp_cache(&mut tx, url, path)?; + tx.commit()?; + Ok(size) + } + #[cfg(feature = "server")] async fn revoke_expired_tokens(&self) -> Result { let mut db = self.db.lock().await; @@ -1123,6 +1232,15 @@ impl PixivDownloaderDb for PixivDownloaderSqlite { Ok(()) } + #[cfg(feature = "server")] + async fn update_tmp_cache(&self, url: &str) -> Result<(), PixivDownloaderDbError> { + let mut db = self.db.lock().await; + let mut tx = db.transaction()?; + let size = Self::_update_tmp_cache(&mut tx, url)?; + tx.commit()?; + Ok(size) + } + #[cfg(feature = "server")] async fn update_user( &self, diff --git a/src/db/sqlite/error.rs b/src/db/sqlite/error.rs index b62753f2..f4f8dc50 100644 --- a/src/db/sqlite/error.rs +++ b/src/db/sqlite/error.rs @@ -4,4 +4,11 @@ pub enum SqliteError { DatabaseVersionTooNew, UserNameAlreadyExists, SerdeError(serde_json::Error), + Str(String), +} + +impl From<&str> for SqliteError { + fn from(value: &str) -> Self { + Self::Str(String::from(value)) + } } diff --git a/src/db/traits.rs b/src/db/traits.rs index ea223829..252229da 100644 --- a/src/db/traits.rs +++ b/src/db/traits.rs @@ -5,6 +5,8 @@ use super::{PixivArtwork, PixivArtworkLock}; use super::{PushConfig, PushTask, PushTaskConfig}; #[cfg(feature = "server")] use super::{Token, User}; +#[cfg(feature = "server")] +use crate::tmp_cache::TmpCacheEntry; use chrono::{DateTime, Utc}; use flagset::FlagSet; @@ -90,6 +92,10 @@ pub trait PixivDownloaderDb { is_admin: bool, ) -> Result; #[cfg(feature = "server")] + /// Delete tmp cache entry + /// * `url` - URL + async fn delete_tmp_cache(&self, url: &str) -> Result<(), PixivDownloaderDbError>; + #[cfg(feature = "server")] /// Delete a token /// * `id` - The token ID async fn delete_token(&self, id: u64) -> Result<(), PixivDownloaderDbError>; @@ -150,6 +156,17 @@ pub trait PixivDownloaderDb { /// * `id` - The task's ID async fn get_push_task_data(&self, id: u64) -> Result, PixivDownloaderDbError>; #[cfg(feature = "server")] + /// Get tmp cache entry via url + /// * `url` - URL + async fn get_tmp_cache( + &self, + url: &str, + ) -> Result, PixivDownloaderDbError>; + #[cfg(feature = "server")] + /// Get tmp cache entries should deleted + /// * `ttl` - Time to live in seconds + async fn get_tmp_caches(&self, ttl: i64) -> Result, PixivDownloaderDbError>; + #[cfg(feature = "server")] /// Get token by ID /// * `id` - The token ID async fn get_token(&self, id: u64) -> Result, PixivDownloaderDbError>; @@ -185,6 +202,11 @@ pub trait PixivDownloaderDb { count: u64, ) -> Result, PixivDownloaderDbError>; #[cfg(feature = "server")] + /// Put new tmp cache + /// * `url` - URL + /// * `path` - Path + async fn put_tmp_cache(&self, url: &str, path: &str) -> Result<(), PixivDownloaderDbError>; + #[cfg(feature = "server")] /// Remove all expired tokens /// Return the number of removed tokens async fn revoke_expired_tokens(&self) -> Result; @@ -238,6 +260,10 @@ pub trait PixivDownloaderDb { last_updated: &DateTime, ) -> Result<(), PixivDownloaderDbError>; #[cfg(feature = "server")] + /// Update tmp cache last used time + /// * `url` - URL + async fn update_tmp_cache(&self, url: &str) -> Result<(), PixivDownloaderDbError>; + #[cfg(feature = "server")] /// Update a user's information /// * `id`: The user's ID /// * `name`: The user's name diff --git a/src/server/context.rs b/src/server/context.rs index be11b3b8..fb6775ed 100644 --- a/src/server/context.rs +++ b/src/server/context.rs @@ -11,6 +11,7 @@ use crate::get_helper; use crate::gettext; use crate::pixiv_app::PixivAppClient; use crate::pixiv_web::PixivWebClient; +use crate::tmp_cache::TmpCache; use crate::utils::get_file_name_from_url; use futures_util::lock::Mutex; use hyper::{http::response::Builder, Body, Request, Response}; @@ -24,19 +25,22 @@ pub struct ServerContext { pub cors: CorsContext, pub db: Arc>, pub rsa_key: Mutex>, + pub tmp_cache: Arc, pub _pixiv_app_client: Mutex>, pub _pixiv_web_client: Mutex>>, } impl ServerContext { pub async fn default() -> Self { + let db = match open_and_init_database(get_helper().db()).await { + Ok(db) => Arc::new(db), + Err(e) => panic!("{} {}", gettext("Failed to open database:"), e), + }; Self { cors: CorsContext::default(), - db: match open_and_init_database(get_helper().db()).await { - Ok(db) => Arc::new(db), - Err(e) => panic!("{} {}", gettext("Failed to open database:"), e), - }, + db: db.clone(), rsa_key: Mutex::new(None), + tmp_cache: Arc::new(TmpCache::new(db)), _pixiv_app_client: Mutex::new(None), _pixiv_web_client: Mutex::new(None), } diff --git a/src/server/push/task/pixiv_send_message.rs b/src/server/push/task/pixiv_send_message.rs index 424b54dd..2356fe3a 100644 --- a/src/server/push/task/pixiv_send_message.rs +++ b/src/server/push/task/pixiv_send_message.rs @@ -4,6 +4,7 @@ use crate::db::push_task::{ TelegramPushConfig, }; use crate::error::PixivDownloaderError; +use crate::formdata::FormDataPartBuilder; use crate::opt::author_name_filter::AuthorFiler; use crate::parser::description::convert_description_to_tg_html; use crate::parser::description::DescriptionParser; @@ -193,7 +194,27 @@ impl RunContext { ) -> Result, PixivDownloaderError> { if download_media { match self._get_image_url(index) { - Some(u) => Ok(Some(InputFile::URL(u))), + Some(u) => match self + .ctx + .tmp_cache + .get_cache(&u, json::object! { "referer": "https://www.pixiv.net/" }) + .await + { + Ok(p) => { + let name = p + .file_name() + .map(|a| a.to_str().unwrap_or("")) + .unwrap_or("") + .to_owned(); + let f = FormDataPartBuilder::default() + .body(p) + .filename(name) + .build() + .map_err(|_| "Failed to create FormDataPart.")?; + Ok(Some(InputFile::Content(f))) + } + Err(e) => Err(e), + }, None => Ok(None), } } else { @@ -679,6 +700,7 @@ impl RunContext { text += "\n"; if cfg.author_locations.contains(&AuthorLocation::Top) { if let Some(a) = &author { + text.push_str(gettext("by ")); text += a; text.push('\n'); } @@ -692,6 +714,7 @@ impl RunContext { text.push('\n'); if cfg.author_locations.contains(&AuthorLocation::Bottom) { if let Some(a) = &author { + text.push_str(gettext("by ")); text += a; text.push('\n'); } diff --git a/src/server/timer.rs b/src/server/timer.rs index 9d1655a8..6bb4ce4d 100644 --- a/src/server/timer.rs +++ b/src/server/timer.rs @@ -1,9 +1,14 @@ use super::auth::*; use super::context::ServerContext; +use crate::error::PixivDownloaderError; use crate::task_manager::{MaxCount, TaskManager}; use std::sync::Arc; use tokio::time::{interval_at, Duration, Instant}; +pub async fn remove_tmp_cache(ctx: Arc) -> Result<(), PixivDownloaderError> { + ctx.tmp_cache.remove_expired_cache().await +} + pub async fn start_timer(ctx: Arc) { let mut interval = interval_at(Instant::now(), Duration::from_secs(60)); let task_count = Arc::new(futures_util::lock::Mutex::new(0usize)); @@ -18,6 +23,7 @@ pub async fn start_timer(ctx: Arc) { Ok(()) }) .await; + tasks.add_task(remove_tmp_cache(ctx.clone())).await; tasks.join().await; for task in tasks.take_finished_tasks() { let re = task.await; diff --git a/src/server/unittest/mod.rs b/src/server/unittest/mod.rs index 856f515d..2d31b5fa 100644 --- a/src/server/unittest/mod.rs +++ b/src/server/unittest/mod.rs @@ -7,6 +7,7 @@ use super::preclude::HttpBodyType; use super::route::ServerRoutes; use crate::db::{open_and_init_database, PixivDownloaderDbConfig}; use crate::error::PixivDownloaderError; +use crate::tmp_cache::TmpCache; use futures_util::lock::Mutex; use hyper::{Body, Request, Response}; use json::JsonValue; @@ -25,21 +26,23 @@ pub struct UnitTestContext { impl UnitTestContext { pub async fn new() -> Self { + let db = Arc::new( + open_and_init_database( + PixivDownloaderDbConfig::new(&json::object! { + "type": "sqlite", + "path": "test/server.db", + }) + .unwrap(), + ) + .await + .unwrap(), + ); Self { ctx: Arc::new(ServerContext { cors: CorsContext::new(true, vec![], vec![]), - db: Arc::new( - open_and_init_database( - PixivDownloaderDbConfig::new(&json::object! { - "type": "sqlite", - "path": "test/server.db", - }) - .unwrap(), - ) - .await - .unwrap(), - ), + db: db.clone(), rsa_key: Mutex::new(None), + tmp_cache: Arc::new(TmpCache::new(db.clone())), _pixiv_app_client: Mutex::new(None), _pixiv_web_client: Mutex::new(None), }), diff --git a/src/tmp_cache/mod.rs b/src/tmp_cache/mod.rs index be57a850..9b6affaf 100644 --- a/src/tmp_cache/mod.rs +++ b/src/tmp_cache/mod.rs @@ -1,3 +1,149 @@ +use crate::concat_pixiv_downloader_error; +use crate::db::PixivDownloaderDb; +use crate::downloader::{DownloaderHelper, DownloaderResult}; +use crate::error::PixivDownloaderError; use crate::get_helper; +use crate::utils::get_file_name_from_url; +use crate::webclient::ToHeaders; +use chrono::{DateTime, Utc}; +use futures_util::lock::Mutex; +use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::path::PathBuf; +use std::sync::Arc; -pub struct TmpCache {} +#[derive(Debug)] +pub struct TmpCacheEntry { + pub url: String, + pub path: String, + pub last_used: DateTime, +} + +pub struct TmpCache { + used: Mutex>, + db: Arc>, + in_cleaning: Mutex<()>, +} + +impl TmpCache { + pub fn new(db: Arc>) -> Self { + Self { + used: Mutex::new(HashMap::new()), + db, + in_cleaning: Mutex::new(()), + } + } + + async fn _get_cache( + &self, + url: &str, + headers: H, + ) -> Result { + match self.db.get_tmp_cache(url).await? { + Some(ent) => { + if tokio::fs::try_exists(&ent.path).await.unwrap_or(false) { + return Ok(PathBuf::from(&ent.path)); + } + } + None => match self.db.delete_tmp_cache(url).await { + _ => {} + }, + } + let mut tmp_dir = get_helper().temp_dir(); + let u = get_file_name_from_url(url).unwrap_or_else(|| { + let mut hasher = DefaultHasher::new(); + url.hash(&mut hasher); + hasher.finish().to_string() + }); + let dh = DownloaderHelper::builder(url)? + .file_name(&u) + .headers(headers) + .build(); + let d = dh.download_local(Some(true), &tmp_dir)?; + match d { + DownloaderResult::Ok(d) => { + d.disable_progress_bar(); + d.download(); + d.join().await?; + } + DownloaderResult::Canceled => { + return Err(PixivDownloaderError::from("Download canceled.")); + } + } + tmp_dir.push(u); + match self + .db + .put_tmp_cache(url, tmp_dir.to_string_lossy().trim()) + .await + { + Ok(()) => {} + Err(e) => { + log::warn!(target: "tmp_cache", "Failed to write cache {} to database: {}", url, e); + } + } + Ok(tmp_dir) + } + + pub async fn get_cache( + &self, + url: &str, + headers: H, + ) -> Result { + self.wait_for_url(url).await; + let re = self._get_cache(url, headers).await; + self.remove_for_url(url).await; + re + } + + pub async fn remove_cache_entry(&self, ent: TmpCacheEntry) -> Result<(), PixivDownloaderError> { + let t = self.in_cleaning.try_lock(); + if t.is_none() { + return Ok(()); + } + self.wait_for_url(&ent.url).await; + match tokio::fs::remove_file(&ent.path).await { + Ok(_) => {} + Err(e) => { + log::warn!(target: "tmp_cache", "Failed to remove cache {}: {}", ent.path, e); + } + } + match self.db.delete_tmp_cache(&ent.url).await { + Ok(_) => {} + Err(e) => { + self.remove_for_url(&ent.url).await; + Err(e)?; + } + } + self.remove_for_url(&ent.url).await; + Ok(()) + } + + pub async fn remove_expired_cache(&self) -> Result<(), PixivDownloaderError> { + let entries = self.db.get_tmp_caches(3600).await?; + let mut err = Ok(()); + for ent in entries { + let e = self.remove_cache_entry(ent).await; + concat_pixiv_downloader_error!(err, e); + } + err + } + + async fn remove_for_url(&self, url: &str) { + let mut m = self.used.lock().await; + m.remove(url); + } + + async fn wait_for_url(&self, url: &str) { + loop { + { + let mut m = self.used.lock().await; + if !m.contains_key(url) { + m.insert(url.to_owned(), ()); + break; + } + } + tokio::time::sleep(std::time::Duration::new(0, 100_000_000)).await; + } + } +}