Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming bindings #399

Merged
merged 17 commits into from
Jan 8, 2024
Merged
4 changes: 3 additions & 1 deletion bindings_ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ xmtp_mls = { path = "../xmtp_mls", features = ["grpc", "native"] }
xmtp_cryptography = { path = "../xmtp_cryptography" }
xmtp_api_grpc = { path = "../xmtp_api_grpc" }
xmtp_v2 = { path = "../xmtp_v2" }
futures = "0.3.28"
tokio = { version = "1.28.1", features = ["macros"] }

[build_dependencies]
uniffi = { git = "https://github.com/mozilla/uniffi-rs", rev = "cae8edc45ba5b56bfcbf35b60c1ab6a97d1bf9da", features = [
Expand All @@ -33,7 +35,7 @@ path = "src/bin.rs"
[dev-dependencies]
ethers = "2.0.4"
ethers-core = "2.0.4"
tokio = { version = "1.0", features = ["full"] }
tokio = { version = "1.28.1", features = ["full"] }
tempfile = "3.5.0"
uniffi = { git = "https://github.com/mozilla/uniffi-rs", rev = "cae8edc45ba5b56bfcbf35b60c1ab6a97d1bf9da", features = [
"bindgen-tests",
Expand Down
Binary file modified bindings_ffi/jniLibs/arm64-v8a/libuniffi_xmtpv3.so
Binary file not shown.
170 changes: 122 additions & 48 deletions bindings_ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ mod v2;

use std::convert::TryInto;

use futures::StreamExt;
use inbox_owner::FfiInboxOwner;
use logger::FfiLogger;
use std::error::Error;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use tokio::sync::{oneshot, oneshot::Sender};

use xmtp_api_grpc::grpc_api_helper::Client as TonicApiClient;
use xmtp_mls::groups::MlsGroup;
Expand Down Expand Up @@ -239,6 +241,41 @@ impl FfiGroup {

Ok(())
}

pub async fn stream(
&self,
message_callback: Box<dyn FfiMessageCallback>,
) -> Result<Arc<FfiMessageStreamCloser>, GenericError> {
let inner_client = Arc::clone(&self.inner_client);
let group_id = self.group_id.clone();
let created_at_ns = self.created_at_ns;
let (close_sender, close_receiver) = oneshot::channel::<()>();

tokio::spawn(async move {
let client = inner_client.as_ref();
let group = MlsGroup::new(&client, group_id, created_at_ns);
let mut stream = group.stream().await.unwrap();
let mut close_receiver = close_receiver;
loop {
tokio::select! {
item = stream.next() => {
match item {
Some(message) => message_callback.on_message(message.into()),
None => break
}
}
_ = &mut close_receiver => {
break;
}
}
}
log::debug!("closing stream");
});

Ok(Arc::new(FfiMessageStreamCloser {
close_fn: Arc::new(Mutex::new(Some(close_sender))),
}))
}
}

#[uniffi::export]
Expand All @@ -248,14 +285,13 @@ impl FfiGroup {
}
}

#[derive(uniffi::Record)]
// #[derive(uniffi::Record)]
pub struct FfiMessage {
pub id: Vec<u8>,
pub sent_at_ns: i64,
pub convo_id: Vec<u8>,
pub addr_from: String,
pub content: Vec<u8>,
// TODO pub kind: GroupMessageKind,
}

impl From<StoredGroupMessage> for FfiMessage {
Expand All @@ -270,12 +306,39 @@ impl From<StoredGroupMessage> for FfiMessage {
}
}

#[derive(uniffi::Object)]
pub struct FfiMessageStreamCloser {
close_fn: Arc<Mutex<Option<Sender<()>>>>,
}

#[uniffi::export]
impl FfiMessageStreamCloser {
pub fn close(&self) {
match self.close_fn.lock() {
Ok(mut close_fn_option) => {
let _ = close_fn_option.take().map(|close_fn| close_fn.send(()));
}
_ => {
log::warn!("close_fn already closed");
}
}
}
}

pub trait FfiMessageCallback: Send + Sync {
fn on_message(&self, message: FfiMessage);
}

#[cfg(test)]
mod tests {
use std::{env, sync::Arc};
use std::{
env,
sync::{Arc, Mutex},
};

use crate::{
create_client, inbox_owner::SigningError, logger::FfiLogger, FfiInboxOwner, FfiXmtpClient,
create_client, inbox_owner::SigningError, logger::FfiLogger, FfiInboxOwner, FfiMessage,
FfiMessageCallback, FfiXmtpClient,
};
use ethers_core::rand::{
self,
Expand Down Expand Up @@ -317,6 +380,29 @@ mod tests {
fn log(&self, _level: u32, _level_label: String, _message: String) {}
}

#[derive(Clone)]
struct RustMessageCallback {
num_messages: Arc<Mutex<u32>>,
}

impl RustMessageCallback {
pub fn new() -> Self {
Self {
num_messages: Arc::new(Mutex::new(0)),
}
}

pub fn message_count(&self) -> u32 {
*self.num_messages.lock().unwrap()
}
}

impl FfiMessageCallback for RustMessageCallback {
fn on_message(&self, _: FfiMessage) {
*self.num_messages.lock().unwrap() += 1;
}
}

pub fn rand_string() -> String {
Alphanumeric.sample_string(&mut rand::thread_rng(), 24)
}
Expand Down Expand Up @@ -430,47 +516,35 @@ mod tests {
assert!(result_errored, "did not error on wrong encryption key")
}

// #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
// async fn test_conversation_list() {
// let client_a = new_test_client().await;
// let client_b = new_test_client().await;

// // Create a conversation between the two clients
// let conversation = client_a
// .conversations()
// .new_conversation(client_b.account_address())
// .await
// .unwrap();
// conversation.send(vec![1, 2, 3]).await.unwrap();
// let convos = client_b.conversations().list().await.unwrap();
// assert_eq!(convos.len(), 1);
// assert_eq!(
// convos.first().unwrap().peer_address,
// client_a.account_address()
// );
// }

// #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
// async fn test_send_and_list() {
// let alice = new_test_client().await;
// let bob = new_test_client().await;

// let alice_to_bob = alice
// .conversations()
// .new_conversation(bob.account_address())
// .await
// .unwrap();

// alice_to_bob.send(vec![1, 2, 3]).await.unwrap();
// let messages = alice_to_bob
// .list_messages(FfiListMessagesOptions {
// start_time_ns: None,
// end_time_ns: None,
// limit: None,
// })
// .await
// .unwrap();
// assert_eq!(messages.len(), 1);
// assert_eq!(messages[0].content, vec![1, 2, 3]);
// }
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn test_streaming() {
let amal = new_test_client().await;
let bola = new_test_client().await;

let group = amal
.conversations()
.create_group(bola.account_address())
.await
.unwrap();

let message_callback = RustMessageCallback::new();
let stream_closer = group
.stream(Box::new(message_callback.clone()))
.await
.unwrap();

tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
group.send("hello".as_bytes().to_vec()).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
group.send("goodbye".as_bytes().to_vec()).await.unwrap();
// Because of the event loop, I need to make the test give control
// back to the stream before it can process each message. Using sleep to do that.
// I think this will work fine in practice
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
assert_eq!(message_callback.message_count(), 2);

stream_closer.close();
// Make sure nothing panics calling `close` twice
stream_closer.close();
}
}
12 changes: 12 additions & 0 deletions bindings_ffi/src/xmtpv3.udl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ callback interface FfiInboxOwner {
callback interface FfiLogger {
void log(u32 level, string level_label, string message);
};

callback interface FfiMessageCallback {
void on_message(FfiMessage message);
};

dictionary FfiMessage {
bytes id;
i64 sent_at_ns;
bytes convo_id;
string addr_from;
bytes content;
};
2 changes: 0 additions & 2 deletions xmtp_api_grpc/src/grpc_api_helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,6 @@ impl XmtpApiClient for Client {
while let Some(result) = receiver.next().await {
yield result;
}

println!("stream closed")
};

let mut tonic_request = Request::new(input_stream);
Expand Down
55 changes: 54 additions & 1 deletion xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{collections::HashSet, mem::Discriminant};
use std::{collections::HashSet, mem::Discriminant, pin::Pin};

use futures::{Stream, StreamExt};
use log::debug;
use openmls::{
framing::{MlsMessageIn, MlsMessageInBody},
Expand Down Expand Up @@ -364,6 +365,40 @@

Ok(groups)
}

fn process_streamed_welcome(
&self,
envelope: Envelope,
) -> Result<MlsGroup<ApiClient>, ClientError> {
let welcome = extract_welcome(&envelope.message)?;
let conn = self.store.conn()?;
let provider = self.mls_provider(&conn);
Ok(MlsGroup::create_from_welcome(self, &provider, welcome)
.map_err(|e| ClientError::Generic(e.to_string()))?)

Check warning on line 377 in xmtp_mls/src/client.rs

View workflow job for this annotation

GitHub Actions / workspace

question mark operator is useless here

warning: question mark operator is useless here --> xmtp_mls/src/client.rs:376:9 | 376 | / Ok(MlsGroup::create_from_welcome(self, &provider, welcome) 377 | | .map_err(|e| ClientError::Generic(e.to_string()))?) | |_______________________________________________________________^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_question_mark = note: `#[warn(clippy::needless_question_mark)]` on by default help: try removing question mark and `Ok()` | 376 ~ MlsGroup::create_from_welcome(self, &provider, welcome) 377 + .map_err(|e| ClientError::Generic(e.to_string())) |
}

pub async fn stream_conversations(
&'a self,
) -> Result<Pin<Box<dyn Stream<Item = MlsGroup<ApiClient>> + 'a>>, ClientError> {
let welcome_topic = get_welcome_topic(&self.installation_public_key());
let subscription = self.api_client.subscribe(vec![welcome_topic]).await?;
let stream = subscription
.map(|envelope_result| async {
let envelope = envelope_result?;
self.process_streamed_welcome(envelope)
})
.filter_map(|res| async {
match res.await {
Ok(group) => Some(group),
Err(err) => {
log::error!("Error processing stream entry: {:?}", err);
None
}
}
});

Ok(Box::pin(stream))
}
}

fn extract_welcome(welcome_bytes: &Vec<u8>) -> Result<Welcome, ClientError> {
Expand All @@ -382,6 +417,7 @@
use xmtp_cryptography::utils::generate_local_wallet;

use crate::{builder::ClientBuilder, InboxOwner};
use futures::StreamExt;

#[tokio::test]
async fn test_mls_error() {
Expand Down Expand Up @@ -442,4 +478,21 @@
let duplicate_received_groups = bob.sync_welcomes().await.unwrap();
assert_eq!(duplicate_received_groups.len(), 0);
}

#[tokio::test]
async fn test_stream_welcomes() {
let alice = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
let bob = ClientBuilder::new_test_client(generate_local_wallet().into()).await;
bob.register_identity().await.unwrap();

let alice_bob_group = alice.create_group().unwrap();

let mut bob_stream = bob.stream_conversations().await.unwrap();
alice_bob_group
.add_members_by_installation_id(vec![bob.installation_public_key()])
.await
.unwrap();
let bob_received_groups = bob_stream.next().await.unwrap();
assert_eq!(bob_received_groups.group_id, alice_bob_group.group_id);
}
}
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where

pub async fn stream(
&'c self,
) -> Result<Pin<Box<dyn Stream<Item = StoredGroupMessage> + 'c>>, GroupError> {
) -> Result<Pin<Box<dyn Stream<Item = StoredGroupMessage> + 'c + Send>>, GroupError> {
let subscription = self.client.api_client.subscribe(vec![self.topic()]).await?;
let stream = subscription
.map(|res| async {
Expand Down
6 changes: 3 additions & 3 deletions xmtp_proto/src/api_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ pub trait XmtpApiSubscription {

#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait MutableApiSubscription: Stream<Item = Result<Envelope, Error>> {
pub trait MutableApiSubscription: Stream<Item = Result<Envelope, Error>> + Send {
async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error>;
fn close(&self);
}

// Wasm futures don't have `Send` or `Sync` bounds.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait XmtpApiClient {
pub trait XmtpApiClient: Send + Sync {
type Subscription: XmtpApiSubscription;
type MutableSubscription: MutableApiSubscription;

Expand All @@ -126,7 +126,7 @@ pub trait XmtpApiClient {
// Wasm futures don't have `Send` or `Sync` bounds.
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait XmtpMlsClient {
pub trait XmtpMlsClient: Send + Sync {
type Subscription: MutableApiSubscription;

async fn register_installation(
Expand Down
Loading