From 1fc724e580a2aa4f2b3c69d7d151cbdb6a001611 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Thu, 9 Nov 2023 21:20:20 -0500 Subject: [PATCH] proto scaffolding for just using a Mutex on SqliteConnection --- xmtp_mls/Cargo.toml | 2 +- xmtp_mls/src/storage/connection.rs | 108 +++++++++++++++++++++++++++++ xmtp_mls/src/storage/mod.rs | 1 + 3 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 xmtp_mls/src/storage/connection.rs diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index f25854025..2278e21d2 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -19,7 +19,7 @@ native = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl"] [dependencies] anyhow = "1.0.71" async-trait = "0.1.68" -diesel = { version = "2.1.3", features = ["sqlite", "r2d2", "returning_clauses_for_sqlite_3_35"] } +diesel = { version = "2.1.3", features = ["sqlite", "r2d2", "returning_clauses_for_sqlite_3_35", "i-implement-a-third-party-backend-and-opt-into-breaking-changes"] } diesel_migrations = { version = "2.1.0", features = ["sqlite"] } ethers = "2.0.4" ethers-core = "2.0.4" diff --git a/xmtp_mls/src/storage/connection.rs b/xmtp_mls/src/storage/connection.rs new file mode 100644 index 000000000..2c5148af4 --- /dev/null +++ b/xmtp_mls/src/storage/connection.rs @@ -0,0 +1,108 @@ +//! An SqliteConnection wrapped in a Arc/Mutex to make it Sync + +use std::sync::{Arc, Mutex}; + +use diesel::{ + associations::HasTable, + connection::{ + AnsiTransactionManager, ConnectionSealed, DefaultLoadingMode, LoadConnection, + SimpleConnection, TransactionManager, + }, + expression::{is_aggregate, MixedAggregates, ValidGrouping}, + helper_types::{Find, Update}, + prelude::{Connection, Identifiable, SqliteConnection}, + query_builder::{AsChangeset, IntoUpdateTarget, QueryFragment, QueryId}, + query_dsl::{ + methods::{ExecuteDsl, FindDsl, LoadQuery}, + UpdateAndFetchResults, + }, + r2d2::R2D2Connection, + sqlite::Sqlite, + ConnectionResult, QueryResult, Table, +}; + +struct SyncSqliteConnection { + inner: Arc>, +} + +/// This is safe because all operations happen through Arc> +unsafe impl Sync for SyncSqliteConnection {} + +impl Connection for SyncSqliteConnection { + type Backend = Sqlite; + type TransactionManager = AnsiTransactionManager; + + fn establish(database_url: &str) -> ConnectionResult { + todo!() + } + + fn execute_returning_count(&mut self, source: &T) -> QueryResult + where + T: QueryFragment + QueryId, + { + todo!() + } + + fn transaction_state( + &mut self, + ) -> &mut >::TransactionStateData { + todo!() + } +} + +impl ConnectionSealed for SyncSqliteConnection {} + +impl SimpleConnection for SyncSqliteConnection { + fn batch_execute(&mut self, query: &str) -> QueryResult<()> { + todo!() + } +} + +impl From for SyncSqliteConnection { + fn from(connection: SqliteConnection) -> Self { + Self { + inner: Arc::new(Mutex::new(connection)), + } + } +} + +impl R2D2Connection for SyncSqliteConnection { + fn ping(&mut self) -> QueryResult<()> { + let mut conn = self.inner.lock().unwrap(); + (*conn).ping() + } + + fn is_broken(&mut self) -> bool { + let mut conn = self.inner.lock().unwrap(); + (*conn).is_broken() + } +} + +impl<'b, Changes, Output> UpdateAndFetchResults for SyncSqliteConnection +where + Changes: Copy + Identifiable, + Changes: AsChangeset::Table> + IntoUpdateTarget, + Changes::Table: FindDsl, + Update: ExecuteDsl, + Find: LoadQuery<'b, SqliteConnection, Output>, + ::AllColumns: ValidGrouping<()>, + <::AllColumns as ValidGrouping<()>>::IsAggregate: + MixedAggregates, +{ + fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult { + let mut conn = self.inner.lock().unwrap(); + (*conn).update_and_fetch(changeset) + } +} + +/* +impl LoadConnection for SyncSqliteConnection { + type Cursor<'conn, 'query> + where + Self: 'conn; + + type Row<'conn, 'query> + where + Self: 'conn; +} +*/ diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 85d7fc25a..4412f2940 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -1,3 +1,4 @@ +mod connection; mod encrypted_store; mod errors; mod serialization;