diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 96b68057a..87fb406df 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -368,7 +368,7 @@ where id: ref msg_id, .. } = *envelope; - let mut locked_openmls_group = openmls_group.lock(); + let mut locked_openmls_group = openmls_group.lock().await; if intent.state == IntentState::Committed { return Ok(IntentState::Committed); @@ -501,7 +501,7 @@ where id: ref msg_id, .. } = *envelope; - let mut locked_openmls_group = openmls_group.lock(); + let mut locked_openmls_group = openmls_group.lock().await; let decrypted_message = locked_openmls_group.process_message(provider, message)?; let (sender_inbox_id, sender_installation_id) = extract_message_sender( diff --git a/xmtp_mls/src/groups/serial.rs b/xmtp_mls/src/groups/serial.rs index f40868e45..6982f2842 100644 --- a/xmtp_mls/src/groups/serial.rs +++ b/xmtp_mls/src/groups/serial.rs @@ -46,11 +46,34 @@ impl<'a> DerefMut for SerialOpenMlsGroup<'a> { } pub(crate) trait OpenMlsLock { - fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a>; + fn lock_blocking<'a>(&'a mut self) -> SerialOpenMlsGroup<'a>; + async fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a>; } impl OpenMlsLock for OpenMlsGroup { - fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> { + async fn lock<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> { + // .clone() is important here so that the outer lock gets dropped + let mutex = MLS_COMMIT_LOCK + .lock() + .entry(self.group_id().to_vec()) + .or_default() + .clone(); + + // this may block + let lock = mutex.lock().await; + let lock = unsafe { + // let the borrow checker know that this guard's mutex is going to be owned by the struct it's returning + std::mem::transmute::, MutexGuard<'a, ()>>(lock) + }; + + SerialOpenMlsGroup { + group: self, + lock, + _mutex: mutex, + } + } + + fn lock_blocking<'a>(&'a mut self) -> SerialOpenMlsGroup<'a> { // .clone() is important here so that the outer lock gets dropped let mutex = MLS_COMMIT_LOCK .lock()