diff --git a/lib/core/src/sync/mod.rs b/lib/core/src/sync/mod.rs index 15eb8900d..de53e8f65 100644 --- a/lib/core/src/sync/mod.rs +++ b/lib/core/src/sync/mod.rs @@ -319,13 +319,16 @@ impl SyncService { #[cfg(test)] mod tests { use anyhow::{anyhow, Result}; - use std::sync::Arc; - use tokio::sync::mpsc; + use std::{collections::HashMap, sync::Arc}; + use tokio::sync::{mpsc, Mutex}; use crate::{ - prelude::Signer, + persist::Persister, + prelude::{Direction, PaymentState, Signer}, + sync::model::SyncState, test_utils::{ - persist::new_persister, + chain_swap::new_chain_swap, + persist::{new_persister, new_receive_swap, new_send_swap}, sync::{ new_chain_sync_data, new_receive_sync_data, new_send_sync_data, MockSyncerClient, }, @@ -334,7 +337,7 @@ mod tests { }; use super::{ - model::{data::SyncData, sync::Record}, + model::{data::SyncData, sync::Record, RecordType}, SyncService, }; @@ -357,7 +360,8 @@ mod tests { ]; let (incoming_tx, incoming_rx) = mpsc::channel::(10); - let client = Box::new(MockSyncerClient::new(incoming_rx)); + let outgoing_records = Arc::new(Mutex::new(HashMap::new())); + let client = Box::new(MockSyncerClient::new(incoming_rx, outgoing_records.clone())); let sync_service = SyncService::new("".to_string(), persister.clone(), signer.clone(), client); @@ -389,7 +393,6 @@ mod tests { let new_preimage = Some("preimage".to_string()); let new_accept_zero_conf = false; - let new_server_lockup_tx_id = Some("server_lockup_tx_id".to_string()); let sync_data = vec![ SyncData::Send(new_send_sync_data(new_preimage.clone())), SyncData::Chain(new_chain_sync_data(Some(new_accept_zero_conf))), @@ -412,11 +415,138 @@ mod tests { } if let Some(chain_swap) = persister.fetch_chain_swap_by_id(&sync_data[2].id())? { assert_eq!(chain_swap.accept_zero_conf, new_accept_zero_conf); - assert_eq!(chain_swap.server_lockup_tx_id, new_server_lockup_tx_id); } else { return Err(anyhow!("Chain swap not found")); } Ok(()) } + + fn get_outgoing_record<'a, 'b>( + persister: Arc, + outgoing: &'a HashMap, + data_id: &'b str, + record_type: RecordType, + ) -> Result<&'a Record> { + let record_id = Record::get_id_from_record_type(record_type, data_id); + let sync_state = persister + .get_sync_state_by_record_id(&record_id)? + .ok_or(anyhow::anyhow!("Expected existing swap state"))?; + let Some(record) = outgoing.get(&sync_state.record_id) else { + return Err(anyhow::anyhow!( + "Expecting existing record in client's outgoing list" + )); + }; + Ok(record) + } + + #[tokio::test] + async fn test_outgoing_sync() -> Result<()> { + let (_temp_dir, persister) = new_persister()?; + let persister = Arc::new(persister); + + let signer: Arc> = Arc::new(Box::new(MockSigner::new())); + + let (_incoming_tx, incoming_rx) = mpsc::channel::(10); + let outgoing_records = Arc::new(Mutex::new(HashMap::new())); + let client = Box::new(MockSyncerClient::new(incoming_rx, outgoing_records.clone())); + let sync_service = + SyncService::new("".to_string(), persister.clone(), signer.clone(), client); + + // Test insert + persister.insert_receive_swap(&new_receive_swap(None))?; + persister.insert_send_swap(&new_send_swap(None))?; + persister.insert_chain_swap(&new_chain_swap(Direction::Incoming, None, true, None))?; + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + assert_eq!(outgoing.len(), 3); + drop(outgoing); + + // Test conflict + let swap = new_receive_swap(None); + persister.insert_receive_swap(&swap)?; + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + assert_eq!(outgoing.len(), 4); + let record = + get_outgoing_record(persister.clone(), &outgoing, &swap.id, RecordType::Receive)?; + persister.set_sync_state(SyncState { + data_id: swap.id.clone(), + record_id: record.id.clone(), + record_revision: 90, // Set a wrong record revision + is_local: true, + })?; + drop(outgoing); + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + assert_eq!(outgoing.len(), 4); // No records were added + drop(outgoing); + + // Test update before push + let swap = new_send_swap(None); + persister.insert_send_swap(&swap)?; + let new_preimage = Some("new-preimage"); + persister.try_handle_send_swap_update( + &swap.id, + PaymentState::Pending, + new_preimage.clone(), + None, + None, + )?; + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + + let record = get_outgoing_record(persister.clone(), &outgoing, &swap.id, RecordType::Send)?; + let decrypted_record = record.clone().decrypt(signer.clone())?; + assert_eq!(decrypted_record.data.id(), &swap.id); + match decrypted_record.data { + SyncData::Send(data) => { + assert_eq!(data.preimage, new_preimage.map(|p| p.to_string())); + } + _ => { + return Err(anyhow::anyhow!("Unexpected sync data type received.")); + } + } + drop(outgoing); + + // Test update after push + let swap = new_send_swap(None); + persister.insert_send_swap(&swap)?; + + sync_service.push().await?; + + let new_preimage = Some("new-preimage"); + persister.try_handle_send_swap_update( + &swap.id, + PaymentState::Pending, + new_preimage.clone(), + None, + None, + )?; + + sync_service.push().await?; + + let outgoing = outgoing_records.lock().await; + let record = get_outgoing_record(persister.clone(), &outgoing, &swap.id, RecordType::Send)?; + let decrypted_record = record.clone().decrypt(signer.clone())?; + assert_eq!(decrypted_record.data.id(), &swap.id); + match decrypted_record.data { + SyncData::Send(data) => { + assert_eq!(data.preimage, new_preimage.map(|p| p.to_string()),); + } + _ => { + return Err(anyhow::anyhow!("Unexpected sync data type received.")); + } + } + + Ok(()) + } } diff --git a/lib/core/src/test_utils/sync.rs b/lib/core/src/test_utils/sync.rs index aaa5d8027..a451ced46 100644 --- a/lib/core/src/test_utils/sync.rs +++ b/lib/core/src/test_utils/sync.rs @@ -1,5 +1,7 @@ #![cfg(test)] +use std::{collections::HashMap, sync::Arc}; + use crate::{ prelude::Direction, sync::{ @@ -8,6 +10,7 @@ use crate::{ data::{ChainSyncData, ReceiveSyncData, SendSyncData}, sync::{ ListChangesReply, ListChangesRequest, Record, SetRecordReply, SetRecordRequest, + SetRecordStatus, }, }, }, @@ -18,12 +21,17 @@ use tokio::sync::{mpsc::Receiver, Mutex}; pub(crate) struct MockSyncerClient { pub(crate) incoming_rx: Mutex>, + pub(crate) outgoing_records: Arc>>, } impl MockSyncerClient { - pub(crate) fn new(incoming_rx: Receiver) -> Self { + pub(crate) fn new( + incoming_rx: Receiver, + outgoing_records: Arc>>, + ) -> Self { Self { incoming_rx: Mutex::new(incoming_rx), + outgoing_records, } } } @@ -34,8 +42,30 @@ impl SyncerClient for MockSyncerClient { todo!() } - async fn push(&self, _req: SetRecordRequest) -> Result { - todo!() + async fn push(&self, req: SetRecordRequest) -> Result { + if let Some(mut record) = req.record { + let mut outgoing_records = self.outgoing_records.lock().await; + + if let Some(existing_record) = outgoing_records.get(&record.id) { + if existing_record.revision != record.revision { + return Ok(SetRecordReply { + status: SetRecordStatus::Conflict as i32, + new_revision: 0, + }); + } + } + + record.revision = outgoing_records.len() as u64 + 1; + let record_revision = record.revision; + + outgoing_records.insert(record.id.clone(), record); + return Ok(SetRecordReply { + status: SetRecordStatus::Success as i32, + new_revision: record_revision, + }); + } + + return Err(anyhow::anyhow!("No record was sent")); } async fn pull(&self, _req: ListChangesRequest) -> Result {