diff --git a/xmtp/src/builder.rs b/xmtp/src/builder.rs index 497e9aa0d..29edaa53f 100644 --- a/xmtp/src/builder.rs +++ b/xmtp/src/builder.rs @@ -170,11 +170,14 @@ where #[cfg(test)] AccountStrategy::ExternalAccount(a) => a, }; - store.insert_or_ignore_user(StoredUser { - user_address: account.addr(), - created_at: now(), - last_refreshed: 0, - })?; + store.insert_or_ignore_user( + &mut store.conn()?, + StoredUser { + user_address: account.addr(), + created_at: now(), + last_refreshed: 0, + }, + )?; Ok(Client::new(api_client, self.network, account, store)) } diff --git a/xmtp/src/client.rs b/xmtp/src/client.rs index 9d8d07583..8dccd4f29 100644 --- a/xmtp/src/client.rs +++ b/xmtp/src/client.rs @@ -183,7 +183,7 @@ where &self, user_address: &str, ) -> Result<(), ClientError> { - let user = self.store.get_user(user_address)?; + let user = self.store.get_user(&mut self.store.conn()?, user_address)?; if user.is_none() || user.unwrap().last_refreshed < now() - INSTALLATION_REFRESH_INTERVAL_NS { self.refresh_user_installations(user_address).await?; @@ -200,6 +200,7 @@ where let self_install_id = key_fingerprint(&self.account.identity_keys().curve25519); let contacts = self.get_contacts(user_address).await?; + let conn = &mut self.store.conn()?; debug!( "Fetched contacts for address {}: {:?}", user_address, contacts @@ -207,7 +208,7 @@ where let installation_map = self .store - .get_installations(&mut self.store.conn()?, user_address)? + .get_installations(conn, user_address)? .into_iter() .map(|v| (v.installation_id.clone(), v)) .collect::>(); @@ -223,37 +224,35 @@ where user_address, new_installs ); - self.store - .conn()? - .transaction(|transaction_manager| -> Result<(), ClientError> { - self.store.insert_or_ignore_user_with_conn( + conn.transaction(|transaction_manager| -> Result<(), ClientError> { + self.store.insert_or_ignore_user( + transaction_manager, + StoredUser { + user_address: user_address.to_string(), + created_at: now(), + last_refreshed: refresh_timestamp, + }, + )?; + for install in new_installs { + info!("Saving Install {}", install.installation_id); + let session = self.create_uninitialized_session(&install.get_contact()?)?; + + self.store + .insert_or_ignore_install(transaction_manager, install)?; + self.store.insert_or_ignore_session( transaction_manager, - StoredUser { - user_address: user_address.to_string(), - created_at: now(), - last_refreshed: refresh_timestamp, - }, + StoredSession::try_from(&session)?, )?; - for install in new_installs { - info!("Saving Install {}", install.installation_id); - let session = self.create_uninitialized_session(&install.get_contact()?)?; - - self.store - .insert_or_ignore_install(install, transaction_manager)?; - self.store.insert_or_ignore_session( - StoredSession::try_from(&session)?, - transaction_manager, - )?; - } + } - self.store.update_user_refresh_timestamp( - transaction_manager, - user_address, - refresh_timestamp, - )?; + self.store.update_user_refresh_timestamp( + transaction_manager, + user_address, + refresh_timestamp, + )?; - Ok(()) - })?; + Ok(()) + })?; Ok(()) } diff --git a/xmtp/src/conversation.rs b/xmtp/src/conversation.rs index ebbb91afe..b78db00b5 100644 --- a/xmtp/src/conversation.rs +++ b/xmtp/src/conversation.rs @@ -115,7 +115,7 @@ where ) -> Result<(), ConversationError> { let peer_address = peer_addr_from_convo_id(convo_id, &client.wallet_address())?; let created_at = now(); - client.store.insert_or_ignore_user_with_conn( + client.store.insert_or_ignore_user( conn, StoredUser { user_address: peer_address.clone(), @@ -124,7 +124,7 @@ where }, )?; - client.store.insert_or_ignore_conversation_with_conn( + client.store.insert_or_ignore_conversation( conn, StoredConversation { peer_address, @@ -203,11 +203,19 @@ mod tests { let client = gen_test_client().await; let peer_address = "0x000"; let convo_id = format!(":{}:{}", peer_address, client.wallet_address()); - assert!(client.store.get_conversation(&convo_id).unwrap().is_none()); + assert!(client + .store + .get_conversation(&mut client.store.conn().unwrap(), &convo_id) + .unwrap() + .is_none()); let conversation = gen_test_conversation(&client, peer_address).await; assert!(conversation.peer_address() == peer_address); - assert!(client.store.get_conversation(&convo_id).unwrap().is_some()); + assert!(client + .store + .get_conversation(&mut client.store.conn().unwrap(), &convo_id) + .unwrap() + .is_some()); } #[tokio::test] @@ -216,7 +224,10 @@ mod tests { let conversation = gen_test_conversation(&client, "0x000").await; conversation.send_text("Hello, world!").await.unwrap(); - let message = &client.store.get_unprocessed_messages().unwrap()[0]; + let message = &client + .store + .get_unprocessed_messages(&mut client.store.conn().unwrap()) + .unwrap()[0]; let content = EncodedContent::decode(&message.content[..]).unwrap(); assert!(TextCodec::decode(content).unwrap() == "Hello, world!"); } diff --git a/xmtp/src/conversations.rs b/xmtp/src/conversations.rs index 586ed72a6..e42844019 100644 --- a/xmtp/src/conversations.rs +++ b/xmtp/src/conversations.rs @@ -14,7 +14,7 @@ use xmtp_proto::xmtp::{ }; use crate::{ - conversation::{peer_addr_from_convo_id, ConversationError, Conversation}, + conversation::{peer_addr_from_convo_id, Conversation, ConversationError}, message::DecodedInboundMessage, session::SessionManager, storage::{ @@ -72,9 +72,10 @@ impl Conversations { pub fn save_inbound_messages(client: &Client) -> Result<(), ConversationError> { let inbound_topic = build_installation_message_topic(&client.installation_id()); - client - .store - .lock_refresh_job(RefreshJobKind::Message, |conn, job| { + client.store.lock_refresh_job( + &mut client.store.conn()?, + RefreshJobKind::Message, + |conn, job| { log::debug!( "Refresh messages start time: {}", Conversations::::get_start_time(&job).unsigned_abs() @@ -94,7 +95,8 @@ impl Conversations { } Ok(()) - })?; + }, + )?; Ok(()) } @@ -138,7 +140,7 @@ impl Conversations { let existing_sessions = client .store - .get_latest_sessions_for_installation(&payload.sender_installation_id, conn)?; + .get_latest_sessions_for_installation(conn, &payload.sender_installation_id)?; // Attempt to decrypt with existing sessions for raw_session in existing_sessions { @@ -294,12 +296,12 @@ impl Conversations { |transaction| -> Result<(), ConversationError> { let my_sessions = client .store - .get_latest_sessions(&client.wallet_address(), transaction)?; + .get_latest_sessions(transaction, &client.wallet_address())?; let their_user_addr = peer_addr_from_convo_id(&message.convo_id, &client.wallet_address())?; let their_sessions = client .store - .get_latest_sessions(&their_user_addr, transaction)?; + .get_latest_sessions(transaction, &their_user_addr)?; if their_sessions.is_empty() { return Err(ConversationError::NoSessions(their_user_addr)); } @@ -321,11 +323,11 @@ impl Conversations { } client.store.commit_outbound_payloads_for_message( + transaction, message.id, MessageState::LocallyCommitted, outbound_payloads, updated_sessions, - transaction, )?; Ok(()) }, @@ -339,7 +341,9 @@ impl Conversations { client .refresh_user_installations_if_stale(&client.wallet_address()) .await?; - let mut messages = client.store.get_unprocessed_messages()?; + let mut messages = client + .store + .get_unprocessed_messages(&mut client.store.conn()?)?; log::debug!("Processing {} messages", messages.len()); messages.sort_by(|a, b| a.created_at.cmp(&b.created_at)); for message in messages { @@ -359,6 +363,7 @@ impl Conversations { pub async fn publish_outbound_payloads(client: &Client) -> Result<(), ConversationError> { let unsent_payloads = client.store.fetch_and_lock_outbound_payloads( + &mut client.store.conn()?, OutboundPayloadState::Pending, Duration::from_secs(60).as_nanos() as i64, )?; @@ -387,6 +392,7 @@ impl Conversations { .map(|payload| payload.created_at_ns) .collect(); client.store.update_and_unlock_outbound_payloads( + &mut client.store.conn()?, payload_ids, OutboundPayloadState::ServerAcknowledged, )?; @@ -418,8 +424,7 @@ mod tests { let alice_client = gen_test_client().await; let bob_client = gen_test_client().await; let conversation = - Conversation::new(&alice_client, bob_client.wallet_address().to_string()) - .unwrap(); + Conversation::new(&alice_client, bob_client.wallet_address().to_string()).unwrap(); assert_eq!(conversation.peer_address(), bob_client.wallet_address()); } diff --git a/xmtp/src/storage/encrypted_store/mod.rs b/xmtp/src/storage/encrypted_store/mod.rs index 4d6dad41c..36cee910a 100644 --- a/xmtp/src/storage/encrypted_store/mod.rs +++ b/xmtp/src/storage/encrypted_store/mod.rs @@ -179,8 +179,8 @@ impl EncryptedMessageStore { pub fn get_latest_sessions_for_installation( &self, - installation_id: &str, conn: &mut DbConnection, + installation_id: &str, ) -> Result, StorageError> { use self::schema::sessions::dsl::*; @@ -193,8 +193,8 @@ impl EncryptedMessageStore { pub fn get_latest_sessions( &self, - user_address: &str, conn: &mut DbConnection, + user_address: &str, ) -> Result, StorageError> { if !is_wallet_address(user_address) { return Err(StorageError::Unknown( @@ -229,8 +229,8 @@ impl EncryptedMessageStore { pub fn session_exists_for_installation( &self, + conn: &mut DbConnection, installation_id: &str, - conn: &mut PooledConnection>, ) -> Result { use self::schema::sessions::dsl as schema; @@ -254,9 +254,11 @@ impl EncryptedMessageStore { Ok(installation_list) } - pub fn get_user(&self, address: &str) -> Result, StorageError> { - let conn = &mut self.conn()?; - + pub fn get_user( + &self, + conn: &mut DbConnection, + address: &str, + ) -> Result, StorageError> { let mut user_list = users::table .filter(users::user_address.eq(address)) .load::(conn)?; @@ -267,7 +269,7 @@ impl EncryptedMessageStore { pub fn update_user_refresh_timestamp( &self, - conn: &mut PooledConnection>, + conn: &mut DbConnection, user_address: &str, timestamp: i64, ) -> Result { @@ -277,12 +279,7 @@ impl EncryptedMessageStore { .map_err(|e| e.into()) } - pub fn insert_or_ignore_user(&self, user: StoredUser) -> Result<(), StorageError> { - let conn = &mut self.conn()?; - self.insert_or_ignore_user_with_conn(conn, user) - } - - pub fn insert_or_ignore_user_with_conn( + pub fn insert_or_ignore_user( &self, conn: &mut DbConnection, user: StoredUser, @@ -295,10 +292,9 @@ impl EncryptedMessageStore { pub fn get_conversation( &self, + conn: &mut DbConnection, convo_id: &str, ) -> Result, StorageError> { - let conn = &mut self.conn()?; - let mut convo_list = conversations::table .find(convo_id) .load::(conn)?; @@ -308,14 +304,6 @@ impl EncryptedMessageStore { } pub fn insert_or_ignore_conversation( - &self, - conversation: StoredConversation, - ) -> Result<(), StorageError> { - let conn = &mut self.conn()?; - self.insert_or_ignore_conversation_with_conn(conn, conversation) - } - - pub fn insert_or_ignore_conversation_with_conn( &self, conn: &mut DbConnection, conversation: StoredConversation, @@ -328,9 +316,9 @@ impl EncryptedMessageStore { pub fn get_contacts( &self, + conn: &mut DbConnection, user_address: &str, ) -> Result, StorageError> { - let conn = &mut self.conn()?; use self::schema::installations::dsl; let install_list = dsl::installations @@ -342,9 +330,10 @@ impl EncryptedMessageStore { Ok(install_list) } - pub fn get_unprocessed_messages(&self) -> Result, StorageError> { - let conn = &mut self.conn()?; - + pub fn get_unprocessed_messages( + &self, + conn: &mut DbConnection, + ) -> Result, StorageError> { let msg_list = messages::table .filter(messages::state.eq(MessageState::Unprocessed as i32)) .load::(conn)?; @@ -352,14 +341,18 @@ impl EncryptedMessageStore { Ok(msg_list) } - pub fn lock_refresh_job(&self, kind: RefreshJobKind, cb: F) -> Result<(), StorageError> + pub fn lock_refresh_job( + &self, + conn: &mut DbConnection, + kind: RefreshJobKind, + cb: F, + ) -> Result<(), StorageError> where F: FnOnce( &mut PooledConnection>, RefreshJob, ) -> Result<(), StorageError>, { - let conn = &mut self.conn()?; conn.transaction::<(), StorageError, _>(|connection| { let start_time = now(); let job: RefreshJob = refresh_jobs::table @@ -398,7 +391,7 @@ impl EncryptedMessageStore { } pub fn save_inbound_message( &self, - conn: &mut PooledConnection>, + conn: &mut DbConnection, message: InboundMessage, ) -> Result<(), StorageError> { use self::schema::inbound_messages::dsl as schema; @@ -438,8 +431,8 @@ impl EncryptedMessageStore { pub fn insert_or_ignore_install( &self, + conn: &mut DbConnection, install: StoredInstallation, - conn: &mut PooledConnection>, ) -> Result<(), StorageError> { diesel::insert_or_ignore_into(installations::table) .values(install) @@ -449,8 +442,8 @@ impl EncryptedMessageStore { pub fn insert_or_ignore_session( &self, + conn: &mut DbConnection, session: StoredSession, - conn: &mut PooledConnection>, ) -> Result<(), StorageError> { diesel::insert_or_ignore_into(schema::sessions::table) .values(session) @@ -460,7 +453,7 @@ impl EncryptedMessageStore { pub fn insert_or_ignore_message( &self, - conn: &mut PooledConnection>, + conn: &mut DbConnection, msg: NewStoredMessage, ) -> Result<(), StorageError> { diesel::insert_or_ignore_into(schema::messages::table) @@ -471,11 +464,11 @@ impl EncryptedMessageStore { pub fn commit_outbound_payloads_for_message( &self, + conn: &mut DbConnection, message_id: i32, updated_message_state: MessageState, new_outbound_payloads: Vec, updated_sessions: Vec, - conn: &mut PooledConnection>, ) -> Result<(), StorageError> { for session in updated_sessions { diesel::update(schema::sessions::table.find(session.session_id)) @@ -493,10 +486,10 @@ impl EncryptedMessageStore { pub fn fetch_and_lock_outbound_payloads( &self, + conn: &mut DbConnection, payload_state: OutboundPayloadState, lock_duration_ns: i64, ) -> Result, StorageError> { - let conn = &mut self.conn()?; use self::schema::outbound_payloads::dsl as schema; let now = now(); // Must happen atomically @@ -510,10 +503,10 @@ impl EncryptedMessageStore { pub fn update_and_unlock_outbound_payloads( &self, + conn: &mut DbConnection, payload_ids: Vec, new_payload_state: OutboundPayloadState, ) -> Result<(), StorageError> { - let conn = &mut self.conn()?; use self::schema::outbound_payloads::dsl::*; diesel::update(outbound_payloads) .filter(created_at_ns.eq_any(payload_ids)) @@ -973,38 +966,54 @@ mod tests { .unwrap(); store - .lock_refresh_job(RefreshJobKind::Message, |_, job| { - assert_eq!(job.id, RefreshJobKind::Message.to_string()); - assert_eq!(job.last_run, 0); - - Ok(()) - }) + .lock_refresh_job( + &mut store.conn().unwrap(), + RefreshJobKind::Message, + |_, job| { + assert_eq!(job.id, RefreshJobKind::Message.to_string()); + assert_eq!(job.last_run, 0); + + Ok(()) + }, + ) .unwrap(); store - .lock_refresh_job(RefreshJobKind::Message, |_, job| { - assert!(job.last_run > 0); - - Ok(()) - }) + .lock_refresh_job( + &mut store.conn().unwrap(), + RefreshJobKind::Message, + |_, job| { + assert!(job.last_run > 0); + + Ok(()) + }, + ) .unwrap(); - + let mut last_run = 0; - let res_expected_err = store.lock_refresh_job(RefreshJobKind::Message, |_, job| { - assert_eq!(job.id, RefreshJobKind::Message.to_string()); - last_run = job.last_run; + let res_expected_err = store.lock_refresh_job( + &mut store.conn().unwrap(), + RefreshJobKind::Message, + |_, job| { + assert_eq!(job.id, RefreshJobKind::Message.to_string()); + last_run = job.last_run; - Err(StorageError::Unknown(String::from("RefreshJob failed"))) - }); + Err(StorageError::Unknown(String::from("RefreshJob failed"))) + }, + ); assert!(res_expected_err.is_err()); store - .lock_refresh_job(RefreshJobKind::Message, |_, job| { - // Ensure that last run time does not change if the job fails - assert_eq!(job.last_run, last_run); - - Ok(()) - }) + .lock_refresh_job( + &mut store.conn().unwrap(), + RefreshJobKind::Message, + |_, job| { + // Ensure that last run time does not change if the job fails + assert_eq!(job.last_run, last_run); + + Ok(()) + }, + ) .unwrap(); }