Skip to content

Commit

Permalink
use a trait
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Dec 18, 2024
1 parent 13031f8 commit 95639d9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
2 changes: 1 addition & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ where
target_inbox_id: String,
) -> Result<MlsGroup<Self>, ClientError> {
let conn = self.store().conn()?;
match conn.find_dm_group(&target_inbox_id)? {
match conn.find_dm_group(self.inbox_id(), &target_inbox_id)? {
Some(dm_group) => Ok(MlsGroup::new(
self.clone(),
dm_group.id,
Expand Down
7 changes: 5 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ use self::{
intents::IntentError,
validated_commit::CommitValidationError,
};
use crate::storage::StorageError;
use crate::storage::{
group::{DmId, DmIdExt},
StorageError,
};
use xmtp_common::time::now_ns;
use xmtp_proto::xmtp::mls::{
api::v1::{
Expand Down Expand Up @@ -481,7 +484,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
now_ns(),
membership_state,
context.inbox_id().to_string(),
Some(StoredGroup::dm_id([&dm_target_inbox_id, client.inbox_id()])),
Some(DmId::from_ids([&dm_target_inbox_id, client.inbox_id()])),
);

stored_group.store(provider.conn_ref())?;
Expand Down
52 changes: 41 additions & 11 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub struct StoredGroup {
/// Enum, [`ConversationType`] signifies the group conversation type which extends to who can access it.
pub conversation_type: ConversationType,
/// The inbox_id of the DM target
pub dm_id: Option<String>,
pub dm_id: Option<DmId>,
/// Timestamp of when the last message was sent for this group (updated automatically in a trigger)
pub last_message_ns: i64,
}
Expand Down Expand Up @@ -120,14 +120,6 @@ impl StoredGroup {
last_message_ns: now_ns(),
}
}

pub fn dm_id(inbox_ids: [&str; 2]) -> String {
let inbox_ids: = inbox_ids
.into_iter()
.map(str::to_lowercase)
.collect::<Vec<_>>();
format!("dm:{}", inbox_ids.join(":"))
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -351,11 +343,14 @@ impl DbConnection {

pub fn find_dm_group(
&self,
inbox_id: &str,
target_inbox_id: &str,
) -> Result<Option<StoredGroup>, StorageError> {
let dm_id = DmId::from_ids([inbox_id, target_inbox_id]);

let query = dsl::groups
.order(dsl::created_at_ns.asc())
.filter(dsl::dm_inbox_id.eq(Some(target_inbox_id)));
.filter(dsl::dm_id.eq(Some(dm_id)));

let groups: Vec<StoredGroup> = self.raw_query(|conn| query.load(conn))?;
if groups.len() > 1 {
Expand Down Expand Up @@ -555,6 +550,38 @@ impl std::fmt::Display for ConversationType {
}
}

pub type DmId = String;

pub trait DmIdExt {
fn from_ids(inbox_ids: [&str; 2]) -> Self;
fn other_id(&self, other: &str) -> String;
}
impl DmIdExt for DmId {
fn from_ids(inbox_ids: [&str; 2]) -> Self {
let inbox_ids = inbox_ids
.into_iter()
.map(str::to_lowercase)
.collect::<Vec<_>>();

format!("dm:{}", inbox_ids.join(":"))
}

fn other_id(&self, id: &str) -> String {
// drop the "dm:"
let dm_id = &self[3..];

// If my id is the first half, return the second half, otherwise return first half
let target_inbox = if dm_id[..id.len()] == *id {
// + 1 because there is a colon (:)
&dm_id[(id.len() + 1)..]
} else {
&dm_id[..id.len()]
};

return target_inbox.to_string();
}
}

#[cfg(test)]
pub(crate) mod tests {
#[cfg(target_arch = "wasm32")]
Expand Down Expand Up @@ -726,7 +753,10 @@ pub(crate) mod tests {
assert_eq!(dm_results[2].id, test_group_3.id);

// test find_dm_group
let dm_result = conn.find_dm_group("placeholder_inbox_id").unwrap();

let dm_result = conn
.find_dm_group("placeholder_inbox_id", "placeholder_inbox_id")
.unwrap();
assert!(dm_result.is_some());

// test only dms are returned
Expand Down

0 comments on commit 95639d9

Please sign in to comment.