Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Mar 9, 2023
1 parent 24d0910 commit 0c95fa0
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 27 deletions.
7 changes: 0 additions & 7 deletions watchtower-plugin/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR";
pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower";

/// Collections of plugin option names, default values and descriptions
pub const DEFAULT_SUBSCRIPTION_START: Option<i64> = None;
pub const SUBSCRIPTION_START: &str = "subscription-start";
pub const SUBSCRIPTION_START_DESC: &str = "subscription-start time";
pub const DEFAULT_SUBSCRIPTION_EXPIRY: Option<i64> = None;
pub const SUBSCRIPTION_EXPIRY: &str = "subscription-expiry";
pub const SUBSCRIPTION_EXPIRY_DESC: &str = "subscription-expiry time";

pub const WT_PORT: &str = "watchtower-port";
pub const DEFAULT_WT_PORT: i64 = 9814;
pub const WT_PORT_DESC: &str = "tower API port";
Expand Down
79 changes: 77 additions & 2 deletions watchtower-plugin/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
let param_count = a.len();
if param_count != 2 {
Err(GetAppointmentError::InvalidFormat(format!(
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
)))
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
)))
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
Expand Down Expand Up @@ -258,6 +258,81 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
}
}

// Errors related to `getregistrationreceipt` command
#[derive(Debug)]
pub enum GetRegistrationReceiptError {
InvalidId(String),
InvalidFormat(String),
}

impl std::fmt::Display for GetRegistrationReceiptError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"),
GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"),
}
}
}

// Parameters related to the `getregistrationreceipt` command
#[derive(Debug)]
pub struct GetRegistrationReceiptParams {
pub tower_id: TowerId,
pub subscription_start: Option<u32>,
pub subscription_expiry: Option<u32>,
}

impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
type Error = GetRegistrationReceiptError;

fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Array(a) => {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
})
} else {
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
))
}?;
let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) {
if start >= 0 {
Some(start as u32)
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-start must be a positive integer".to_owned(),
));
}
} else {
None
};
let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) {
if expire >= 0 {
Some(expire as u32)
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"Subscription-expire must be a positive integer".to_owned(),
));
}
} else {
None
};
Ok(Self {
tower_id,
subscription_start,
subscription_expiry,
})
}
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'"
))),
}
}
}


/// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook.
#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentRevocation {
Expand Down
8 changes: 5 additions & 3 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,16 @@ impl DBM {
if let Some(start) = subscription_start {
query.push_str(" AND subscription_start >= ?2");
params.push(start.to_be_bytes().to_vec());
} else {
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)");
}

if let Some(expiry) = subscription_expiry {
query.push_str(" AND subscription_expiry <= ?3");
params.push(expiry.to_be_bytes().to_vec());
}

query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1");
//query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1");

let mut stmt = self.connection.prepare(&query).unwrap();
let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect();
Expand All @@ -247,8 +249,8 @@ impl DBM {
Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
})
.ok()
}).
ok()
}

/// Removes a tower record from the database.
Expand Down
22 changes: 7 additions & 15 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use teos_common::protos as common_msgs;
use teos_common::TowerId;
use teos_common::{cryptography, errors};

use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams};
use watchtower_plugin::convert::{
CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams,
};
use watchtower_plugin::net::http::{
self, post_request, process_post_response, AddAppointmentError, ApiResponse, RequestError,
};
Expand Down Expand Up @@ -133,10 +135,10 @@ async fn get_registration_receipt(
plugin: Plugin<Arc<Mutex<WTClient>>>,
v: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let tower_id = TowerId::try_from(v.clone()).map_err(|x| anyhow!(x))?;
let subscription_start = v["subscription-start"].as_u64().map(|v| v as u32);
let subscription_expiry = v["subscription-expiry"].as_u64().map(|v| v as u32);

let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?;
let tower_id = params.tower_id;
let subscription_start = params.subscription_start;
let subscription_expiry = params.subscription_expiry;
let state = plugin.state().lock().unwrap();

if let Some(response) =
Expand Down Expand Up @@ -506,16 +508,6 @@ async fn main() -> Result<(), Error> {
};

let builder = Builder::new(stdin(), stdout())
.option(ConfigOption::new(
constants::SUBSCRIPTION_START,
Value::OptInteger,
constants::SUBSCRIPTION_START_DESC,
))
.option(ConfigOption::new(
constants::SUBSCRIPTION_EXPIRY,
Value::OptInteger,
constants::SUBSCRIPTION_EXPIRY_DESC,
))
.option(ConfigOption::new(
constants::WT_PORT,
Value::Integer(constants::DEFAULT_WT_PORT),
Expand Down

0 comments on commit 0c95fa0

Please sign in to comment.