From d7e509c67bf41357e46790c3bad6d10cc50c765a Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Thu, 15 Aug 2024 11:35:44 -0400 Subject: [PATCH] Statement Stream for AsyncConnection::Load (#961) * statement stream * fix lifetimes on load query * execute returning count * transaction manager * cleanup --- diesel-wasm-sqlite/.vscode/settings.json | 3 +- diesel-wasm-sqlite/Cargo.lock | 2 +- diesel-wasm-sqlite/src/backend.rs | 2 +- .../src/connection/bind_collector.rs | 4 +- diesel-wasm-sqlite/src/connection/mod.rs | 138 ++++++------ diesel-wasm-sqlite/src/connection/raw.rs | 5 +- diesel-wasm-sqlite/src/connection/row.rs | 30 +-- .../src/connection/sqlite_value.rs | 69 +++--- .../src/connection/statement_iterator.rs | 172 --------------- .../src/connection/statement_stream.rs | 199 ++++++++++++++++++ diesel-wasm-sqlite/src/connection/stmt.rs | 88 +++----- diesel-wasm-sqlite/src/query_builder/mod.rs | 2 +- 12 files changed, 355 insertions(+), 359 deletions(-) mode change 100644 => 100755 diesel-wasm-sqlite/src/connection/raw.rs delete mode 100644 diesel-wasm-sqlite/src/connection/statement_iterator.rs create mode 100644 diesel-wasm-sqlite/src/connection/statement_stream.rs diff --git a/diesel-wasm-sqlite/.vscode/settings.json b/diesel-wasm-sqlite/.vscode/settings.json index cb6099b62..ae0d8b87d 100644 --- a/diesel-wasm-sqlite/.vscode/settings.json +++ b/diesel-wasm-sqlite/.vscode/settings.json @@ -13,7 +13,8 @@ "napi-derive": ["napi"], "async-recursion": ["async_recursion"], "ctor": ["ctor"], - "tokio": ["test"] + "tokio": ["test"], + "diesel": ["table"], } } }, diff --git a/diesel-wasm-sqlite/Cargo.lock b/diesel-wasm-sqlite/Cargo.lock index aadc60362..f179de053 100644 --- a/diesel-wasm-sqlite/Cargo.lock +++ b/diesel-wasm-sqlite/Cargo.lock @@ -142,7 +142,7 @@ dependencies = [ [[package]] name = "diesel-async" version = "0.5.0" -source = "git+https://github.com/insipx/diesel_async?branch=insipx/make-stmt-cache-public#86a24a38d9d841ef9e92022cd983bbd700286397" +source = "git+https://github.com/insipx/diesel_async?branch=insipx/make-stmt-cache-public#f1c4838ae6d7951b78572c249ba65f7a107488a0" dependencies = [ "async-trait", "diesel", diff --git a/diesel-wasm-sqlite/src/backend.rs b/diesel-wasm-sqlite/src/backend.rs index 4891c37e0..05ecf0950 100644 --- a/diesel-wasm-sqlite/src/backend.rs +++ b/diesel-wasm-sqlite/src/backend.rs @@ -39,7 +39,7 @@ pub enum SqliteType { impl Backend for WasmSqlite { type QueryBuilder = SqliteQueryBuilder; - type RawValue<'a> = SqliteValue<'a, 'a, 'a>; + type RawValue<'a> = SqliteValue<'a, 'a>; type BindCollector<'a> = SqliteBindCollector<'a>; } diff --git a/diesel-wasm-sqlite/src/connection/bind_collector.rs b/diesel-wasm-sqlite/src/connection/bind_collector.rs index 6eef85140..dc51e115a 100644 --- a/diesel-wasm-sqlite/src/connection/bind_collector.rs +++ b/diesel-wasm-sqlite/src/connection/bind_collector.rs @@ -185,7 +185,7 @@ impl<'a> BindCollector<'a, WasmSqlite> for SqliteBindCollector<'a> { #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] -enum OwnedSqliteBindValue { +pub enum OwnedSqliteBindValue { String(Box), Binary(Box<[u8]>), I32(i32), @@ -229,7 +229,7 @@ impl<'a> std::convert::From<&OwnedSqliteBindValue> for InternalSqliteBindValue<' #[derive(Debug)] /// Sqlite bind collector data that is movable across threads pub struct SqliteBindCollectorData { - binds: Vec<(OwnedSqliteBindValue, SqliteType)>, + pub binds: Vec<(OwnedSqliteBindValue, SqliteType)>, } impl MoveableBindCollector for SqliteBindCollector<'_> { diff --git a/diesel-wasm-sqlite/src/connection/mod.rs b/diesel-wasm-sqlite/src/connection/mod.rs index fac2bbe8b..2694eadfe 100644 --- a/diesel-wasm-sqlite/src/connection/mod.rs +++ b/diesel-wasm-sqlite/src/connection/mod.rs @@ -5,7 +5,7 @@ mod raw; mod row; // mod serialized_database; mod sqlite_value; -// mod statement_iterator; +mod statement_stream; mod stmt; pub(crate) use self::bind_collector::SqliteBindCollector; @@ -17,15 +17,19 @@ use self::raw::RawConnection; // use self::statement_iterator::*; use self::stmt::{Statement, StatementUse}; use crate::query_builder::*; -use diesel::{connection::{statement_cache::StatementCacheKey, DefaultLoadingMode, LoadConnection}, deserialize::{FromSqlRow, StaticallySizedRow}, expression::QueryMetadata, query_builder::QueryBuilder as _, result::*, serialize::ToSql, sql_types::HasSqlType}; -use futures::{FutureExt, TryFutureExt}; +use diesel::query_builder::MoveableBindCollector; +use diesel::{connection::{statement_cache::StatementCacheKey}, query_builder::QueryBuilder as _, result::*}; +use futures::future::LocalBoxFuture; +use futures::stream::LocalBoxStream; +use futures::FutureExt; +use owned_row::OwnedSqliteRow; +use statement_stream::StatementStream; +use std::future::Future; use std::sync::{Arc, Mutex}; -use diesel::{connection::{ConnectionSealed, Instrumentation, WithMetadataLookup}, query_builder::{AsQuery, QueryFragment, QueryId}, sql_types::TypeMetadata, QueryResult}; +use diesel::{connection::{ConnectionSealed, Instrumentation}, query_builder::{AsQuery, QueryFragment, QueryId}, QueryResult}; pub use diesel_async::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection, TransactionManager, stmt_cache::StmtCache}; -use futures::{future::BoxFuture, stream::BoxStream}; -use row::SqliteRow; use crate::{get_sqlite_unchecked, WasmSqlite, WasmSqliteError}; @@ -35,10 +39,9 @@ pub struct WasmSqliteConnection { // connection itself statement_cache: StmtCache, pub raw_connection: RawConnection, - transaction_state: AnsiTransactionManager, + transaction_manager: AnsiTransactionManager, // this exists for the sole purpose of implementing `WithMetadataLookup` trait // and avoiding static mut which will be deprecated in 2024 edition - metadata_lookup: (), instrumentation: Arc>>>, } @@ -69,37 +72,44 @@ impl SimpleAsyncConnection for WasmSqliteConnection { impl AsyncConnection for WasmSqliteConnection { type Backend = WasmSqlite; type TransactionManager = AnsiTransactionManager; - type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult>; - type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult>>; - type Stream<'conn, 'query> = BoxStream<'static, QueryResult>>; - type Row<'conn, 'query> = SqliteRow<'conn, 'query>; + type ExecuteFuture<'conn, 'query> = LocalBoxFuture<'conn, QueryResult>; + type LoadFuture<'conn, 'query> = LocalBoxFuture<'conn, QueryResult>>; + type Stream<'conn, 'query> = LocalBoxStream<'conn, QueryResult>>; + type Row<'conn, 'query> = OwnedSqliteRow; async fn establish(database_url: &str) -> diesel::prelude::ConnectionResult { WasmSqliteConnection::establish_inner(database_url).await } - fn load<'conn, 'query, T>(&'conn mut self, _source: T) -> Self::LoadFuture<'conn, 'query> + fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query> where - T: AsQuery + 'query, - T::Query: QueryFragment + QueryId + 'query, + T: AsQuery, + T::Query: QueryFragment + QueryId, { - todo!() + let query = source.as_query(); + self.with_prepared_statement(query, |_, statement| async move { + Ok(StatementStream::new(statement).stream()) + }) } fn execute_returning_count<'conn, 'query, T>( &'conn mut self, - _source: T, + query: T, ) -> Self::ExecuteFuture<'conn, 'query> where T: QueryFragment + QueryId + 'query, { - todo!() + self.with_prepared_statement(query, |conn, statement| async move { + statement.run().await.map(|_| { + conn.rows_affected_by_last_query() + }) + }) } fn transaction_state( &mut self, ) -> &mut >::TransactionStateData{ - todo!() + &mut self.transaction_manager } fn instrumentation(&mut self) -> &mut dyn Instrumentation { @@ -111,32 +121,7 @@ impl AsyncConnection for WasmSqliteConnection { } } -/* -impl LoadConnection for WasmSqliteConnection { - type Cursor<'conn, 'query> = StatementIterator<'conn, 'query>; - type Row<'conn, 'query> = self::row::SqliteRow<'conn, 'query>; - - fn load<'conn, 'query, T>( - &'conn mut self, - source: T, - ) -> QueryResult> - where - T: Query + QueryFragment + QueryId + 'query, - Self::Backend: QueryMetadata, - { - let statement = self.prepared_query(source)?; - Ok(StatementIterator::new(statement)) - } -} -*/ -/* -impl WithMetadataLookup for WasmSqliteConnection { - fn metadata_lookup(&mut self) -> &mut ::MetadataLookup { - &mut self.metadata_lookup - } -} - */ #[cfg(feature = "r2d2")] impl crate::r2d2::R2D2Connection for crate::sqlite::SqliteConnection { @@ -243,39 +228,53 @@ impl WasmSqliteConnection { } } - async fn prepared_query<'conn, 'query, T>( + fn with_prepared_statement<'conn, Q, F, R>( &'conn mut self, - source: T, - ) -> QueryResult> + query: Q, + callback: impl (FnOnce(&'conn mut RawConnection, StatementUse<'conn>) -> F) + 'conn + ) -> LocalBoxFuture<'_, QueryResult> where - T: QueryFragment + QueryId + 'query, + Q: QueryFragment + QueryId, + F: Future>, { - let raw_connection = &self.raw_connection; - let cache = &mut self.statement_cache; - let maybe_type_id = T::query_id(); - let cache_key = StatementCacheKey::for_source(maybe_type_id, &source, &[], &WasmSqlite)?; - - - let is_safe_to_cache_prepared = source.is_safe_to_cache_prepared(&WasmSqlite)?; - let mut qb = SqliteQueryBuilder::new(); - let sql = source.to_sql(&mut qb, &WasmSqlite).map(|()| qb.finish())?; + let WasmSqliteConnection { + ref mut raw_connection, + ref mut statement_cache, + ref mut instrumentation, + .. + } = self; - let statement = cache.cached_prepared_statement( - cache_key, - sql, - is_safe_to_cache_prepared, - &[], - raw_connection.clone(), - &self.instrumentation, - ).await?.0; // Cloned RawConnection is dropped here + let maybe_type_id = Q::query_id(); + let instrumentation = instrumentation.clone(); - - Ok(StatementUse::bind(statement, source, self.instrumentation.as_ref())?) + let cache_key = StatementCacheKey::for_source(maybe_type_id, &query, &[], &WasmSqlite); + let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&WasmSqlite); + + // C put this in box to avoid virtual fn call for SQLite C + // not sure if that still applies here + let query = Box::new(query); + let mut bind_collector = SqliteBindCollector::new(); + let bind_collector = query.collect_binds(&mut bind_collector, &mut (), &WasmSqlite).map(|()| bind_collector.moveable()); + + let mut qb = SqliteQueryBuilder::new(); + let sql = query.to_sql(&mut qb, &WasmSqlite).map(|()| qb.finish()); + async move { + let (statement, conn) = statement_cache.cached_prepared_statement( + cache_key?, + sql?, + is_safe_to_cache_prepared?, + &[], + raw_connection, + &instrumentation, + ).await?; // Cloned RawConnection is dropped here + let statement = StatementUse::bind(statement, bind_collector?, instrumentation)?; + callback(conn, statement).await + }.boxed_local() } async fn establish_inner(database_url: &str) -> Result { - use diesel::result::ConnectionError::CouldntSetupConfiguration; + // use diesel::result::ConnectionError::CouldntSetupConfiguration; let raw_connection = RawConnection::establish(database_url).await.unwrap(); let sqlite3 = crate::get_sqlite().await; @@ -284,8 +283,7 @@ impl WasmSqliteConnection { Ok(Self { statement_cache: StmtCache::new(), raw_connection, - transaction_state: AnsiTransactionManager::default(), - metadata_lookup: (), + transaction_manager: AnsiTransactionManager::default(), instrumentation: Arc::new(Mutex::new(None)), }) } diff --git a/diesel-wasm-sqlite/src/connection/raw.rs b/diesel-wasm-sqlite/src/connection/raw.rs old mode 100644 new mode 100755 index dd11abf25..65d3cb4bc --- a/diesel-wasm-sqlite/src/connection/raw.rs +++ b/diesel-wasm-sqlite/src/connection/raw.rs @@ -1,3 +1,6 @@ +#![allow(dead_code)] +// functions are needed, but missing functionality means they aren't used yet. + use crate::{ sqlite_types::{SqliteFlags, SqliteOpenFlags}, SqliteType, WasmSqlite, WasmSqliteError, @@ -159,7 +162,7 @@ impl RawConnection { } #[async_trait::async_trait(?Send)] -impl diesel_async::stmt_cache::PrepareCallback for RawConnection { +impl diesel_async::stmt_cache::PrepareCallback for &'_ mut RawConnection { async fn prepare( self, sql: &str, diff --git a/diesel-wasm-sqlite/src/connection/row.rs b/diesel-wasm-sqlite/src/connection/row.rs index 41c2c132d..d63a29783 100644 --- a/diesel-wasm-sqlite/src/connection/row.rs +++ b/diesel-wasm-sqlite/src/connection/row.rs @@ -12,20 +12,20 @@ use diesel::{ }; #[allow(missing_debug_implementations)] -pub struct SqliteRow<'stmt, 'query> { - pub(super) inner: Rc>>, +pub struct SqliteRow<'stmt> { + pub(super) inner: Rc>>, pub(super) field_count: usize, } -pub(super) enum PrivateSqliteRow<'stmt, 'query> { - Direct(StatementUse<'stmt, 'query>), +pub(super) enum PrivateSqliteRow<'stmt> { + Direct(StatementUse<'stmt>), Duplicated { values: Vec>, column_names: Rc<[Option]>, }, } -impl<'stmt> IntoOwnedRow<'stmt, WasmSqlite> for SqliteRow<'stmt, '_> { +impl<'stmt> IntoOwnedRow<'stmt, WasmSqlite> for SqliteRow<'stmt> { type OwnedRow = OwnedSqliteRow; type Cache = Option]>>; @@ -35,11 +35,11 @@ impl<'stmt> IntoOwnedRow<'stmt, WasmSqlite> for SqliteRow<'stmt, '_> { } } -impl<'stmt, 'query> PrivateSqliteRow<'stmt, 'query> { +impl<'stmt> PrivateSqliteRow<'stmt> { pub(super) fn duplicate( &mut self, column_names: &mut Option]>>, - ) -> PrivateSqliteRow<'stmt, 'query> { + ) -> PrivateSqliteRow<'stmt> { match self { PrivateSqliteRow::Direct(stmt) => { let column_names = if let Some(column_names) = column_names { @@ -129,10 +129,10 @@ impl<'stmt, 'query> PrivateSqliteRow<'stmt, 'query> { } } -impl<'stmt, 'query> RowSealed for SqliteRow<'stmt, 'query> {} +impl<'stmt> RowSealed for SqliteRow<'stmt> {} -impl<'stmt, 'query> Row<'stmt, WasmSqlite> for SqliteRow<'stmt, 'query> { - type Field<'field> = SqliteField<'field, 'field> where 'stmt: 'field, Self: 'field; +impl<'stmt> Row<'stmt, WasmSqlite> for SqliteRow<'stmt> { + type Field<'field> = SqliteField<'field> where 'stmt: 'field, Self: 'field; type InnerPartialRow = Self; fn field_count(&self) -> usize { @@ -156,7 +156,7 @@ impl<'stmt, 'query> Row<'stmt, WasmSqlite> for SqliteRow<'stmt, 'query> { } } -impl<'stmt, 'query> RowIndex for SqliteRow<'stmt, 'query> { +impl<'stmt> RowIndex for SqliteRow<'stmt> { fn idx(&self, idx: usize) -> Option { if idx < self.field_count { Some(idx) @@ -166,7 +166,7 @@ impl<'stmt, 'query> RowIndex for SqliteRow<'stmt, 'query> { } } -impl<'stmt, 'idx, 'query> RowIndex<&'idx str> for SqliteRow<'stmt, 'query> { +impl<'stmt, 'idx> RowIndex<&'idx str> for SqliteRow<'stmt> { fn idx(&self, field_name: &'idx str) -> Option { match &mut *self.inner.borrow_mut() { PrivateSqliteRow::Direct(stmt) => stmt.index_for_column_name(field_name), @@ -178,12 +178,12 @@ impl<'stmt, 'idx, 'query> RowIndex<&'idx str> for SqliteRow<'stmt, 'query> { } #[allow(missing_debug_implementations)] -pub struct SqliteField<'stmt, 'query> { - pub(super) row: Ref<'stmt, PrivateSqliteRow<'stmt, 'query>>, +pub struct SqliteField<'stmt> { + pub(super) row: Ref<'stmt, PrivateSqliteRow<'stmt>>, pub(super) col_idx: i32, } -impl<'stmt, 'query> Field<'stmt, WasmSqlite> for SqliteField<'stmt, 'query> { +impl<'stmt> Field<'stmt, WasmSqlite> for SqliteField<'stmt> { fn field_name(&self) -> Option<&str> { match &*self.row { PrivateSqliteRow::Direct(stmt) => stmt.field_name(self.col_idx), diff --git a/diesel-wasm-sqlite/src/connection/sqlite_value.rs b/diesel-wasm-sqlite/src/connection/sqlite_value.rs index e94e77217..7510a892b 100644 --- a/diesel-wasm-sqlite/src/connection/sqlite_value.rs +++ b/diesel-wasm-sqlite/src/connection/sqlite_value.rs @@ -2,7 +2,7 @@ use std::cell::Ref; -use crate::ffi::{self, SQLiteCompatibleType}; +use crate::ffi::SQLiteCompatibleType; use crate::{backend::SqliteType, sqlite_types}; use wasm_bindgen::JsValue; @@ -14,11 +14,11 @@ use super::row::PrivateSqliteRow; /// Use existing `FromSql` implementations to convert this into /// rust values #[allow(missing_debug_implementations, missing_copy_implementations)] -pub struct SqliteValue<'row, 'stmt, 'query> { +pub struct SqliteValue<'row, 'stmt> { // This field exists to ensure that nobody // can modify the underlying row while we are // holding a reference to some row value here - _row: Option>>, + _row: Option>>, // we extract the raw value pointer as part of the constructor // to safe the match statements for each method // According to benchmarks this leads to a ~20-30% speedup @@ -39,11 +39,11 @@ pub(super) struct OwnedSqliteValue { // see https://www.sqlite.org/c3ref/value.html unsafe impl Send for OwnedSqliteValue {} -impl<'row, 'stmt, 'query> SqliteValue<'row, 'stmt, 'query> { +impl<'row, 'stmt> SqliteValue<'row, 'stmt> { pub(super) fn new( - row: Ref<'row, PrivateSqliteRow<'stmt, 'query>>, + row: Ref<'row, PrivateSqliteRow<'stmt>>, col_idx: i32, - ) -> Option> { + ) -> Option> { let value = match &*row { PrivateSqliteRow::Direct(stmt) => stmt.column_value(col_idx)?, PrivateSqliteRow::Duplicated { values, .. } => values @@ -67,7 +67,7 @@ impl<'row, 'stmt, 'query> SqliteValue<'row, 'stmt, 'query> { pub(super) fn from_owned_row( row: &'row OwnedSqliteRow, col_idx: i32, - ) -> Option> { + ) -> Option> { let value = row .values .get(col_idx as usize) @@ -81,38 +81,39 @@ impl<'row, 'stmt, 'query> SqliteValue<'row, 'stmt, 'query> { Some(ret) } } + /* + pub(crate) fn parse_string(&self, f: impl FnOnce(String) -> R) -> R { + let sqlite3 = crate::get_sqlite_unchecked(); + let s = sqlite3.value_text(&self.value); + f(s) + } - pub(crate) fn parse_string(&self, f: impl FnOnce(String) -> R) -> R { - let sqlite3 = crate::get_sqlite_unchecked(); - let s = sqlite3.value_text(&self.value); - f(s) - } - - // TODO: Wasm bindgen can't work with references yet - // not sure if this will effect perf - pub(crate) fn read_text(&self) -> String { - self.parse_string(|s| s) - } + // TODO: Wasm bindgen can't work with references yet + // not sure if this will effect perf + pub(crate) fn read_text(&self) -> String { + self.parse_string(|s| s) + } - pub(crate) fn read_blob(&self) -> Vec { - let sqlite3 = crate::get_sqlite_unchecked(); - sqlite3.value_blob(&self.value) - } + pub(crate) fn read_blob(&self) -> Vec { + let sqlite3 = crate::get_sqlite_unchecked(); + sqlite3.value_blob(&self.value) + } - pub(crate) fn read_integer(&self) -> i32 { - let sqlite3 = crate::get_sqlite_unchecked(); - sqlite3.value_int(&self.value) - } + pub(crate) fn read_integer(&self) -> i32 { + let sqlite3 = crate::get_sqlite_unchecked(); + sqlite3.value_int(&self.value) + } - pub(crate) fn read_long(&self) -> i64 { - let sqlite3 = crate::get_sqlite_unchecked(); - sqlite3.value_int64(&self.value) - } + pub(crate) fn read_long(&self) -> i64 { + let sqlite3 = crate::get_sqlite_unchecked(); + sqlite3.value_int64(&self.value) + } - pub(crate) fn read_double(&self) -> f64 { - let sqlite3 = crate::get_sqlite_unchecked(); - sqlite3.value_double(&self.value) - } + pub(crate) fn read_double(&self) -> f64 { + let sqlite3 = crate::get_sqlite_unchecked(); + sqlite3.value_double(&self.value) + } + */ /// Get the type of the value as returned by sqlite pub fn value_type(&self) -> Option { diff --git a/diesel-wasm-sqlite/src/connection/statement_iterator.rs b/diesel-wasm-sqlite/src/connection/statement_iterator.rs deleted file mode 100644 index 393ec9e47..000000000 --- a/diesel-wasm-sqlite/src/connection/statement_iterator.rs +++ /dev/null @@ -1,172 +0,0 @@ -use std::cell::RefCell; -use std::rc::Rc; - -use super::row::{PrivateSqliteRow, SqliteRow}; -use super::stmt::StatementUse; -use crate::result::QueryResult; - -#[allow(missing_debug_implementations)] -pub struct StatementIterator<'stmt, 'query> { - inner: PrivateStatementIterator<'stmt, 'query>, - column_names: Option]>>, - field_count: usize, -} - -impl<'stmt, 'query> StatementIterator<'stmt, 'query> { - #[cold] - #[allow(unsafe_code)] // call to unsafe function - fn handle_duplicated_row_case( - outer_last_row: &mut Rc>>, - column_names: &mut Option]>>, - field_count: usize, - ) -> Option>> { - // We don't own the statement. There is another existing reference, likely because - // a user stored the row in some long time container before calling next another time - // In this case we copy out the current values into a temporary store and advance - // the statement iterator internally afterwards - let last_row = { - let mut last_row = match outer_last_row.try_borrow_mut() { - Ok(o) => o, - Err(_e) => { - return Some(Err(crate::result::Error::DeserializationError( - "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ - that exists at this point" - .into(), - ))); - } - }; - let last_row = &mut *last_row; - let duplicated = last_row.duplicate(column_names); - std::mem::replace(last_row, duplicated) - }; - if let PrivateSqliteRow::Direct(mut stmt) = last_row { - let res = unsafe { - // This is actually safe here as we've already - // performed one step. For the first step we would have - // used `PrivateStatementIterator::NotStarted` where we don't - // have access to `PrivateSqliteRow` at all - stmt.step(false) - }; - *outer_last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); - match res { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => Some(Ok(SqliteRow { - inner: Rc::clone(outer_last_row), - field_count, - })), - } - } else { - // any other state than `PrivateSqliteRow::Direct` is invalid here - // and should not happen. If this ever happens this is a logic error - // in the code above - unreachable!( - "You've reached an impossible internal state. \ - If you ever see this error message please open \ - an issue at https://github.com/diesel-rs/diesel \ - providing example code how to trigger this error." - ) - } - } -} - -enum PrivateStatementIterator<'stmt, 'query> { - NotStarted(Option>), - Started(Rc>>), -} - -impl<'stmt, 'query> StatementIterator<'stmt, 'query> { - pub fn new(stmt: StatementUse<'stmt, 'query>) -> StatementIterator<'stmt, 'query> { - Self { - inner: PrivateStatementIterator::NotStarted(Some(stmt)), - column_names: None, - field_count: 0, - } - } -} - -impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> { - type Item = QueryResult>; - - #[allow(unsafe_code)] // call to unsafe function - fn next(&mut self) -> Option { - use PrivateStatementIterator::{NotStarted, Started}; - match &mut self.inner { - NotStarted(ref mut stmt @ Some(_)) => { - let mut stmt = stmt - .take() - .expect("It must be there because we checked that above"); - let step = unsafe { - // This is safe as we pass `first_step = true` to reset the cached column names - stmt.step(true) - }; - match step { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => { - let field_count = stmt.column_count() as usize; - self.field_count = field_count; - let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); - self.inner = Started(inner.clone()); - Some(Ok(SqliteRow { inner, field_count })) - } - } - } - Started(ref mut last_row) => { - // There was already at least one iteration step - // We check here if the caller already released the row value or not - // by checking if our Rc owns the data or not - if let Some(last_row_ref) = Rc::get_mut(last_row) { - // We own the statement, there is no other reference here. - // This means we don't need to copy out values from the sqlite provided - // datastructures for now - // We don't need to use the runtime borrowing system of the RefCell here - // as we have a mutable reference, so all of this below is checked at compile time - if let PrivateSqliteRow::Direct(ref mut stmt) = last_row_ref.get_mut() { - let step = unsafe { - // This is actually safe here as we've already - // performed one step. For the first step we would have - // used `PrivateStatementIterator::NotStarted` where we don't - // have access to `PrivateSqliteRow` at all - - stmt.step(false) - }; - match step { - Err(e) => Some(Err(e)), - Ok(false) => None, - Ok(true) => { - let field_count = self.field_count; - Some(Ok(SqliteRow { - inner: Rc::clone(last_row), - field_count, - })) - } - } - } else { - // any other state than `PrivateSqliteRow::Direct` is invalid here - // and should not happen. If this ever happens this is a logic error - // in the code above - unreachable!( - "You've reached an impossible internal state. \ - If you ever see this error message please open \ - an issue at https://github.com/diesel-rs/diesel \ - providing example code how to trigger this error." - ) - } - } else { - Self::handle_duplicated_row_case( - last_row, - &mut self.column_names, - self.field_count, - ) - } - } - NotStarted(_s) => { - // we likely got an error while executing the other - // `NotStarted` branch above. In this case we just want to stop - // iterating here - None - } - } - } -} diff --git a/diesel-wasm-sqlite/src/connection/statement_stream.rs b/diesel-wasm-sqlite/src/connection/statement_stream.rs new file mode 100644 index 000000000..2b44cd5a3 --- /dev/null +++ b/diesel-wasm-sqlite/src/connection/statement_stream.rs @@ -0,0 +1,199 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use super::owned_row::OwnedSqliteRow; +use super::row::{PrivateSqliteRow, SqliteRow}; +use super::stmt::StatementUse; +use diesel::result::QueryResult; +use diesel::row::IntoOwnedRow; +use futures::stream::LocalBoxStream; + +#[allow(missing_debug_implementations)] +pub struct StatementStream<'stmt> { + inner: StatementStreamState<'stmt>, + column_names: Option]>>, + field_count: usize, +} + +impl<'stmt> StatementStream<'stmt> { + #[cold] + async fn handle_duplicated_row_case( + outer_last_row: &mut Rc>>, + column_names: &mut Option]>>, + field_count: usize, + ) -> Option> { + // We don't own the statement. There is another existing reference, likely because + // a user stored the row in some long time container before calling next another time + // In this case we copy out the current values into a temporary store and advance + // the statement iterator internally afterwards + let last_row = { + let mut last_row = match outer_last_row.try_borrow_mut() { + Ok(o) => o, + Err(_e) => { + return Some(Err(diesel::result::Error::DeserializationError( + "Failed to reborrow row. Try to release any `SqliteField` or `SqliteValue` \ + that exists at this point" + .into(), + ))); + } + }; + let last_row = &mut *last_row; + let duplicated = last_row.duplicate(column_names); + std::mem::replace(last_row, duplicated) + }; + if let PrivateSqliteRow::Direct(mut stmt) = last_row { + let res = stmt.step(false).await; + *outer_last_row = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + match res { + Err(e) => Some(Err(e)), + Ok(false) => None, + Ok(true) => Some(Ok(SqliteRow { + inner: Rc::clone(outer_last_row), + field_count, + } + .into_owned(&mut None))), + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } +} + +enum StatementStreamState<'stmt> { + NotStarted(Option>), + Started(Rc>>), +} + +impl<'stmt> StatementStream<'stmt> { + pub fn new(stmt: StatementUse<'stmt>) -> StatementStream<'stmt> { + Self { + inner: StatementStreamState::NotStarted(Some(stmt)), + column_names: None, + field_count: 0, + } + } +} +/// Rolling a custom `Stream` impl on StatementStream was taking too long/tricky +/// so using `futures::unfold`. Rolling a custom `Stream` would probably be better, +/// but performance wise/code-readability sense not very different +impl<'stmt> StatementStream<'stmt> { + pub fn stream(self) -> LocalBoxStream<'stmt, QueryResult> { + use StatementStreamState::{NotStarted, Started}; + let stream = futures::stream::unfold(self, |mut statement| async move { + match statement.inner { + NotStarted(mut stmt @ Some(_)) => { + let mut stmt = stmt + .take() + .expect("It must be there because we checked that above"); + match stmt.step(true).await { + Ok(true) => { + let field_count = stmt.column_count() as usize; + statement.field_count = field_count; + let inner = Rc::new(RefCell::new(PrivateSqliteRow::Direct(stmt))); + let new_inner = inner.clone(); + Some(( + Ok(SqliteRow { inner, field_count }.into_owned(&mut None)), + Self { + inner: Started(new_inner), + ..statement + }, + )) + } + Ok(false) => None, + Err(e) => Some(( + Err(e), + Self { + inner: NotStarted(Some(stmt)), + ..statement + }, + )), + } + // res.poll_next(cx).map(|t| t.flatten()) + } + Started(ref mut last_row) => { + // There was already at least one iteration step + // We check here if the caller already released the row value or not + // by checking if our Rc owns the data or not + if let Some(last_row_ref) = Rc::get_mut(last_row) { + // We own the statement, there is no other reference here. + // This means we don't need to copy out values from the sqlite provided + // datastructures for now + // We don't need to use the runtime borrowing system of the RefCell here + // as we have a mutable reference, so all of this below is checked at compile time + if let PrivateSqliteRow::Direct(ref mut stmt) = last_row_ref.get_mut() { + // This is actually safe here as we've already + // performed one step. For the first step we would have + // used `StatementStreamState::NotStarted` where we don't + // have access to `PrivateSqliteRow` at all + match stmt.step(false).await { + Err(e) => Some(( + Err(e), + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + )), + Ok(false) => None, + Ok(true) => { + let field_count = statement.field_count; + Some(( + Ok(SqliteRow { + inner: Rc::clone(last_row), + field_count, + } + .into_owned(&mut None)), + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + )) + } + } + } else { + // any other state than `PrivateSqliteRow::Direct` is invalid here + // and should not happen. If this ever happens this is a logic error + // in the code above + unreachable!( + "You've reached an impossible internal state. \ + If you ever see this error message please open \ + an issue at https://github.com/diesel-rs/diesel \ + providing example code how to trigger this error." + ) + } + } else { + let res = Self::handle_duplicated_row_case( + last_row, + &mut statement.column_names, + statement.field_count, + ) + .await; + res.map(|r| { + ( + r, + Self { + inner: Started(Rc::clone(last_row)), + ..statement + }, + ) + }) + } + } + NotStarted(_s) => { + // we likely got an error while executing the other + // `NotStarted` branch above. In this case we just want to stop + // iterating here + None + } + } + }); + Box::pin(stream) + } +} diff --git a/diesel-wasm-sqlite/src/connection/stmt.rs b/diesel-wasm-sqlite/src/connection/stmt.rs index 53fb5ba78..37958f19c 100644 --- a/diesel-wasm-sqlite/src/connection/stmt.rs +++ b/diesel-wasm-sqlite/src/connection/stmt.rs @@ -1,22 +1,21 @@ #![allow(unsafe_code)] //TODO: can probably remove for wa-sqlite -use super::bind_collector::{InternalSqliteBindValue, SqliteBindCollector}; +use super::bind_collector::{OwnedSqliteBindValue, SqliteBindCollectorData}; use super::raw::RawConnection; use super::sqlite_value::OwnedSqliteValue; use crate::ffi::SQLiteCompatibleType; use crate::{ sqlite_types::{self, PrepareOptions, SqlitePrepareFlags}, - SqliteType, WasmSqlite, + SqliteType, }; use diesel::{ connection::{ statement_cache::{MaybeCached, PrepareForCache}, Instrumentation, }, - query_builder::{QueryFragment, QueryId}, - result::{Error::DatabaseError, *}, + result::{Error, QueryResult}, }; use std::cell::OnceCell; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use wasm_bindgen::JsValue; @@ -90,7 +89,7 @@ impl Statement { fn bind( &self, _tpe: SqliteType, - value: InternalSqliteBindValue<'_>, + value: OwnedSqliteBindValue, bind_index: i32, ) -> QueryResult { let sqlite3 = crate::get_sqlite_unchecked(); @@ -100,7 +99,7 @@ impl Statement { .bind(&self.inner_statement, bind_index, value.into()) .unwrap(); - // TODO:insipx Pretty sure we can have a simpler implementation here + // TODO:insipx Pretty sure we can have a simpler implementation here vs diesel // making use of `wa-sqlite` `bind` which abstracts over the individual bind functions in // sqlite3. However, not sure how this will work further up the stack. // This might not work because of differences in how serde_json recognizes js types @@ -178,59 +177,41 @@ impl Drop for Statement { // * https://github.com/weiznich/diesel/pull/7 // * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/ // * https://github.com/rust-lang/unsafe-code-guidelines/issues/194 -struct BoundStatement<'stmt, 'query> { +struct BoundStatement<'stmt> { statement: MaybeCached<'stmt, Statement>, // we need to store the query here to ensure no one does // drop it till the end of the statement // We use a boxed queryfragment here just to erase the // generic type, we use NonNull to communicate // that this is a shared buffer - query: Option + 'query>>, - instrumentation: &'stmt Mutex, + // query: Option>>, + #[allow(unused)] + instrumentation: Arc>, has_error: bool, } -impl<'stmt, 'query> BoundStatement<'stmt, 'query> { - fn bind( +impl<'stmt> BoundStatement<'stmt> { + fn bind( statement: MaybeCached<'stmt, Statement>, - query: T, - instrumentation: &'stmt Mutex, - ) -> QueryResult> - where - T: QueryFragment + QueryId + 'query, - { - // Don't use a trait object here to prevent using a virtual function call - // For sqlite this can introduce a measurable overhead - // Query is boxed here to make sure it won't move in memory anymore, so any bind - // it could output would stay valid. - let query = Box::new(query); - - let mut bind_collector = SqliteBindCollector::new(); - query.collect_binds(&mut bind_collector, &mut (), &WasmSqlite)?; - let SqliteBindCollector { binds } = bind_collector; - + bind_collector: SqliteBindCollectorData, + instrumentation: Arc>, + ) -> QueryResult> { + let SqliteBindCollectorData { binds } = bind_collector; let mut ret = BoundStatement { statement, - query: None, instrumentation, has_error: false, }; ret.bind_buffers(binds)?; - let query = query as Box + 'query>; - ret.query = Some(query); - Ok(ret) } // This is a separated function so that // not the whole constructor is generic over the query type T. // This hopefully prevents binary bloat. - fn bind_buffers( - &mut self, - binds: Vec<(InternalSqliteBindValue<'_>, SqliteType)>, - ) -> QueryResult<()> { + fn bind_buffers(&mut self, binds: Vec<(OwnedSqliteBindValue, SqliteType)>) -> QueryResult<()> { for (bind_idx, (bind, tpe)) in (1..).zip(binds) { // It's safe to call bind here as: // * The type and value matches @@ -252,18 +233,6 @@ impl<'stmt, 'query> BoundStatement<'stmt, 'query> { fn finish_query_with_error(mut self, _e: &Error) { self.has_error = true; - /* - if let Some(q) = self.query { - // it's safe to get a reference from this ptr as it's guaranteed to not be null - let q = unsafe { q.as_ref() }; - self.instrumentation.on_connection_event( - diesel::connection::InstrumentationEvent::FinishQuery { - query: &crate::debug_query(&q), - error: Some(e), - }, - ); - } - */ } // FIXME: [`AsyncDrop`](https://github.com/rust-lang/rust/issues/126482) is a missing feature in rust. @@ -275,7 +244,7 @@ impl<'stmt, 'query> BoundStatement<'stmt, 'query> { } // Eventually replace with `AsyncDrop`: https://github.com/rust-lang/rust/issues/126482 -impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> { +impl<'stmt> Drop for BoundStatement<'stmt> { fn drop(&mut self) { let sender = self .statement @@ -283,27 +252,24 @@ impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> { .take() .expect("Drop may only be ran once"); let _ = sender.send(self.statement.inner_statement.clone()); - self.query.take(); } } #[allow(missing_debug_implementations)] -pub struct StatementUse<'stmt, 'query> { - statement: BoundStatement<'stmt, 'query>, +pub struct StatementUse<'stmt> { + statement: BoundStatement<'stmt>, column_names: OnceCell>, } -impl<'stmt, 'query> StatementUse<'stmt, 'query> { - pub(super) fn bind( +impl<'stmt> StatementUse<'stmt> { + pub(super) fn bind( statement: MaybeCached<'stmt, Statement>, - query: T, - instrumentation: &'stmt Mutex, - ) -> QueryResult> - where - T: QueryFragment + QueryId + 'query, - { + bind_collector: SqliteBindCollectorData, + instrumentation: Arc>, + ) -> QueryResult> +where { Ok(Self { - statement: BoundStatement::bind(statement, query, instrumentation)?, + statement: BoundStatement::bind(statement, bind_collector, instrumentation)?, column_names: OnceCell::new(), }) } diff --git a/diesel-wasm-sqlite/src/query_builder/mod.rs b/diesel-wasm-sqlite/src/query_builder/mod.rs index 8bf48119b..3489876b8 100644 --- a/diesel-wasm-sqlite/src/query_builder/mod.rs +++ b/diesel-wasm-sqlite/src/query_builder/mod.rs @@ -6,7 +6,7 @@ use diesel::result::QueryResult; mod limit_offset; mod query_fragment_impls; -mod returning; +// mod returning; /// Constructs SQL queries for use with the SQLite backend #[allow(missing_debug_implementations)]