diff --git a/watchtower-plugin/src/constants.rs b/watchtower-plugin/src/constants.rs index 9e224b94..22ae22af 100644 --- a/watchtower-plugin/src/constants.rs +++ b/watchtower-plugin/src/constants.rs @@ -3,13 +3,6 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR"; pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower"; /// Collections of plugin option names, default values and descriptions -pub const DEFAULT_SUBSCRIPTION_START: Option = None; -pub const SUBSCRIPTION_START: &str = "subscription-start"; -pub const SUBSCRIPTION_START_DESC: &str = "subscription-start time"; -pub const DEFAULT_SUBSCRIPTION_EXPIRY: Option = None; -pub const SUBSCRIPTION_EXPIRY: &str = "subscription-expiry"; -pub const SUBSCRIPTION_EXPIRY_DESC: &str = "subscription-expiry time"; - pub const WT_PORT: &str = "watchtower-port"; pub const DEFAULT_WT_PORT: i64 = 9814; pub const WT_PORT_DESC: &str = "tower API port"; diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs index ec1b9d01..4cf1031f 100644 --- a/watchtower-plugin/src/convert.rs +++ b/watchtower-plugin/src/convert.rs @@ -199,8 +199,8 @@ impl TryFrom for GetAppointmentParams { let param_count = a.len(); if param_count != 2 { Err(GetAppointmentError::InvalidFormat(format!( - "Unexpected request format. The request needs 2 parameter. Received: {param_count}" - ))) + "Unexpected request format. The request needs 2 parameter. Received: {param_count}" + ))) } else { let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { TowerId::from_str(s).map_err(|_| { @@ -258,6 +258,81 @@ impl TryFrom for GetAppointmentParams { } } +// Errors related to `getregistrationreceipt` command +#[derive(Debug)] +pub enum GetRegistrationReceiptError { + InvalidId(String), + InvalidFormat(String), +} + +impl std::fmt::Display for GetRegistrationReceiptError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"), + GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"), + } + } +} + +// Parameters related to the `getregistrationreceipt` command +#[derive(Debug)] +pub struct GetRegistrationReceiptParams { + pub tower_id: TowerId, + pub subscription_start: Option, + pub subscription_expiry: Option, +} + +impl TryFrom for GetRegistrationReceiptParams { + type Error = GetRegistrationReceiptError; + + fn try_from(value: serde_json::Value) -> Result { + match value { + serde_json::Value::Array(a) => { + let tower_id = if let Some(s) = a.get(0).unwrap().as_str() { + TowerId::from_str(s).map_err(|_| { + GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned()) + }) + } else { + Err(GetRegistrationReceiptError::InvalidId( + "tower_id must be a hex encoded string".to_owned(), + )) + }?; + let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) { + if start >= 0 { + Some(start as u32) + } else { + return Err(GetRegistrationReceiptError::InvalidFormat( + "Subscription-start must be a positive integer".to_owned(), + )); + } + } else { + None + }; + let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) { + if expire >= 0 { + Some(expire as u32) + } else { + return Err(GetRegistrationReceiptError::InvalidFormat( + "Subscription-expire must be a positive integer".to_owned(), + )); + } + } else { + None + }; + Ok(Self { + tower_id, + subscription_start, + subscription_expiry, + }) + } + _ => Err(GetRegistrationReceiptError::InvalidFormat(format!( + "Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'" + ))), + } + } +} + + /// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook. #[derive(Debug, Serialize, Deserialize)] pub struct CommitmentRevocation { diff --git a/watchtower-plugin/src/dbm.rs b/watchtower-plugin/src/dbm.rs index 4e77dc86..a3d04ba1 100755 --- a/watchtower-plugin/src/dbm.rs +++ b/watchtower-plugin/src/dbm.rs @@ -226,6 +226,8 @@ impl DBM { if let Some(start) = subscription_start { query.push_str(" AND subscription_start >= ?2"); params.push(start.to_be_bytes().to_vec()); + } else { + query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)"); } if let Some(expiry) = subscription_expiry { @@ -233,7 +235,7 @@ impl DBM { params.push(expiry.to_be_bytes().to_vec()); } - query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1"); + //query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1"); let mut stmt = self.connection.prepare(&query).unwrap(); let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect(); @@ -247,8 +249,8 @@ impl DBM { Ok(RegistrationReceipt::with_signature( user_id, slots, start, expiry, signature, )) - }) - .ok() + }). + ok() } /// Removes a tower record from the database. diff --git a/watchtower-plugin/src/main.rs b/watchtower-plugin/src/main.rs index 81786ff8..d0548e9c 100755 --- a/watchtower-plugin/src/main.rs +++ b/watchtower-plugin/src/main.rs @@ -18,7 +18,9 @@ use teos_common::protos as common_msgs; use teos_common::TowerId; use teos_common::{cryptography, errors}; -use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams}; +use watchtower_plugin::convert::{ + CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams, +}; use watchtower_plugin::net::http::{ self, post_request, process_post_response, AddAppointmentError, ApiResponse, RequestError, }; @@ -133,10 +135,10 @@ async fn get_registration_receipt( plugin: Plugin>>, v: serde_json::Value, ) -> Result { - let tower_id = TowerId::try_from(v.clone()).map_err(|x| anyhow!(x))?; - let subscription_start = v["subscription-start"].as_u64().map(|v| v as u32); - let subscription_expiry = v["subscription-expiry"].as_u64().map(|v| v as u32); - + let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?; + let tower_id = params.tower_id; + let subscription_start = params.subscription_start; + let subscription_expiry = params.subscription_expiry; let state = plugin.state().lock().unwrap(); if let Some(response) = @@ -506,16 +508,6 @@ async fn main() -> Result<(), Error> { }; let builder = Builder::new(stdin(), stdout()) - .option(ConfigOption::new( - constants::SUBSCRIPTION_START, - Value::OptInteger, - constants::SUBSCRIPTION_START_DESC, - )) - .option(ConfigOption::new( - constants::SUBSCRIPTION_EXPIRY, - Value::OptInteger, - constants::SUBSCRIPTION_EXPIRY_DESC, - )) .option(ConfigOption::new( constants::WT_PORT, Value::Integer(constants::DEFAULT_WT_PORT),