diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index e23bacc0efdf..088c9f16a7b3 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -46,7 +46,7 @@ ipnetwork = ">=0.12.2, <0.21.0" quickcheck = "1.0.3" [features] -default = ["with-deprecated", "32-column-tables"] +default = ["postgres"] extras = ["chrono", "time", "serde_json", "uuid", "network-address", "numeric", "r2d2"] unstable = ["diesel_derives/nightly"] large-tables = ["32-column-tables"] diff --git a/diesel/src/expression/bound.rs b/diesel/src/expression/bound.rs index 26b53d5c0a65..414716d2251c 100644 --- a/diesel/src/expression/bound.rs +++ b/diesel/src/expression/bound.rs @@ -10,7 +10,7 @@ use crate::sql_types::{DieselNumericOps, HasSqlType, SqlType}; #[doc(hidden)] // This is used by the `AsExpression` derive #[derive(Debug, Clone, Copy, DieselNumericOps)] pub struct Bound { - item: U, + pub(crate) item: U, _marker: PhantomData, } diff --git a/diesel/src/insertable.rs b/diesel/src/insertable.rs index 4747dce5f9c6..e2010457f6ce 100644 --- a/diesel/src/insertable.rs +++ b/diesel/src/insertable.rs @@ -122,7 +122,7 @@ pub trait InsertValues: QueryFragment { #[derive(Debug, Copy, Clone, QueryId)] #[doc(hidden)] pub struct ColumnInsertValue { - expr: Expr, + pub(crate) expr: Expr, p: PhantomData, } diff --git a/diesel/src/lib.rs b/diesel/src/lib.rs index ef8f22404528..be93494a50dd 100644 --- a/diesel/src/lib.rs +++ b/diesel/src/lib.rs @@ -332,6 +332,10 @@ pub mod dsl { delete, insert_into, insert_or_ignore_into, replace_into, select, sql_query, update, }; + #[doc(inline)] + #[cfg(feature = "postgres")] + pub use crate::query_builder::functions::copy_in; + #[doc(inline)] pub use diesel_derives::auto_type; } @@ -683,6 +687,9 @@ pub use crate::prelude::*; #[doc(inline)] pub use crate::query_builder::debug_query; #[doc(inline)] +#[cfg(feature = "postgres")] +pub use crate::query_builder::functions::{copy_in, copy_out}; +#[doc(inline)] pub use crate::query_builder::functions::{ delete, insert_into, insert_or_ignore_into, replace_into, select, sql_query, update, }; diff --git a/diesel/src/pg/connection/copy.rs b/diesel/src/pg/connection/copy.rs new file mode 100644 index 000000000000..a4759129cef1 --- /dev/null +++ b/diesel/src/pg/connection/copy.rs @@ -0,0 +1,121 @@ +use core::ffi; +use std::io::BufRead; +use std::io::Read; +use std::io::Write; + +use super::raw::RawConnection; +use super::result::PgResult; +use crate::QueryResult; + +#[allow(missing_debug_implementations)] // `PgConnection` is not debug +pub struct CopyIn<'conn> { + conn: &'conn mut RawConnection, +} + +impl<'conn> CopyIn<'conn> { + pub(super) fn new(conn: &'conn mut RawConnection) -> Self { + Self { conn } + } + + pub(super) fn finish(self, err: Option) -> QueryResult<()> { + self.conn.finish_copy_in(err) + } +} + +impl<'conn> Write for CopyIn<'conn> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.conn.put_copy_data(buf).unwrap(); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +#[allow(missing_debug_implementations)] // `PgConnection` is not debug +pub struct CopyOut<'conn> { + conn: &'conn mut RawConnection, + ptr: *mut ffi::c_char, + offset: usize, + len: usize, + result: PgResult, +} + +impl<'conn> CopyOut<'conn> { + pub(super) fn new(conn: &'conn mut RawConnection, result: PgResult) -> Self { + Self { + conn, + ptr: std::ptr::null_mut(), + offset: 0, + len: 0, + result, + } + } + + #[allow(unsafe_code)] // construct a slice from a raw ptr + pub(crate) fn data_slice(&self) -> &[u8] { + if self.ptr.is_null() { + &[] + } else if self.offset < self.len { + let slice = unsafe { std::slice::from_raw_parts(self.ptr as *const u8, self.len - 1) }; + &slice[self.offset..] + } else { + &[] + } + } + + pub(crate) fn get_result(&self) -> &PgResult { + &self.result + } +} + +impl<'conn> Drop for CopyOut<'conn> { + #[allow(unsafe_code)] // ffi code + fn drop(&mut self) { + if !self.ptr.is_null() { + unsafe { pq_sys::PQfreemem(self.ptr as *mut ffi::c_void) }; + self.ptr = std::ptr::null_mut(); + } + } +} + +impl<'conn> Read for CopyOut<'conn> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let data = self.fill_buf()?; + let len = usize::min(buf.len(), data.len()); + buf[..len].copy_from_slice(&data[..len]); + self.consume(len); + Ok(len) + } +} + +impl<'conn> BufRead for CopyOut<'conn> { + #[allow(unsafe_code)] // ffi code + ptr arithmetic + fn fill_buf(&mut self) -> std::io::Result<&[u8]> { + if self.data_slice().is_empty() { + unsafe { + if !self.ptr.is_null() { + pq_sys::PQfreemem(self.ptr as *mut ffi::c_void); + self.ptr = std::ptr::null_mut(); + } + let len = + pq_sys::PQgetCopyData(self.conn.internal_connection.as_ptr(), &mut self.ptr, 0); + if len >= 0 { + self.len = len as usize + 1; + } else if len == -1 { + self.len = 0; + } else { + let error = self.conn.last_error_message(); + return Err(std::io::Error::new(std::io::ErrorKind::Other, error)); + } + self.offset = 0; + } + } + Ok(self.data_slice()) + } + + fn consume(&mut self, amt: usize) { + self.offset = usize::min(self.len, self.offset + amt); + } +} diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index aa1bec8a55e7..1a8de2a02b9e 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -1,13 +1,15 @@ +pub(super) mod copy; pub(crate) mod cursor; mod raw; mod result; mod row; mod stmt; +use self::copy::CopyIn; +use self::copy::CopyOut; use self::cursor::*; use self::private::ConnectionAndTransactionManager; use self::raw::{PgTransactionStatus, RawConnection}; -use self::result::PgResult; use self::stmt::Statement; use crate::connection::instrumentation::DebugQuery; use crate::connection::instrumentation::StrQueryHelper; @@ -16,6 +18,7 @@ use crate::connection::statement_cache::{MaybeCached, StatementCache}; use crate::connection::*; use crate::expression::QueryMetadata; use crate::pg::metadata_lookup::{GetPgMetadataCache, PgMetadataCache}; +use crate::pg::query_builder::copy::InternalCopyInQuery; use crate::pg::{Pg, TransactionBuilder}; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; @@ -26,6 +29,12 @@ use std::ffi::CString; use std::fmt::Debug; use std::os::raw as libc; +use super::query_builder::copy::CopyInExpression; +use super::query_builder::copy::CopyTarget; +use super::query_builder::copy::CopyTo; + +pub(super) use self::result::PgResult; + /// The connection string expected by `PgConnection::establish` /// should be a PostgreSQL connection string, as documented at /// @@ -393,7 +402,46 @@ impl PgConnection { TransactionBuilder::new(self) } - fn with_prepared_query<'conn, T: QueryFragment + QueryId, R>( + pub(crate) fn copy_in(&mut self, target: S) -> Result + where + S: CopyInExpression, + { + let query = InternalCopyInQuery::new(target); + let res = self.with_prepared_query(query, true, |stmt, binds, conn, source| { + let _res = stmt.execute(&mut conn.raw_connection, &binds, false)?; + let mut copy_in = CopyIn::new(&mut conn.raw_connection); + let r = source.target.callback(&mut copy_in); + copy_in.finish(r.as_ref().err().map(|e| e.to_string()))?; + let next_res = conn.raw_connection.get_next_result()?.expect("exists"); + let rows = next_res.rows_affected(); + // need to pull out any other result + while let Some(_r) = conn.raw_connection.get_next_result()? {} + // it's important to only return a potential error here as + // we need to ensure that `finish` is called and we pull + // all the results + r?; + Ok::<_, S::Error>(rows) + })?; + + Ok(res) + } + + pub(crate) fn copy_out(&mut self, command: CopyTo) -> QueryResult> + where + T: CopyTarget, + { + let res = self.with_prepared_query::<_, _, Error>( + command, + true, + |stmt, binds, conn, _source| { + let res = stmt.execute(&mut conn.raw_connection, &binds, false)?; + Ok(CopyOut::new(&mut conn.raw_connection, res)) + }, + )?; + Ok(res) + } + + fn with_prepared_query<'conn, T, R, E>( &'conn mut self, source: T, execute_returning_count: bool, @@ -402,8 +450,12 @@ impl PgConnection { Vec>>, &'conn mut ConnectionAndTransactionManager, T, - ) -> QueryResult, - ) -> QueryResult { + ) -> Result, + ) -> Result + where + T: QueryFragment + QueryId, + E: From, + { self.connection_and_transaction_manager .instrumentation .on_connection_event(InstrumentationEvent::StartQuery { diff --git a/diesel/src/pg/connection/raw.rs b/diesel/src/pg/connection/raw.rs index f20313b41bfd..4caf963b3a56 100644 --- a/diesel/src/pg/connection/raw.rs +++ b/diesel/src/pg/connection/raw.rs @@ -15,7 +15,7 @@ use super::result::PgResult; #[allow(missing_debug_implementations, missing_copy_implementations)] pub(super) struct RawConnection { - internal_connection: NonNull, + pub(super) internal_connection: NonNull, } impl RawConnection { @@ -140,6 +140,42 @@ impl RawConnection { )) } } + + pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> { + for c in buf.chunks(i32::MAX as usize) { + let res = unsafe { + pq_sys::PQputCopyData( + self.internal_connection.as_ptr(), + c.as_ptr() as *const i8, + c.len() as i32, + ) + }; + if res != 1 { + return Err(Error::DatabaseError( + DatabaseErrorKind::Unknown, + Box::new(self.last_error_message()), + )); + } + } + Ok(()) + } + + pub(crate) fn finish_copy_in(&self, err: Option) -> QueryResult<()> { + let error = err.map(CString::new).transpose()?; + let error = error + .as_ref() + .map(|l| l.as_ptr()) + .unwrap_or(std::ptr::null()); + let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) }; + if ret == 1 { + Ok(()) + } else { + Err(Error::DatabaseError( + DatabaseErrorKind::Unknown, + Box::new(self.last_error_message()), + )) + } + } } /// Represents the current in-transaction status of the connection diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 3dc22f8fb555..d66b6d72d173 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -31,6 +31,8 @@ impl PgResult { match result_status { ExecStatusType::PGRES_SINGLE_TUPLE | ExecStatusType::PGRES_COMMAND_OK + | ExecStatusType::PGRES_COPY_IN + | ExecStatusType::PGRES_COPY_OUT | ExecStatusType::PGRES_TUPLES_OK => { let column_count = unsafe { PQnfields(internal_result.as_ptr()) as usize }; let row_count = unsafe { PQntuples(internal_result.as_ptr()) as usize }; @@ -138,7 +140,7 @@ impl PgResult { } } - pub(super) fn column_type(&self, col_idx: usize) -> NonZeroU32 { + pub(in crate::pg) fn column_type(&self, col_idx: usize) -> NonZeroU32 { let type_oid = unsafe { PQftype(self.internal_result.as_ptr(), col_idx as libc::c_int) }; NonZeroU32::new(type_oid).expect( "Got a zero oid from postgres. If you see this error message \ diff --git a/diesel/src/pg/mod.rs b/diesel/src/pg/mod.rs index 952e8e49d217..8c948e737191 100644 --- a/diesel/src/pg/mod.rs +++ b/diesel/src/pg/mod.rs @@ -27,6 +27,7 @@ pub use self::query_builder::DistinctOnClause; pub use self::query_builder::OrderDecorator; #[doc(inline)] pub use self::query_builder::PgQueryBuilder; +pub use self::query_builder::{CopyFormat, CopyHeader}; #[doc(inline)] pub use self::transaction::TransactionBuilder; #[doc(inline)] diff --git a/diesel/src/pg/query_builder/copy/copy_in.rs b/diesel/src/pg/query_builder/copy/copy_in.rs new file mode 100644 index 000000000000..1e8976747978 --- /dev/null +++ b/diesel/src/pg/query_builder/copy/copy_in.rs @@ -0,0 +1,442 @@ +use std::borrow::Cow; +use std::io::Write; +use std::marker::PhantomData; + +use byteorder::NetworkEndian; +use byteorder::WriteBytesExt; + +use super::CommonOptions; +use super::CopyFormat; +use super::CopyTarget; +use crate::expression::bound::Bound; +use crate::insertable::ColumnInsertValue; +use crate::pg::backend::FailedToLookupTypeError; +use crate::pg::connection::copy::CopyIn; +use crate::pg::metadata_lookup::PgMetadataCacheKey; +use crate::pg::Pg; +use crate::pg::PgMetadataLookup; +use crate::query_builder::BatchInsert; +use crate::query_builder::QueryFragment; +use crate::query_builder::QueryId; +use crate::query_builder::ValuesClause; +use crate::serialize::IsNull; +use crate::serialize::ToSql; +use crate::Insertable; +use crate::PgConnection; +use crate::QueryResult; +use crate::{Column, Table}; + +#[derive(Debug, Copy, Clone)] +pub enum CopyHeader { + Set(bool), + Match, +} + +#[derive(Debug, Default)] +pub struct CopyFromOptions { + common: CommonOptions, + default: Option, + header: Option, +} + +impl QueryFragment for CopyFromOptions { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + if self.any_set() { + let mut comma = ""; + pass.push_sql(" WITH ("); + self.common.walk_ast(pass.reborrow(), &mut comma)?; + if let Some(ref default) = self.default { + pass.push_sql(comma); + comma = ", "; + pass.push_sql("DEFAULT '"); + // cannot use binds here :( + pass.push_sql(default); + pass.push_sql("'"); + } + if let Some(ref header) = self.header { + pass.push_sql(comma); + // commented out because rustc complains otherwise + //comma = ", "; + pass.push_sql("HEADER "); + match header { + CopyHeader::Set(true) => pass.push_sql("1"), + CopyHeader::Set(false) => pass.push_sql("0"), + CopyHeader::Match => pass.push_sql("MATCH"), + } + } + + pass.push_sql(")"); + } + Ok(()) + } +} + +impl CopyFromOptions { + fn any_set(&self) -> bool { + self.common.any_set() || self.default.is_some() || self.header.is_some() + } +} + +#[derive(Debug)] +pub struct CopyFrom { + options: CopyFromOptions, + copy_callback: F, + p: PhantomData, +} + +pub(crate) struct InternalCopyInQuery { + pub(crate) target: S, + p: PhantomData, +} + +impl InternalCopyInQuery { + pub(crate) fn new(target: S) -> Self { + Self { + target, + p: PhantomData, + } + } +} + +impl QueryId for InternalCopyInQuery +where + S: CopyInExpression, +{ + const HAS_STATIC_QUERY_ID: bool = false; + type QueryId = (); +} + +impl QueryFragment for InternalCopyInQuery +where + S: CopyInExpression, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + pass.unsafe_to_cache_prepared(); + pass.push_sql("COPY "); + self.target.walk_target(pass.reborrow())?; + pass.push_sql(" FROM STDIN"); + self.target.options().walk_ast(pass.reborrow())?; + // todo: where? + Ok(()) + } +} + +pub trait CopyInExpression { + type Error: From + std::error::Error; + + fn callback(self, copy: &mut CopyIn<'_>) -> Result<(), Self::Error>; + + fn walk_target<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()>; + + fn options(&self) -> &CopyFromOptions; +} + +impl CopyInExpression for CopyFrom +where + E: From + std::error::Error, + S: CopyTarget, + F: Fn(&mut CopyIn<'_>) -> Result<(), E>, +{ + type Error = E; + + fn callback(self, copy: &mut CopyIn<'_>) -> Result<(), Self::Error> { + (self.copy_callback)(copy) + } + + fn options(&self) -> &CopyFromOptions { + &self.options + } + + fn walk_target<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + S::walk_target(pass) + } +} + +struct Dummy; + +impl PgMetadataLookup for Dummy { + fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> crate::pg::PgTypeMetadata { + let cache_key = PgMetadataCacheKey::new( + schema.map(Into::into).map(Cow::Owned), + Cow::Owned(type_name.into()), + ); + crate::pg::PgTypeMetadata(Err(FailedToLookupTypeError::new_internal(cache_key))) + } +} + +trait CopyFromInsertableHelper { + type Target: CopyTarget; + const COLUMN_COUNT: i16; + + fn write_to_buffer(&self, idx: i16, out: &mut Vec) -> QueryResult; +} + +macro_rules! impl_copy_from_insertable_helper_for_values_clause { + ($( + $Tuple:tt { + $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+ + } + )+) => { + $( + impl CopyFromInsertableHelper for ValuesClause< + ($(ColumnInsertValue<$ST, Bound<$T, $TT>>,)*), + T> + where + T: Table, + $($ST: Column,)* + ($($ST,)*): CopyTarget, + $($TT: ToSql<$T, Pg>,)* + { + type Target = ($($ST,)*); + const COLUMN_COUNT: i16 = $Tuple as i16; + + fn write_to_buffer(&self, idx: i16, out: &mut Vec) -> QueryResult { + use crate::query_builder::ByteWrapper; + use crate::serialize::Output; + + let values = &self.values; + match idx { + $($idx =>{ + let item = &values.$idx.expr.item; + let is_null = ToSql::<$T, Pg>::to_sql( + item, + &mut Output::new( ByteWrapper(out), &mut Dummy as _) + ).map_err(crate::result::Error::SerializationError)?; + return Ok(is_null); + })* + _ => unreachable!(), + } + } + } + + impl<'a, T, $($ST,)* $($T,)* $($TT,)*> CopyFromInsertableHelper for ValuesClause< + ($(ColumnInsertValue<$ST, &'a Bound<$T, $TT>>,)*), + T> + where + T: Table, + $($ST: Column
,)* + ($($ST,)*): CopyTarget, + $($TT: ToSql<$T, Pg>,)* + { + type Target = ($($ST,)*); + const COLUMN_COUNT: i16 = $Tuple as i16; + + fn write_to_buffer(&self, idx: i16, out: &mut Vec) -> QueryResult { + use crate::query_builder::ByteWrapper; + use crate::serialize::Output; + + let values = &self.values; + match idx { + $($idx =>{ + let item = &values.$idx.expr.item; + let is_null = ToSql::<$T, Pg>::to_sql( + item, + &mut Output::new( ByteWrapper(out), &mut Dummy as _) + ).map_err(crate::result::Error::SerializationError)?; + return Ok(is_null); + })* + _ => unreachable!(), + } + } + } + )* + } +} + +diesel_derives::__diesel_for_each_tuple!(impl_copy_from_insertable_helper_for_values_clause); + +#[derive(Debug)] +pub struct InsertableWrapper(I); + +impl CopyInExpression for InsertableWrapper +where + I: Insertable, T, QId, STATIC_QUERY_ID>>, + V: CopyFromInsertableHelper, +{ + type Error = crate::result::Error; + + fn callback(self, copy: &mut CopyIn<'_>) -> Result<(), Self::Error> { + let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e)); + // see https://www.postgresql.org/docs/current/sql-copy.html for + // a description of the binary format + // + // We don't write oids + + // write the header + copy.write_all(&super::COPY_MAGIC_HEADER) + .map_err(io_result_mapper)?; + copy.write_i32::(0) + .map_err(io_result_mapper)?; + copy.write_i32::(0) + .map_err(io_result_mapper)?; + // write the data + // we reuse the same buffer here again and again + // as we expect the data to be "similar" + // this skips reallocating + let mut buffer = Vec::::new(); + let values = self.0.values(); + for i in values.values { + // column count + buffer + .write_i16::(V::COLUMN_COUNT) + .map_err(io_result_mapper)?; + for idx in 0..V::COLUMN_COUNT { + // first write the null indicator as dummy value + buffer + .write_i32::(-1) + .map_err(io_result_mapper)?; + let len_before = buffer.len(); + let is_null = i.write_to_buffer(idx, &mut buffer)?; + if is_null == IsNull::No { + // fill in the length afterwards + let len_after = buffer.len(); + let diff = (len_after - len_before) as i32; + let bytes = i32::to_be_bytes(diff); + for (b, t) in bytes.into_iter().zip(&mut buffer[len_before - 4..]) { + *t = b; + } + } + } + copy.write_all(&buffer).map_err(io_result_mapper)?; + buffer.clear(); + } + // write the trailer + copy.write_i16::(-1) + .map_err(io_result_mapper)?; + Ok(()) + } + + fn options(&self) -> &CopyFromOptions { + &CopyFromOptions { + common: CommonOptions { + format: Some(CopyFormat::Binary), + freeze: None, + delimiter: None, + null: None, + quote: None, + escape: None, + }, + default: None, + header: None, + } + } + + fn walk_target<'b>( + &'b self, + pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + ::Target::walk_target(pass) + } +} + +#[derive(Debug)] +pub struct CopyInQuery { + table: T, + action: Action, +} + +impl CopyInQuery +where + T: Table, +{ + pub fn from_raw_data(self, _target: C, action: F) -> CopyInQuery> + where + C: CopyTarget
, + F: Fn(&mut CopyIn<'_>) -> Result<(), E>, + { + CopyInQuery { + table: self.table, + action: CopyFrom { + p: PhantomData, + options: Default::default(), + copy_callback: action, + }, + } + } + + pub fn from_insertable(self, insertable: I) -> CopyInQuery> + where + InsertableWrapper: CopyInExpression, + { + CopyInQuery { + table: self.table, + action: InsertableWrapper(insertable), + } + } +} + +impl CopyInQuery> { + pub fn with_format(mut self, format: CopyFormat) -> Self { + self.action.options.common.format = Some(format); + self + } + + pub fn with_freeze(mut self, freeze: bool) -> Self { + self.action.options.common.freeze = Some(freeze); + self + } + + pub fn with_delimiter(mut self, delimiter: char) -> Self { + self.action.options.common.delimiter = Some(delimiter); + self + } + + pub fn with_null(mut self, null: impl Into) -> Self { + self.action.options.common.null = Some(null.into()); + self + } + + pub fn with_quote(mut self, quote: char) -> Self { + self.action.options.common.quote = Some(quote); + self + } + + pub fn with_escape(mut self, escape: char) -> Self { + self.action.options.common.escape = Some(escape); + self + } + + pub fn with_default(mut self, default: impl Into) -> Self { + self.action.options.default = Some(default.into()); + self + } + + pub fn with_header(mut self, header: CopyHeader) -> Self { + self.action.options.header = Some(header); + self + } +} + +impl CopyInQuery +where + A: CopyInExpression, +{ + pub fn execute(self, conn: &mut PgConnection) -> Result { + conn.copy_in::(self.action) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NotSet; + +pub fn copy_in(table: T) -> CopyInQuery +where + T: Table, +{ + CopyInQuery { + table, + action: NotSet, + } +} diff --git a/diesel/src/pg/query_builder/copy/copy_out.rs b/diesel/src/pg/query_builder/copy/copy_out.rs new file mode 100644 index 000000000000..eaab623f70ff --- /dev/null +++ b/diesel/src/pg/query_builder/copy/copy_out.rs @@ -0,0 +1,338 @@ +use std::io::BufRead; +use std::marker::PhantomData; + +use super::CommonOptions; +use super::CopyFormat; +use super::CopyTarget; +use crate::deserialize::FromSqlRow; +use crate::pg::connection::copy::CopyOut; +use crate::pg::connection::PgResult; +use crate::pg::value::TypeOidLookup; +use crate::pg::Pg; +use crate::pg::PgValue; +use crate::query_builder::QueryFragment; +use crate::query_builder::QueryId; +use crate::row; +use crate::row::Field; +use crate::row::PartialRow; +use crate::row::Row; +use crate::row::RowIndex; +use crate::row::RowSealed; +use crate::PgConnection; +use crate::QueryResult; + +#[derive(Default, Debug)] +pub struct CopyToOptions { + common: CommonOptions, +} + +impl CopyToOptions { + fn any_set(&self) -> bool { + self.common.any_set() + } +} + +impl QueryFragment for CopyToOptions { + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + if self.any_set() { + let mut comma = ""; + pass.push_sql(" WITH ("); + self.common.walk_ast(pass.reborrow(), &mut comma)?; + + pass.push_sql(")"); + } + Ok(()) + } +} + +#[derive(Debug)] +pub(crate) struct CopyTo { + options: CopyToOptions, + p: PhantomData, +} + +impl QueryId for CopyTo +where + S: CopyTarget, +{ + type QueryId = (); + + const HAS_STATIC_QUERY_ID: bool = false; +} + +impl QueryFragment for CopyTo +where + S: CopyTarget, +{ + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + pass.unsafe_to_cache_prepared(); + pass.push_sql("COPY "); + S::walk_target(pass.reborrow())?; + pass.push_sql(" TO STDOUT"); + self.options.walk_ast(pass.reborrow())?; + Ok(()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NotSet; + +pub trait CopyOutMarker: Sized { + fn setup_options(q: CopyOutQuery) -> CopyOutQuery; +} + +impl CopyOutMarker for NotSet { + fn setup_options(q: CopyOutQuery) -> CopyOutQuery { + CopyOutQuery { + target: q.target, + options: CopyToOptions::default(), + } + } +} +impl CopyOutMarker for CopyToOptions { + fn setup_options(q: CopyOutQuery) -> CopyOutQuery { + q + } +} + +#[derive(Debug)] +pub struct CopyOutQuery { + target: T, + options: O, +} + +struct CopyRow<'a> { + buffers: Vec>, + result: &'a PgResult, +} + +struct CopyField<'a> { + field: &'a Option<&'a [u8]>, + result: &'a PgResult, + col_idx: usize, +} + +impl<'f> Field<'f, Pg> for CopyField<'f> { + fn field_name(&self) -> Option<&str> { + None + } + + fn value(&self) -> Option<::RawValue<'_>> { + let value = self.field.as_deref()?; + Some(PgValue::new_internal(value, self)) + } +} + +impl<'a> TypeOidLookup for CopyField<'a> { + fn lookup(&self) -> std::num::NonZeroU32 { + self.result.column_type(self.col_idx) + } +} + +impl RowSealed for CopyRow<'_> {} + +impl RowIndex for CopyRow<'_> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None + } + } +} + +impl<'a> RowIndex<&'a str> for CopyRow<'_> { + fn idx(&self, _idx: &'a str) -> Option { + None + } +} + +impl<'a> Row<'a, Pg> for CopyRow<'_> { + type Field<'f> = CopyField<'f> + where + 'a: 'f, + Self: 'f; + + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + self.buffers.len() + } + + fn get<'b, I>(&'b self, idx: I) -> Option> + where + 'a: 'b, + Self: RowIndex, + { + let idx = self.idx(idx)?; + let buffer = self.buffers.get(idx)?; + Some(CopyField { + field: buffer, + result: self.result, + col_idx: idx, + }) + } + + fn partial_row( + &self, + range: std::ops::Range, + ) -> row::PartialRow<'_, Self::InnerPartialRow> { + PartialRow::new(self, range) + } +} + +impl CopyOutQuery +where + T: CopyTarget, +{ + pub fn load<'a, U>( + self, + conn: &'a mut PgConnection, + ) -> QueryResult> + 'a> + where + U: FromSqlRow, + { + let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e)); + + let command = CopyTo { + p: PhantomData::, + options: CopyToOptions { + common: CommonOptions { + format: Some(CopyFormat::Binary), + ..Default::default() + }, + }, + }; + // see https://www.postgresql.org/docs/current/sql-copy.html for + // a description of the binary format + // + // We don't write oids + + let mut out = conn.copy_out(command)?; + out.fill_buf().map_err(io_result_mapper)?; + let buffer = out.data_slice(); + if &buffer[..super::COPY_MAGIC_HEADER.len()] != super::COPY_MAGIC_HEADER { + return Err(crate::result::Error::DeserializationError( + "Unexpected protocol header".into(), + )); + } + // we care only about bit 16-31 here, so we can just skip the bytes in between + let flags_backward_incompatible = i16::from_be_bytes( + (&buffer[super::COPY_MAGIC_HEADER.len() + 2..super::COPY_MAGIC_HEADER.len() + 4]) + .try_into() + .expect("Exactly 2 byte"), + ); + if flags_backward_incompatible != 0 { + return Err(crate::result::Error::DeserializationError( + format!("Unexpected flag value: {flags_backward_incompatible:x}").into(), + )); + } + let header_size = i32::from_be_bytes( + (&buffer[super::COPY_MAGIC_HEADER.len() + 4..super::COPY_MAGIC_HEADER.len() + 8]) + .try_into() + .expect("Exactly 4 byte"), + ); + out.consume(super::COPY_MAGIC_HEADER.len() + 8 + header_size as usize); + let mut len = None; + Ok(std::iter::from_fn(move || { + if let Some(len) = len { + out.consume(len); + if let Err(e) = out.fill_buf().map_err(io_result_mapper) { + return Some(Err(e)); + } + } + let buffer = out.data_slice(); + len = Some(buffer.len()); + let tuple_count = + i16::from_be_bytes((&buffer[..2]).try_into().expect("Exactly 2 bytes")); + if tuple_count > 0 { + let mut buffers = Vec::with_capacity(tuple_count as usize); + let mut offset = 2; + for _t in 0..tuple_count { + let data_size = i32::from_be_bytes( + (&buffer[offset..offset + 4]) + .try_into() + .expect("Exactly 4 bytes"), + ); + if data_size < 0 { + buffers.push(None); + } else { + buffers.push(Some(&buffer[offset + 4..offset + 4 + data_size as usize])); + offset = offset + 4 + data_size as usize; + } + } + let row = CopyRow { + buffers, + result: out.get_result(), + }; + Some(U::build_from_row(&row).map_err(crate::result::Error::DeserializationError)) + } else { + None + } + })) + } +} + +impl CopyOutQuery +where + O: CopyOutMarker, + T: CopyTarget, +{ + pub fn load_raw(self, conn: &mut PgConnection) -> QueryResult> { + let q = O::setup_options(self); + let command = CopyTo { + p: PhantomData::, + options: q.options, + }; + conn.copy_out(command) + } + + pub fn with_format(self, format: CopyFormat) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.format = Some(format); + out + } + + pub fn with_freeze(self, freeze: bool) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.freeze = Some(freeze); + out + } + + pub fn with_delimiter(self, delimiter: char) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.delimiter = Some(delimiter); + out + } + + pub fn with_null(self, null: impl Into) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.null = Some(null.into()); + out + } + + pub fn with_quote(self, quote: char) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.quote = Some(quote); + out + } + + pub fn with_escape(self, escape: char) -> CopyOutQuery { + let mut out = O::setup_options(self); + out.options.common.escape = Some(escape); + out + } +} + +pub fn copy_out(target: T) -> CopyOutQuery { + CopyOutQuery { + target, + options: NotSet, + } +} diff --git a/diesel/src/pg/query_builder/copy/mod.rs b/diesel/src/pg/query_builder/copy/mod.rs new file mode 100644 index 000000000000..0a745789d101 --- /dev/null +++ b/diesel/src/pg/query_builder/copy/mod.rs @@ -0,0 +1,158 @@ +use crate::pg::Pg; +use crate::query_builder::nodes::StaticQueryFragment; +use crate::query_builder::ColumnList; +use crate::query_builder::QueryFragment; +use crate::sql_types::SqlType; +use crate::Expression; +use crate::{Column, Table}; + +pub(crate) mod copy_in; +pub(crate) mod copy_out; + +pub(crate) use self::copy_in::{CopyInExpression, InternalCopyInQuery}; +pub(crate) use self::copy_out::CopyTo; + +pub use self::copy_in::CopyHeader; + +const COPY_MAGIC_HEADER: [u8; 11] = [ + 0x50, 0x47, 0x43, 0x4F, 0x50, 0x59, 0x0A, 0xFF, 0x0D, 0x0A, 0x00, +]; + +#[derive(Default, Debug, Copy, Clone)] +pub enum CopyFormat { + #[default] + Text, + Csv, + Binary, +} + +impl CopyFormat { + fn to_sql_format(&self) -> &'static str { + match self { + CopyFormat::Text => "text", + CopyFormat::Csv => "csv", + CopyFormat::Binary => "binary", + } + } +} + +#[derive(Default, Debug)] +struct CommonOptions { + format: Option, + freeze: Option, + delimiter: Option, + null: Option, + quote: Option, + escape: Option, +} + +impl CommonOptions { + fn any_set(&self) -> bool { + self.format.is_some() + || self.freeze.is_some() + || self.delimiter.is_some() + || self.null.is_some() + || self.quote.is_some() + || self.escape.is_some() + } + + fn walk_ast<'b>( + &'b self, + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + comma: &mut &'static str, + ) -> crate::QueryResult<()> { + if let Some(format) = self.format { + pass.push_sql(*comma); + *comma = ", "; + pass.push_sql("FORMAT "); + pass.push_sql(format.to_sql_format()); + } + if let Some(freeze) = self.freeze { + pass.push_sql(&format!("{comma}FREEZE {}", freeze as u8)); + *comma = ", "; + } + if let Some(delimiter) = self.delimiter { + pass.push_sql(&format!("{comma}DELIMITER '{delimiter}'")); + *comma = ", "; + } + if let Some(ref null) = self.null { + pass.push_sql(*comma); + *comma = ", "; + pass.push_sql("NULL '"); + // we cannot use binds here :( + pass.push_sql(null); + pass.push_sql("'"); + } + if let Some(quote) = self.quote { + pass.push_sql(&format!("{comma}QUOTE '{quote}'")); + *comma = ", "; + } + if let Some(escape) = self.escape { + pass.push_sql(&format!("{comma}ESCAPE '{escape}'")); + *comma = ", "; + } + Ok(()) + } +} + +pub trait CopyTarget { + type Table: Table; + type SqlType: SqlType; + + fn walk_target<'b>(pass: crate::query_builder::AstPass<'_, 'b, Pg>) -> crate::QueryResult<()>; +} + +impl CopyTarget for T +where + T: Table + StaticQueryFragment, + T::SqlType: SqlType, + T::AllColumns: ColumnList, + T::Component: QueryFragment, +{ + type Table = Self; + type SqlType = T::SqlType; + + fn walk_target<'b>( + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + T::STATIC_COMPONENT.walk_ast(pass.reborrow())?; + pass.push_sql("("); + T::all_columns().walk_ast(pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } +} + +macro_rules! copy_target_for_columns { + ($( + $Tuple:tt { + $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+ + } + )+) => { + $( + impl CopyTarget for ($($ST,)*) + where + $($ST: Column
,)* + ($(<$ST as Expression>::SqlType,)*): SqlType, + T: Table + StaticQueryFragment, + T::Component: QueryFragment, + Self: ColumnList + Default, + { + type Table = T; + type SqlType = crate::dsl::SqlTypeOf; + + fn walk_target<'b>( + mut pass: crate::query_builder::AstPass<'_, 'b, Pg>, + ) -> crate::QueryResult<()> { + T::STATIC_COMPONENT.walk_ast(pass.reborrow())?; + pass.push_sql("("); + ::walk_ast(&Self::default(), pass.reborrow())?; + pass.push_sql(")"); + Ok(()) + } + } + )* + } +} + +diesel_derives::__diesel_for_each_tuple!(copy_target_for_columns); diff --git a/diesel/src/pg/query_builder/mod.rs b/diesel/src/pg/query_builder/mod.rs index 3d56825cfee8..37441d938fcf 100644 --- a/diesel/src/pg/query_builder/mod.rs +++ b/diesel/src/pg/query_builder/mod.rs @@ -2,11 +2,13 @@ use super::backend::Pg; use crate::query_builder::QueryBuilder; use crate::result::QueryResult; +pub(crate) mod copy; mod distinct_on; mod limit_offset; pub(crate) mod on_constraint; pub(crate) mod only; mod query_fragment_impls; +pub use self::copy::{CopyFormat, CopyHeader}; pub use self::distinct_on::DistinctOnClause; pub use self::distinct_on::OrderDecorator; diff --git a/diesel/src/query_builder/functions.rs b/diesel/src/query_builder/functions.rs index eb29845a198c..d69e70cc9c4c 100644 --- a/diesel/src/query_builder/functions.rs +++ b/diesel/src/query_builder/functions.rs @@ -603,3 +603,8 @@ pub fn replace_into(target: T) -> IncompleteReplaceStatement { pub fn sql_query>(query: T) -> SqlQuery { SqlQuery::from_sql(query.into()) } + +#[cfg(feature = "postgres")] +pub use crate::pg::query_builder::copy::copy_in::copy_in; +#[cfg(feature = "postgres")] +pub use crate::pg::query_builder::copy::copy_out::copy_out; diff --git a/diesel/src/query_builder/mod.rs b/diesel/src/query_builder/mod.rs index 6da06828c1bb..c39e716c47c4 100644 --- a/diesel/src/query_builder/mod.rs +++ b/diesel/src/query_builder/mod.rs @@ -120,6 +120,7 @@ pub(crate) use self::insert_statement::ColumnList; #[cfg(feature = "postgres_backend")] pub use crate::pg::query_builder::only::Only; +pub(crate) use self::bind_collector::ByteWrapper; use crate::backend::Backend; use crate::result::QueryResult; use std::error::Error; diff --git a/diesel_tests/tests/copy.rs b/diesel_tests/tests/copy.rs new file mode 100644 index 000000000000..cf4b4aeb8e16 --- /dev/null +++ b/diesel_tests/tests/copy.rs @@ -0,0 +1,328 @@ +use crate::schema::*; +use diesel::pg::{CopyFormat, CopyHeader}; +use diesel::prelude::*; +use std::io::{Read, Write}; + +#[test] +fn copy_in_csv() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let count = diesel::copy_in(users::table) + .from_raw_data(users::table, |copy| { + writeln!(copy, "1,Sean,").unwrap(); + writeln!(copy, "2,Tess,").unwrap(); + diesel::QueryResult::Ok(()) + }) + .with_format(CopyFormat::Csv) + .execute(conn) + .unwrap(); + + assert_eq!(count, 2); + + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 2); +} + +#[test] +fn copy_in_text() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let count = diesel::copy_in(users::table) + .from_raw_data(users::table, |copy| { + writeln!(copy, "1\tSean\t").unwrap(); + writeln!(copy, "2\tTess\t").unwrap(); + diesel::QueryResult::Ok(()) + }) + .with_format(CopyFormat::Text) + .execute(conn) + .unwrap(); + + assert_eq!(count, 2); + + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 2); + + // default is text + let count = diesel::copy_in(users::table) + .from_raw_data(users::table, |copy| { + writeln!(copy, "3\tSean\t").unwrap(); + writeln!(copy, "4\tTess\t").unwrap(); + diesel::QueryResult::Ok(()) + }) + .execute(conn) + .unwrap(); + + assert_eq!(count, 2); + + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 4); +} + +#[test] +fn copy_in_allows_to_return_error() { + // use a connection without transaction here as otherwise + // we fail the last query + let conn = &mut connection_without_transaction(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let res = diesel::copy_in(users::table) + .from_raw_data(users::table, |copy| { + writeln!(copy, "1,Sean,").unwrap(); + diesel::QueryResult::Err(diesel::result::Error::RollbackTransaction) + }) + .with_format(CopyFormat::Csv) + .execute(conn); + + assert!(res.is_err()); + + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); +} + +#[test] +fn copy_in_with_columns() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let count = diesel::copy_in(users::table) + .from_raw_data((users::name, users::id), |copy| { + writeln!(copy, "Sean\t1").unwrap(); + writeln!(copy, "Tess\t2").unwrap(); + diesel::QueryResult::Ok(()) + }) + .with_format(CopyFormat::Text) + .execute(conn) + .unwrap(); + + assert_eq!(count, 2); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 2); +} + +#[test] +fn copy_in_csv_all_options() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let count = diesel::copy_in(users::table) + .from_raw_data((users::id, users::name, users::hair_color), |copy| { + // need to send the header here + // as we set header = match below + writeln!(copy, "id;name;hair_color").unwrap(); + writeln!(copy, "1;Sean;").unwrap(); + writeln!(copy, "2;Tess;").unwrap(); + diesel::QueryResult::Ok(()) + }) + .with_format(CopyFormat::Csv) + .with_freeze(false) + .with_delimiter(';') + .with_null("") + .with_quote('"') + .with_escape('\\') + .with_header(CopyHeader::Match) + // that option is new in postgres 16, + // so just skip testing it for now + //.set_default("default") + .execute(conn) + .unwrap(); + + assert_eq!(count, 2); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 2); +} + +#[test] +fn copy_in_from_insertable_struct() { + let conn = &mut connection(); + + #[derive(Insertable)] + #[diesel(table_name = users)] + #[diesel(treat_none_as_default_value = false)] + struct NewUser { + name: &'static str, + hair_color: Option<&'static str>, + } + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let users = vec![ + NewUser { + name: "Sean", + hair_color: None, + }, + NewUser { + name: "Tess", + hair_color: Some("green"), + }, + ]; + let count = diesel::copy_in(users::table) + .from_insertable(&users) + .execute(conn) + .unwrap(); + assert_eq!(count, 2); + let user_count = user_count_query.get_result::(conn).unwrap(); + assert_eq!(user_count, 2); + let users = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn) + .unwrap(); + + assert_eq!(users[0], ("Sean".to_owned(), None)); + assert_eq!(users[1], ("Tess".to_owned(), Some("green".into()))); +} + +#[test] +fn copy_in_from_insertable_tuple() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let users = vec![ + (users::name.eq("Sean"), users::hair_color.eq(None)), + (users::name.eq("Tess"), users::hair_color.eq(Some("green"))), + ]; + let count = diesel::copy_in(users::table) + .from_insertable(&users) + .execute(conn) + .unwrap(); + assert_eq!(count, 2); + let user_count = user_count_query.get_result::(conn).unwrap(); + assert_eq!(user_count, 2); + let users = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn) + .unwrap(); + + assert_eq!(users[0], ("Sean".to_owned(), None)); + assert_eq!(users[1], ("Tess".to_owned(), Some("green".into()))); +} + +#[test] +fn copy_in_from_insertable_vec() { + let conn = &mut connection(); + + let user_count_query = users::table.count(); + let users = user_count_query.get_result::(conn).unwrap(); + assert_eq!(users, 0); + + let users = vec![ + (users::name.eq("Sean"), users::hair_color.eq(None)), + (users::name.eq("Tess"), users::hair_color.eq(Some("green"))), + ]; + let count = diesel::copy_in(users::table) + .from_insertable(users) + .execute(conn) + .unwrap(); + assert_eq!(count, 2); + let user_count = user_count_query.get_result::(conn).unwrap(); + assert_eq!(user_count, 2); + let users = users::table + .select((users::name, users::hair_color)) + .load::<(String, Option)>(conn) + .unwrap(); + + assert_eq!(users[0], ("Sean".to_owned(), None)); + assert_eq!(users[1], ("Tess".to_owned(), Some("green".into()))); +} + +#[test] +fn copy_out_csv() { + let conn = &mut connection_with_sean_and_tess_in_users_table(); + + let mut out = String::new(); + let mut copy = diesel::copy_out(users::table) + .with_format(CopyFormat::Csv) + .load_raw(conn) + .unwrap(); + copy.read_to_string(&mut out).unwrap(); + + assert_eq!(out, "1,Sean,\n2,Tess,\n"); +} + +#[test] +fn copy_out_text() { + let conn = &mut connection_with_sean_and_tess_in_users_table(); + { + let mut out = String::new(); + let mut copy = diesel::copy_out(users::table) + .with_format(CopyFormat::Text) + .load_raw(conn) + .unwrap(); + copy.read_to_string(&mut out).unwrap(); + assert_eq!(out, "1\tSean\t\\N\n2\tTess\t\\N\n"); + } + let mut out = String::new(); + // default is text + let mut copy = diesel::copy_out(users::table).load_raw(conn).unwrap(); + copy.read_to_string(&mut out).unwrap(); + assert_eq!(out, "1\tSean\t\\N\n2\tTess\t\\N\n"); +} + +#[test] +fn copy_out_csv_all_options() { + let conn = &mut connection_with_sean_and_tess_in_users_table(); + let mut out = String::new(); + let mut copy = diesel::copy_out(users::table) + .with_format(CopyFormat::Csv) + .with_freeze(true) + .with_delimiter(';') + .with_quote('"') + .with_escape('\\') + .with_null("") + .load_raw(conn) + .unwrap(); + + copy.read_to_string(&mut out).unwrap(); + assert_eq!(out, "1;Sean;\n2;Tess;\n"); +} + +#[test] +fn copy_out_queryable() { + let conn = &mut connection_with_sean_and_tess_in_users_table(); + + #[derive(Queryable)] + struct User { + name: String, + hair_color: Option, + } + + let out = diesel::copy_out((users::name, users::hair_color)) + .load::(conn) + .unwrap() + .collect::, _>>() + .unwrap(); + + assert_eq!(out[0].name, "Sean"); + assert_eq!(out[0].hair_color, None); + assert_eq!(out[1].name, "Tess"); + assert_eq!(out[1].hair_color, None); + + // some query afterwards + let name = users::table + .select(users::name) + .filter(users::name.eq("Sean")) + .get_result::(conn) + .unwrap(); + assert_eq!(name, "Sean"); +} diff --git a/diesel_tests/tests/lib.rs b/diesel_tests/tests/lib.rs index 81c78f73186c..f3f0d6ba8463 100644 --- a/diesel_tests/tests/lib.rs +++ b/diesel_tests/tests/lib.rs @@ -14,6 +14,8 @@ mod boxed_queries; mod combination; mod connection; #[cfg(feature = "postgres")] +mod copy; +#[cfg(feature = "postgres")] mod custom_types; mod debug; mod delete;