diff --git a/Cargo.lock b/Cargo.lock index cd2bbbc3a..6e9bae001 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6905,6 +6905,7 @@ name = "xmtp_api_grpc" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "base64 0.22.1", "futures", "hex", @@ -6922,6 +6923,7 @@ name = "xmtp_api_http" version = "0.1.0" dependencies = [ "async-stream", + "async-trait", "futures", "reqwest 0.12.8", "serde", @@ -7079,6 +7081,7 @@ dependencies = [ name = "xmtp_proto" version = "0.1.0" dependencies = [ + "async-trait", "futures", "openmls", "pbjson", @@ -7086,7 +7089,6 @@ dependencies = [ "prost", "serde", "tonic", - "trait-variant", "wasm-bindgen-test", ] diff --git a/bindings_ffi/src/mls.rs b/bindings_ffi/src/mls.rs index a3f6de8ed..57ec0cf5a 100644 --- a/bindings_ffi/src/mls.rs +++ b/bindings_ffi/src/mls.rs @@ -153,7 +153,9 @@ pub async fn get_inbox_id_for_address( account_address: String, ) -> Result, GenericError> { let api_client = ApiClientWrapper::new( - TonicApiClient::create(host.clone(), is_secure).await?, + TonicApiClient::create(host.clone(), is_secure) + .await? + .into(), Retry::default(), ); diff --git a/bindings_node/src/inbox_id.rs b/bindings_node/src/inbox_id.rs index 094fa68fc..08875e461 100644 --- a/bindings_node/src/inbox_id.rs +++ b/bindings_node/src/inbox_id.rs @@ -16,7 +16,8 @@ pub async fn get_inbox_id_for_address( let api_client = ApiClientWrapper::new( TonicApiClient::create(host.clone(), is_secure) .await - .map_err(ErrorWrapper::from)?, + .map_err(ErrorWrapper::from)? + .into(), Retry::default(), ); diff --git a/bindings_wasm/src/inbox_id.rs b/bindings_wasm/src/inbox_id.rs index 93e5fc22b..359948266 100644 --- a/bindings_wasm/src/inbox_id.rs +++ b/bindings_wasm/src/inbox_id.rs @@ -11,7 +11,7 @@ pub async fn get_inbox_id_for_address( ) -> Result, JsError> { let account_address = account_address.to_lowercase(); let api_client = ApiClientWrapper::new( - XmtpHttpApiClient::new(host.clone()).unwrap(), + XmtpHttpApiClient::new(host.clone()).unwrap().into(), Retry::default(), ); diff --git a/examples/cli/cli-client.rs b/examples/cli/cli-client.rs index 79009cc07..4cb059a31 100755 --- a/examples/cli/cli-client.rs +++ b/examples/cli/cli-client.rs @@ -18,11 +18,13 @@ use ethers::signers::{coins_bip39::English, LocalWallet, MnemonicBuilder}; use futures::future::join_all; use kv_log_macro::{error, info}; use prost::Message; +use xmtp_api_grpc::replication_client::ClientV4; use xmtp_id::associations::unverified::{UnverifiedRecoverableEcdsaSignature, UnverifiedSignature}; use xmtp_mls::client::FindGroupParams; use xmtp_mls::groups::device_sync::DeviceSyncContent; use xmtp_mls::storage::group_message::{GroupMessageKind, MsgQueryArgs}; +use xmtp_mls::XmtpApi; use crate::{ json_logger::make_value, @@ -30,7 +32,7 @@ use crate::{ }; use serializable::maybe_get_text; use thiserror::Error; -use xmtp_api_grpc::grpc_api_helper::Client as ApiClient; +use xmtp_api_grpc::grpc_api_helper::Client as ClientV3; use xmtp_cryptography::{ signature::{RecoverableSignature, SignatureError}, utils::rng, @@ -49,8 +51,8 @@ use xmtp_mls::{ utils::time::now_ns, InboxOwner, }; -type Client = xmtp_mls::client::Client; -type ClientBuilder = xmtp_mls::builder::ClientBuilder; + +type Client = xmtp_mls::client::Client>; type MlsGroup = xmtp_mls::groups::MlsGroup; /// A fictional versioning CLI @@ -67,6 +69,8 @@ struct Cli { local: bool, #[clap(long, default_value_t = false)] json: bool, + #[clap(long, default_value_t = false)] + testnet: bool, } #[derive(ValueEnum, Debug, Copy, Clone)] @@ -179,14 +183,41 @@ async fn main() { } info!("Starting CLI Client...."); + let grpc = match (cli.testnet, cli.local) { + (true, true) => Box::new( + ClientV4::create("http://localhost:5050".into(), false) + .await + .unwrap(), + ) as Box, + (true, false) => Box::new( + ClientV4::create("https://grpc.testnet.xmtp.network:443".into(), true) + .await + .unwrap(), + ) as Box, + (false, true) => Box::new( + ClientV3::create("http://localhost:5556".into(), false) + .await + .unwrap(), + ) as Box, + (false, false) => Box::new( + ClientV3::create("https://grpc.dev.xmtp.network:443".into(), true) + .await + .unwrap(), + ) as Box, + }; + if let Commands::Register { seed_phrase } = &cli.command { info!("Register"); - if let Err(e) = register(&cli, seed_phrase.clone()).await { + if let Err(e) = register(&cli, seed_phrase.clone(), grpc).await { error!("Registration failed: {:?}", e) } return; } + let client = create_client(&cli, IdentityStrategy::CachedOnly, grpc) + .await + .unwrap(); + match &cli.command { #[allow(unused_variables)] Commands::Register { seed_phrase } => { @@ -194,19 +225,12 @@ async fn main() { } Commands::Info {} => { info!("Info"); - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let installation_id = hex::encode(client.installation_public_key()); info!("identity info", { command_output: true, account_address: client.inbox_id(), installation_id: installation_id }); } Commands::ListGroups {} => { info!("List Groups"); - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let conn = client.store().conn().unwrap(); - client .sync_welcomes(&conn) .await @@ -236,9 +260,6 @@ async fn main() { } Commands::Send { group_id, msg } => { info!("Sending message to group", { group_id: group_id, message: msg }); - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); info!("Inbox ID is: {}", client.inbox_id()); let group = get_group(&client, hex::decode(group_id).expect("group id decode")) .await @@ -248,9 +269,6 @@ async fn main() { } Commands::ListGroupMessages { group_id } => { info!("Recv"); - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let group = get_group(&client, hex::decode(group_id).expect("group id decode")) .await @@ -277,10 +295,6 @@ async fn main() { group_id, account_addresses, } => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); - let group = get_group(&client, hex::decode(group_id).expect("group id decode")) .await .expect("failed to get group"); @@ -299,10 +313,6 @@ async fn main() { group_id, account_addresses, } => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); - let group = get_group(&client, hex::decode(group_id).expect("group id decode")) .await .expect("failed to get group"); @@ -324,10 +334,6 @@ async fn main() { xmtp_mls::groups::PreconfiguredPolicies::AdminsOnly } }; - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); - let group = client .create_group( Some(group_permissions.to_policy_set()), @@ -338,9 +344,6 @@ async fn main() { info!("Created group {}", group_id, { command_output: true, group_id: group_id}) } Commands::GroupInfo { group_id } => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let group = &client .group(hex::decode(group_id).expect("bad group id")) .expect("group not found"); @@ -349,9 +352,6 @@ async fn main() { info!("Group {}", group_id, { command_output: true, group_id: group_id, group_info: make_value(&serializable) }) } Commands::RequestHistorySync {} => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); client.sync_welcomes(&conn).await.unwrap(); @@ -361,9 +361,6 @@ async fn main() { info!("Sent history sync request in sync group {group_id_str}", { group_id: group_id_str}) } Commands::ReplyToHistorySyncRequest {} => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let provider = client.mls_provider().unwrap(); let group = client.get_sync_group().unwrap(); let group_id_str = hex::encode(group.group_id); @@ -376,9 +373,6 @@ async fn main() { info!("Reply: {:?}", reply); } Commands::ProcessHistorySyncReply {} => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); client.sync_welcomes(&conn).await.unwrap(); @@ -388,9 +382,6 @@ async fn main() { info!("History bundle downloaded and inserted into user DB", {}) } Commands::ProcessConsentSyncReply {} => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); client.sync_welcomes(&conn).await.unwrap(); @@ -400,9 +391,6 @@ async fn main() { info!("Consent bundle downloaded and inserted into user DB", {}) } Commands::ListHistorySyncMessages {} => { - let client = create_client(&cli, IdentityStrategy::CachedOnly) - .await - .unwrap(); let conn = client.store().conn().unwrap(); let provider = client.mls_provider().unwrap(); client.sync_welcomes(&conn).await.unwrap(); @@ -437,6 +425,28 @@ async fn main() { } } +async fn create_client( + cli: &Cli, + account: IdentityStrategy, + grpc: C, +) -> Result, CliError> { + let msg_store = get_encrypted_store(&cli.db).await.unwrap(); + let mut builder = xmtp_mls::builder::ClientBuilder::::new(account).store(msg_store); + + builder = builder.api_client(grpc); + + if cli.local { + builder = builder.history_sync_url(MessageHistoryUrls::LOCAL_ADDRESS); + } else { + builder = builder.history_sync_url(MessageHistoryUrls::DEV_ADDRESS); + } + + let client = builder.build().await.map_err(CliError::ClientBuilder)?; + + Ok(client) +} + +/* async fn create_client(cli: &Cli, account: IdentityStrategy) -> Result { let msg_store = get_encrypted_store(&cli.db).await.unwrap(); let mut builder = ClientBuilder::new(account).store(msg_store); @@ -465,8 +475,16 @@ async fn create_client(cli: &Cli, account: IdentityStrategy) -> Result) -> Result<(), CliError> { +async fn register( + cli: &Cli, + maybe_seed_phrase: Option, + client: C, +) -> Result<(), CliError> +where + C: XmtpApi, +{ let w: Wallet = if let Some(seed_phrase) = maybe_seed_phrase { Wallet::LocalWallet( MnemonicBuilder::::default() @@ -483,6 +501,7 @@ async fn register(cli: &Cli, maybe_seed_phrase: Option) -> Result<(), Cl let client = create_client( cli, IdentityStrategy::CreateIfNotFound(inbox_id, w.get_address(), nonce, None), + client, ) .await?; let mut signature_request = client.identity().signature_request().unwrap(); diff --git a/examples/cli/serializable.rs b/examples/cli/serializable.rs index a8c8d6a6a..741482c89 100644 --- a/examples/cli/serializable.rs +++ b/examples/cli/serializable.rs @@ -22,9 +22,7 @@ pub struct SerializableGroup { } impl SerializableGroup { - pub async fn from( - group: &MlsGroup>, - ) -> Self { + pub async fn from(group: &MlsGroup>) -> Self { let group_id = hex::encode(group.group_id.clone()); let members = group .members() diff --git a/xmtp_api_grpc/Cargo.toml b/xmtp_api_grpc/Cargo.toml index e7ac81947..6aba948e3 100644 --- a/xmtp_api_grpc/Cargo.toml +++ b/xmtp_api_grpc/Cargo.toml @@ -14,6 +14,7 @@ tonic = { workspace = true, features = ["tls", "tls-native-roots", "tls-webpki-r tracing.workspace = true xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } xmtp_v2 = { path = "../xmtp_v2" } +async-trait = "0.1" [dev-dependencies] uuid = { workspace = true, features = ["v4"] } diff --git a/xmtp_api_grpc/src/grpc_api_helper.rs b/xmtp_api_grpc/src/grpc_api_helper.rs index e45916424..880ec5eaf 100644 --- a/xmtp_api_grpc/src/grpc_api_helper.rs +++ b/xmtp_api_grpc/src/grpc_api_helper.rs @@ -135,6 +135,7 @@ impl ClientWithMetadata for Client { } } +#[async_trait::async_trait] impl XmtpApiClient for Client { type Subscription = Subscription; type MutableSubscription = GrpcMutableSubscription; @@ -267,6 +268,7 @@ impl Subscription { } } +#[async_trait::async_trait] impl XmtpApiSubscription for Subscription { fn is_closed(&self) -> bool { self.closed.load(Ordering::SeqCst) @@ -321,6 +323,7 @@ impl Stream for GrpcMutableSubscription { } } +#[async_trait::async_trait] impl MutableApiSubscription for GrpcMutableSubscription { async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error> { self.update_channel @@ -336,6 +339,8 @@ impl MutableApiSubscription for GrpcMutableSubscription { self.update_channel.close_channel(); } } + +#[async_trait::async_trait] impl XmtpMlsClient for Client { #[tracing::instrument(level = "trace", skip_all)] async fn upload_key_package(&self, req: UploadKeyPackageRequest) -> Result<(), Error> { @@ -453,6 +458,7 @@ impl Stream for WelcomeMessageStream { } } +#[async_trait::async_trait] impl XmtpMlsStreams for Client { type GroupMessageStream<'a> = GroupMessageStream; type WelcomeMessageStream<'a> = WelcomeMessageStream; diff --git a/xmtp_api_grpc/src/identity.rs b/xmtp_api_grpc/src/identity.rs index 7a9dc0d06..d0011d6e2 100644 --- a/xmtp_api_grpc/src/identity.rs +++ b/xmtp_api_grpc/src/identity.rs @@ -9,6 +9,7 @@ use xmtp_proto::{ }, }; +#[async_trait::async_trait] impl XmtpIdentityClient for Client { #[tracing::instrument(level = "trace", skip_all)] async fn publish_identity_update( diff --git a/xmtp_api_grpc/src/lib.rs b/xmtp_api_grpc/src/lib.rs index 265f6f107..d6bfe96e1 100644 --- a/xmtp_api_grpc/src/lib.rs +++ b/xmtp_api_grpc/src/lib.rs @@ -1,6 +1,7 @@ pub mod auth_token; pub mod grpc_api_helper; mod identity; +pub mod replication_client; pub const LOCALHOST_ADDRESS: &str = "http://localhost:5556"; pub const DEV_ADDRESS: &str = "https://grpc.dev.xmtp.network:443"; @@ -12,6 +13,7 @@ mod utils { mod test { use xmtp_proto::api_client::XmtpTestClient; + #[async_trait::async_trait] impl XmtpTestClient for crate::Client { async fn create_local() -> Self { crate::Client::create("http://localhost:5556".into(), false) diff --git a/xmtp_api_grpc/src/replication_client.rs b/xmtp_api_grpc/src/replication_client.rs new file mode 100644 index 000000000..df8e582a9 --- /dev/null +++ b/xmtp_api_grpc/src/replication_client.rs @@ -0,0 +1,464 @@ +#![allow(unused)] +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +// TODO switch to async mutexes +use std::time::Duration; + +use futures::stream::{AbortHandle, Abortable}; +use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; +use tokio::sync::oneshot; +use tonic::transport::ClientTlsConfig; +use tonic::{metadata::MetadataValue, transport::Channel, Request, Streaming}; + +#[cfg(any(feature = "test-utils", test))] +use xmtp_proto::api_client::XmtpTestClient; +use xmtp_proto::api_client::{ClientWithMetadata, XmtpIdentityClient, XmtpMlsStreams}; +use xmtp_proto::xmtp::mls::api::v1::{GroupMessage, WelcomeMessage}; +use xmtp_proto::xmtp::xmtpv4::message_api::replication_api_client::ReplicationApiClient; +use xmtp_proto::{ + api_client::{ + Error, ErrorKind, MutableApiSubscription, XmtpApiClient, XmtpApiSubscription, XmtpMlsClient, + }, + xmtp::identity::api::v1::{ + get_inbox_ids_response, GetIdentityUpdatesRequest as GetIdentityUpdatesV2Request, + GetIdentityUpdatesResponse as GetIdentityUpdatesV2Response, GetInboxIdsRequest, + GetInboxIdsResponse, PublishIdentityUpdateRequest, PublishIdentityUpdateResponse, + VerifySmartContractWalletSignaturesRequest, VerifySmartContractWalletSignaturesResponse, + }, + xmtp::message_api::v1::{ + BatchQueryRequest, BatchQueryResponse, Envelope, PublishRequest, PublishResponse, + QueryRequest, QueryResponse, SubscribeRequest, + }, + xmtp::mls::api::v1::{ + FetchKeyPackagesRequest, FetchKeyPackagesResponse, QueryGroupMessagesRequest, + QueryGroupMessagesResponse, QueryWelcomeMessagesRequest, QueryWelcomeMessagesResponse, + SendGroupMessagesRequest, SendWelcomeMessagesRequest, SubscribeGroupMessagesRequest, + SubscribeWelcomeMessagesRequest, UploadKeyPackageRequest, + }, + xmtp::xmtpv4::message_api::{ + get_inbox_ids_request, GetInboxIdsRequest as GetInboxIdsRequestV4, + }, +}; + +async fn create_tls_channel(address: String) -> Result { + let channel = Channel::from_shared(address) + .map_err(|e| Error::new(ErrorKind::SetupCreateChannelError).with(e))? + // Purpose: This setting controls the size of the initial connection-level flow control window for HTTP/2, which is the underlying protocol for gRPC. + // Functionality: Flow control in HTTP/2 manages how much data can be in flight on the network. Setting the initial connection window size to (1 << 31) - 1 (the maximum possible value for a 32-bit integer, which is 2,147,483,647 bytes) essentially allows the client to receive a very large amount of data from the server before needing to acknowledge receipt and permit more data to be sent. This can be particularly useful in high-latency networks or when transferring large amounts of data. + // Impact: Increasing the window size can improve throughput by allowing more data to be in transit at a time, but it may also increase memory usage and can potentially lead to inefficient use of bandwidth if the network is unreliable. + .initial_connection_window_size(Some((1 << 31) - 1)) + // Purpose: Configures whether the client should send keep-alive pings to the server when the connection is idle. + // Functionality: When set to true, this option ensures that periodic pings are sent on an idle connection to keep it alive and detect if the server is still responsive. + // Impact: This helps maintain active connections, particularly through NATs, load balancers, and other middleboxes that might drop idle connections. It helps ensure that the connection is promptly usable when new requests need to be sent. + .keep_alive_while_idle(true) + // Purpose: Sets the maximum amount of time the client will wait for a connection to be established. + // Functionality: If a connection cannot be established within the specified duration, the attempt is aborted and an error is returned. + // Impact: This setting prevents the client from waiting indefinitely for a connection to be established, which is crucial in scenarios where rapid failure detection is necessary to maintain responsiveness or to quickly fallback to alternative services or retry logic. + .connect_timeout(Duration::from_secs(10)) + // Purpose: Configures the TCP keep-alive interval for the socket connection. + // Functionality: This setting tells the operating system to send TCP keep-alive probes periodically when no data has been transferred over the connection within the specified interval. + // Impact: Similar to the gRPC-level keep-alive, this helps keep the connection alive at the TCP layer and detect broken connections. It's particularly useful for detecting half-open connections and ensuring that resources are not wasted on unresponsive peers. + .tcp_keepalive(Some(Duration::from_secs(15))) + // Purpose: Sets a maximum duration for the client to wait for a response to a request. + // Functionality: If a response is not received within the specified timeout, the request is canceled and an error is returned. + // Impact: This is critical for bounding the wait time for operations, which can enhance the predictability and reliability of client interactions by avoiding indefinitely hanging requests. + .timeout(Duration::from_secs(120)) + // Purpose: Specifies how long the client will wait for a response to a keep-alive ping before considering the connection dead. + // Functionality: If a ping response is not received within this duration, the connection is presumed to be lost and is closed. + // Impact: This setting is crucial for quickly detecting unresponsive connections and freeing up resources associated with them. It ensures that the client has up-to-date information on the status of connections and can react accordingly. + .keep_alive_timeout(Duration::from_secs(25)) + .tls_config(ClientTlsConfig::new().with_enabled_roots()) + .map_err(|e| Error::new(ErrorKind::SetupTLSConfigError).with(e))? + .connect() + .await + .map_err(|e| Error::new(ErrorKind::SetupConnectionError).with(e))?; + + Ok(channel) +} + +#[derive(Debug, Clone)] +pub struct ClientV4 { + pub(crate) client: ReplicationApiClient, + pub(crate) app_version: MetadataValue, + pub(crate) libxmtp_version: MetadataValue, +} + +impl ClientV4 { + pub async fn create(host: String, is_secure: bool) -> Result { + let host = host.to_string(); + let app_version = MetadataValue::try_from(&String::from("0.0.0")) + .map_err(|e| Error::new(ErrorKind::MetadataError).with(e))?; + let libxmtp_version = MetadataValue::try_from(&String::from("0.0.0")) + .map_err(|e| Error::new(ErrorKind::MetadataError).with(e))?; + + let channel = match is_secure { + true => create_tls_channel(host).await?, + false => Channel::from_shared(host) + .map_err(|e| Error::new(ErrorKind::SetupCreateChannelError).with(e))? + .connect() + .await + .map_err(|e| Error::new(ErrorKind::SetupConnectionError).with(e))?, + }; + + let client = ReplicationApiClient::new(channel.clone()); + + Ok(Self { + client, + app_version, + libxmtp_version, + }) + } + + pub fn build_request(&self, request: RequestType) -> Request { + let mut req = Request::new(request); + req.metadata_mut() + .insert("x-app-version", self.app_version.clone()); + req.metadata_mut() + .insert("x-libxmtp-version", self.libxmtp_version.clone()); + + req + } +} + +impl ClientWithMetadata for ClientV4 { + fn set_libxmtp_version(&mut self, version: String) -> Result<(), Error> { + self.libxmtp_version = MetadataValue::try_from(&version) + .map_err(|e| Error::new(ErrorKind::MetadataError).with(e))?; + + Ok(()) + } + + fn set_app_version(&mut self, version: String) -> Result<(), Error> { + self.app_version = MetadataValue::try_from(&version) + .map_err(|e| Error::new(ErrorKind::MetadataError).with(e))?; + + Ok(()) + } +} + +#[async_trait::async_trait] +impl XmtpApiClient for ClientV4 { + type Subscription = Subscription; + type MutableSubscription = GrpcMutableSubscription; + + async fn publish( + &self, + token: String, + request: PublishRequest, + ) -> Result { + unimplemented!(); + } + + async fn subscribe(&self, request: SubscribeRequest) -> Result { + unimplemented!(); + } + + async fn subscribe2( + &self, + request: SubscribeRequest, + ) -> Result { + unimplemented!(); + } + + async fn query(&self, request: QueryRequest) -> Result { + unimplemented!(); + } + + async fn batch_query(&self, request: BatchQueryRequest) -> Result { + unimplemented!(); + } +} + +pub struct Subscription { + pending: Arc>>, + close_sender: Option>, + closed: Arc, +} + +impl Subscription { + pub async fn start(stream: Streaming) -> Self { + let pending = Arc::new(Mutex::new(Vec::new())); + let pending_clone = pending.clone(); + let (close_sender, close_receiver) = oneshot::channel::<()>(); + let closed = Arc::new(AtomicBool::new(false)); + let closed_clone = closed.clone(); + tokio::spawn(async move { + let mut stream = Box::pin(stream); + let mut close_receiver = Box::pin(close_receiver); + + loop { + tokio::select! { + item = stream.message() => { + match item { + Ok(Some(envelope)) => { + let mut pending = pending_clone.lock().unwrap(); + pending.push(envelope); + } + _ => break, + } + }, + _ = &mut close_receiver => { + break; + } + } + } + + closed_clone.store(true, Ordering::SeqCst); + }); + + Subscription { + pending, + closed, + close_sender: Some(close_sender), + } + } +} + +impl XmtpApiSubscription for Subscription { + fn is_closed(&self) -> bool { + self.closed.load(Ordering::SeqCst) + } + + fn get_messages(&self) -> Vec { + let mut pending = self.pending.lock().unwrap(); + let items = pending.drain(..).collect::>(); + items + } + + fn close_stream(&mut self) { + // Set this value here, even if it will be eventually set again when the loop exits + // This makes the `closed` status immediately correct + self.closed.store(true, Ordering::SeqCst); + if let Some(close_tx) = self.close_sender.take() { + let _ = close_tx.send(()); + } + } +} + +type EnvelopeStream = Pin> + Send>>; + +pub struct GrpcMutableSubscription { + envelope_stream: Abortable, + update_channel: futures::channel::mpsc::UnboundedSender, + abort_handle: AbortHandle, +} + +impl GrpcMutableSubscription { + pub fn new( + envelope_stream: EnvelopeStream, + update_channel: futures::channel::mpsc::UnboundedSender, + ) -> Self { + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + Self { + envelope_stream: Abortable::new(envelope_stream, abort_registration), + update_channel, + abort_handle, + } + } +} + +impl Stream for GrpcMutableSubscription { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.envelope_stream.poll_next_unpin(cx) + } +} + +#[async_trait::async_trait] +impl MutableApiSubscription for GrpcMutableSubscription { + async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error> { + self.update_channel + .send(req) + .await + .map_err(|_| Error::new(ErrorKind::SubscriptionUpdateError))?; + + Ok(()) + } + + fn close(&self) { + self.abort_handle.abort(); + self.update_channel.close_channel(); + } +} + +#[async_trait::async_trait] +impl XmtpMlsClient for ClientV4 { + #[tracing::instrument(level = "trace", skip_all)] + async fn upload_key_package(&self, req: UploadKeyPackageRequest) -> Result<(), Error> { + unimplemented!(); + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn fetch_key_packages( + &self, + req: FetchKeyPackagesRequest, + ) -> Result { + unimplemented!(); + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn send_group_messages(&self, req: SendGroupMessagesRequest) -> Result<(), Error> { + unimplemented!(); + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn send_welcome_messages(&self, req: SendWelcomeMessagesRequest) -> Result<(), Error> { + unimplemented!(); + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn query_group_messages( + &self, + req: QueryGroupMessagesRequest, + ) -> Result { + unimplemented!(); + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn query_welcome_messages( + &self, + req: QueryWelcomeMessagesRequest, + ) -> Result { + unimplemented!(); + } +} + +pub struct GroupMessageStream { + inner: tonic::codec::Streaming, +} + +impl From> for GroupMessageStream { + fn from(inner: tonic::codec::Streaming) -> Self { + GroupMessageStream { inner } + } +} + +impl Stream for GroupMessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner + .poll_next_unpin(cx) + .map(|data| data.map(|v| v.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)))) + } +} + +pub struct WelcomeMessageStream { + inner: tonic::codec::Streaming, +} + +impl From> for WelcomeMessageStream { + fn from(inner: tonic::codec::Streaming) -> Self { + WelcomeMessageStream { inner } + } +} + +impl Stream for WelcomeMessageStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner + .poll_next_unpin(cx) + .map(|data| data.map(|v| v.map_err(|e| Error::new(ErrorKind::SubscribeError).with(e)))) + } +} + +#[async_trait::async_trait] +impl XmtpMlsStreams for ClientV4 { + type GroupMessageStream<'a> = GroupMessageStream; + type WelcomeMessageStream<'a> = WelcomeMessageStream; + + async fn subscribe_group_messages( + &self, + req: SubscribeGroupMessagesRequest, + ) -> Result, Error> { + unimplemented!(); + } + + async fn subscribe_welcome_messages( + &self, + req: SubscribeWelcomeMessagesRequest, + ) -> Result, Error> { + unimplemented!(); + } +} + +#[async_trait::async_trait] +impl XmtpIdentityClient for ClientV4 { + #[tracing::instrument(level = "trace", skip_all)] + async fn publish_identity_update( + &self, + request: PublishIdentityUpdateRequest, + ) -> Result { + unimplemented!() + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn get_inbox_ids( + &self, + request: GetInboxIdsRequest, + ) -> Result { + let client = &mut self.client.clone(); + let req = GetInboxIdsRequestV4 { + requests: request + .requests + .into_iter() + .map(|r| get_inbox_ids_request::Request { address: r.address }) + .collect(), + }; + + let res = client.get_inbox_ids(self.build_request(req)).await; + + res.map(|response| response.into_inner()) + .map(|response| GetInboxIdsResponse { + responses: response + .responses + .into_iter() + .map(|r| get_inbox_ids_response::Response { + address: r.address, + inbox_id: r.inbox_id, + }) + .collect(), + }) + .map_err(|err| Error::new(ErrorKind::IdentityError).with(err)) + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn get_identity_updates_v2( + &self, + request: GetIdentityUpdatesV2Request, + ) -> Result { + unimplemented!() + } + + #[tracing::instrument(level = "trace", skip_all)] + async fn verify_smart_contract_wallet_signatures( + &self, + request: VerifySmartContractWalletSignaturesRequest, + ) -> Result { + unimplemented!() + } +} + +#[cfg(any(feature = "test-utils", test))] +#[async_trait::async_trait] +impl XmtpTestClient for ClientV4 { + async fn create_local() -> Self { + todo!() + } + + async fn create_dev() -> Self { + todo!() + } +} diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index ff5035fb4..e5ef41eb9 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -16,6 +16,7 @@ serde_json = { workspace = true } thiserror = "1.0" tokio = { workspace = true, features = ["sync", "rt", "macros"] } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } +async-trait = "0.1" [dev-dependencies] xmtp_proto = { path = "../xmtp_proto", features = ["test-utils"] } diff --git a/xmtp_api_http/src/lib.rs b/xmtp_api_http/src/lib.rs index b026f47df..668c059ae 100755 --- a/xmtp_api_http/src/lib.rs +++ b/xmtp_api_http/src/lib.rs @@ -75,6 +75,8 @@ where Error::new(ErrorKind::MetadataError).with(e) } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] impl ClientWithMetadata for XmtpHttpApiClient { fn set_app_version(&mut self, version: String) -> Result<(), Error> { self.app_version = Some(version); @@ -123,6 +125,8 @@ impl ClientWithMetadata for XmtpHttpApiClient { } } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] impl XmtpMlsClient for XmtpHttpApiClient { async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error> { let res = self @@ -233,6 +237,8 @@ impl XmtpMlsClient for XmtpHttpApiClient { } } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] impl XmtpMlsStreams for XmtpHttpApiClient { // hard to avoid boxing here: // 1.) use `hyper` instead of `reqwest` and create our own `Stream` type @@ -275,6 +281,8 @@ impl XmtpMlsStreams for XmtpHttpApiClient { } } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] impl XmtpIdentityClient for XmtpHttpApiClient { async fn publish_identity_update( &self, diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index f8d2b3b11..431ad0144 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -117,6 +117,8 @@ pub fn create_grpc_stream_inner< } #[cfg(feature = "test-utils")] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] impl xmtp_proto::api_client::XmtpTestClient for crate::XmtpHttpApiClient { async fn create_local() -> Self { crate::XmtpHttpApiClient::new("http://localhost:5555".into()) diff --git a/xmtp_id/src/scw_verifier/mod.rs b/xmtp_id/src/scw_verifier/mod.rs index 931a93f27..40481332a 100644 --- a/xmtp_id/src/scw_verifier/mod.rs +++ b/xmtp_id/src/scw_verifier/mod.rs @@ -1,13 +1,12 @@ mod chain_rpc_verifier; mod remote_signature_verifier; -use std::{collections::HashMap, fs, path::Path}; - use crate::associations::AccountId; use ethers::{ providers::{Http, Provider, ProviderError}, types::{BlockNumber, Bytes}, }; +use std::{collections::HashMap, fs, path::Path, sync::Arc}; use thiserror::Error; use tracing::info; use url::Url; @@ -65,6 +64,25 @@ pub trait SmartContractSignatureVerifier { ) -> Result; } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +impl SmartContractSignatureVerifier for Arc +where + T: SmartContractSignatureVerifier, +{ + async fn is_valid_signature( + &self, + account_id: AccountId, + hash: [u8; 32], + signature: Bytes, + block_number: Option, + ) -> Result { + (**self) + .is_valid_signature(account_id, hash, signature, block_number) + .await + } +} + #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] impl SmartContractSignatureVerifier for &T diff --git a/xmtp_id/src/scw_verifier/remote_signature_verifier.rs b/xmtp_id/src/scw_verifier/remote_signature_verifier.rs index 257b9768a..dedf4575e 100644 --- a/xmtp_id/src/scw_verifier/remote_signature_verifier.rs +++ b/xmtp_id/src/scw_verifier/remote_signature_verifier.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use super::{SmartContractSignatureVerifier, ValidationResponse, VerifierError}; use crate::associations::AccountId; use ethers::types::{BlockNumber, Bytes}; @@ -11,11 +13,11 @@ use xmtp_proto::{ }; pub struct RemoteSignatureVerifier { - identity_client: C, + identity_client: Arc, } impl RemoteSignatureVerifier { - pub fn new(identity_client: C) -> Self { + pub fn new(identity_client: Arc) -> Self { Self { identity_client } } } diff --git a/xmtp_id/src/utils/mod.rs b/xmtp_id/src/utils/mod.rs index fd484a1ca..0acaf32ef 100644 --- a/xmtp_id/src/utils/mod.rs +++ b/xmtp_id/src/utils/mod.rs @@ -1,4 +1,6 @@ use wasm_timer::{SystemTime, UNIX_EPOCH}; +#[cfg(any(test, feature = "test-utils"))] +pub mod test; pub const NS_IN_SEC: i64 = 1_000_000_000; @@ -9,159 +11,3 @@ pub fn now_ns() -> i64 { .expect("Time went backwards") .as_nanos() as i64 } - -#[cfg(any(test, feature = "test-utils"))] -pub mod test { - #![allow(clippy::unwrap_used)] - - use ethers::{ - contract::abigen, - core::k256::{elliptic_curve::SecretKey, Secp256k1}, - middleware::SignerMiddleware, - providers::{Http, Provider}, - signers::LocalWallet, - }; - abigen!( - CoinbaseSmartWallet, - "artifact/CoinbaseSmartWallet.json", - derives(serde::Serialize, serde::Deserialize) - ); - - abigen!( - CoinbaseSmartWalletFactory, - "artifact/CoinbaseSmartWalletFactory.json", - derives(serde::Serialize, serde::Deserialize) - ); - - pub struct SmartContracts { - coinbase_smart_wallet_factory: - CoinbaseSmartWalletFactory, LocalWallet>>, - } - - impl SmartContracts { - #[cfg(not(target_arch = "wasm32"))] - fn new( - coinbase_smart_wallet_factory: CoinbaseSmartWalletFactory< - SignerMiddleware, LocalWallet>, - >, - ) -> Self { - Self { - coinbase_smart_wallet_factory, - } - } - - pub fn coinbase_smart_wallet_factory( - &self, - ) -> &CoinbaseSmartWalletFactory, LocalWallet>> { - &self.coinbase_smart_wallet_factory - } - } - - pub struct AnvilMeta { - pub keys: Vec>, - pub endpoint: String, - pub chain_id: u64, - } - - #[cfg(not(target_arch = "wasm32"))] - pub async fn with_docker_smart_contracts(fun: Func) - where - Func: FnOnce( - AnvilMeta, - Provider, - SignerMiddleware, LocalWallet>, - SmartContracts, - ) -> Fut, - Fut: futures::Future, - { - use ethers::signers::Signer; - use ethers::utils::Anvil; - use std::sync::Arc; - - // Spawn an anvil instance to get the keys and chain_id - let anvil = Anvil::new().port(8546u16).spawn(); - - let anvil_meta = AnvilMeta { - keys: anvil.keys().to_vec(), - chain_id: 31337, - endpoint: "http://localhost:8545".to_string(), - }; - - let keys = anvil.keys().to_vec(); - let contract_deployer: LocalWallet = keys[9].clone().into(); - let provider = Provider::::try_from(&anvil_meta.endpoint).unwrap(); - let client = SignerMiddleware::new( - provider.clone(), - contract_deployer.clone().with_chain_id(anvil_meta.chain_id), - ); - // 1. coinbase smart wallet - // deploy implementation for factory - let implementation = CoinbaseSmartWallet::deploy(Arc::new(client.clone()), ()) - .unwrap() - .gas_price(100) - .send() - .await - .unwrap(); - // deploy factory - let factory = - CoinbaseSmartWalletFactory::deploy(Arc::new(client.clone()), implementation.address()) - .unwrap() - .gas_price(100) - .send() - .await - .unwrap(); - - let smart_contracts = SmartContracts::new(factory); - fun( - anvil_meta, - provider.clone(), - client.clone(), - smart_contracts, - ) - .await - } - - // anvil can't be used in wasm because it is a system binary - /// Test harness that loads a local anvil node with deployed smart contracts. - #[cfg(not(target_arch = "wasm32"))] - pub async fn with_smart_contracts(fun: Func) - where - Func: FnOnce( - ethers::utils::AnvilInstance, - Provider, - SignerMiddleware, LocalWallet>, - SmartContracts, - ) -> Fut, - Fut: futures::Future, - { - use ethers::signers::Signer; - use ethers::utils::Anvil; - use std::sync::Arc; - let anvil = Anvil::new().args(vec!["--base-fee", "100"]).spawn(); - let contract_deployer: LocalWallet = anvil.keys()[9].clone().into(); - let provider = Provider::::try_from(anvil.endpoint()).unwrap(); - let client = SignerMiddleware::new( - provider.clone(), - contract_deployer.clone().with_chain_id(anvil.chain_id()), - ); - // 1. coinbase smart wallet - // deploy implementation for factory - let implementation = CoinbaseSmartWallet::deploy(Arc::new(client.clone()), ()) - .unwrap() - .gas_price(100) - .send() - .await - .unwrap(); - // deploy factory - let factory = - CoinbaseSmartWalletFactory::deploy(Arc::new(client.clone()), implementation.address()) - .unwrap() - .gas_price(100) - .send() - .await - .unwrap(); - - let smart_contracts = SmartContracts::new(factory); - fun(anvil, provider.clone(), client.clone(), smart_contracts).await - } -} diff --git a/xmtp_id/src/utils/test.rs b/xmtp_id/src/utils/test.rs new file mode 100644 index 000000000..f39109ada --- /dev/null +++ b/xmtp_id/src/utils/test.rs @@ -0,0 +1,215 @@ +#![allow(clippy::unwrap_used)] + +use ethers::{ + contract::abigen, + core::k256::{elliptic_curve::SecretKey, Secp256k1}, + middleware::SignerMiddleware, + providers::{Http, Provider}, + signers::LocalWallet, +}; +use std::sync::LazyLock; + +abigen!( + CoinbaseSmartWallet, + "artifact/CoinbaseSmartWallet.json", + derives(serde::Serialize, serde::Deserialize) +); + +abigen!( + CoinbaseSmartWalletFactory, + "artifact/CoinbaseSmartWalletFactory.json", + derives(serde::Serialize, serde::Deserialize) +); + +pub struct SmartContracts { + coinbase_smart_wallet_factory: + CoinbaseSmartWalletFactory, LocalWallet>>, +} + +impl SmartContracts { + #[cfg(not(target_arch = "wasm32"))] + fn new( + coinbase_smart_wallet_factory: CoinbaseSmartWalletFactory< + SignerMiddleware, LocalWallet>, + >, + ) -> Self { + Self { + coinbase_smart_wallet_factory, + } + } + + pub fn coinbase_smart_wallet_factory( + &self, + ) -> &CoinbaseSmartWalletFactory, LocalWallet>> { + &self.coinbase_smart_wallet_factory + } +} + +pub static ANVIL_KEYS: LazyLock>> = LazyLock::new(|| { + vec![ + SecretKey::from_slice( + hex::decode("ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("7c852118294e51e653712a81e05800f419141751be58f605c371e15141b007a6") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("47e179ec197488593b187f80a00eb0da91f1b9d0b13f8733639f19c30a34926a") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("8b3a350cf5c34c9194ca85829a2df0ec3153be0318b5e2d3348e872092edffba") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("92db14e403b83dfe3df233f83dfa3a0d7096f21ca9b0d6d6b8d88b2b4ec1564e") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("4bbbf85ce3377467afe5d46f804f221813b2bb87f24d81f60f1fcdbf7cbf4356") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("dbda1821b80551c9d65939329250298aa3472ba22feea921c0cf5d620ea67b97") + .unwrap() + .as_slice(), + ) + .unwrap(), + SecretKey::from_slice( + hex::decode("2a871d0798f97d79848a013d4936a73bf4cc922c825d33c1cf7073dff6d409c6") + .unwrap() + .as_slice(), + ) + .unwrap(), + ] +}); + +pub struct AnvilMeta { + pub keys: Vec>, + pub endpoint: String, + pub chain_id: u64, +} + +#[cfg(not(target_arch = "wasm32"))] +pub async fn with_docker_smart_contracts(fun: Func) +where + Func: FnOnce( + AnvilMeta, + Provider, + SignerMiddleware, LocalWallet>, + SmartContracts, + ) -> Fut, + Fut: futures::Future, +{ + use ethers::signers::Signer; + use std::sync::Arc; + + let keys = ANVIL_KEYS.clone(); + let anvil_meta = AnvilMeta { + keys: keys.clone(), + chain_id: 31337, + endpoint: "http://localhost:8545".to_string(), + }; + + let contract_deployer: LocalWallet = keys[9].clone().into(); + let provider = Provider::::try_from(&anvil_meta.endpoint).unwrap(); + let client = SignerMiddleware::new( + provider.clone(), + contract_deployer.clone().with_chain_id(anvil_meta.chain_id), + ); + // 1. coinbase smart wallet + // deploy implementation for factory + let implementation = CoinbaseSmartWallet::deploy(Arc::new(client.clone()), ()) + .unwrap() + .gas_price(100) + .send() + .await + .unwrap(); + // deploy factory + let factory = + CoinbaseSmartWalletFactory::deploy(Arc::new(client.clone()), implementation.address()) + .unwrap() + .gas_price(100) + .send() + .await + .unwrap(); + + let smart_contracts = SmartContracts::new(factory); + fun( + anvil_meta, + provider.clone(), + client.clone(), + smart_contracts, + ) + .await +} + +// anvil can't be used in wasm because it is a system binary +/// Test harness that loads a local anvil node with deployed smart contracts. +#[cfg(not(target_arch = "wasm32"))] +pub async fn with_smart_contracts(fun: Func) +where + Func: FnOnce( + ethers::utils::AnvilInstance, + Provider, + SignerMiddleware, LocalWallet>, + SmartContracts, + ) -> Fut, + Fut: futures::Future, +{ + use ethers::signers::Signer; + use ethers::utils::Anvil; + use std::sync::Arc; + let anvil = Anvil::new().args(vec!["--base-fee", "100"]).spawn(); + let contract_deployer: LocalWallet = anvil.keys()[9].clone().into(); + let provider = Provider::::try_from(anvil.endpoint()).unwrap(); + let client = SignerMiddleware::new( + provider.clone(), + contract_deployer.clone().with_chain_id(anvil.chain_id()), + ); + // 1. coinbase smart wallet + // deploy implementation for factory + let implementation = CoinbaseSmartWallet::deploy(Arc::new(client.clone()), ()) + .unwrap() + .gas_price(100) + .send() + .await + .unwrap(); + // deploy factory + let factory = + CoinbaseSmartWalletFactory::deploy(Arc::new(client.clone()), implementation.address()) + .unwrap() + .gas_price(100) + .send() + .await + .unwrap(); + + let smart_contracts = SmartContracts::new(factory); + fun(anvil, provider.clone(), client.clone(), smart_contracts).await +} diff --git a/xmtp_mls/src/api/identity.rs b/xmtp_mls/src/api/identity.rs index 74d0597e0..02e6be91c 100644 --- a/xmtp_mls/src/api/identity.rs +++ b/xmtp_mls/src/api/identity.rs @@ -181,7 +181,7 @@ pub(crate) mod tests { .withf(move |req| req.identity_update.as_ref().unwrap().inbox_id.eq(&inbox_id)) .returning(move |_| Ok(PublishIdentityUpdateResponse {})); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper.publish_identity_update(identity_update).await; assert!(result.is_ok()); @@ -211,7 +211,7 @@ pub(crate) mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .get_identity_updates_v2(vec![GetIdentityUpdatesV2Filter { inbox_id: inbox_id_clone_2.clone(), @@ -257,7 +257,7 @@ pub(crate) mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .get_inbox_ids(vec![address.clone()]) .await diff --git a/xmtp_mls/src/api/mls.rs b/xmtp_mls/src/api/mls.rs index 22d453477..82d339ab1 100644 --- a/xmtp_mls/src/api/mls.rs +++ b/xmtp_mls/src/api/mls.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use super::ApiClientWrapper; use crate::{retry_async, XmtpApi}; -use xmtp_proto::api_client::{Error as ApiError, ErrorKind}; +use xmtp_proto::api_client::{Error as ApiError, ErrorKind, XmtpMlsStreams}; use xmtp_proto::xmtp::mls::api::v1::{ group_message_input::{Version as GroupMessageInputVersion, V1 as GroupMessageInputV1}, subscribe_group_messages_request::Filter as GroupFilterProto, @@ -262,7 +262,10 @@ where pub async fn subscribe_group_messages( &self, filters: Vec, - ) -> Result> + '_, ApiError> { + ) -> Result> + '_, ApiError> + where + ApiClient: XmtpMlsStreams, + { self.api_client .subscribe_group_messages(SubscribeGroupMessagesRequest { filters: filters.into_iter().map(|f| f.into()).collect(), @@ -274,7 +277,10 @@ where &self, installation_key: Vec, id_cursor: Option, - ) -> Result> + '_, ApiError> { + ) -> Result> + '_, ApiError> + where + ApiClient: XmtpMlsStreams, + { self.api_client .subscribe_welcome_messages(SubscribeWelcomeMessagesRequest { filters: vec![WelcomeFilterProto { @@ -320,7 +326,7 @@ pub mod tests { .eq(&key_package) }) .returning(move |_| Ok(())); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper.upload_key_package(key_package_clone, false).await; assert!(result.is_ok()); } @@ -343,7 +349,7 @@ pub mod tests { ], }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .fetch_key_packages(installation_keys.clone()) .await @@ -381,7 +387,7 @@ pub mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .query_group_messages(group_id_clone, None) @@ -413,7 +419,7 @@ pub mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .query_group_messages(group_id_clone, None) @@ -464,7 +470,7 @@ pub mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .query_group_messages(group_id_clone2, None) @@ -498,7 +504,7 @@ pub mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let result = wrapper .query_group_messages(group_id_clone, None) diff --git a/xmtp_mls/src/api/mod.rs b/xmtp_mls/src/api/mod.rs index b623ef62f..0d1fc7daf 100644 --- a/xmtp_mls/src/api/mod.rs +++ b/xmtp_mls/src/api/mod.rs @@ -3,6 +3,8 @@ pub mod mls; #[cfg(test)] pub mod test_utils; +use std::sync::Arc; + use crate::{ retry::{Retry, RetryableError}, XmtpApi, @@ -30,7 +32,7 @@ impl RetryableError for WrappedApiError { #[derive(Clone, Debug)] pub struct ApiClientWrapper { - pub(crate) api_client: ApiClient, + pub(crate) api_client: Arc, pub(crate) retry_strategy: Retry, } @@ -38,7 +40,7 @@ impl ApiClientWrapper where ApiClient: XmtpApi, { - pub fn new(api_client: ApiClient, retry_strategy: Retry) -> Self { + pub fn new(api_client: Arc, retry_strategy: Retry) -> Self { Self { api_client, retry_strategy, diff --git a/xmtp_mls/src/api/test_utils.rs b/xmtp_mls/src/api/test_utils.rs index 09e58743a..e20def9cd 100644 --- a/xmtp_mls/src/api/test_utils.rs +++ b/xmtp_mls/src/api/test_utils.rs @@ -41,59 +41,137 @@ pub fn build_group_messages(num_messages: usize, group_id: Vec) -> Vec Result<(), Error>; - fn set_app_version(&mut self, version: String) -> Result<(), Error>; - } + #[async_trait::async_trait] + impl ClientWithMetadata for ApiClient { + fn set_libxmtp_version(&mut self, version: String) -> Result<(), Error>; + fn set_app_version(&mut self, version: String) -> Result<(), Error>; + } - impl XmtpMlsClient for ApiClient { - async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; - async fn fetch_key_packages( - &self, - request: FetchKeyPackagesRequest, - ) -> Result; - async fn send_group_messages(&self, request: SendGroupMessagesRequest) -> Result<(), Error>; - async fn send_welcome_messages(&self, request: SendWelcomeMessagesRequest) -> Result<(), Error>; - async fn query_group_messages(&self, request: QueryGroupMessagesRequest) -> Result; - async fn query_welcome_messages(&self, request: QueryWelcomeMessagesRequest) -> Result; - } + #[async_trait::async_trait] + impl XmtpMlsClient for ApiClient { + async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; + async fn fetch_key_packages( + &self, + request: FetchKeyPackagesRequest, + ) -> Result; + async fn send_group_messages(&self, request: SendGroupMessagesRequest) -> Result<(), Error>; + async fn send_welcome_messages(&self, request: SendWelcomeMessagesRequest) -> Result<(), Error>; + async fn query_group_messages(&self, request: QueryGroupMessagesRequest) -> Result; + async fn query_welcome_messages(&self, request: QueryWelcomeMessagesRequest) -> Result; + } - impl XmtpMlsStreams for ApiClient { - #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] - type GroupMessageStream<'a> = xmtp_api_grpc::GroupMessageStream; - #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] - type WelcomeMessageStream<'a> = xmtp_api_grpc::WelcomeMessageStream; + #[async_trait::async_trait] + impl XmtpMlsStreams for ApiClient { + #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] + type GroupMessageStream<'a> = xmtp_api_grpc::GroupMessageStream; + #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] + type WelcomeMessageStream<'a> = xmtp_api_grpc::WelcomeMessageStream; - #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] - type GroupMessageStream<'a> = futures::stream::BoxStream<'static, Result>; - #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] - type WelcomeMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] + type GroupMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] + type WelcomeMessageStream<'a> = futures::stream::BoxStream<'static, Result>; - #[cfg(target_arch = "wasm32")] - type GroupMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; - #[cfg(target_arch = "wasm32")] - type WelcomeMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; + #[cfg(target_arch = "wasm32")] + type GroupMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; + #[cfg(target_arch = "wasm32")] + type WelcomeMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; - async fn subscribe_group_messages(&self, request: SubscribeGroupMessagesRequest) -> Result<::GroupMessageStream<'static>, Error>; - async fn subscribe_welcome_messages(&self, request: SubscribeWelcomeMessagesRequest) -> Result<::WelcomeMessageStream<'static>, Error>; - } + async fn subscribe_group_messages(&self, request: SubscribeGroupMessagesRequest) -> Result<::GroupMessageStream<'static>, Error>; + async fn subscribe_welcome_messages(&self, request: SubscribeWelcomeMessagesRequest) -> Result<::WelcomeMessageStream<'static>, Error>; + } - impl XmtpIdentityClient for ApiClient { - async fn publish_identity_update(&self, request: PublishIdentityUpdateRequest) -> Result; - async fn get_identity_updates_v2(&self, request: GetIdentityUpdatesV2Request) -> Result; - async fn get_inbox_ids(&self, request: GetInboxIdsRequest) -> Result; - async fn verify_smart_contract_wallet_signatures(&self, request: VerifySmartContractWalletSignaturesRequest) - -> Result; + #[async_trait::async_trait] + impl XmtpIdentityClient for ApiClient { + async fn publish_identity_update(&self, request: PublishIdentityUpdateRequest) -> Result; + async fn get_identity_updates_v2(&self, request: GetIdentityUpdatesV2Request) -> Result; + async fn get_inbox_ids(&self, request: GetInboxIdsRequest) -> Result; + async fn verify_smart_contract_wallet_signatures(&self, request: VerifySmartContractWalletSignaturesRequest) + -> Result; + } + + #[async_trait::async_trait] + impl XmtpTestClient for ApiClient { + async fn create_local() -> Self { ApiClient } + async fn create_dev() -> Self { ApiClient } + } } +} + +#[cfg(target_arch = "wasm32")] +mod wasm { + use super::*; + mock! { + pub ApiClient {} + + #[async_trait::async_trait(?Send)] + impl ClientWithMetadata for ApiClient { + fn set_libxmtp_version(&mut self, version: String) -> Result<(), Error>; + fn set_app_version(&mut self, version: String) -> Result<(), Error>; + } + + #[async_trait::async_trait(?Send)] + impl XmtpMlsClient for ApiClient { + async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; + async fn fetch_key_packages( + &self, + request: FetchKeyPackagesRequest, + ) -> Result; + async fn send_group_messages(&self, request: SendGroupMessagesRequest) -> Result<(), Error>; + async fn send_welcome_messages(&self, request: SendWelcomeMessagesRequest) -> Result<(), Error>; + async fn query_group_messages(&self, request: QueryGroupMessagesRequest) -> Result; + async fn query_welcome_messages(&self, request: QueryWelcomeMessagesRequest) -> Result; + } + + #[async_trait::async_trait(?Send)] + impl XmtpMlsStreams for ApiClient { + #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] + type GroupMessageStream<'a> = xmtp_api_grpc::GroupMessageStream; + #[cfg(all(not(feature = "http-api"), not(target_arch = "wasm32")))] + type WelcomeMessageStream<'a> = xmtp_api_grpc::WelcomeMessageStream; + + #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] + type GroupMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + #[cfg(all(feature = "http-api", not(target_arch = "wasm32")))] + type WelcomeMessageStream<'a> = futures::stream::BoxStream<'static, Result>; + + #[cfg(target_arch = "wasm32")] + type GroupMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; + #[cfg(target_arch = "wasm32")] + type WelcomeMessageStream<'a> = futures::stream::LocalBoxStream<'static, Result>; + + + async fn subscribe_group_messages(&self, request: SubscribeGroupMessagesRequest) -> Result<::GroupMessageStream<'static>, Error>; + async fn subscribe_welcome_messages(&self, request: SubscribeWelcomeMessagesRequest) -> Result<::WelcomeMessageStream<'static>, Error>; + } + + #[async_trait::async_trait(?Send)] + impl XmtpIdentityClient for ApiClient { + async fn publish_identity_update(&self, request: PublishIdentityUpdateRequest) -> Result; + async fn get_identity_updates_v2(&self, request: GetIdentityUpdatesV2Request) -> Result; + async fn get_inbox_ids(&self, request: GetInboxIdsRequest) -> Result; + async fn verify_smart_contract_wallet_signatures(&self, request: VerifySmartContractWalletSignaturesRequest) + -> Result; + } - impl XmtpTestClient for ApiClient { - async fn create_local() -> Self { ApiClient } - async fn create_dev() -> Self { ApiClient } + #[async_trait::async_trait(?Send)] + impl XmtpTestClient for ApiClient { + async fn create_local() -> Self { ApiClient } + async fn create_dev() -> Self { ApiClient } + } } } diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index cf2449088..bb60d7082 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use thiserror::Error; use tracing::debug; @@ -106,49 +108,40 @@ impl ClientBuilder { impl ClientBuilder where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { /// Build with a custom smart contract wallet verifier pub async fn build_with_verifier(self) -> Result, ClientBuilderError> { - inner_build(self).await + let (builder, api_client) = inner_build_api_client(self)?; + inner_build(builder, api_client).await } } impl ClientBuilder> where - ApiClient: XmtpApi + Clone, + ApiClient: XmtpApi, { /// Build with the default [`RemoteSignatureVerifier`] - pub async fn build(mut self) -> Result, ClientBuilderError> { - let api_client = - self.api_client - .clone() - .take() - .ok_or(ClientBuilderError::MissingParameter { - parameter: "api_client", - })?; - self = self.scw_signature_verifier(RemoteSignatureVerifier::new(api_client)); - inner_build::>(self).await + pub async fn build(self) -> Result, ClientBuilderError> { + let (mut builder, api_client) = inner_build_api_client(self)?; + builder = builder.scw_signature_verifier(RemoteSignatureVerifier::new(api_client.clone())); + inner_build::>(builder, api_client).await } } -async fn inner_build(client: ClientBuilder) -> Result, ClientBuilderError> +fn inner_build_api_client( + mut builder: ClientBuilder, +) -> Result<(ClientBuilder, Arc), ClientBuilderError> where - C: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, { let ClientBuilder { - mut api_client, - mut store, - identity_strategy, - history_sync_url, - app_version, - mut scw_verifier, + ref mut api_client, + ref app_version, .. - } = client; + } = builder; - debug!("Building client"); let mut api_client = api_client .take() .ok_or(ClientBuilderError::MissingParameter { @@ -157,9 +150,30 @@ where api_client.set_libxmtp_version(env!("CARGO_PKG_VERSION").to_string())?; if let Some(app_version) = app_version { - api_client.set_app_version(app_version)?; + api_client.set_app_version(app_version.to_string())?; } + Ok((builder, Arc::new(api_client))) +} + +async fn inner_build( + client: ClientBuilder, + api_client: Arc, +) -> Result, ClientBuilderError> +where + C: XmtpApi, + V: SmartContractSignatureVerifier, +{ + let ClientBuilder { + mut store, + identity_strategy, + history_sync_url, + mut scw_verifier, + .. + } = client; + + debug!("Building client"); + let scw_verifier = scw_verifier .take() .ok_or(ClientBuilderError::MissingParameter { @@ -242,7 +256,7 @@ pub(crate) mod tests { Client, InboxOwner, }; - async fn register_client( + async fn register_client( client: &Client, owner: &impl InboxOwner, ) { @@ -526,7 +540,7 @@ pub(crate) mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let identity = IdentityStrategy::CreateIfNotFound("other_inbox_id".to_string(), address, nonce, None); @@ -568,7 +582,7 @@ pub(crate) mod tests { }) }); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let identity = IdentityStrategy::CreateIfNotFound(inbox_id.clone(), address, nonce, None); assert!(dbg!( @@ -608,7 +622,7 @@ pub(crate) mod tests { .unwrap(); stored.store(&store.conn().unwrap()).unwrap(); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let identity = IdentityStrategy::CreateIfNotFound(inbox_id.clone(), address, nonce, None); assert!(identity .initialize_identity(&wrapper, &store, &scw_verifier) @@ -646,7 +660,7 @@ pub(crate) mod tests { stored.store(&store.conn().unwrap()).unwrap(); - let wrapper = ApiClientWrapper::new(mock_api, Retry::default()); + let wrapper = ApiClientWrapper::new(mock_api.into(), Retry::default()); let inbox_id = "inbox_id".to_string(); let identity = diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index da3bb71c0..525a7aef9 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -228,21 +228,17 @@ pub struct FindGroupParams { /// Clients manage access to the network, identity, and data store pub struct Client> { - pub(crate) api_client: ApiClientWrapper, + pub(crate) api_client: Arc>, pub(crate) intents: Arc, pub(crate) context: Arc, pub(crate) history_sync_url: Option, pub(crate) local_events: broadcast::Sender>, /// The method of verifying smart contract wallet signatures for this Client - pub(crate) scw_verifier: V, + pub(crate) scw_verifier: Arc, } // most of these things are `Arc`'s -impl Clone for Client -where - ApiClient: Clone, - V: Clone, -{ +impl Clone for Client { fn clone(&self) -> Self { Self { api_client: self.api_client.clone(), @@ -300,8 +296,8 @@ impl XmtpMlsLocalContext { impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { /// Create a new client with the given network, identity, and store. /// It is expected that most users will use the [`ClientBuilder`](crate::builder::ClientBuilder) instead of instantiating @@ -326,11 +322,11 @@ where }); let (tx, _) = broadcast::channel(10); Self { - api_client, + api_client: api_client.into(), context, history_sync_url, local_events: tx, - scw_verifier, + scw_verifier: scw_verifier.into(), intents, } } @@ -342,8 +338,8 @@ where impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { /// Retrieves the client's installation public key, sometimes also called `installation_id` pub fn installation_public_key(&self) -> Vec { @@ -1309,8 +1305,8 @@ pub(crate) mod tests { } async fn get_key_package_init_key< - ApiClient: XmtpApi + Clone, - Verifier: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + Verifier: SmartContractSignatureVerifier, >( client: &Client, installation_id: &[u8], diff --git a/xmtp_mls/src/groups/device_sync.rs b/xmtp_mls/src/groups/device_sync.rs index c1d1e3bb0..09bed12fd 100644 --- a/xmtp_mls/src/groups/device_sync.rs +++ b/xmtp_mls/src/groups/device_sync.rs @@ -104,8 +104,8 @@ pub enum DeviceSyncError { impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { pub async fn enable_sync(&self, provider: &XmtpOpenMlsProvider) -> Result<(), GroupError> { let sync_group = match self.get_sync_group() { @@ -372,8 +372,8 @@ impl MessageHistoryUrls { impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { pub fn get_sync_group(&self) -> Result, GroupError> { let conn = self.store().conn()?; diff --git a/xmtp_mls/src/groups/device_sync/consent_sync.rs b/xmtp_mls/src/groups/device_sync/consent_sync.rs index 351a30aee..b8f4e4a89 100644 --- a/xmtp_mls/src/groups/device_sync/consent_sync.rs +++ b/xmtp_mls/src/groups/device_sync/consent_sync.rs @@ -4,8 +4,8 @@ use xmtp_id::scw_verifier::SmartContractSignatureVerifier; impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { pub async fn send_consent_sync_request( &self, diff --git a/xmtp_mls/src/groups/device_sync/message_sync.rs b/xmtp_mls/src/groups/device_sync/message_sync.rs index ca516fc80..96cb6ffc7 100644 --- a/xmtp_mls/src/groups/device_sync/message_sync.rs +++ b/xmtp_mls/src/groups/device_sync/message_sync.rs @@ -6,8 +6,8 @@ use xmtp_id::scw_verifier::SmartContractSignatureVerifier; impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { // returns (request_id, pin_code) pub async fn send_history_sync_request( diff --git a/xmtp_mls/src/groups/message_history.rs b/xmtp_mls/src/groups/message_history.rs new file mode 100644 index 000000000..5efad7054 --- /dev/null +++ b/xmtp_mls/src/groups/message_history.rs @@ -0,0 +1,1469 @@ +use std::fs::{File, OpenOptions}; +use std::io::{BufRead, BufReader, Read, Write}; +use std::path::{Path, PathBuf}; + +use aes_gcm::aead::generic_array::GenericArray; +use aes_gcm::{ + aead::{Aead, KeyInit}, + Aes256Gcm, +}; +use rand::{ + distributions::{Alphanumeric, DistString}, + Rng, RngCore, +}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use xmtp_cryptography::utils as crypto_utils; +use xmtp_id::scw_verifier::SmartContractSignatureVerifier; +use xmtp_proto::{ + xmtp::mls::message_contents::plaintext_envelope::v2::MessageType::{Reply, Request}, + xmtp::mls::message_contents::plaintext_envelope::{Content, V2}, + xmtp::mls::message_contents::PlaintextEnvelope, + xmtp::mls::message_contents::{ + message_history_key_type::Key, MessageHistoryKeyType, MessageHistoryReply, + MessageHistoryRequest, + }, +}; + +use super::group_metadata::ConversationType; +use super::{GroupError, MlsGroup}; + +use crate::storage::group_message::MsgQueryArgs; +use crate::XmtpApi; +use crate::{ + client::ClientError, + groups::{GroupMessageKind, StoredGroupMessage}, + storage::{group::StoredGroup, StorageError}, + Client, Store, +}; + +const ENC_KEY_SIZE: usize = 32; // 256-bit key +const NONCE_SIZE: usize = 12; // 96-bit nonce + +pub struct MessageHistoryUrls; + +impl MessageHistoryUrls { + pub const LOCAL_ADDRESS: &'static str = "http://0.0.0.0:5558"; + pub const DEV_ADDRESS: &'static str = "https://message-history.dev.ephemera.network/"; + pub const PRODUCTION_ADDRESS: &'static str = "https://message-history.ephemera.network/"; +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum MessageHistoryContent { + Request(MessageHistoryRequest), + Reply(MessageHistoryReply), +} + +#[derive(Debug, Error)] +pub enum MessageHistoryError { + #[error("pin not found")] + PinNotFound, + #[error("pin does not match the expected value")] + PinMismatch, + #[error("IO error: {0}")] + IO(#[from] std::io::Error), + #[error("Serialization/Deserialization Error {0}")] + Serde(#[from] serde_json::Error), + #[error("AES-GCM encryption error")] + AesGcm(#[from] aes_gcm::Error), + #[error("reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("storage error: {0}")] + Storage(#[from] StorageError), + #[error("type conversion error")] + Conversion, + #[error("utf-8 error: {0}")] + UTF8(#[from] std::str::Utf8Error), + #[error("client error: {0}")] + Client(#[from] ClientError), + #[error("group error: {0}")] + Group(#[from] GroupError), + #[error("request ID of reply does not match request")] + ReplyRequestIdMismatch, + #[error("reply already processed")] + ReplyAlreadyProcessed, + #[error("no pending request to reply to")] + NoPendingRequest, + #[error("no reply to process")] + NoReplyToProcess, + #[error("generic: {0}")] + Generic(String), + #[error("missing history sync url")] + MissingHistorySyncUrl, + #[error("invalid history message payload")] + InvalidPayload, + #[error("invalid history bundle url")] + InvalidBundleUrl, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum SyncableTables { + StoredGroup(StoredGroup), + StoredGroupMessage(StoredGroupMessage), +} + +impl Client +where + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, +{ + pub fn get_sync_group(&self) -> Result, GroupError> { + let conn = self.store().conn()?; + let sync_group_id = conn + .find_sync_groups()? + .pop() + .ok_or(GroupError::GroupNotFound)? + .id; + let sync_group = self.group(sync_group_id.clone())?; + + Ok(sync_group) + } + + pub async fn enable_history_sync(&self) -> Result<(), GroupError> { + // look for the sync group, create if not found + let sync_group = match self.get_sync_group() { + Ok(group) => group, + Err(_) => { + // create the sync group + self.create_sync_group()? + } + }; + + // sync the group + sync_group.sync().await?; + + Ok(()) + } + + pub async fn ensure_member_of_all_groups(&self, inbox_id: String) -> Result<(), GroupError> { + let conn = self.store().conn()?; + let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?; + for group in groups { + let group = self.group(group.id)?; + Box::pin(group.add_members_by_inbox_id(vec![inbox_id.clone()])).await?; + } + + Ok(()) + } + + // returns (request_id, pin_code) + pub async fn send_history_request(&self) -> Result<(String, String), MessageHistoryError> { + // find the sync group + let conn = self.store().conn()?; + let sync_group = self.get_sync_group()?; + + // sync the group + sync_group.sync().await?; + + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + + let last_message = messages.last(); + if let Some(msg) = last_message { + let message_history_content = + serde_json::from_slice::(&msg.decrypted_message_bytes)?; + + if let MessageHistoryContent::Request(request) = message_history_content { + return Ok((request.request_id, request.pin_code)); + } + }; + + // build the request + let history_request = HistoryRequest::new(); + let pin_code = history_request.pin_code.clone(); + let request_id = history_request.request_id.clone(); + + let content = MessageHistoryContent::Request(MessageHistoryRequest { + request_id: request_id.clone(), + pin_code: pin_code.clone(), + }); + let content_bytes = serde_json::to_vec(&content)?; + + let _message_id = + sync_group.prepare_message(content_bytes.as_slice(), &conn, move |_time_ns| { + PlaintextEnvelope { + content: Some(Content::V2(V2 { + message_type: Some(Request(history_request.into())), + idempotency_key: new_request_id(), + })), + } + })?; + + // publish the intent + if let Err(err) = sync_group.publish_intents(&conn.into()).await { + tracing::error!("error publishing sync group intents: {:?}", err); + } + + Ok((request_id, pin_code)) + } + + pub(crate) async fn send_history_reply( + &self, + contents: MessageHistoryReply, + ) -> Result<(), MessageHistoryError> { + // find the sync group + let conn = self.store().conn()?; + let sync_group = self.get_sync_group()?; + + // sync the group + Box::pin(sync_group.sync()).await?; + + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + + let last_message = match messages.last() { + Some(msg) => { + let message_history_content = + serde_json::from_slice::(&msg.decrypted_message_bytes)?; + match message_history_content { + MessageHistoryContent::Request(request) => { + // check that the request ID matches + if !request.request_id.eq(&contents.request_id) { + return Err(MessageHistoryError::ReplyRequestIdMismatch); + } + Some(msg) + } + MessageHistoryContent::Reply(_) => { + // if last message is a reply, it's already been processed + return Err(MessageHistoryError::ReplyAlreadyProcessed); + } + } + } + None => { + return Err(MessageHistoryError::NoPendingRequest); + } + }; + + tracing::info!("{:?}", last_message); + + if let Some(msg) = last_message { + // ensure the requester is a member of all the groups + self.ensure_member_of_all_groups(msg.sender_inbox_id.clone()) + .await?; + } + + // the reply message + let content = MessageHistoryContent::Reply(contents.clone()); + let content_bytes = serde_json::to_vec(&content)?; + + let _message_id = + sync_group.prepare_message(content_bytes.as_slice(), &conn, move |_time_ns| { + PlaintextEnvelope { + content: Some(Content::V2(V2 { + idempotency_key: new_request_id(), + message_type: Some(Reply(contents)), + })), + } + })?; + + // publish the intent + if let Err(err) = sync_group.publish_messages().await { + tracing::error!("error publishing sync group intents: {:?}", err); + } + Ok(()) + } + + pub async fn get_pending_history_request( + &self, + ) -> Result, MessageHistoryError> { + let sync_group = self.get_sync_group()?; + + // sync the group + sync_group.sync().await?; + + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + let last_message = messages.last(); + + let history_request: Option<(String, String)> = if let Some(msg) = last_message { + let message_history_content = + serde_json::from_slice::(&msg.decrypted_message_bytes)?; + match message_history_content { + // if the last message is a request, return its request ID and pin code + MessageHistoryContent::Request(request) => { + Some((request.request_id, request.pin_code)) + } + _ => None, + } + } else { + None + }; + + Ok(history_request) + } + + pub async fn reply_to_history_request( + &self, + ) -> Result { + let pending_request = self.get_pending_history_request().await?; + + if let Some((request_id, _)) = pending_request { + let reply = self.prepare_history_reply(&request_id).await?; + self.send_history_reply(reply.clone().into()).await?; + return Ok(reply.into()); + } + + Err(MessageHistoryError::NoPendingRequest) + } + + pub async fn get_latest_history_reply( + &self, + ) -> Result, MessageHistoryError> { + let sync_group = self.get_sync_group()?; + + // sync the group + sync_group.sync().await?; + + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + + let last_message = messages.last(); + + let reply: Option = match last_message { + Some(msg) => { + // if the message was sent by this installation, ignore it + if msg + .sender_installation_id + .eq(&self.installation_public_key()) + { + None + } else { + let message_history_content = serde_json::from_slice::( + &msg.decrypted_message_bytes, + )?; + match message_history_content { + // if the last message is a reply, return it + MessageHistoryContent::Reply(reply) => Some(reply), + _ => None, + } + } + } + None => None, + }; + + Ok(reply) + } + + pub async fn process_history_reply(&self) -> Result<(), MessageHistoryError> { + let reply = self.get_latest_history_reply().await?; + + if let Some(reply) = reply { + let Some(encryption_key) = reply.encryption_key.clone() else { + return Err(MessageHistoryError::InvalidPayload); + }; + + let history_bundle = download_history_bundle(&reply.url).await?; + let messages_path = std::env::temp_dir().join("messages.jsonl"); + + decrypt_history_file(&history_bundle, &messages_path, encryption_key)?; + + self.insert_history_bundle(&messages_path)?; + + // clean up temporary files associated with the bundle + std::fs::remove_file(history_bundle.as_path())?; + std::fs::remove_file(messages_path.as_path())?; + + self.sync_welcomes().await?; + + let conn = self.store().conn()?; + let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?; + for crate::storage::group::StoredGroup { id, .. } in groups.into_iter() { + let group = self.group(id)?; + Box::pin(group.sync()).await?; + } + + return Ok(()); + } + + Err(MessageHistoryError::NoReplyToProcess) + } + + pub(crate) fn verify_pin( + &self, + request_id: &str, + pin_code: &str, + ) -> Result<(), MessageHistoryError> { + let sync_group = self.get_sync_group()?; + let requests = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application))?; + let request = requests.into_iter().find(|msg| { + let message_history_content = + serde_json::from_slice::(&msg.decrypted_message_bytes); + + match message_history_content { + Ok(MessageHistoryContent::Request(request)) => { + request.request_id.eq(request_id) && request.pin_code.eq(pin_code) + } + Err(e) => { + tracing::debug!("serde_json error: {:?}", e); + false + } + _ => false, + } + }); + + if request.is_none() { + return Err(MessageHistoryError::PinNotFound); + } + + Ok(()) + } + + pub(crate) fn insert_history_bundle( + &self, + history_file: &Path, + ) -> Result<(), MessageHistoryError> { + let file = File::open(history_file)?; + let reader = BufReader::new(file); + let lines = reader.lines(); + + let conn = self.store().conn()?; + + for line in lines { + let line = line?; + let db_entry: SyncableTables = serde_json::from_str(&line)?; + match db_entry { + SyncableTables::StoredGroup(group) => { + // alternatively consider: group.store(&conn)? + conn.insert_or_replace_group(group)?; + } + SyncableTables::StoredGroupMessage(group_message) => { + group_message.store(&conn)?; + } + } + } + + Ok(()) + } + + pub(crate) async fn prepare_history_reply( + &self, + request_id: &str, + ) -> Result { + let (history_file, enc_key) = self.write_history_bundle().await?; + let url = match &self.history_sync_url { + Some(url) => url.as_str(), + None => return Err(MessageHistoryError::MissingHistorySyncUrl), + }; + let upload_url = format!("{}{}", url, "upload"); + tracing::info!("using upload url {:?}", upload_url); + + let bundle_file = upload_history_bundle(&upload_url, history_file.clone()).await?; + let bundle_url = format!("{}files/{}", url, bundle_file); + + tracing::info!("history bundle uploaded to {:?}", bundle_url); + + Ok(HistoryReply::new(request_id, &bundle_url, enc_key)) + } + + async fn write_history_bundle(&self) -> Result<(PathBuf, HistoryKeyType), MessageHistoryError> { + let groups = self.prepare_groups_to_sync().await?; + let messages = self.prepare_messages_to_sync().await?; + + let temp_file = std::env::temp_dir().join("history.jsonl.tmp"); + write_to_file(temp_file.as_path(), groups)?; + write_to_file(temp_file.as_path(), messages)?; + + let history_file = std::env::temp_dir().join("history.jsonl.enc"); + let enc_key = HistoryKeyType::new_chacha20_poly1305_key(); + encrypt_history_file( + temp_file.as_path(), + history_file.as_path(), + enc_key.as_bytes(), + )?; + + std::fs::remove_file(temp_file.as_path())?; + + Ok((history_file, enc_key)) + } + + async fn prepare_groups_to_sync(&self) -> Result, MessageHistoryError> { + let conn = self.store().conn()?; + Ok(conn.find_groups(None, None, None, None, Some(ConversationType::Group))?) + } + + async fn prepare_messages_to_sync( + &self, + ) -> Result, MessageHistoryError> { + let conn = self.store().conn()?; + let groups = conn.find_groups(None, None, None, None, Some(ConversationType::Group))?; + let mut all_messages: Vec = vec![]; + + for StoredGroup { id, .. } in groups.into_iter() { + let messages = conn.get_group_messages(&id, &MsgQueryArgs::default())?; + all_messages.extend(messages); + } + + Ok(all_messages) + } +} + +fn write_to_file( + file_path: &Path, + content: Vec, +) -> Result<(), MessageHistoryError> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(file_path)?; + for entry in content { + let entry_str = serde_json::to_string(&entry)?; + file.write_all(entry_str.as_bytes())?; + file.write_all(b"\n")?; + } + + Ok(()) +} + +fn encrypt_history_file( + input_path: &Path, + output_path: &Path, + encryption_key: &[u8; ENC_KEY_SIZE], +) -> Result<(), MessageHistoryError> { + // Read in the messages file content + let mut input_file = File::open(input_path)?; + let mut buffer = Vec::new(); + input_file.read_to_end(&mut buffer)?; + + let nonce = generate_nonce(); + + // Create a cipher instance + let cipher = Aes256Gcm::new(GenericArray::from_slice(encryption_key)); + let nonce_array = GenericArray::from_slice(&nonce); + + // Encrypt the file content + let ciphertext = cipher.encrypt(nonce_array, buffer.as_ref())?; + + // Write the nonce and ciphertext to the output file + let mut output_file = File::create(output_path)?; + output_file.write_all(&nonce)?; + output_file.write_all(&ciphertext)?; + + Ok(()) +} + +pub(crate) fn decrypt_history_file( + input_path: &Path, + output_path: &Path, + encryption_key: MessageHistoryKeyType, +) -> Result<(), MessageHistoryError> { + let enc_key: HistoryKeyType = encryption_key.try_into()?; + let enc_key_bytes = enc_key.as_bytes(); + // Read the messages file content + let mut input_file = File::open(input_path)?; + let mut buffer = Vec::new(); + input_file.read_to_end(&mut buffer)?; + + // Split the nonce and ciphertext + let (nonce, ciphertext) = buffer.split_at(NONCE_SIZE); + + // Create a cipher instance + let cipher = Aes256Gcm::new(GenericArray::from_slice(enc_key_bytes)); + let nonce_array = GenericArray::from_slice(nonce); + + // Decrypt the ciphertext + let plaintext = cipher.decrypt(nonce_array, ciphertext)?; + + // Write the plaintext to the output file + let mut output_file = File::create(output_path)?; + output_file.write_all(&plaintext)?; + + Ok(()) +} + +async fn upload_history_bundle( + url: &str, + file_path: PathBuf, +) -> Result { + let mut file = File::open(file_path)?; + let mut content = Vec::new(); + file.read_to_end(&mut content)?; + + let client = reqwest::Client::new(); + let response = client.post(url).body(content).send().await?; + + if response.status().is_success() { + Ok(response.text().await?) + } else { + tracing::error!( + "Failed to upload file. Status code: {} Response: {:?}", + response.status(), + response + ); + Err(MessageHistoryError::Reqwest( + response + .error_for_status() + .expect_err("Checked for success"), + )) + } +} + +pub(crate) async fn download_history_bundle(url: &str) -> Result { + let client = reqwest::Client::new(); + + tracing::info!("downloading history bundle from {:?}", url); + + let bundle_name = url + .split('/') + .last() + .ok_or(MessageHistoryError::InvalidBundleUrl)?; + + let response = client.get(url).send().await?; + + if response.status().is_success() { + let file_name = format!("{}.jsonl.enc", bundle_name); + let file_path = std::env::temp_dir().join(file_name); + let mut file = File::create(&file_path)?; + let bytes = response.bytes().await?; + file.write_all(&bytes)?; + tracing::info!("downloaded history bundle to {:?}", file_path); + Ok(file_path) + } else { + tracing::error!( + "Failed to download file. Status code: {} Response: {:?}", + response.status(), + response + ); + Err(MessageHistoryError::Reqwest( + response + .error_for_status() + .expect_err("Checked for success"), + )) + } +} + +#[derive(Clone)] +struct HistoryRequest { + pin_code: String, + request_id: String, +} + +impl HistoryRequest { + pub(crate) fn new() -> Self { + Self { + pin_code: new_pin(), + request_id: new_request_id(), + } + } +} + +impl From for MessageHistoryRequest { + fn from(req: HistoryRequest) -> Self { + MessageHistoryRequest { + pin_code: req.pin_code, + request_id: req.request_id, + } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct HistoryReply { + /// Unique ID for each client Message History Request + request_id: String, + /// URL to download the backup bundle + url: String, + /// Encryption key for the backup bundle + encryption_key: HistoryKeyType, +} + +impl HistoryReply { + pub(crate) fn new(id: &str, url: &str, encryption_key: HistoryKeyType) -> Self { + Self { + request_id: id.into(), + url: url.into(), + encryption_key, + } + } +} + +impl From for MessageHistoryReply { + fn from(reply: HistoryReply) -> Self { + MessageHistoryReply { + request_id: reply.request_id, + url: reply.url, + encryption_key: Some(reply.encryption_key.into()), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum HistoryKeyType { + Chacha20Poly1305([u8; ENC_KEY_SIZE]), +} + +impl HistoryKeyType { + fn new_chacha20_poly1305_key() -> Self { + let mut rng = crypto_utils::rng(); + let mut key = [0u8; ENC_KEY_SIZE]; + rng.fill_bytes(&mut key); + HistoryKeyType::Chacha20Poly1305(key) + } + + fn len(&self) -> usize { + match self { + HistoryKeyType::Chacha20Poly1305(key) => key.len(), + } + } + + fn as_bytes(&self) -> &[u8; ENC_KEY_SIZE] { + match self { + HistoryKeyType::Chacha20Poly1305(key) => key, + } + } +} + +impl From for MessageHistoryKeyType { + fn from(key: HistoryKeyType) -> Self { + match key { + HistoryKeyType::Chacha20Poly1305(key) => MessageHistoryKeyType { + key: Some(Key::Chacha20Poly1305(key.to_vec())), + }, + } + } +} + +impl TryFrom for HistoryKeyType { + type Error = MessageHistoryError; + fn try_from(key: MessageHistoryKeyType) -> Result { + let MessageHistoryKeyType { key } = key; + match key { + Some(k) => { + let Key::Chacha20Poly1305(hist_key) = k; + match hist_key.try_into() { + Ok(array) => Ok(HistoryKeyType::Chacha20Poly1305(array)), + Err(_) => Err(MessageHistoryError::Conversion), + } + } + None => Err(MessageHistoryError::Conversion), + } + } +} + +fn new_request_id() -> String { + Alphanumeric.sample_string(&mut rand::thread_rng(), ENC_KEY_SIZE) +} + +fn generate_nonce() -> [u8; NONCE_SIZE] { + let mut nonce = [0u8; NONCE_SIZE]; + let mut rng = crypto_utils::rng(); + rng.fill_bytes(&mut nonce); + nonce +} + +fn new_pin() -> String { + let mut rng = crypto_utils::rng(); + let pin: u32 = rng.gen_range(0..10000); + format!("{:04}", pin) +} + +#[cfg(all(not(target_arch = "wasm32"), test))] +pub(crate) mod tests { + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_dedicated_worker); + + const HISTORY_SERVER_HOST: &str = "0.0.0.0"; + const HISTORY_SERVER_PORT: u16 = 5558; + + use super::*; + use mockito; + use std::io::{BufRead, BufReader}; + use tempfile::NamedTempFile; + use xmtp_cryptography::utils::generate_local_wallet; + use xmtp_id::InboxOwner; + + use crate::{assert_ok, builder::ClientBuilder, groups::GroupMetadataOptions}; + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_enable_history_sync() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(client.enable_history_sync().await); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_installations_are_added_to_sync_group() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + let amal_c = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal_c.enable_history_sync().await); + + amal_a.sync_welcomes().await.expect("sync_welcomes"); + amal_b.sync_welcomes().await.expect("sync_welcomes"); + + let conn_a = amal_a.store().conn().unwrap(); + let amal_a_sync_groups = conn_a.find_sync_groups().unwrap(); + + let conn_b = amal_b.store().conn().unwrap(); + let amal_b_sync_groups = conn_b.find_sync_groups().unwrap(); + + let conn_c = amal_c.store().conn().unwrap(); + let amal_c_sync_groups = conn_c.find_sync_groups().unwrap(); + + assert_eq!(amal_a_sync_groups.len(), 1); + assert_eq!(amal_b_sync_groups.len(), 1); + assert_eq!(amal_c_sync_groups.len(), 1); + // make sure all installations are in the same sync group + assert_eq!(amal_a_sync_groups[0].id, amal_b_sync_groups[0].id); + assert_eq!(amal_b_sync_groups[0].id, amal_c_sync_groups[0].id); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_send_history_request() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(client.enable_history_sync().await); + + // test that the request is sent, and that the pin code is returned + let (request_id, pin_code) = client + .send_history_request() + .await + .expect("history request"); + assert_eq!(request_id.len(), 32); + assert_eq!(pin_code.len(), 4); + + // test that another request will return the same request_id and + // pin_code because it hasn't been replied to yet + let (request_id2, pin_code2) = client + .send_history_request() + .await + .expect("history request"); + assert_eq!(request_id, request_id2); + assert_eq!(pin_code, pin_code2); + + // make sure there's only 1 message in the sync group + let sync_group = client.get_sync_group().unwrap(); + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application)) + .unwrap(); + assert_eq!(messages.len(), 1); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_send_history_reply() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(client.enable_history_sync().await); + + let request_id = new_request_id(); + let url = "https://test.com/abc-123"; + let encryption_key = HistoryKeyType::new_chacha20_poly1305_key(); + let reply = HistoryReply::new(&request_id, url, encryption_key); + let result = client.send_history_reply(reply.into()).await; + + // the reply should fail because there's no pending request to reply to + assert!(result.is_err()); + + let (request_id, _) = client + .send_history_request() + .await + .expect("history request"); + + let request_id2 = new_request_id(); + let url = "https://test.com/abc-123"; + let encryption_key = HistoryKeyType::new_chacha20_poly1305_key(); + let reply = HistoryReply::new(&request_id2, url, encryption_key); + let result = client.send_history_reply(reply.into()).await; + + // the reply should fail because there's a mismatched request ID + assert!(result.is_err()); + + let url = "https://test.com/abc-123"; + let encryption_key = HistoryKeyType::new_chacha20_poly1305_key(); + let reply = HistoryReply::new(&request_id, url, encryption_key); + let result = client.send_history_reply(reply.into()).await; + + // the reply should succeed with a valid request ID + assert_ok!(result); + + // make sure there's 2 messages in the sync group + let sync_group = client.get_sync_group().unwrap(); + let messages = sync_group + .find_messages(&MsgQueryArgs::default().kind(GroupMessageKind::Application)) + .unwrap(); + assert_eq!(messages.len(), 2); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_history_messages_stored_correctly() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal_b.enable_history_sync().await); + + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + let (_group_id, _pin_code) = amal_b + .send_history_request() + .await + .expect("history request"); + + // find the sync group + let amal_a_sync_groups = amal_a.store().conn().unwrap().find_sync_groups().unwrap(); + assert_eq!(amal_a_sync_groups.len(), 1); + // get the first sync group + let amal_a_sync_group = amal_a.group(amal_a_sync_groups[0].id.clone()).unwrap(); + amal_a_sync_group.sync().await.expect("sync"); + + // find the sync group (it should be the same as amal_a's sync group) + let amal_b_sync_groups = amal_b.store().conn().unwrap().find_sync_groups().unwrap(); + assert_eq!(amal_b_sync_groups.len(), 1); + // get the first sync group + let amal_b_sync_group = amal_b.group(amal_b_sync_groups[0].id.clone()).unwrap(); + amal_b_sync_group.sync().await.expect("sync"); + + // make sure they are the same group + assert_eq!(amal_a_sync_group.group_id, amal_b_sync_group.group_id); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[ignore] // this test is only relevant if we are enforcing the PIN challenge + async fn test_verify_pin() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal_b.enable_history_sync().await); + + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + let (request_id, pin_code) = amal_b + .send_history_request() + .await + .expect("history request"); + + let amal_a_sync_groups = amal_a.store().conn().unwrap().find_sync_groups().unwrap(); + assert_eq!(amal_a_sync_groups.len(), 1); + // get the first sync group + let amal_a_sync_group = amal_a.group(amal_a_sync_groups[0].id.clone()).unwrap(); + amal_a_sync_group.sync().await.expect("sync"); + let pin_challenge_result = amal_a.verify_pin(&request_id, &pin_code); + assert_ok!(pin_challenge_result); + + let pin_challenge_result_2 = amal_a.verify_pin("000", "000"); + assert!(pin_challenge_result_2.is_err()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + #[ignore] + async fn test_request_reply_roundtrip() { + let options = mockito::ServerOpts { + host: HISTORY_SERVER_HOST, + port: HISTORY_SERVER_PORT + 1, + ..Default::default() + }; + let mut server = mockito::Server::new_with_opts_async(options).await; + + let _m = server + .mock("POST", "/upload") + .with_status(201) + .with_body("File uploaded") + .create(); + + let history_sync_url = format!( + "http://{}:{}/upload", + HISTORY_SERVER_HOST, + HISTORY_SERVER_PORT + 1 + ); + + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let _group_a = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + + let groups = amal_a.prepare_groups_to_sync().await.unwrap(); + + let input_file = NamedTempFile::new().unwrap(); + let input_path = input_file.path(); + write_to_file(input_path, groups).unwrap(); + + let output_file = NamedTempFile::new().unwrap(); + let output_path = output_file.path(); + let encryption_key = HistoryKeyType::new_chacha20_poly1305_key(); + encrypt_history_file(input_path, output_path, encryption_key.as_bytes()).unwrap(); + + let mut file = File::open(output_path).unwrap(); + let mut content = Vec::new(); + file.read_to_end(&mut content).unwrap(); + + let _m = server + .mock("GET", "/upload") + .with_status(201) + .with_body(content) + .create(); + + let wallet = generate_local_wallet(); + let mut amal_a = ClientBuilder::new_test_client(&wallet).await; + amal_a.history_sync_url = Some(history_sync_url.clone()); + let amal_b = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal_b.enable_history_sync().await); + + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + // amal_b sends a message history request to sync group messages + let (_group_id, _pin_code) = amal_b + .send_history_request() + .await + .expect("history request"); + + let amal_a_sync_groups = amal_a.store().conn().unwrap().find_sync_groups().unwrap(); + assert_eq!(amal_a_sync_groups.len(), 1); + // get the first sync group + let amal_a_sync_group = amal_a.group(amal_a_sync_groups[0].id.clone()).unwrap(); + amal_a_sync_group.sync().await.expect("sync"); + + // amal_a builds and sends a message history reply back + let history_reply = HistoryReply::new(&new_request_id(), &history_sync_url, encryption_key); + amal_a + .send_history_reply(history_reply.into()) + .await + .expect("send reply"); + + amal_a_sync_group.sync().await.expect("sync"); + // amal_b should have received the reply + let amal_b_sync_groups = amal_b.store().conn().unwrap().find_sync_groups().unwrap(); + assert_eq!(amal_b_sync_groups.len(), 1); + + let amal_b_sync_group = amal_b.group(amal_b_sync_groups[0].id.clone()).unwrap(); + amal_b_sync_group.sync().await.expect("sync"); + + let amal_b_conn = amal_b.store().conn().unwrap(); + let amal_b_messages = amal_b_conn + .get_group_messages(&amal_b_sync_group.group_id, &MsgQueryArgs::default()) + .unwrap(); + + assert_eq!(amal_b_messages.len(), 1); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_prepare_groups_to_sync() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let _group_a = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + let _group_b = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + + let result = amal_a.prepare_groups_to_sync().await.unwrap(); + assert_eq!(result.len(), 2); + } + + #[tokio::test] + async fn test_prepare_group_messages_to_sync() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let group_a = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + let group_b = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + + group_a.send_message(b"hi").await.expect("send"); + group_a.send_message(b"hi x2").await.expect("send"); + group_b.send_message(b"hi").await.expect("send"); + group_b.send_message(b"hi x2").await.expect("send"); + + let messages_result = amal_a.prepare_messages_to_sync().await.unwrap(); + assert_eq!(messages_result.len(), 4); + } + + #[tokio::test] + async fn test_write_to_file() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let group_a = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + let group_b = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + + group_a.send_message(b"hi").await.expect("send"); + group_a.send_message(b"hi").await.expect("send"); + group_b.send_message(b"hi").await.expect("send"); + group_b.send_message(b"hi").await.expect("send"); + + let groups = amal_a.prepare_groups_to_sync().await.unwrap(); + let messages = amal_a.prepare_messages_to_sync().await.unwrap(); + + let temp_file = NamedTempFile::new().expect("Unable to create temp file"); + let wrote_groups = write_to_file(temp_file.path(), groups); + assert!(wrote_groups.is_ok()); + let wrote_messages = write_to_file(temp_file.path(), messages); + assert!(wrote_messages.is_ok()); + + let file = File::open(temp_file.path()).expect("Unable to open test file"); + let reader = BufReader::new(file); + let n_lines_written = reader.lines().count(); + assert_eq!(n_lines_written, 6); + + std::fs::remove_file(temp_file).expect("Unable to remove test file"); + } + + #[test] + fn test_encrypt_decrypt_file() { + let key = HistoryKeyType::new_chacha20_poly1305_key(); + let converted_key: MessageHistoryKeyType = key.into(); + let key_bytes = key.as_bytes(); + let input_content = b"'{\"test\": \"data\"}\n{\"test\": \"data2\"}\n'"; + let input_file = NamedTempFile::new().expect("Unable to create temp file"); + let encrypted_file = NamedTempFile::new().expect("Unable to create temp file"); + let decrypted_file = NamedTempFile::new().expect("Unable to create temp file"); + + // Write test input file + std::fs::write(input_file.path(), input_content).expect("Unable to write test input file"); + + // Encrypt the file + encrypt_history_file(input_file.path(), encrypted_file.path(), key_bytes) + .expect("Encryption failed"); + + // Decrypt the file + decrypt_history_file(encrypted_file.path(), decrypted_file.path(), converted_key) + .expect("Decryption failed"); + + // Read the decrypted file content + let decrypted_content = + std::fs::read(decrypted_file.path()).expect("Unable to read decrypted file"); + + // Assert the decrypted content is the same as the original input content + assert_eq!(decrypted_content, input_content); + + // Clean up test files + std::fs::remove_file(input_file).expect("Unable to remove test input file"); + std::fs::remove_file(encrypted_file).expect("Unable to remove test encrypted file"); + std::fs::remove_file(decrypted_file).expect("Unable to remove test decrypted file"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_upload_history_bundle() { + let options = mockito::ServerOpts { + host: HISTORY_SERVER_HOST, + port: HISTORY_SERVER_PORT + 1, + ..Default::default() + }; + let mut server = mockito::Server::new_with_opts_async(options).await; + + let _m = server + .mock("POST", "/upload") + .with_status(201) + .with_body("File uploaded") + .create(); + + let file_content = b"'{\"test\": \"data\"}\n{\"test\": \"data2\"}\n'"; + + let mut file = NamedTempFile::new().unwrap(); + file.write_all(file_content).unwrap(); + let file_path = file.path().to_str().unwrap().to_string(); + + let url = format!( + "http://{}:{}/upload", + HISTORY_SERVER_HOST, + HISTORY_SERVER_PORT + 1 + ); + let result = upload_history_bundle(&url, file_path.into()).await; + + assert!(result.is_ok()); + _m.assert_async().await; + server.reset(); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_download_history_bundle() { + let bundle_id = "test_bundle_id"; + let options = mockito::ServerOpts { + host: HISTORY_SERVER_HOST, + port: HISTORY_SERVER_PORT, + ..Default::default() + }; + let mut server = mockito::Server::new_with_opts_async(options).await; + + let _m = server + .mock("GET", format!("/files/{}", bundle_id).as_str()) + .with_status(200) + .with_body("encrypted_content") + .create(); + + let url = format!( + "http://{}:{}/files/{bundle_id}", + HISTORY_SERVER_HOST, HISTORY_SERVER_PORT + ); + let output_path = download_history_bundle(&url) + .await + .expect("could not download history bundle"); + + _m.assert_async().await; + std::fs::remove_file(output_path.as_path()).expect("Unable to remove test output file"); + server.reset(); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_prepare_history_reply() { + let wallet = generate_local_wallet(); + let mut amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal_b.enable_history_sync().await); + + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + let request_id = new_request_id(); + + let port = HISTORY_SERVER_PORT + 2; + let options = mockito::ServerOpts { + host: HISTORY_SERVER_HOST, + port, + ..Default::default() + }; + let mut server = mockito::Server::new_with_opts_async(options).await; + + let url = format!("http://{HISTORY_SERVER_HOST}:{port}/"); + let _m = server + .mock("POST", "/upload") + .with_status(201) + .with_body("encrypted_content") + .create(); + + amal_a.history_sync_url = Some(url); + let reply = amal_a.prepare_history_reply(&request_id).await; + assert!(reply.is_ok()); + _m.assert_async().await; + server.reset(); + } + + #[tokio::test] + async fn test_get_pending_history_request() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + + // enable history sync for the client + assert_ok!(amal_a.enable_history_sync().await); + + // ensure there's no pending request initially + let initial_request = amal_a.get_pending_history_request().await; + assert!(initial_request.is_ok()); + assert!(initial_request.unwrap().is_none()); + + // create a history request + let request = amal_a + .send_history_request() + .await + .expect("history request"); + + // check for the pending request + let pending_request = amal_a.get_pending_history_request().await; + assert!(pending_request.is_ok()); + let pending = pending_request.unwrap(); + assert!(pending.is_some()); + + let (request_id, pin_code) = pending.unwrap(); + assert_eq!(request_id, request.0); + assert_eq!(pin_code, request.1); + } + + #[tokio::test] + async fn test_get_latest_history_reply() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + + // enable history sync for both clients + assert_ok!(amal_a.enable_history_sync().await); + assert_ok!(amal_b.enable_history_sync().await); + + // ensure there's no reply initially + let initial_reply = amal_b.get_latest_history_reply().await; + assert!(initial_reply.is_ok()); + assert!(initial_reply.unwrap().is_none()); + + // amal_b sends a history request + let (request_id, _pin_code) = amal_b + .send_history_request() + .await + .expect("history request"); + + // sync amal_a + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + // amal_a sends a reply + amal_a + .send_history_reply(MessageHistoryReply { + request_id: request_id.clone(), + url: "http://foo/bar".to_string(), + encryption_key: None, + }) + .await + .expect("send reply"); + + // check latest reply for amal_b + let latest_reply = amal_b.get_latest_history_reply().await; + assert!(latest_reply.is_ok()); + let received_reply = latest_reply.unwrap(); + assert!(received_reply.is_some()); + + let received_reply = received_reply.unwrap(); + assert_eq!(received_reply.request_id, request_id); + } + + #[tokio::test] + async fn test_reply_to_history_request() { + let wallet = generate_local_wallet(); + let mut amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + + // enable history sync for both clients + assert_ok!(amal_a.enable_history_sync().await); + assert_ok!(amal_b.enable_history_sync().await); + + // amal_b sends a history request + let (request_id, _pin_code) = amal_b + .send_history_request() + .await + .expect("history request"); + + // sync amal_a + amal_a.sync_welcomes().await.expect("sync_welcomes"); + + // start mock server + let options = mockito::ServerOpts { + host: HISTORY_SERVER_HOST, + port: HISTORY_SERVER_PORT + 3, + ..Default::default() + }; + let mut server = mockito::Server::new_with_opts_async(options).await; + + let _m = server + .mock("POST", "/upload") + .with_status(201) + .with_body("File uploaded") + .create(); + + let url = format!( + "http://{}:{}/", + HISTORY_SERVER_HOST, + HISTORY_SERVER_PORT + 3 + ); + amal_a.history_sync_url = Some(url); + + // amal_a replies to the history request + let reply = amal_a.reply_to_history_request().await; + assert!(reply.is_ok()); + let reply = reply.unwrap(); + + // verify the reply + assert_eq!(reply.request_id, request_id); + assert!(!reply.url.is_empty()); + assert!(reply.encryption_key.is_some()); + + // check if amal_b received the reply + let received_reply = amal_b.get_latest_history_reply().await; + assert!(received_reply.is_ok()); + let received_reply = received_reply.unwrap(); + assert!(received_reply.is_some()); + let received_reply = received_reply.unwrap(); + assert_eq!(received_reply.request_id, request_id); + assert_eq!(received_reply.url, reply.url); + assert_eq!(received_reply.encryption_key, reply.encryption_key); + + _m.assert_async().await; + server.reset(); + } + + #[tokio::test] + async fn test_insert_history_bundle() { + let wallet = generate_local_wallet(); + let amal_a = ClientBuilder::new_test_client(&wallet).await; + let amal_b = ClientBuilder::new_test_client(&wallet).await; + let group_a = amal_a + .create_group(None, GroupMetadataOptions::default()) + .expect("create group"); + + group_a.send_message(b"hi").await.expect("send message"); + + let (bundle_path, enc_key) = amal_a + .write_history_bundle() + .await + .expect("Unable to write history bundle"); + + let output_file = NamedTempFile::new().expect("Unable to create temp file"); + let converted_key: MessageHistoryKeyType = enc_key.into(); + decrypt_history_file(&bundle_path, output_file.path(), converted_key) + .expect("Unable to decrypt history file"); + + let inserted = amal_b.insert_history_bundle(output_file.path()); + assert!(inserted.is_ok()); + } + + #[tokio::test] + async fn test_externals_cant_join_sync_group() { + let wallet = generate_local_wallet(); + let amal = ClientBuilder::new_test_client(&wallet).await; + assert_ok!(amal.enable_history_sync().await); + amal.sync_welcomes().await.expect("sync welcomes"); + + let external_wallet = generate_local_wallet(); + let external_client = ClientBuilder::new_test_client(&external_wallet).await; + assert_ok!(external_client.enable_history_sync().await); + external_client + .sync_welcomes() + .await + .expect("sync welcomes"); + + let amal_sync_groups = amal + .store() + .conn() + .unwrap() + .find_sync_groups() + .expect("find sync groups"); + assert_eq!(amal_sync_groups.len(), 1); + + // try to join amal's sync group + let sync_group_id = amal_sync_groups[0].id.clone(); + let created_at_ns = amal_sync_groups[0].created_at_ns; + + let external_client_group = MlsGroup::new( + external_client.clone(), + sync_group_id.clone(), + created_at_ns, + ); + let result = external_client_group + .add_members(vec![external_wallet.get_address()]) + .await; + assert!(result.is_err()); + } + + #[test] + fn test_new_pin() { + let pin = new_pin(); + assert!(pin.chars().all(|c| c.is_numeric())); + assert_eq!(pin.len(), 4); + } + + #[test] + fn test_new_request_id() { + let request_id = new_request_id(); + assert_eq!(request_id.len(), ENC_KEY_SIZE); + } + + #[test] + fn test_new_key() { + let sig_key = HistoryKeyType::new_chacha20_poly1305_key(); + let enc_key = HistoryKeyType::new_chacha20_poly1305_key(); + assert_eq!(sig_key.len(), ENC_KEY_SIZE); + assert_eq!(enc_key.len(), ENC_KEY_SIZE); + // ensure keys are different (seed isn't reused) + assert_ne!(sig_key, enc_key); + } + + #[test] + fn test_generate_nonce() { + let nonce_1 = generate_nonce(); + let nonce_2 = generate_nonce(); + assert_eq!(nonce_1.len(), NONCE_SIZE); + // ensure nonces are different (seed isn't reused) + assert_ne!(nonce_1, nonce_2); + } +} diff --git a/xmtp_mls/src/groups/scoped_client.rs b/xmtp_mls/src/groups/scoped_client.rs index 10460cffa..be4b814d7 100644 --- a/xmtp_mls/src/groups/scoped_client.rs +++ b/xmtp_mls/src/groups/scoped_client.rs @@ -139,8 +139,8 @@ pub trait ScopedGroupClient: Sized { impl ScopedGroupClient for Client where - ApiClient: XmtpApi + Clone, - Verifier: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + Verifier: SmartContractSignatureVerifier, { type ApiClient = ApiClient; @@ -149,7 +149,7 @@ where } fn context_ref(&self) -> &Arc { - self.context() + Client::::context(self) } fn intents(&self) -> &Arc { diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 51a6a3562..5e1a6adc3 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::oneshot; use xmtp_proto::api_client::trait_impls::XmtpApi; +use xmtp_proto::api_client::XmtpMlsStreams; use super::{extract_message_v1, GroupError, MlsGroup, ScopedGroupClient}; use crate::api::GroupFilter; @@ -115,8 +116,7 @@ impl MlsGroup { ClientError, > where - ScopedClient: Clone, - ::ApiClient: Clone + 'static, + ::ApiClient: XmtpMlsStreams + 'static, { let group_list = HashMap::from([( self.group_id.clone(), @@ -135,8 +135,8 @@ impl MlsGroup { callback: impl FnMut(Result) + Send + 'static, ) -> impl crate::StreamHandle> where - ScopedClient: Clone + 'static, - ::ApiClient: Clone + 'static, + ScopedClient: 'static, + ::ApiClient: XmtpMlsStreams + 'static, { let group_list = HashMap::from([( group_id, @@ -155,8 +155,8 @@ pub(crate) async fn stream_messages( group_id_to_info: Arc, MessagesStreamInfo>>, ) -> Result> + '_, ClientError> where - ScopedClient: ScopedGroupClient + Clone, - ::ApiClient: XmtpApi + Clone + 'static, + ScopedClient: ScopedGroupClient, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'static, { let filters: Vec = group_id_to_info .iter() @@ -179,8 +179,7 @@ where .ok_or(ClientError::StreamInconsistency( "Received message for a non-subscribed group".to_string(), ))?; - let mls_group = - MlsGroup::new(client.clone(), group_id, stream_info.convo_created_at_ns); + let mls_group = MlsGroup::new(client, group_id, stream_info.convo_created_at_ns); mls_group.process_stream_entry(envelope).await } }) @@ -204,8 +203,8 @@ pub(crate) fn stream_messages_with_callback( mut callback: impl FnMut(Result) + Send + 'static, ) -> impl crate::StreamHandle> where - ScopedClient: ScopedGroupClient + Clone + 'static, - ::ApiClient: XmtpApi + Clone + 'static, + ScopedClient: ScopedGroupClient + 'static, + ::ApiClient: XmtpApi + XmtpMlsStreams + 'static, { let (tx, rx) = oneshot::channel(); diff --git a/xmtp_mls/src/identity_updates.rs b/xmtp_mls/src/identity_updates.rs index 0d0202749..76030e20c 100644 --- a/xmtp_mls/src/identity_updates.rs +++ b/xmtp_mls/src/identity_updates.rs @@ -89,8 +89,8 @@ impl DbConnection { impl<'a, ApiClient, V> Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { /// Get the association state for all provided `inbox_id`/optional `sequence_id` tuples, using the cache when available /// If the association state is not available in the cache, this falls back to reconstructing the association state @@ -544,8 +544,8 @@ pub(crate) mod tests { inbox_id: String, ) -> AssociationState where - ApiClient: XmtpApi + Clone, - Verifier: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + Verifier: SmartContractSignatureVerifier, { let conn = client.store().conn().unwrap(); load_identity_updates(&client.api_client, &conn, vec![inbox_id.clone()]) diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 281231278..d6467df35 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -4,7 +4,7 @@ use futures::{FutureExt, Stream, StreamExt}; use prost::Message; use tokio::{sync::oneshot, task::JoinHandle}; use xmtp_id::scw_verifier::SmartContractSignatureVerifier; -use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; +use xmtp_proto::{api_client::XmtpMlsStreams, xmtp::mls::api::v1::WelcomeMessage}; use crate::{ client::{extract_welcome_message, ClientError, MessageProcessingError}, @@ -44,11 +44,11 @@ impl LocalEvents { } } -impl Clone for LocalEvents { +impl Clone for LocalEvents { fn clone(&self) -> LocalEvents { use LocalEvents::*; match self { - NewGroup(c) => NewGroup(c.clone()), + NewGroup(group) => NewGroup(group.clone()), } } } @@ -127,8 +127,8 @@ impl RetryableError for SubscribeError { impl Client where - ApiClient: XmtpApi + Clone + Send + Sync + 'static, - V: SmartContractSignatureVerifier + Clone + Send + Sync + 'static, + ApiClient: XmtpApi + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, { async fn process_streamed_welcome( &self, @@ -189,7 +189,10 @@ where pub async fn stream_conversations( &self, conversation_type: Option, - ) -> Result, SubscribeError>> + '_, ClientError> { + ) -> Result, SubscribeError>> + '_, ClientError> + where + ApiClient: XmtpMlsStreams, + { let event_queue = tokio_stream::wrappers::BroadcastStream::new( self.local_events.subscribe(), ) @@ -239,8 +242,8 @@ where impl Client where - ApiClient: XmtpApi + Clone + Send + Sync + 'static, - V: SmartContractSignatureVerifier + Clone + Send + Sync + 'static, + ApiClient: XmtpApi + XmtpMlsStreams + Send + Sync + 'static, + V: SmartContractSignatureVerifier + Send + Sync + 'static, { pub fn stream_conversations_with_callback( client: Arc>, diff --git a/xmtp_mls/src/utils/test/mod.rs b/xmtp_mls/src/utils/test/mod.rs index 8e9801d5e..1effb52f0 100755 --- a/xmtp_mls/src/utils/test/mod.rs +++ b/xmtp_mls/src/utils/test/mod.rs @@ -12,7 +12,7 @@ use xmtp_id::{ test_utils::MockSmartContractSignatureVerifier, unverified::{UnverifiedRecoverableEcdsaSignature, UnverifiedSignature}, }, - scw_verifier::{RemoteSignatureVerifier, SmartContractSignatureVerifier}, + scw_verifier::SmartContractSignatureVerifier, }; use xmtp_proto::api_client::XmtpTestClient; @@ -26,7 +26,6 @@ use crate::{ #[cfg(not(target_arch = "wasm32"))] pub mod traced_test; - #[cfg(not(target_arch = "wasm32"))] pub use traced_test::traced_test; @@ -103,13 +102,13 @@ impl ClientBuilder { ) } } - impl ClientBuilder { pub async fn new_test_client(owner: &impl InboxOwner) -> FullXmtpClient { let api_client = ::create_local().await; - inner_build( + + build_with_verifier( owner, - &api_client, + api_client, MockSmartContractSignatureVerifier::new(true), ) .await @@ -120,9 +119,10 @@ impl ClientBuilder { owner: impl InboxOwner, ) -> Client { let api_client = ::create_dev().await; - inner_build( + + build_with_verifier( owner, - &api_client, + api_client, MockSmartContractSignatureVerifier::new(true), ) .await @@ -133,22 +133,12 @@ impl ClientBuilder { /// Create a client pointed at the local container with the default remote verifier pub async fn new_local_client(owner: &impl InboxOwner) -> Client { let api_client = ::create_local().await; - inner_build( - owner, - &api_client, - RemoteSignatureVerifier::new(api_client.clone()), - ) - .await + inner_build(owner, api_client).await } pub async fn new_dev_client(owner: &impl InboxOwner) -> Client { let api_client = ::create_dev().await; - inner_build( - owner, - &api_client, - RemoteSignatureVerifier::new(api_client.clone()), - ) - .await + inner_build(owner, api_client).await } /// Add the local client to this builder @@ -161,10 +151,41 @@ impl ClientBuilder { } } -async fn inner_build(owner: impl InboxOwner, api_client: &A, scw_verifier: V) -> Client +async fn inner_build(owner: impl InboxOwner, api_client: A) -> Client where - A: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + A: XmtpApi, +{ + let nonce = 1; + let inbox_id = generate_inbox_id(&owner.get_address(), &nonce); + + let client = Client::::builder(IdentityStrategy::CreateIfNotFound( + inbox_id, + owner.get_address(), + nonce, + None, + )); + + let client = client + .temp_store() + .await + .api_client(api_client) + .build() + .await + .unwrap(); + + register_client(&client, owner).await; + + client +} + +async fn build_with_verifier( + owner: impl InboxOwner, + api_client: A, + scw_verifier: V, +) -> Client +where + A: XmtpApi, + V: SmartContractSignatureVerifier, { let nonce = 1; let inbox_id = generate_inbox_id(&owner.get_address(), &nonce); @@ -174,19 +195,22 @@ where owner.get_address(), nonce, None, - )) - .temp_store() - .await - .api_client(api_client.clone()) - .scw_signature_verifier(scw_verifier) - .build_with_verifier() - .await - .unwrap(); + )); + + let client = client + .temp_store() + .await + .api_client(api_client) + .scw_signature_verifier(scw_verifier) + .build_with_verifier() + .await + .unwrap(); register_client(&client, owner).await; client } + /// wrapper over a `Notify` with a 60-scond timeout for waiting #[derive(Clone, Default)] pub struct Delivery { @@ -214,8 +238,8 @@ impl Delivery { impl Client where - ApiClient: XmtpApi + Clone, - V: SmartContractSignatureVerifier + Clone, + ApiClient: XmtpApi, + V: SmartContractSignatureVerifier, { pub async fn is_registered(&self, address: &String) -> bool { let ids = self @@ -227,7 +251,7 @@ where } } -pub async fn register_client( +pub async fn register_client( client: &Client, owner: impl InboxOwner, ) { diff --git a/xmtp_proto/Cargo.toml b/xmtp_proto/Cargo.toml index 18fc5f49b..21ce62bf9 100644 --- a/xmtp_proto/Cargo.toml +++ b/xmtp_proto/Cargo.toml @@ -10,7 +10,7 @@ pbjson-types.workspace = true pbjson.workspace = true prost = { workspace = true, features = ["prost-derive"] } serde = { workspace = true } -trait-variant = "0.1.2" +async-trait = "0.1" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tonic = { workspace = true } @@ -40,4 +40,4 @@ proto_full = ["xmtp-identity","xmtp-identity-api-v1","xmtp-identity-associations "xmtp-xmtpv4-envelopes" = ["xmtp-identity-associations","xmtp-mls-api-v1"] "xmtp-xmtpv4-message_api" = ["xmtp-xmtpv4-envelopes"] "xmtp-xmtpv4-payer_api" = ["xmtp-xmtpv4-envelopes"] -## @@protoc_insertion_point(features) \ No newline at end of file +## @@protoc_insertion_point(features) diff --git a/xmtp_proto/src/api_client.rs b/xmtp_proto/src/api_client.rs index 796fd1d41..9ae07b664 100644 --- a/xmtp_proto/src/api_client.rs +++ b/xmtp_proto/src/api_client.rs @@ -20,8 +20,9 @@ use crate::xmtp::mls::api::v1::{ }; #[cfg(any(test, feature = "test-utils"))] -#[trait_variant::make(XmtpTestClient: Send)] -pub trait LocalXmtpTestClient { +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +pub trait XmtpTestClient { async fn create_local() -> Self; async fn create_dev() -> Self; } @@ -31,34 +32,21 @@ pub trait LocalXmtpTestClient { pub mod trait_impls { #[allow(unused)] #[cfg(any(test, feature = "test-utils"))] - use super::{LocalXmtpTestClient, XmtpTestClient}; + use super::XmtpTestClient; pub use inner::*; // native, release #[cfg(all(not(feature = "test-utils"), not(target_arch = "wasm32")))] mod inner { - use crate::api_client::{ - ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams, - }; + use crate::api_client::{ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient}; pub trait XmtpApi where - Self: XmtpMlsClient - + XmtpMlsStreams - + XmtpIdentityClient - + ClientWithMetadata - + Send - + Sync, + Self: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + Send + Sync, { } impl XmtpApi for T where - T: XmtpMlsClient - + XmtpMlsStreams - + XmtpIdentityClient - + ClientWithMetadata - + Send - + Sync - + ?Sized + T: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + Send + Sync + ?Sized { } } @@ -68,54 +56,29 @@ pub mod trait_impls { mod inner { use crate::api_client::{ - ClientWithMetadata, LocalXmtpIdentityClient, LocalXmtpMlsClient, LocalXmtpMlsStreams, + ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams, }; pub trait XmtpApi where - Self: LocalXmtpMlsClient - + LocalXmtpMlsStreams - + LocalXmtpIdentityClient - + ClientWithMetadata, + Self: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata, { } - impl XmtpApi for T where - T: LocalXmtpMlsClient - + LocalXmtpMlsStreams - + LocalXmtpIdentityClient - + ClientWithMetadata - + ?Sized - { - } + impl XmtpApi for T where T: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + ?Sized {} } // test, native #[cfg(all(feature = "test-utils", not(target_arch = "wasm32")))] mod inner { - use crate::api_client::{ - ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient, XmtpMlsStreams, - }; + use crate::api_client::{ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient}; pub trait XmtpApi where - Self: XmtpMlsClient - + XmtpMlsStreams - + XmtpIdentityClient - + super::XmtpTestClient - + ClientWithMetadata - + Send - + Sync, + Self: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + Send + Sync, { } impl XmtpApi for T where - T: XmtpMlsClient - + XmtpMlsStreams - + XmtpIdentityClient - + super::XmtpTestClient - + ClientWithMetadata - + Send - + Sync - + ?Sized + T: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + Send + Sync + ?Sized { } } @@ -123,29 +86,16 @@ pub mod trait_impls { // test, wasm32 #[cfg(all(feature = "test-utils", target_arch = "wasm32"))] mod inner { - use crate::api_client::{ - ClientWithMetadata, LocalXmtpIdentityClient, LocalXmtpMlsClient, LocalXmtpMlsStreams, - }; + use crate::api_client::{ClientWithMetadata, XmtpIdentityClient, XmtpMlsClient}; pub trait XmtpApi where - Self: LocalXmtpMlsClient - + LocalXmtpMlsStreams - + LocalXmtpIdentityClient - + super::LocalXmtpTestClient - + ClientWithMetadata, + Self: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata, { } impl XmtpApi for T where - T: LocalXmtpMlsClient - + LocalXmtpMlsStreams - + LocalXmtpIdentityClient - + super::LocalXmtpTestClient - + ClientWithMetadata - + Send - + Sync - + ?Sized + T: XmtpMlsClient + XmtpIdentityClient + ClientWithMetadata + Send + Sync + ?Sized { } } @@ -236,7 +186,8 @@ pub trait XmtpApiSubscription { fn close_stream(&mut self); } -#[allow(async_fn_in_trait)] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] pub trait MutableApiSubscription: Stream> + Send { async fn update(&mut self, req: SubscribeRequest) -> Result<(), Error>; fn close(&self); @@ -247,6 +198,19 @@ pub trait ClientWithMetadata { fn set_app_version(&mut self, version: String) -> Result<(), Error>; } +impl ClientWithMetadata for Box +where + T: ClientWithMetadata + ?Sized, +{ + fn set_libxmtp_version(&mut self, version: String) -> Result<(), Error> { + (**self).set_libxmtp_version(version) + } + + fn set_app_version(&mut self, version: String) -> Result<(), Error> { + (**self).set_app_version(version) + } +} + /// Global Marker trait for WebAssembly #[cfg(target_arch = "wasm32")] pub trait Wasm {} @@ -254,10 +218,9 @@ pub trait Wasm {} impl Wasm for T {} // Wasm futures don't have `Send` or `Sync` bounds. -#[allow(async_fn_in_trait)] -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpApiClient: Send))] -#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpApiClient: Wasm))] -pub trait LocalXmtpApiClient { +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +pub trait XmtpApiClient { type Subscription: XmtpApiSubscription; type MutableSubscription: MutableApiSubscription; @@ -279,11 +242,49 @@ pub trait LocalXmtpApiClient { async fn batch_query(&self, request: BatchQueryRequest) -> Result; } +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +impl XmtpApiClient for Box +where + T: XmtpApiClient + Sync + ?Sized, +{ + type Subscription = ::Subscription; + + type MutableSubscription = ::MutableSubscription; + + async fn publish( + &self, + token: String, + request: PublishRequest, + ) -> Result { + (**self).publish(token, request).await + } + + async fn subscribe(&self, request: SubscribeRequest) -> Result { + (**self).subscribe(request).await + } + + async fn subscribe2( + &self, + request: SubscribeRequest, + ) -> Result { + (**self).subscribe2(request).await + } + + async fn query(&self, request: QueryRequest) -> Result { + (**self).query(request).await + } + + async fn batch_query(&self, request: BatchQueryRequest) -> Result { + (**self).batch_query(request).await + } +} + // Wasm futures don't have `Send` or `Sync` bounds. #[allow(async_fn_in_trait)] -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpMlsClient: Send))] -#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpMlsClient: Wasm))] -pub trait LocalXmtpMlsClient { +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +pub trait XmtpMlsClient { async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error>; async fn fetch_key_packages( &self, @@ -302,9 +303,73 @@ pub trait LocalXmtpMlsClient { ) -> Result; } -#[allow(async_fn_in_trait)] -#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpMlsStreams: Wasm))] -pub trait LocalXmtpMlsStreams { +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +impl XmtpMlsClient for Box +where + T: XmtpMlsClient + Sync + ?Sized, +{ + async fn upload_key_package(&self, request: UploadKeyPackageRequest) -> Result<(), Error> { + (**self).upload_key_package(request).await + } + + async fn fetch_key_packages( + &self, + request: FetchKeyPackagesRequest, + ) -> Result { + (**self).fetch_key_packages(request).await + } + + async fn send_group_messages(&self, request: SendGroupMessagesRequest) -> Result<(), Error> { + (**self).send_group_messages(request).await + } + + async fn send_welcome_messages( + &self, + request: SendWelcomeMessagesRequest, + ) -> Result<(), Error> { + (**self).send_welcome_messages(request).await + } + + async fn query_group_messages( + &self, + request: QueryGroupMessagesRequest, + ) -> Result { + (**self).query_group_messages(request).await + } + + async fn query_welcome_messages( + &self, + request: QueryWelcomeMessagesRequest, + ) -> Result { + (**self).query_welcome_messages(request).await + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +pub trait XmtpMlsStreams { + type GroupMessageStream<'a>: Stream> + Send + 'a + where + Self: 'a; + + type WelcomeMessageStream<'a>: Stream> + Send + 'a + where + Self: 'a; + + async fn subscribe_group_messages( + &self, + request: SubscribeGroupMessagesRequest, + ) -> Result, Error>; + async fn subscribe_welcome_messages( + &self, + request: SubscribeWelcomeMessagesRequest, + ) -> Result, Error>; +} + +#[cfg(target_arch = "wasm32")] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +pub trait XmtpMlsStreams { type GroupMessageStream<'a>: Stream> + 'a where Self: 'a; @@ -323,34 +388,38 @@ pub trait LocalXmtpMlsStreams { ) -> Result, Error>; } -// we manually make a Local+Non-Local trait variant here b/c the -// macro breaks with GATs -#[allow(async_fn_in_trait)] -#[cfg(not(target_arch = "wasm32"))] -pub trait XmtpMlsStreams: Send { - type GroupMessageStream<'a>: Stream> + Send + 'a +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +impl XmtpMlsStreams for Box +where + T: XmtpMlsStreams + Sync + ?Sized, +{ + type GroupMessageStream<'a> = ::GroupMessageStream<'a> where Self: 'a; - type WelcomeMessageStream<'a>: Stream> + Send + 'a + type WelcomeMessageStream<'a> = ::WelcomeMessageStream<'a> where Self: 'a; - fn subscribe_group_messages( + async fn subscribe_group_messages( &self, request: SubscribeGroupMessagesRequest, - ) -> impl futures::Future, Error>> + Send; + ) -> Result, Error> { + (**self).subscribe_group_messages(request).await + } - fn subscribe_welcome_messages( + async fn subscribe_welcome_messages( &self, request: SubscribeWelcomeMessagesRequest, - ) -> impl futures::Future, Error>> + Send; + ) -> Result, Error> { + (**self).subscribe_welcome_messages(request).await + } } -#[allow(async_fn_in_trait)] -#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(XmtpIdentityClient: Send))] -#[cfg_attr(target_arch = "wasm32", trait_variant::make(XmtpIdentityClient: Wasm))] -pub trait LocalXmtpIdentityClient { +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +pub trait XmtpIdentityClient { async fn publish_identity_update( &self, request: PublishIdentityUpdateRequest, @@ -371,3 +440,40 @@ pub trait LocalXmtpIdentityClient { request: VerifySmartContractWalletSignaturesRequest, ) -> Result; } + +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +impl XmtpIdentityClient for Box +where + T: XmtpIdentityClient + Send + Sync + ?Sized, +{ + async fn publish_identity_update( + &self, + request: PublishIdentityUpdateRequest, + ) -> Result { + (**self).publish_identity_update(request).await + } + + async fn get_identity_updates_v2( + &self, + request: GetIdentityUpdatesV2Request, + ) -> Result { + (**self).get_identity_updates_v2(request).await + } + + async fn get_inbox_ids( + &self, + request: GetInboxIdsRequest, + ) -> Result { + (**self).get_inbox_ids(request).await + } + + async fn verify_smart_contract_wallet_signatures( + &self, + request: VerifySmartContractWalletSignaturesRequest, + ) -> Result { + (**self) + .verify_smart_contract_wallet_signatures(request) + .await + } +}