From 24d0910a7331e7e022164e679a50049bd4e0fdd6 Mon Sep 17 00:00:00 2001 From: Shubham Patel Date: Sat, 4 Mar 2023 11:04:25 +0530 Subject: [PATCH] Add optional boundaries to getregistrationreceipt --- watchtower-plugin/src/constants.rs | 6 +++ watchtower-plugin/src/dbm.rs | 74 ++++++++++++++++++++++-------- watchtower-plugin/src/main.rs | 19 +++++++- watchtower-plugin/src/wt_client.rs | 14 +++++- 4 files changed, 89 insertions(+), 24 deletions(-) diff --git a/watchtower-plugin/src/constants.rs b/watchtower-plugin/src/constants.rs index 1b14803d..9e224b94 100644 --- a/watchtower-plugin/src/constants.rs +++ b/watchtower-plugin/src/constants.rs @@ -3,6 +3,12 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR"; pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower"; /// Collections of plugin option names, default values and descriptions +pub const DEFAULT_SUBSCRIPTION_START: Option = None; +pub const SUBSCRIPTION_START: &str = "subscription-start"; +pub const SUBSCRIPTION_START_DESC: &str = "subscription-start time"; +pub const DEFAULT_SUBSCRIPTION_EXPIRY: Option = None; +pub const SUBSCRIPTION_EXPIRY: &str = "subscription-expiry"; +pub const SUBSCRIPTION_EXPIRY_DESC: &str = "subscription-expiry time"; pub const WT_PORT: &str = "watchtower-port"; pub const DEFAULT_WT_PORT: i64 = 9814; diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index bf271bb8..4e77dc86 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -3,7 +3,7 @@ use std::iter::FromIterator; use std::path::PathBuf; use std::str::FromStr; -use rusqlite::{params, Connection, Error as SqliteError}; +use rusqlite::{params, Connection, Error as SqliteError, ToSql}; use bitcoin::secp256k1::SecretKey; @@ -216,19 +216,29 @@ impl DBM { &self, tower_id: TowerId, user_id: UserId, + subscription_start: Option, + subscription_expiry: Option, ) -> Option { - let mut stmt = self - .connection - .prepare( - "SELECT available_slots, subscription_start, subscription_expiry, signature - FROM registration_receipts - WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry) - FROM registration_receipts - WHERE tower_id = ?1)", - ) - .unwrap(); + let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string(); + + let mut params = vec![tower_id.to_vec()]; + + if let Some(start) = subscription_start { + query.push_str(" AND subscription_start >= ?2"); + params.push(start.to_be_bytes().to_vec()); + } + + if let Some(expiry) = subscription_expiry { + query.push_str(" AND subscription_expiry <= ?3"); + params.push(expiry.to_be_bytes().to_vec()); + } + + query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1"); + + let mut stmt = self.connection.prepare(&query).unwrap(); + let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect(); - stmt.query_row([tower_id.to_vec()], |row| { + stmt.query_row(params.as_slice(), |row| { let slots: u32 = row.get(0).unwrap(); let start: u32 = row.get(1).unwrap(); let expiry: u32 = row.get(2).unwrap(); @@ -725,13 +735,20 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); + let subscription_start = None; + let subscription_expiry = None; // Check the receipt was stored dbm.store_tower_record(tower_id, net_addr, &receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap(), receipt ); @@ -742,8 +759,13 @@ mod tests { dbm.store_tower_record(tower_id, net_addr, &latest_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, latest_receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + latest_receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap(), latest_receipt ); @@ -751,8 +773,13 @@ mod tests { dbm.store_tower_record(tower_id, net_addr, &middle_receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, latest_receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + latest_receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap(), latest_receipt ); } @@ -765,13 +792,20 @@ mod tests { let tower_id = get_random_user_id(); let net_addr = "talaia.watch"; let receipt = get_random_registration_receipt(); + let subscription_start = None; + let subscription_expiry = None; // Store it once dbm.store_tower_record(tower_id, net_addr, &receipt) .unwrap(); assert_eq!( - dbm.load_registration_receipt(tower_id, receipt.user_id()) - .unwrap(), + dbm.load_registration_receipt( + tower_id, + receipt.user_id(), + subscription_start, + subscription_expiry + ) + .unwrap(), receipt ); diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index 3d4910ca..81786ff8 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -133,10 +133,15 @@ async fn get_registration_receipt( plugin: Plugin>>, v: serde_json::Value, ) -> Result { - let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?; + let tower_id = TowerId::try_from(v.clone()).map_err(|x| anyhow!(x))?; + let subscription_start = v["subscription-start"].as_u64().map(|v| v as u32); + let subscription_expiry = v["subscription-expiry"].as_u64().map(|v| v as u32); + let state = plugin.state().lock().unwrap(); - if let Some(response) = state.get_registration_receipt(tower_id) { + if let Some(response) = + state.get_registration_receipt(tower_id, subscription_start, subscription_expiry) + { Ok(json!(response)) } else { Err(anyhow!( @@ -501,6 +506,16 @@ async fn main() -> Result<(), Error> { }; let builder = Builder::new(stdin(), stdout()) + .option(ConfigOption::new( + constants::SUBSCRIPTION_START, + Value::OptInteger, + constants::SUBSCRIPTION_START_DESC, + )) + .option(ConfigOption::new( + constants::SUBSCRIPTION_EXPIRY, + Value::OptInteger, + constants::SUBSCRIPTION_EXPIRY_DESC, + )) .option(ConfigOption::new( constants::WT_PORT, Value::Integer(constants::DEFAULT_WT_PORT), diff --git a/watchtower-plugin/src/wt_client.rs b/watchtower-plugin/src/wt_client.rs index ecf26db0..454b073c 100644 --- a/watchtower-plugin/src/wt_client.rs +++ b/watchtower-plugin/src/wt_client.rs @@ -180,8 +180,18 @@ impl WTClient { } /// Gets the latest registration receipt of a given tower. - pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option { - self.dbm.load_registration_receipt(tower_id, self.user_id) + pub fn get_registration_receipt( + &self, + tower_id: TowerId, + subscription_start: Option, + subscription_expiry: Option, + ) -> Option { + self.dbm.load_registration_receipt( + tower_id, + self.user_id, + subscription_start, + subscription_expiry, + ) } /// Loads a tower record from the database.