diff --git a/src/comms/comms.rs b/src/comms/comms.rs index 54b79f3..52cd883 100644 --- a/src/comms/comms.rs +++ b/src/comms/comms.rs @@ -86,7 +86,7 @@ impl CommsAccess { } pub(crate) async fn register_peer_message_tx( - &mut self, + &self, trade_uuid: Uuid, tx: mpsc::Sender, ) -> Result<(), N3xbError> { @@ -101,7 +101,7 @@ impl CommsAccess { } pub(crate) async fn unregister_peer_message_tx( - &mut self, + &self, trade_uuid: Uuid, ) -> Result<(), N3xbError> { let (rsp_tx, rsp_rx) = oneshot::channel::>(); @@ -111,7 +111,7 @@ impl CommsAccess { } pub(crate) async fn register_peer_message_fallback_tx( - &mut self, + &self, tx: mpsc::Sender, ) -> Result<(), N3xbError> { let (rsp_tx, rsp_rx) = oneshot::channel::>(); @@ -120,7 +120,7 @@ impl CommsAccess { rsp_rx.await.unwrap() } - pub(crate) async fn unregister_peer_message_fallback_tx(&mut self) -> Result<(), N3xbError> { + pub(crate) async fn unregister_peer_message_fallback_tx(&self) -> Result<(), N3xbError> { let (rsp_tx, rsp_rx) = oneshot::channel::>(); let request = CommsRequest::UnregisterFallbackTx { rsp_tx }; self.tx.send(request).await.unwrap(); diff --git a/src/maker/data.rs b/src/maker/data.rs index 0c4a27a..c5cbf73 100644 --- a/src/maker/data.rs +++ b/src/maker/data.rs @@ -7,7 +7,10 @@ use std::{ use uuid::Uuid; use serde::{Deserialize, Serialize}; -use tokio::sync::{mpsc, RwLock}; +use tokio::{ + select, + sync::{mpsc, RwLock}, +}; use url::Url; use crate::{ @@ -49,10 +52,16 @@ impl MakerActorDataStore { } } +enum MakerActorDataMsg { + Persist, + Close, +} + pub(crate) struct MakerActorData { pub(crate) trade_uuid: Uuid, - persist_tx: mpsc::Sender<()>, + persist_tx: mpsc::Sender, store: Arc>, + task_handle: tokio::task::JoinHandle<()>, } impl MakerActorData { @@ -74,11 +83,14 @@ impl MakerActorData { reject_invalid_offers_silently, }; let store = Arc::new(RwLock::new(store)); + let (persist_tx, task_handle) = + Self::setup_persistance(store.clone(), trade_uuid, &dir_path); Self { - persist_tx: Self::setup_persistance(store.clone(), trade_uuid, &dir_path), + persist_tx, trade_uuid, store, + task_handle, } } @@ -87,10 +99,16 @@ impl MakerActorData { let trade_uuid = store.order.trade_uuid; let store = Arc::new(RwLock::new(store)); + let dir_path = data_path.as_ref().parent().unwrap(); + + let (persist_tx, task_handle) = + Self::setup_persistance(store.clone(), trade_uuid, &dir_path); + let data = Self { - persist_tx: Self::setup_persistance(store.clone(), trade_uuid, &data_path), + persist_tx, trade_uuid, store, + task_handle, }; Ok((trade_uuid, data)) } @@ -99,32 +117,41 @@ impl MakerActorData { store: Arc>, trade_uuid: Uuid, dir_path: impl AsRef, - ) -> mpsc::Sender<()> { + ) -> (mpsc::Sender, tokio::task::JoinHandle<()>) { // No more than 1 persistance request is allowed nor needed. // This is essentilaly a debounce mechanism let (persist_tx, mut persist_rx) = mpsc::channel(1); let dir_path_buf = dir_path.as_ref().to_path_buf(); - tokio::spawn(async move { + let task_handle = tokio::spawn(async move { let dir_path_buf = dir_path_buf.clone(); loop { - persist_rx.recv().await; - match store.read().await.persist(&dir_path_buf).await { - Ok(_) => {} - Err(err) => { - error!( - "Maker w/ TradeUUID {} - Error persisting data: {}", - trade_uuid, err - ); - } + select! { + Some(msg) = persist_rx.recv() => { + match msg { + MakerActorDataMsg::Persist => { + if let Some(err) = store.read().await.persist(&dir_path_buf).await.err() { + error!( + "Maker w/ TradeUUID {} - Error persisting data: {}", + trade_uuid, err + ); + } + } + MakerActorDataMsg::Close => { + break; + } + } + + }, + else => break, } } }); - persist_tx + (persist_tx, task_handle) } fn queue_persistance(&self) { - match self.persist_tx.try_send(()) { + match self.persist_tx.try_send(MakerActorDataMsg::Persist) { Ok(_) => {} Err(error) => match error { mpsc::error::TrySendError::Full(_) => { @@ -240,4 +267,10 @@ impl MakerActorData { self.store.write().await.reject_invalid_offers_silently = reject_invalid_offers_silently; self.queue_persistance(); } + + pub(crate) async fn terminate(self) -> Result<(), N3xbError> { + self.persist_tx.send(MakerActorDataMsg::Close).await?; + self.task_handle.await?; + Ok(()) + } } diff --git a/src/maker/maker.rs b/src/maker/maker.rs index b10f6c3..342f96b 100644 --- a/src/maker/maker.rs +++ b/src/maker/maker.rs @@ -222,7 +222,7 @@ impl MakerActor { Ok((trade_uuid, actor)) } - async fn run(&mut self) { + async fn run(mut self) { let (tx, mut rx) = mpsc::channel::(20); if let Some(error) = self @@ -252,6 +252,7 @@ impl MakerActor { } } info!("Maker w/ TradeUUID {} terminating", self.data.trade_uuid); + self.data.terminate().await.unwrap(); } // Top-down Request Handling diff --git a/src/taker/data.rs b/src/taker/data.rs index 8ab3297..251a43c 100644 --- a/src/taker/data.rs +++ b/src/taker/data.rs @@ -1,8 +1,11 @@ +use log::{error, trace}; use std::{path::Path, sync::Arc}; -use log::{error, trace}; use serde::{Deserialize, Serialize}; -use tokio::sync::{mpsc, RwLock}; +use tokio::{ + select, + sync::{mpsc, RwLock}, +}; use uuid::Uuid; use crate::{ @@ -38,10 +41,16 @@ impl TakerActorDataStore { } } +enum TakerActorDataMsg { + Persist, + Close, +} + pub(crate) struct TakerActorData { pub(crate) trade_uuid: Uuid, - persist_tx: mpsc::Sender<()>, + persist_tx: mpsc::Sender, store: Arc>, + task_handle: tokio::task::JoinHandle<()>, } impl TakerActorData { @@ -59,11 +68,14 @@ impl TakerActorData { trade_completed: false, }; let store = Arc::new(RwLock::new(store)); + let (persist_tx, task_handle) = + Self::setup_persistance(store.clone(), trade_uuid, &dir_path); Self { - persist_tx: Self::setup_persistance(store.clone(), trade_uuid, &dir_path), + persist_tx, trade_uuid, store, + task_handle, } } @@ -72,10 +84,16 @@ impl TakerActorData { let trade_uuid = store.order_envelope.order.trade_uuid; let store = Arc::new(RwLock::new(store)); + let dir_path = data_path.as_ref().parent().unwrap(); + + let (persist_tx, task_handle) = + Self::setup_persistance(store.clone(), trade_uuid, &dir_path); + let data = Self { - persist_tx: Self::setup_persistance(store.clone(), trade_uuid, &data_path), + persist_tx, trade_uuid, store, + task_handle, }; Ok((trade_uuid, data)) @@ -85,32 +103,41 @@ impl TakerActorData { store: Arc>, trade_uuid: Uuid, dir_path: impl AsRef, - ) -> mpsc::Sender<()> { + ) -> (mpsc::Sender, tokio::task::JoinHandle<()>) { // No more than 1 persistance request is allowed nor needed. // This is essentilaly a debounce mechanism let (persist_tx, mut persist_rx) = mpsc::channel(1); let dir_path_buf = dir_path.as_ref().to_path_buf(); - tokio::spawn(async move { + let task_handle = tokio::spawn(async move { let dir_path_buf = dir_path_buf.clone(); loop { - persist_rx.recv().await; - match store.read().await.persist(&dir_path_buf).await { - Ok(_) => {} - Err(err) => { - error!( - "Taker w/ Trade UUID {} - Error persisting data: {}", - trade_uuid, err - ); - } + select! { + Some(msg) = persist_rx.recv() => { + match msg { + TakerActorDataMsg::Persist => { + if let Some(err) = store.read().await.persist(&dir_path_buf).await.err() { + error!( + "Taker w/ TradeUUID {} - Error persisting data: {}", + trade_uuid, err + ); + } + } + TakerActorDataMsg::Close => { + break; + } + } + + }, + else => break, } } }); - persist_tx + (persist_tx, task_handle) } fn queue_persistance(&self) { - match self.persist_tx.try_send(()) { + match self.persist_tx.try_send(TakerActorDataMsg::Persist) { Ok(_) => {} Err(error) => match error { mpsc::error::TrySendError::Full(_) => { @@ -167,4 +194,10 @@ impl TakerActorData { self.store.write().await.trade_completed = trade_completed; self.queue_persistance(); } + + pub(crate) async fn terminate(self) -> Result<(), N3xbError> { + self.persist_tx.send(TakerActorDataMsg::Close).await?; + self.task_handle.await?; + Ok(()) + } } diff --git a/src/taker/taker.rs b/src/taker/taker.rs index db07fc8..c0b97dd 100644 --- a/src/taker/taker.rs +++ b/src/taker/taker.rs @@ -183,7 +183,7 @@ impl TakerActor { Ok((trade_uuid, actor)) } - async fn run(&mut self) { + async fn run(mut self) { let (tx, mut rx) = mpsc::channel::(20); if let Some(error) = self @@ -213,7 +213,8 @@ impl TakerActor { } } - info!("Taker w/ TradeUUID {} terminating", self.data.trade_uuid) + info!("Taker w/ TradeUUID {} terminating", self.data.trade_uuid); + self.data.terminate().await.unwrap(); } // Top-down Requests Handling