Skip to content

Commit

Permalink
Add Postgres COPY FROM/TO support
Browse files Browse the repository at this point in the history
This commit adds support for PostgreSQL `COPY` commands.

For `COPY FROM` we expose a variant that allows users to configure the
stream manually and write directly to the stream. We also support a
variant that takes essentially a `Vec<Insertable>` (or equivalent batch
insert containers) and uses the binary format to perform a streamed
batch insert.

For `COPY TO` we expose again a variant that allows the user to
configure the stream manually and read directly from the stream. We also
support loading the results directly via the binary protocol into a
iterator of `Queryable` structs.
  • Loading branch information
weiznich committed Mar 1, 2024
1 parent d3f7099 commit eb721e9
Show file tree
Hide file tree
Showing 17 changed files with 1,504 additions and 9 deletions.
2 changes: 1 addition & 1 deletion diesel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ipnetwork = ">=0.12.2, <0.21.0"
quickcheck = "1.0.3"

[features]
default = ["with-deprecated", "32-column-tables"]
default = ["postgres"]
extras = ["chrono", "time", "serde_json", "uuid", "network-address", "numeric", "r2d2"]
unstable = ["diesel_derives/nightly"]
large-tables = ["32-column-tables"]
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/expression/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::sql_types::{DieselNumericOps, HasSqlType, SqlType};
#[doc(hidden)] // This is used by the `AsExpression` derive
#[derive(Debug, Clone, Copy, DieselNumericOps)]
pub struct Bound<T, U> {
item: U,
pub(crate) item: U,
_marker: PhantomData<T>,
}

Expand Down
2 changes: 1 addition & 1 deletion diesel/src/insertable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub trait InsertValues<T: Table, DB: Backend>: QueryFragment<DB> {
#[derive(Debug, Copy, Clone, QueryId)]
#[doc(hidden)]
pub struct ColumnInsertValue<Col, Expr> {
expr: Expr,
pub(crate) expr: Expr,
p: PhantomData<Col>,
}

Expand Down
7 changes: 7 additions & 0 deletions diesel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ pub mod dsl {
delete, insert_into, insert_or_ignore_into, replace_into, select, sql_query, update,
};

#[doc(inline)]
#[cfg(feature = "postgres")]
pub use crate::query_builder::functions::copy_in;

#[doc(inline)]
pub use diesel_derives::auto_type;
}
Expand Down Expand Up @@ -683,6 +687,9 @@ pub use crate::prelude::*;
#[doc(inline)]
pub use crate::query_builder::debug_query;
#[doc(inline)]
#[cfg(feature = "postgres")]
pub use crate::query_builder::functions::{copy_in, copy_out};
#[doc(inline)]
pub use crate::query_builder::functions::{
delete, insert_into, insert_or_ignore_into, replace_into, select, sql_query, update,
};
Expand Down
121 changes: 121 additions & 0 deletions diesel/src/pg/connection/copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use core::ffi;
use std::io::BufRead;
use std::io::Read;
use std::io::Write;

use super::raw::RawConnection;
use super::result::PgResult;
use crate::QueryResult;

#[allow(missing_debug_implementations)] // `PgConnection` is not debug
pub struct CopyIn<'conn> {
conn: &'conn mut RawConnection,
}

impl<'conn> CopyIn<'conn> {
pub(super) fn new(conn: &'conn mut RawConnection) -> Self {
Self { conn }
}

pub(super) fn finish(self, err: Option<String>) -> QueryResult<()> {
self.conn.finish_copy_in(err)
}
}

impl<'conn> Write for CopyIn<'conn> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.conn.put_copy_data(buf).unwrap();
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}

#[allow(missing_debug_implementations)] // `PgConnection` is not debug
pub struct CopyOut<'conn> {
conn: &'conn mut RawConnection,
ptr: *mut ffi::c_char,
offset: usize,
len: usize,
result: PgResult,
}

impl<'conn> CopyOut<'conn> {
pub(super) fn new(conn: &'conn mut RawConnection, result: PgResult) -> Self {
Self {
conn,
ptr: std::ptr::null_mut(),
offset: 0,
len: 0,
result,
}
}

#[allow(unsafe_code)] // construct a slice from a raw ptr
pub(crate) fn data_slice(&self) -> &[u8] {
if self.ptr.is_null() {
&[]
} else if self.offset < self.len {
let slice = unsafe { std::slice::from_raw_parts(self.ptr as *const u8, self.len - 1) };
&slice[self.offset..]
} else {
&[]
}
}

pub(crate) fn get_result(&self) -> &PgResult {
&self.result
}
}

impl<'conn> Drop for CopyOut<'conn> {
#[allow(unsafe_code)] // ffi code
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { pq_sys::PQfreemem(self.ptr as *mut ffi::c_void) };
self.ptr = std::ptr::null_mut();
}
}
}

impl<'conn> Read for CopyOut<'conn> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let data = self.fill_buf()?;
let len = usize::min(buf.len(), data.len());
buf[..len].copy_from_slice(&data[..len]);
self.consume(len);
Ok(len)
}
}

impl<'conn> BufRead for CopyOut<'conn> {
#[allow(unsafe_code)] // ffi code + ptr arithmetic
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
if self.data_slice().is_empty() {
unsafe {
if !self.ptr.is_null() {
pq_sys::PQfreemem(self.ptr as *mut ffi::c_void);
self.ptr = std::ptr::null_mut();
}
let len =
pq_sys::PQgetCopyData(self.conn.internal_connection.as_ptr(), &mut self.ptr, 0);
if len >= 0 {
self.len = len as usize + 1;
} else if len == -1 {
self.len = 0;
} else {
let error = self.conn.last_error_message();
return Err(std::io::Error::new(std::io::ErrorKind::Other, error));
}
self.offset = 0;
}
}
Ok(self.data_slice())
}

fn consume(&mut self, amt: usize) {
self.offset = usize::min(self.len, self.offset + amt);
}
}
60 changes: 56 additions & 4 deletions diesel/src/pg/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
pub(super) mod copy;
pub(crate) mod cursor;
mod raw;
mod result;
mod row;
mod stmt;

use self::copy::CopyIn;
use self::copy::CopyOut;
use self::cursor::*;
use self::private::ConnectionAndTransactionManager;
use self::raw::{PgTransactionStatus, RawConnection};
use self::result::PgResult;
use self::stmt::Statement;
use crate::connection::instrumentation::DebugQuery;
use crate::connection::instrumentation::StrQueryHelper;
Expand All @@ -16,6 +18,7 @@ use crate::connection::statement_cache::{MaybeCached, StatementCache};
use crate::connection::*;
use crate::expression::QueryMetadata;
use crate::pg::metadata_lookup::{GetPgMetadataCache, PgMetadataCache};
use crate::pg::query_builder::copy::InternalCopyInQuery;
use crate::pg::{Pg, TransactionBuilder};
use crate::query_builder::bind_collector::RawBytesBindCollector;
use crate::query_builder::*;
Expand All @@ -26,6 +29,12 @@ use std::ffi::CString;
use std::fmt::Debug;
use std::os::raw as libc;

use super::query_builder::copy::CopyInExpression;
use super::query_builder::copy::CopyTarget;
use super::query_builder::copy::CopyTo;

pub(super) use self::result::PgResult;

/// The connection string expected by `PgConnection::establish`
/// should be a PostgreSQL connection string, as documented at
/// <https://www.postgresql.org/docs/9.4/static/libpq-connect.html#LIBPQ-CONNSTRING>
Expand Down Expand Up @@ -393,7 +402,46 @@ impl PgConnection {
TransactionBuilder::new(self)
}

fn with_prepared_query<'conn, T: QueryFragment<Pg> + QueryId, R>(
pub(crate) fn copy_in<S, T>(&mut self, target: S) -> Result<usize, S::Error>
where
S: CopyInExpression<T>,
{
let query = InternalCopyInQuery::new(target);
let res = self.with_prepared_query(query, true, |stmt, binds, conn, source| {
let _res = stmt.execute(&mut conn.raw_connection, &binds, false)?;
let mut copy_in = CopyIn::new(&mut conn.raw_connection);
let r = source.target.callback(&mut copy_in);
copy_in.finish(r.as_ref().err().map(|e| e.to_string()))?;
let next_res = conn.raw_connection.get_next_result()?.expect("exists");
let rows = next_res.rows_affected();
// need to pull out any other result
while let Some(_r) = conn.raw_connection.get_next_result()? {}
// it's important to only return a potential error here as
// we need to ensure that `finish` is called and we pull
// all the results
r?;
Ok::<_, S::Error>(rows)
})?;

Ok(res)
}

pub(crate) fn copy_out<T>(&mut self, command: CopyTo<T>) -> QueryResult<CopyOut<'_>>
where
T: CopyTarget,
{
let res = self.with_prepared_query::<_, _, Error>(
command,
true,
|stmt, binds, conn, _source| {
let res = stmt.execute(&mut conn.raw_connection, &binds, false)?;
Ok(CopyOut::new(&mut conn.raw_connection, res))
},
)?;
Ok(res)
}

fn with_prepared_query<'conn, T, R, E>(
&'conn mut self,
source: T,
execute_returning_count: bool,
Expand All @@ -402,8 +450,12 @@ impl PgConnection {
Vec<Option<Vec<u8>>>,
&'conn mut ConnectionAndTransactionManager,
T,
) -> QueryResult<R>,
) -> QueryResult<R> {
) -> Result<R, E>,
) -> Result<R, E>
where
T: QueryFragment<Pg> + QueryId,
E: From<crate::result::Error>,
{
self.connection_and_transaction_manager
.instrumentation
.on_connection_event(InstrumentationEvent::StartQuery {
Expand Down
38 changes: 37 additions & 1 deletion diesel/src/pg/connection/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::result::PgResult;

#[allow(missing_debug_implementations, missing_copy_implementations)]
pub(super) struct RawConnection {
internal_connection: NonNull<PGconn>,
pub(super) internal_connection: NonNull<PGconn>,
}

impl RawConnection {
Expand Down Expand Up @@ -140,6 +140,42 @@ impl RawConnection {
))
}
}

pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> {
for c in buf.chunks(i32::MAX as usize) {
let res = unsafe {
pq_sys::PQputCopyData(
self.internal_connection.as_ptr(),
c.as_ptr() as *const i8,
c.len() as i32,
)
};
if res != 1 {
return Err(Error::DatabaseError(
DatabaseErrorKind::Unknown,
Box::new(self.last_error_message()),
));
}
}
Ok(())
}

pub(crate) fn finish_copy_in(&self, err: Option<String>) -> QueryResult<()> {
let error = err.map(CString::new).transpose()?;
let error = error
.as_ref()
.map(|l| l.as_ptr())
.unwrap_or(std::ptr::null());
let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) };
if ret == 1 {
Ok(())
} else {
Err(Error::DatabaseError(
DatabaseErrorKind::Unknown,
Box::new(self.last_error_message()),
))
}
}
}

/// Represents the current in-transaction status of the connection
Expand Down
4 changes: 3 additions & 1 deletion diesel/src/pg/connection/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ impl PgResult {
match result_status {
ExecStatusType::PGRES_SINGLE_TUPLE
| ExecStatusType::PGRES_COMMAND_OK
| ExecStatusType::PGRES_COPY_IN
| ExecStatusType::PGRES_COPY_OUT
| ExecStatusType::PGRES_TUPLES_OK => {
let column_count = unsafe { PQnfields(internal_result.as_ptr()) as usize };
let row_count = unsafe { PQntuples(internal_result.as_ptr()) as usize };
Expand Down Expand Up @@ -138,7 +140,7 @@ impl PgResult {
}
}

pub(super) fn column_type(&self, col_idx: usize) -> NonZeroU32 {
pub(in crate::pg) fn column_type(&self, col_idx: usize) -> NonZeroU32 {
let type_oid = unsafe { PQftype(self.internal_result.as_ptr(), col_idx as libc::c_int) };
NonZeroU32::new(type_oid).expect(
"Got a zero oid from postgres. If you see this error message \
Expand Down
1 change: 1 addition & 0 deletions diesel/src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub use self::query_builder::DistinctOnClause;
pub use self::query_builder::OrderDecorator;
#[doc(inline)]
pub use self::query_builder::PgQueryBuilder;
pub use self::query_builder::{CopyFormat, CopyHeader};
#[doc(inline)]
pub use self::transaction::TransactionBuilder;
#[doc(inline)]
Expand Down
Loading

0 comments on commit eb721e9

Please sign in to comment.