Skip to content

Commit

Permalink
Add optional boundaries to getregistrationreceipt
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamBhut committed Mar 4, 2023
1 parent 076c972 commit 24d0910
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 24 deletions.
6 changes: 6 additions & 0 deletions watchtower-plugin/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ pub const TOWERS_DATA_DIR: &str = "TOWERS_DATA_DIR";
pub const DEFAULT_TOWERS_DATA_DIR: &str = ".watchtower";

/// Collections of plugin option names, default values and descriptions
pub const DEFAULT_SUBSCRIPTION_START: Option<i64> = None;
pub const SUBSCRIPTION_START: &str = "subscription-start";
pub const SUBSCRIPTION_START_DESC: &str = "subscription-start time";
pub const DEFAULT_SUBSCRIPTION_EXPIRY: Option<i64> = None;
pub const SUBSCRIPTION_EXPIRY: &str = "subscription-expiry";
pub const SUBSCRIPTION_EXPIRY_DESC: &str = "subscription-expiry time";

pub const WT_PORT: &str = "watchtower-port";
pub const DEFAULT_WT_PORT: i64 = 9814;
Expand Down
74 changes: 54 additions & 20 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FromIterator;
use std::path::PathBuf;
use std::str::FromStr;

use rusqlite::{params, Connection, Error as SqliteError};
use rusqlite::{params, Connection, Error as SqliteError, ToSql};

use bitcoin::secp256k1::SecretKey;

Expand Down Expand Up @@ -216,19 +216,29 @@ impl DBM {
&self,
tower_id: TowerId,
user_id: UserId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<RegistrationReceipt> {
let mut stmt = self
.connection
.prepare(
"SELECT available_slots, subscription_start, subscription_expiry, signature
FROM registration_receipts
WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry)
FROM registration_receipts
WHERE tower_id = ?1)",
)
.unwrap();
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

let mut params = vec![tower_id.to_vec()];

if let Some(start) = subscription_start {
query.push_str(" AND subscription_start >= ?2");
params.push(start.to_be_bytes().to_vec());
}

if let Some(expiry) = subscription_expiry {
query.push_str(" AND subscription_expiry <= ?3");
params.push(expiry.to_be_bytes().to_vec());
}

query.push_str(" ORDER BY subscription_expiry DESC LIMIT 1");

let mut stmt = self.connection.prepare(&query).unwrap();
let params: Vec<&dyn ToSql> = params.iter().map(|v| v as &dyn ToSql).collect();

stmt.query_row([tower_id.to_vec()], |row| {
stmt.query_row(params.as_slice(), |row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
Expand Down Expand Up @@ -725,13 +735,20 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = None;
let subscription_expiry = None;

// Check the receipt was stored
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
receipt
);

Expand All @@ -742,17 +759,27 @@ mod tests {
dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
latest_receipt
);

// Add a final one with a lower expiry and check the last is still loaded
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
latest_receipt
);
}
Expand All @@ -765,13 +792,20 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = None;
let subscription_expiry = None;

// Store it once
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap(),
receipt
);

Expand Down
19 changes: 17 additions & 2 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,15 @@ async fn get_registration_receipt(
plugin: Plugin<Arc<Mutex<WTClient>>>,
v: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?;
let tower_id = TowerId::try_from(v.clone()).map_err(|x| anyhow!(x))?;
let subscription_start = v["subscription-start"].as_u64().map(|v| v as u32);
let subscription_expiry = v["subscription-expiry"].as_u64().map(|v| v as u32);

let state = plugin.state().lock().unwrap();

if let Some(response) = state.get_registration_receipt(tower_id) {
if let Some(response) =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry)
{
Ok(json!(response))
} else {
Err(anyhow!(
Expand Down Expand Up @@ -501,6 +506,16 @@ async fn main() -> Result<(), Error> {
};

let builder = Builder::new(stdin(), stdout())
.option(ConfigOption::new(
constants::SUBSCRIPTION_START,
Value::OptInteger,
constants::SUBSCRIPTION_START_DESC,
))
.option(ConfigOption::new(
constants::SUBSCRIPTION_EXPIRY,
Value::OptInteger,
constants::SUBSCRIPTION_EXPIRY_DESC,
))
.option(ConfigOption::new(
constants::WT_PORT,
Value::Integer(constants::DEFAULT_WT_PORT),
Expand Down
14 changes: 12 additions & 2 deletions watchtower-plugin/src/wt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,18 @@ impl WTClient {
}

/// Gets the latest registration receipt of a given tower.
pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option<RegistrationReceipt> {
self.dbm.load_registration_receipt(tower_id, self.user_id)
pub fn get_registration_receipt(
&self,
tower_id: TowerId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<RegistrationReceipt> {
self.dbm.load_registration_receipt(
tower_id,
self.user_id,
subscription_start,
subscription_expiry,
)
}

/// Loads a tower record from the database.
Expand Down

0 comments on commit 24d0910

Please sign in to comment.