diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 26819c8..6d67a23 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,6 +32,9 @@ jobs: - run: rustup toolchain install ${{ env.MSRV }} --profile minimal - run: rustup override set ${{ env.MSRV }} - run: rustup show active-toolchain -v + - run: cargo update -p native-tls --precise 0.2.13 # 0.2.14 requires rustc 1.80 + - run: cargo update -p litemap --precise 0.7.4 # 0.7.5 requires rustc 1.81 + - run: cargo update -p zerofrom --precise 0.1.5 # 0.1.6 requires rustc 1.81 - run: cargo build - run: cargo build --no-default-features - run: cargo build --features uuid,time,chrono diff --git a/Cargo.toml b/Cargo.toml index 1bf6b88..d14572e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" edition = "2021" # update `derive/Cargo.toml` and CI if changed +# TODO: after bumping to v1.80, remove `--precise` in the "msrv" CI job rust-version = "1.73.0" [lints.rust] @@ -70,6 +71,8 @@ uuid = ["dep:uuid"] time = ["dep:time"] lz4 = ["dep:lz4_flex", "dep:cityhash-rs"] chrono = ["dep:chrono"] +futures03 = [] + ## TLS native-tls = ["dep:hyper-tls"] # ext: native-tls-alpn diff --git a/benches/README.md b/benches/README.md index bd2ccf9..d39bc8a 100644 --- a/benches/README.md +++ b/benches/README.md @@ -23,8 +23,7 @@ Then upload the `perf.script` file to [Firefox Profiler](https://profiler.firefo These benchmarks are run against a real ClickHouse server, so it must be started: ```bash -docker run -d -p 8123:8123 -p 9000:9000 --name ch clickhouse/clickhouse-server - +docker compose up -d cargo bench --bench ``` diff --git a/benches/select.rs b/benches/select.rs index 316015e..89836ba 100644 --- a/benches/select.rs +++ b/benches/select.rs @@ -14,7 +14,10 @@ use hyper::{ }; use serde::Deserialize; -use clickhouse::{error::Result, Client, Compression, Row}; +use clickhouse::{ + error::{Error, Result}, + Client, Compression, Row, +}; mod common; @@ -52,8 +55,7 @@ fn select(c: &mut Criterion) { let _server = common::start_server(addr, move |req| serve(req, chunk.clone())); let runner = common::start_runner(); - #[allow(dead_code)] - #[derive(Debug, Row, Deserialize)] + #[derive(Default, Debug, Row, Deserialize)] struct SomeRow { a: u64, b: i64, @@ -61,27 +63,72 @@ fn select(c: &mut Criterion) { d: u32, } - async fn run(client: Client, iters: u64) -> Result { + async fn select_rows(client: Client, iters: u64) -> Result { + let mut sum = SomeRow::default(); let start = Instant::now(); let mut cursor = client .query("SELECT ?fields FROM some") .fetch::()?; for _ in 0..iters { - black_box(cursor.next().await?); + let Some(row) = cursor.next().await? else { + return Err(Error::NotEnoughData); + }; + sum.a = sum.a.wrapping_add(row.a); + sum.b = sum.b.wrapping_add(row.b); + sum.c = sum.c.wrapping_add(row.c); + sum.d = sum.d.wrapping_add(row.d); } + black_box(sum); Ok(start.elapsed()) } - let mut group = c.benchmark_group("select"); + async fn select_bytes(client: Client, min_size: u64) -> Result { + let start = Instant::now(); + let mut cursor = client + .query("SELECT value FROM some") + .fetch_bytes("RowBinary")?; + + let mut size = 0; + while size < min_size { + let buf = black_box(cursor.next().await?); + size += buf.unwrap().len() as u64; + } + + Ok(start.elapsed()) + } + + let mut group = c.benchmark_group("rows"); group.throughput(Throughput::Bytes(mem::size_of::() as u64)); - group.bench_function("no compression", |b| { + group.bench_function("uncompressed", |b| { + b.iter_custom(|iters| { + let client = Client::default() + .with_url(format!("http://{addr}")) + .with_compression(Compression::None); + runner.run(select_rows(client, iters)) + }) + }); + #[cfg(feature = "lz4")] + group.bench_function("lz4", |b| { + b.iter_custom(|iters| { + let client = Client::default() + .with_url(format!("http://{addr}")) + .with_compression(Compression::Lz4); + runner.run(select_rows(client, iters)) + }) + }); + group.finish(); + + const MIB: u64 = 1024 * 1024; + let mut group = c.benchmark_group("mbytes"); + group.throughput(Throughput::Bytes(MIB)); + group.bench_function("uncompressed", |b| { b.iter_custom(|iters| { let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::None); - runner.run(run(client, iters)) + runner.run(select_bytes(client, iters * MIB)) }) }); #[cfg(feature = "lz4")] @@ -90,7 +137,7 @@ fn select(c: &mut Criterion) { let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::Lz4); - runner.run(run(client, iters)) + runner.run(select_bytes(client, iters * MIB)) }) }); group.finish(); diff --git a/examples/README.md b/examples/README.md index f3724fc..67f1783 100644 --- a/examples/README.md +++ b/examples/README.md @@ -28,6 +28,8 @@ If something is missing, or you found a mistake in one of these examples, please - [custom_http_headers.rs](custom_http_headers.rs) - setting additional HTTP headers to the client, or overriding the generated ones - [query_id.rs](query_id.rs) - setting a specific `query_id` on the query level - [session_id.rs](session_id.rs) - using the client in the session context with temporary tables +- [stream_into_file.rs](stream_into_file.rs) - streaming the query result as raw bytes into a file in an arbitrary format. Required cargo features: `futures03`. +- [stream_arbitrary_format_rows.rs](stream_arbitrary_format_rows.rs) - streaming the query result in an arbitrary format, row by row. Required cargo features: `futures03`. ## How to run diff --git a/examples/stream_arbitrary_format_rows.rs b/examples/stream_arbitrary_format_rows.rs new file mode 100644 index 0000000..a14e766 --- /dev/null +++ b/examples/stream_arbitrary_format_rows.rs @@ -0,0 +1,33 @@ +use tokio::io::AsyncBufReadExt; + +use clickhouse::Client; + +/// An example of streaming raw data in an arbitrary format leveraging the +/// [`AsyncBufReadExt`] helpers. In this case, the format is `JSONEachRow`. +/// Incoming data is then split into lines, and each line is deserialized into +/// `serde_json::Value`, a dynamic representation of JSON values. +/// +/// Similarly, it can be used with other formats such as CSV, TSV, and others +/// that produce each row on a new line; the only difference will be in how the +/// data is parsed. See also: https://clickhouse.com/docs/en/interfaces/formats +/// +/// Note: `lines()` produces a new `String` for each line, so it's not the +/// most performant way to interate over lines. +#[tokio::main] +async fn main() { + let client = Client::default().with_url("http://localhost:8123"); + let mut lines = client + .query( + "SELECT number, hex(randomPrintableASCII(20)) AS hex_str + FROM system.numbers + LIMIT 100", + ) + .fetch_bytes("JSONEachRow") + .unwrap() + .lines(); + + while let Some(line) = lines.next_line().await.unwrap() { + let value: serde_json::Value = serde_json::de::from_str(&line).unwrap(); + println!("JSONEachRow value: {}", value); + } +} diff --git a/examples/stream_into_file.rs b/examples/stream_into_file.rs new file mode 100644 index 0000000..ae1c7ec --- /dev/null +++ b/examples/stream_into_file.rs @@ -0,0 +1,78 @@ +use clickhouse::{query::BytesCursor, Client}; +use std::time::Instant; +use tokio::{fs::File, io::AsyncWriteExt}; + +// Examples of streaming the result of a query in an arbitrary format into a +// file. In this case, `CSVWithNamesAndTypes` format is used. +// Check also other formats in https://clickhouse.com/docs/en/interfaces/formats. +// +// Note: there is no need to wrap `File` into `BufWriter` because `BytesCursor` +// is buffered internally already and produces chunks of data. + +const NUMBERS: u32 = 100_000; + +fn query(numbers: u32) -> BytesCursor { + let client = Client::default().with_url("http://localhost:8123"); + + client + .query( + "SELECT number, hex(randomPrintableASCII(20)) AS hex_str + FROM system.numbers + LIMIT {limit: Int32}", + ) + .param("limit", numbers) + .fetch_bytes("CSVWithNamesAndTypes") + .unwrap() +} + +// Pattern 1: use the `tokio::io::copy_buf` helper. +// +// It shows integration with `tokio::io::AsyncBufWriteExt` trait. +async fn tokio_copy_buf(filename: &str) { + let mut cursor = query(NUMBERS); + let mut file = File::create(filename).await.unwrap(); + tokio::io::copy_buf(&mut cursor, &mut file).await.unwrap(); +} + +// Pattern 2: use `BytesCursor::next()`. +async fn cursor_next(filename: &str) { + let mut cursor = query(NUMBERS); + let mut file = File::create(filename).await.unwrap(); + + while let Some(bytes) = cursor.next().await.unwrap() { + file.write_all(&bytes).await.unwrap(); + println!("chunk of {}B written to {filename}", bytes.len()); + } +} + +// Pattern 3: use the `futures::(Try)StreamExt` traits. +#[cfg(feature = "futures03")] +async fn futures03_stream(filename: &str) { + use futures::TryStreamExt; + + let mut cursor = query(NUMBERS); + let mut file = File::create(filename).await.unwrap(); + + while let Some(bytes) = cursor.try_next().await.unwrap() { + file.write_all(&bytes).await.unwrap(); + println!("chunk of {}B written to {filename}", bytes.len()); + } +} + +#[tokio::main] +async fn main() { + let start = Instant::now(); + tokio_copy_buf("output-1.csv").await; + println!("written to output-1.csv in {:?}", start.elapsed()); + + let start = Instant::now(); + cursor_next("output-2.csv").await; + println!("written to output-2.csv in {:?}", start.elapsed()); + + #[cfg(feature = "futures03")] + { + let start = Instant::now(); + futures03_stream("output-3.csv").await; + println!("written to output-3.csv in {:?}", start.elapsed()); + } +} diff --git a/src/bytes_ext.rs b/src/bytes_ext.rs index 1911468..feebded 100644 --- a/src/bytes_ext.rs +++ b/src/bytes_ext.rs @@ -17,6 +17,12 @@ impl BytesExt { self.bytes.len() - self.cursor } + #[inline(always)] + pub(crate) fn is_empty(&self) -> bool { + debug_assert!(self.cursor <= self.bytes.len()); + self.cursor >= self.bytes.len() + } + #[inline(always)] pub(crate) fn set_remaining(&mut self, n: usize) { // We can use `bytes.advance()` here, but it's slower. @@ -26,13 +32,15 @@ impl BytesExt { #[cfg(any(test, feature = "lz4", feature = "watch"))] #[inline(always)] pub(crate) fn advance(&mut self, n: usize) { + debug_assert!(n <= self.remaining()); + // We can use `bytes.advance()` here, but it's slower. self.cursor += n; } #[inline(always)] pub(crate) fn extend(&mut self, chunk: Bytes) { - if self.cursor == self.bytes.len() { + if self.is_empty() { // Most of the time, we read the next chunk after consuming the previous one. self.bytes = chunk; self.cursor = 0; diff --git a/src/cursor.rs b/src/cursor.rs deleted file mode 100644 index b207433..0000000 --- a/src/cursor.rs +++ /dev/null @@ -1,221 +0,0 @@ -use std::marker::PhantomData; - -use bytes::Bytes; -use futures::TryStreamExt; -use serde::Deserialize; - -use crate::{ - bytes_ext::BytesExt, - error::{Error, Result}, - response::{Chunks, Response, ResponseFuture}, - rowbinary, -}; - -// === RawCursor === - -struct RawCursor(RawCursorInner); - -enum RawCursorInner { - Waiting(ResponseFuture), - Loading(RawCursorLoading), -} - -struct RawCursorLoading { - chunks: Chunks, - net_size: u64, - data_size: u64, -} - -impl RawCursor { - fn new(response: Response) -> Self { - Self(RawCursorInner::Waiting(response.into_future())) - } - - async fn next(&mut self) -> Result> { - if matches!(self.0, RawCursorInner::Waiting(_)) { - self.resolve().await?; - } - - let state = match &mut self.0 { - RawCursorInner::Loading(state) => state, - RawCursorInner::Waiting(_) => unreachable!(), - }; - - match state.chunks.try_next().await { - Ok(Some(chunk)) => { - state.net_size += chunk.net_size as u64; - state.data_size += chunk.data.len() as u64; - Ok(Some(chunk.data)) - } - Ok(None) => Ok(None), - Err(err) => Err(err), - } - } - - async fn resolve(&mut self) -> Result<()> { - if let RawCursorInner::Waiting(future) = &mut self.0 { - let chunks = future.await; - self.0 = RawCursorInner::Loading(RawCursorLoading { - chunks: chunks?, - net_size: 0, - data_size: 0, - }); - } - Ok(()) - } - - fn received_bytes(&self) -> u64 { - match &self.0 { - RawCursorInner::Waiting(_) => 0, - RawCursorInner::Loading(state) => state.net_size, - } - } - - fn decoded_bytes(&self) -> u64 { - match &self.0 { - RawCursorInner::Waiting(_) => 0, - RawCursorInner::Loading(state) => state.data_size, - } - } -} - -// XXX: it was a workaround for https://github.com/rust-lang/rust/issues/51132, -// but introduced #24 and must be fixed. -fn workaround_51132<'a, T: ?Sized>(ptr: &T) -> &'a T { - // SAFETY: actually, it leads to unsoundness, see #24 - unsafe { &*(ptr as *const T) } -} - -// === RowCursor === - -/// A cursor that emits rows. -#[must_use] -pub struct RowCursor { - raw: RawCursor, - bytes: BytesExt, - _marker: PhantomData, -} - -impl RowCursor { - pub(crate) fn new(response: Response) -> Self { - Self { - raw: RawCursor::new(response), - bytes: BytesExt::default(), - _marker: PhantomData, - } - } - - /// Emits the next row. - /// - /// An result is unspecified if it's called after `Err` is returned. - pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result> - where - T: Deserialize<'b>, - { - loop { - let mut slice = workaround_51132(self.bytes.slice()); - - match rowbinary::deserialize_from(&mut slice) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); - } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), - } - - match self.raw.next().await? { - Some(chunk) => self.bytes.extend(chunk), - None if self.bytes.remaining() > 0 => { - // If some data is left, we have an incomplete row in the buffer. - // This is usually a schema mismatch on the client side. - return Err(Error::NotEnoughData); - } - None => return Ok(None), - } - } - } - - /// Returns the total size in bytes received from the CH server since - /// the cursor was created. - /// - /// This method counts only size without HTTP headers for now. - /// It can be changed in the future without notice. - #[inline] - pub fn received_bytes(&self) -> u64 { - self.raw.received_bytes() - } - - /// Returns the total size in bytes decompressed since the cursor was - /// created. - #[inline] - pub fn decoded_bytes(&self) -> u64 { - self.raw.decoded_bytes() - } -} - -// === JsonCursor === - -#[cfg(feature = "watch")] -pub(crate) struct JsonCursor { - raw: RawCursor, - bytes: BytesExt, - line: String, - _marker: PhantomData, -} - -// We use `JSONEachRowWithProgress` to avoid infinite HTTP connections. -// See https://github.com/ClickHouse/ClickHouse/issues/22996 for details. -#[cfg(feature = "watch")] -#[derive(Deserialize)] -#[serde(rename_all = "lowercase")] -enum JsonRow { - Row(T), - Progress {}, -} - -#[cfg(feature = "watch")] -impl JsonCursor { - const INITIAL_BUFFER_SIZE: usize = 1024; - - pub(crate) fn new(response: Response) -> Self { - Self { - raw: RawCursor::new(response), - bytes: BytesExt::default(), - line: String::with_capacity(Self::INITIAL_BUFFER_SIZE), - _marker: PhantomData, - } - } - - pub(crate) async fn next<'a, 'b: 'a>(&'a mut self) -> Result> - where - T: Deserialize<'b>, - { - use bytes::Buf; - use std::io::BufRead; - - loop { - self.line.clear(); - - let read = match self.bytes.slice().reader().read_line(&mut self.line) { - Ok(read) => read, - Err(err) => return Err(Error::Custom(err.to_string())), - }; - - if let Some(line) = self.line.strip_suffix('\n') { - self.bytes.advance(read); - - match serde_json::from_str(workaround_51132(line)) { - Ok(JsonRow::Row(value)) => return Ok(Some(value)), - Ok(JsonRow::Progress { .. }) => continue, - Err(err) => return Err(Error::BadResponse(err.to_string())), - } - } - - match self.raw.next().await? { - Some(chunk) => self.bytes.extend(chunk), - None => return Ok(None), - } - } - } -} diff --git a/src/cursors/bytes.rs b/src/cursors/bytes.rs new file mode 100644 index 0000000..df4a663 --- /dev/null +++ b/src/cursors/bytes.rs @@ -0,0 +1,217 @@ +use crate::{cursors::RawCursor, error::Result, response::Response}; +use bytes::{Buf, Bytes, BytesMut}; +use std::{ + io::Result as IoResult, + pin::Pin, + task::{ready, Context, Poll}, +}; +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + +/// A cursor over raw bytes of the response returned by [`Query::fetch_bytes`]. +/// +/// Unlike [`RowCursor`] which emits rows deserialized as structures from +/// RowBinary, this cursor emits raw bytes without deserialization. +/// +/// # Integration +/// +/// Additionally to [`BytesCursor::next`] and [`BytesCursor::collect`], +/// this cursor implements: +/// * [`AsyncRead`] and [`AsyncBufRead`] for `tokio`-based ecosystem. +/// * [`futures::Stream`], [`futures::AsyncRead`] and [`futures::AsyncBufRead`] +/// for `futures`-based ecosystem. (requires the `futures03` feature) +/// +/// For instance, if the requested format emits each row on a newline +/// (e.g. `JSONEachRow`, `CSV`, `TSV`, etc.), the cursor can be read line by +/// line using `AsyncBufReadExt::lines`. Note that this method +/// produces a new `String` for each line, so it's not the most performant way +/// to iterate. +/// +/// Note: methods of these traits use [`std::io::Error`] for errors. +/// To get an original error from this crate, use `From` conversion. +/// +/// [`RowCursor`]: crate::query::RowCursor +/// [`Query::fetch_bytes`]: crate::query::Query::fetch_bytes +pub struct BytesCursor { + raw: RawCursor, + bytes: Bytes, +} + +// TODO: what if any next/poll_* called AFTER error returned? + +impl BytesCursor { + pub(crate) fn new(response: Response) -> Self { + Self { + raw: RawCursor::new(response), + bytes: Bytes::default(), + } + } + + /// Emits the next bytes chunk. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. + pub async fn next(&mut self) -> Result> { + assert!( + self.bytes.is_empty(), + "mixing `BytesCursor::next()` and `AsyncRead` API methods is not allowed" + ); + + self.raw.next().await + } + + /// Collects the whole response into a single [`Bytes`]. + /// + /// # Cancel safety + /// + /// This method is NOT cancellation safe. + /// If cancelled, already collected bytes are lost. + pub async fn collect(&mut self) -> Result { + let mut chunks = Vec::new(); + let mut total_len = 0; + + while let Some(chunk) = self.next().await? { + total_len += chunk.len(); + chunks.push(chunk); + } + + // The whole response is in a single chunk. + if chunks.len() == 1 { + return Ok(chunks.pop().unwrap()); + } + + let mut collected = BytesMut::with_capacity(total_len); + for chunk in chunks { + collected.extend_from_slice(&chunk); + } + debug_assert_eq!(collected.capacity(), total_len); + + Ok(collected.freeze()) + } + + #[cold] + fn poll_refill(&mut self, cx: &mut Context<'_>) -> Poll> { + debug_assert_eq!(self.bytes.len(), 0); + + // Theoretically, `self.raw.poll_next(cx)` can return empty chunks. + // In this case, we should continue polling until we get a non-empty chunk or + // end of stream in order to avoid false positive `Ok(0)` in I/O traits. + while self.bytes.is_empty() { + match ready!(self.raw.poll_next(cx)?) { + Some(chunk) => self.bytes = chunk, + None => return Poll::Ready(Ok(false)), + } + } + + Poll::Ready(Ok(true)) + } + + /// Returns the total size in bytes received from the CH server since + /// the cursor was created. + /// + /// This method counts only size without HTTP headers for now. + /// It can be changed in the future without notice. + #[inline] + pub fn received_bytes(&self) -> u64 { + self.raw.received_bytes() + } + + /// Returns the total size in bytes decompressed since the cursor was + /// created. + #[inline] + pub fn decoded_bytes(&self) -> u64 { + self.raw.decoded_bytes() + } +} + +impl AsyncRead for BytesCursor { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + while buf.remaining() > 0 { + if self.bytes.is_empty() && !ready!(self.poll_refill(cx)?) { + break; + } + + let len = self.bytes.len().min(buf.remaining()); + let bytes = self.bytes.slice(..len); + buf.put_slice(&bytes[0..len]); + self.bytes.advance(len); + } + + Poll::Ready(Ok(())) + } +} + +impl AsyncBufRead for BytesCursor { + #[inline] + fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.bytes.is_empty() { + ready!(self.poll_refill(cx)?); + } + + Poll::Ready(Ok(&self.get_mut().bytes)) + } + + #[inline] + fn consume(mut self: Pin<&mut Self>, amt: usize) { + assert!( + amt <= self.bytes.len(), + "invalid `AsyncBufRead::consume` usage" + ); + self.bytes.advance(amt); + } +} + +#[cfg(feature = "futures03")] +impl futures::AsyncRead for BytesCursor { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut buf = ReadBuf::new(buf); + ready!(AsyncRead::poll_read(self, cx, &mut buf)?); + Poll::Ready(Ok(buf.filled().len())) + } +} + +#[cfg(feature = "futures03")] +impl futures::AsyncBufRead for BytesCursor { + #[inline] + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncBufRead::poll_fill_buf(self, cx) + } + + #[inline] + fn consume(self: Pin<&mut Self>, amt: usize) { + AsyncBufRead::consume(self, amt); + } +} + +#[cfg(feature = "futures03")] +impl futures::stream::Stream for BytesCursor { + type Item = crate::error::Result; + + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + assert!( + self.bytes.is_empty(), + "mixing `Stream` and `AsyncRead` API methods is not allowed" + ); + + self.raw.poll_next(cx).map(Result::transpose) + } +} + +#[cfg(feature = "futures03")] +impl futures::stream::FusedStream for BytesCursor { + #[inline] + fn is_terminated(&self) -> bool { + self.bytes.is_empty() && self.raw.is_terminated() + } +} diff --git a/src/cursors/json.rs b/src/cursors/json.rs new file mode 100644 index 0000000..44b2d7d --- /dev/null +++ b/src/cursors/json.rs @@ -0,0 +1,69 @@ +use crate::{ + bytes_ext::BytesExt, + cursors::RawCursor, + error::{Error, Result}, + response::Response, +}; +use serde::Deserialize; +use std::marker::PhantomData; + +pub(crate) struct JsonCursor { + raw: RawCursor, + bytes: BytesExt, + line: String, + _marker: PhantomData, +} + +// We use `JSONEachRowWithProgress` to avoid infinite HTTP connections. +// See https://github.com/ClickHouse/ClickHouse/issues/22996 for details. +#[derive(Deserialize)] +#[serde(rename_all = "lowercase")] +enum JsonRow { + Row(T), + Progress {}, +} + +impl JsonCursor { + const INITIAL_BUFFER_SIZE: usize = 1024; + + pub(crate) fn new(response: Response) -> Self { + Self { + raw: RawCursor::new(response), + bytes: BytesExt::default(), + line: String::with_capacity(Self::INITIAL_BUFFER_SIZE), + _marker: PhantomData, + } + } + + pub(crate) async fn next<'a, 'b: 'a>(&'a mut self) -> Result> + where + T: Deserialize<'b>, + { + use bytes::Buf; + use std::io::BufRead; + + loop { + self.line.clear(); + + let read = match self.bytes.slice().reader().read_line(&mut self.line) { + Ok(read) => read, + Err(err) => return Err(Error::Custom(err.to_string())), + }; + + if let Some(line) = self.line.strip_suffix('\n') { + self.bytes.advance(read); + + match serde_json::from_str(super::workaround_51132(line)) { + Ok(JsonRow::Row(value)) => return Ok(Some(value)), + Ok(JsonRow::Progress { .. }) => continue, + Err(err) => return Err(Error::BadResponse(err.to_string())), + } + } + + match self.raw.next().await? { + Some(chunk) => self.bytes.extend(chunk), + None => return Ok(None), + } + } + } +} diff --git a/src/cursors/mod.rs b/src/cursors/mod.rs new file mode 100644 index 0000000..3492a7b --- /dev/null +++ b/src/cursors/mod.rs @@ -0,0 +1,17 @@ +#[cfg(feature = "watch")] +pub(crate) use self::json::JsonCursor; +pub(crate) use self::raw::RawCursor; +pub use self::{bytes::BytesCursor, row::RowCursor}; + +mod bytes; +#[cfg(feature = "watch")] +mod json; +mod raw; +mod row; + +// XXX: it was a workaround for https://github.com/rust-lang/rust/issues/51132, +// but introduced #24 and must be fixed. +fn workaround_51132<'a, T: ?Sized>(ptr: &T) -> &'a T { + // SAFETY: actually, it leads to unsoundness, see #24 + unsafe { &*(ptr as *const T) } +} diff --git a/src/cursors/raw.rs b/src/cursors/raw.rs new file mode 100644 index 0000000..7aa416a --- /dev/null +++ b/src/cursors/raw.rs @@ -0,0 +1,98 @@ +use crate::{ + error::Result, + response::{Chunks, Response, ResponseFuture}, +}; +use bytes::Bytes; +use futures::Stream; +use std::{ + pin::pin, + task::{ready, Context, Poll}, +}; + +/// A cursor over raw bytes of a query response. +/// All other cursors are built on top of this one. +pub(crate) struct RawCursor(RawCursorState); + +enum RawCursorState { + Waiting(ResponseFuture), + Loading(RawCursorLoading), +} + +struct RawCursorLoading { + chunks: Chunks, + net_size: u64, + data_size: u64, +} + +impl RawCursor { + pub(crate) fn new(response: Response) -> Self { + Self(RawCursorState::Waiting(response.into_future())) + } + + pub(crate) async fn next(&mut self) -> Result> { + std::future::poll_fn(|cx| self.poll_next(cx)).await + } + + pub(crate) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll>> { + if let RawCursorState::Loading(state) = &mut self.0 { + let chunks = pin!(&mut state.chunks); + + Poll::Ready(match ready!(chunks.poll_next(cx)?) { + Some(chunk) => { + state.net_size += chunk.net_size as u64; + state.data_size += chunk.data.len() as u64; + Ok(Some(chunk.data)) + } + None => Ok(None), + }) + } else { + ready!(self.poll_resolve(cx)?); + self.poll_next(cx) + } + } + + #[cold] + #[inline(never)] + fn poll_resolve(&mut self, cx: &mut Context<'_>) -> Poll> { + let RawCursorState::Waiting(future) = &mut self.0 else { + panic!("poll_resolve called in invalid state"); + }; + + // Poll the future, but don't return the result yet. + // In case of an error, we should replace the current state anyway + // in order to provide proper fused behavior of the cursor. + let res = ready!(future.as_mut().poll(cx)); + let mut chunks = Chunks::empty(); + let res = res.map(|c| chunks = c); + + self.0 = RawCursorState::Loading(RawCursorLoading { + chunks, + net_size: 0, + data_size: 0, + }); + + Poll::Ready(res) + } + + pub(crate) fn received_bytes(&self) -> u64 { + match &self.0 { + RawCursorState::Loading(state) => state.net_size, + RawCursorState::Waiting(_) => 0, + } + } + + pub(crate) fn decoded_bytes(&self) -> u64 { + match &self.0 { + RawCursorState::Loading(state) => state.data_size, + RawCursorState::Waiting(_) => 0, + } + } + + #[cfg(feature = "futures03")] + pub(crate) fn is_terminated(&self) -> bool { + match &self.0 { + RawCursorState::Loading(state) => state.chunks.is_terminated(), + RawCursorState::Waiting(_) => false, + } + } +} diff --git a/src/cursors/row.rs b/src/cursors/row.rs new file mode 100644 index 0000000..6f17cfc --- /dev/null +++ b/src/cursors/row.rs @@ -0,0 +1,79 @@ +use crate::{ + bytes_ext::BytesExt, + cursors::RawCursor, + error::{Error, Result}, + response::Response, + rowbinary, +}; +use serde::Deserialize; +use std::marker::PhantomData; + +/// A cursor that emits rows deserialized as structures from RowBinary. +#[must_use] +pub struct RowCursor { + raw: RawCursor, + bytes: BytesExt, + _marker: PhantomData, +} + +impl RowCursor { + pub(crate) fn new(response: Response) -> Self { + Self { + raw: RawCursor::new(response), + bytes: BytesExt::default(), + _marker: PhantomData, + } + } + + /// Emits the next row. + /// + /// The result is unspecified if it's called after `Err` is returned. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. + pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result> + where + T: Deserialize<'b>, + { + loop { + let mut slice = super::workaround_51132(self.bytes.slice()); + + match rowbinary::deserialize_from(&mut slice) { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(Error::NotEnoughData) => {} + Err(err) => return Err(err), + } + + match self.raw.next().await? { + Some(chunk) => self.bytes.extend(chunk), + None if self.bytes.remaining() > 0 => { + // If some data is left, we have an incomplete row in the buffer. + // This is usually a schema mismatch on the client side. + return Err(Error::NotEnoughData); + } + None => return Ok(None), + } + } + } + + /// Returns the total size in bytes received from the CH server since + /// the cursor was created. + /// + /// This method counts only size without HTTP headers for now. + /// It can be changed in the future without notice. + #[inline] + pub fn received_bytes(&self) -> u64 { + self.raw.received_bytes() + } + + /// Returns the total size in bytes decompressed since the cursor was + /// created. + #[inline] + pub fn decoded_bytes(&self) -> u64 { + self.raw.decoded_bytes() + } +} diff --git a/src/error.rs b/src/error.rs index 0ab1d61..f4bde3c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,19 +7,21 @@ use serde::{de, ser}; /// A result with a specified [`Error`] type. pub type Result = result::Result; +type BoxedError = Box; + /// Represents all possible errors. #[derive(Debug, thiserror::Error)] #[non_exhaustive] #[allow(missing_docs)] pub enum Error { #[error("invalid params: {0}")] - InvalidParams(#[source] Box), + InvalidParams(#[source] BoxedError), #[error("network error: {0}")] - Network(#[source] Box), + Network(#[source] BoxedError), #[error("compression error: {0}")] - Compression(#[source] Box), + Compression(#[source] BoxedError), #[error("decompression error: {0}")] - Decompression(#[source] Box), + Decompression(#[source] BoxedError), #[error("no rows returned by a query that expected to return at least one row")] RowNotFound, #[error("sequences must have a known size ahead of time")] @@ -42,6 +44,8 @@ pub enum Error { TimedOut, #[error("unsupported: {0}")] Unsupported(String), + #[error("{0}")] + Other(BoxedError), } assert_impl_all!(Error: StdError, Send, Sync); @@ -70,18 +74,19 @@ impl de::Error for Error { } } -impl Error { - #[allow(dead_code)] - pub(crate) fn into_io(self) -> io::Error { - io::Error::new(io::ErrorKind::Other, self) +impl From for io::Error { + fn from(error: Error) -> Self { + io::Error::new(io::ErrorKind::Other, error) } +} - #[allow(dead_code)] - pub(crate) fn decode_io(err: io::Error) -> Self { - if err.get_ref().map(|r| r.is::()).unwrap_or(false) { - *err.into_inner().unwrap().downcast::().unwrap() +impl From for Error { + fn from(error: io::Error) -> Self { + // TODO: after MSRV 1.79 replace with `io::Error::downcast`. + if error.get_ref().is_some_and(|r| r.is::()) { + *error.into_inner().unwrap().downcast::().unwrap() } else { - Self::Decompression(Box::new(err)) + Self::Other(error.into()) } } } @@ -89,7 +94,14 @@ impl Error { #[test] fn roundtrip_io_error() { let orig = Error::NotEnoughData; - let io = orig.into_io(); - let err = Error::decode_io(io); - assert!(matches!(err, Error::NotEnoughData)); + + // Error -> io::Error + let orig_str = orig.to_string(); + let io = io::Error::from(orig); + assert_eq!(io.kind(), io::ErrorKind::Other); + assert_eq!(io.to_string(), orig_str); + + // io::Error -> Error + let orig = Error::from(io); + assert!(matches!(orig, Error::NotEnoughData)); } diff --git a/src/lib.rs b/src/lib.rs index f1e4eb3..7d02cdc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ pub mod watch; mod bytes_ext; mod compression; -mod cursor; +mod cursors; mod headers; mod http_client; mod request_body; diff --git a/src/query.rs b/src/query.rs index eba1915..7a76dbd 100644 --- a/src/query.rs +++ b/src/query.rs @@ -15,7 +15,7 @@ use crate::{ const MAX_QUERY_LEN_TO_USE_GET: usize = 8192; -pub use crate::cursor::RowCursor; +pub use crate::cursors::{BytesCursor, RowCursor}; #[must_use] #[derive(Clone)] @@ -84,7 +84,7 @@ impl Query { /// ``` pub fn fetch(mut self) -> Result> { self.sql.bind_fields::(); - self.sql.append(" FORMAT RowBinary"); + self.sql.set_output_format("RowBinary"); let response = self.do_execute(true)?; Ok(RowCursor::new(response)) @@ -132,6 +132,16 @@ impl Query { Ok(result) } + /// Executes the query, returning a [`BytesCursor`] to obtain results as raw + /// bytes containing data in the [provided format]. + /// + /// [provided format]: https://clickhouse.com/docs/en/interfaces/formats + pub fn fetch_bytes(mut self, format: impl Into) -> Result { + self.sql.set_output_format(format); + let response = self.do_execute(true)?; + Ok(BytesCursor::new(response)) + } + pub(crate) fn do_execute(self, read_only: bool) -> Result { let query = self.sql.finish()?; diff --git a/src/response.rs b/src/response.rs index b500ba6..e9363f8 100644 --- a/src/response.rs +++ b/src/response.rs @@ -147,23 +147,30 @@ impl Chunks { fn new(stream: Incoming, compression: Compression) -> Self { let stream = IncomingStream(stream); let stream = Decompress::new(stream, compression); - let stream = DetectDbException::new(stream); + let stream = DetectDbException(stream); Self(Some(Box::new(stream))) } + + pub(crate) fn empty() -> Self { + Self(None) + } + + #[cfg(feature = "futures03")] + pub(crate) fn is_terminated(&self) -> bool { + self.0.is_none() + } } impl Stream for Chunks { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - // `take()` prevents from use after caught panic. + // We use `take()` to make the stream fused, including the case of panics. if let Some(mut stream) = self.0.take() { let res = Pin::new(&mut stream).poll_next(cx); if matches!(res, Poll::Pending | Poll::Ready(Some(Ok(_)))) { self.0 = Some(stream); - } else { - assert!(self.0.is_none()); } res @@ -244,16 +251,7 @@ where // === DetectDbException === -enum DetectDbException { - Stream(S), - Exception(Option), -} - -impl DetectDbException { - fn new(stream: S) -> Self { - Self::Stream(stream) - } -} +struct DetectDbException(S); impl Stream for DetectDbException where @@ -262,22 +260,15 @@ where type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match &mut *self { - Self::Stream(stream) => { - let mut res = Pin::new(stream).poll_next(cx); + let res = Pin::new(&mut self.0).poll_next(cx); - if let Poll::Ready(Some(Ok(chunk))) = &mut res { - if let Some(err) = extract_exception(&mut chunk.data) { - *self = Self::Exception(Some(err)); - - // NOTE: `chunk` can be empty, but it's ok for callers. - } - } - - res + if let Poll::Ready(Some(Ok(chunk))) = &res { + if let Some(err) = extract_exception(&chunk.data) { + return Poll::Ready(Some(Err(err))); } - Self::Exception(err) => Poll::Ready(err.take().map(Err)), } + + res } } @@ -285,7 +276,7 @@ where // ``` // Code: . DB::Exception: (version (official build))\n // ``` -fn extract_exception(chunk: &mut Bytes) -> Option { +fn extract_exception(chunk: &[u8]) -> Option { // `))\n` is very rare in real data, so it's fast dirty check. // In random data, it occurs with a probability of ~6*10^-8 only. if chunk.ends_with(b"))\n") { @@ -297,14 +288,13 @@ fn extract_exception(chunk: &mut Bytes) -> Option { #[cold] #[inline(never)] -fn extract_exception_slow(chunk: &mut Bytes) -> Option { +fn extract_exception_slow(chunk: &[u8]) -> Option { let index = chunk.rfind(b"Code:")?; if !chunk[index..].contains_str(b"DB::Exception:") { return None; } - let exception = chunk.split_off(index); - let exception = String::from_utf8_lossy(&exception[..exception.len() - 1]); + let exception = String::from_utf8_lossy(&chunk[index..chunk.len() - 1]); Some(Error::BadResponse(exception.into())) } diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 66330f6..d4a7e3b 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -13,7 +13,7 @@ pub(crate) mod ser; #[derive(Debug, Clone)] pub(crate) enum SqlBuilder { - InProgress(Vec), + InProgress(Vec, Option), Failed(String), } @@ -21,7 +21,6 @@ pub(crate) enum SqlBuilder { pub(crate) enum Part { Arg, Fields, - Str(&'static str), Text(String), } @@ -29,15 +28,17 @@ pub(crate) enum Part { impl fmt::Display for SqlBuilder { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - SqlBuilder::InProgress(parts) => { + SqlBuilder::InProgress(parts, output_format_opt) => { for part in parts { match part { Part::Arg => f.write_char('?')?, Part::Fields => f.write_str("?fields")?, - Part::Str(text) => f.write_str(text)?, Part::Text(text) => f.write_str(text)?, } } + if let Some(output_format) = output_format_opt { + f.write_str(&format!(" FORMAT {output_format}"))? + } } SqlBuilder::Failed(err) => f.write_str(err)?, } @@ -71,11 +72,17 @@ impl SqlBuilder { parts.push(Part::Text(rest.to_string())); } - SqlBuilder::InProgress(parts) + SqlBuilder::InProgress(parts, None) + } + + pub(crate) fn set_output_format(&mut self, format: impl Into) { + if let Self::InProgress(_, format_opt) = self { + *format_opt = Some(format.into()); + } } pub(crate) fn bind_arg(&mut self, value: impl Bind) { - let Self::InProgress(parts) = self else { + let Self::InProgress(parts, _) = self else { return; }; @@ -93,7 +100,7 @@ impl SqlBuilder { } pub(crate) fn bind_fields(&mut self) { - let Self::InProgress(parts) = self else { + let Self::InProgress(parts, _) = self else { return; }; @@ -106,21 +113,12 @@ impl SqlBuilder { } } - pub(crate) fn append(&mut self, suffix: &'static str) { - let Self::InProgress(parts) = self else { - return; - }; - - parts.push(Part::Str(suffix)); - } - pub(crate) fn finish(mut self) -> Result { let mut sql = String::new(); - if let Self::InProgress(parts) = &self { + if let Self::InProgress(parts, _) = &self { for part in parts { match part { - Part::Str(text) => sql.push_str(text), Part::Text(text) => sql.push_str(text), Part::Arg => { self.error("unbound query argument"); @@ -135,7 +133,12 @@ impl SqlBuilder { } match self { - Self::InProgress(_) => Ok(sql), + Self::InProgress(_, output_format_opt) => { + if let Some(output_format) = output_format_opt { + sql.push_str(&format!(" FORMAT {output_format}")) + } + Ok(sql) + } Self::Failed(err) => Err(Error::InvalidParams(err.into())), } } diff --git a/src/ticks.rs b/src/ticks.rs index 6a643a0..b7c0021 100644 --- a/src/ticks.rs +++ b/src/ticks.rs @@ -46,7 +46,7 @@ impl Ticks { } pub(crate) fn reached(&self) -> bool { - self.next_at.map_or(false, |n| Instant::now() >= n) + self.next_at.is_some_and(|n| Instant::now() >= n) } pub(crate) fn reschedule(&mut self) { diff --git a/src/watch.rs b/src/watch.rs index 1914399..deaa446 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use sha1::{Digest, Sha1}; use crate::{ - cursor::JsonCursor, + cursors::JsonCursor, error::{Error, Result}, row::Row, sql::{Bind, SqlBuilder}, diff --git a/tests/it/fetch_bytes.rs b/tests/it/fetch_bytes.rs new file mode 100644 index 0000000..8578231 --- /dev/null +++ b/tests/it/fetch_bytes.rs @@ -0,0 +1,141 @@ +use clickhouse::error::Error; +use std::str::from_utf8; +use tokio::io::{AsyncBufReadExt, AsyncReadExt}; + +#[tokio::test] +async fn single_chunk() { + let client = prepare_database!(); + + let mut cursor = client + .query("SELECT number FROM system.numbers LIMIT 3") + .fetch_bytes("CSV") + .unwrap(); + + let mut total_chunks = 0; + let mut buffer = Vec::::new(); + while let Some(chunk) = cursor.next().await.unwrap() { + buffer.extend(chunk); + total_chunks += 1; + } + + assert_eq!(from_utf8(&buffer).unwrap(), "0\n1\n2\n"); + assert_eq!(total_chunks, 1); + assert_eq!(cursor.decoded_bytes(), 6); +} + +#[tokio::test] +async fn multiple_chunks() { + let client = prepare_database!(); + + let mut cursor = client + .query("SELECT number FROM system.numbers LIMIT 3") + // each number will go into a separate chunk + .with_option("max_block_size", "1") + .fetch_bytes("CSV") + .unwrap(); + + let mut total_chunks = 0; + let mut buffer = Vec::::new(); + while let Some(data) = cursor.next().await.unwrap() { + buffer.extend(data); + total_chunks += 1; + } + + assert_eq!(from_utf8(&buffer).unwrap(), "0\n1\n2\n"); + assert_eq!(total_chunks, 3); + assert_eq!(cursor.decoded_bytes(), 6); +} + +#[tokio::test] +async fn error() { + let client = prepare_database!(); + + let mut bytes_cursor = client + .query("SELECT sleepEachRow(0.05) AS s FROM system.numbers LIMIT 30") + .with_option("max_block_size", "1") + .with_option("max_execution_time", "0.01") + .fetch_bytes("JSONEachRow") + .unwrap(); + + let err = bytes_cursor.next().await; + println!("{:?}", err); + assert!(matches!(err, Err(Error::BadResponse(_)))); +} + +#[tokio::test] +async fn lines() { + let client = prepare_database!(); + let expected = ["0", "1", "2"]; + + for n in 0..4 { + let mut lines = client + .query("SELECT number FROM system.numbers LIMIT {limit: Int32}") + .param("limit", n) + // each number will go into a separate chunk + .with_option("max_block_size", "1") + .fetch_bytes("CSV") + .unwrap() + .lines(); + + let mut actual = Vec::::new(); + while let Some(data) = lines.next_line().await.unwrap() { + actual.push(data); + } + + assert_eq!(actual, &expected[..n]); + } +} + +#[tokio::test] +async fn collect() { + let client = prepare_database!(); + let expected = b"0\n1\n2\n3\n"; + + for n in 0..4 { + let mut cursor = client + .query("SELECT number FROM system.numbers LIMIT {limit: Int32}") + .param("limit", n) + // each number will go into a separate chunk + .with_option("max_block_size", "1") + .fetch_bytes("CSV") + .unwrap(); + + let data = cursor.collect().await.unwrap(); + assert_eq!(&data[..], &expected[..n * 2]); + + // The cursor is fused. + assert_eq!(&cursor.collect().await.unwrap()[..], b""); + } +} + +#[tokio::test] +async fn async_read() { + let client = prepare_database!(); + let limit = 1000; + + let mut cursor = client + .query("SELECT number, number FROM system.numbers LIMIT {limit: Int32}") + .param("limit", limit) + .with_option("max_block_size", "3") + .fetch_bytes("CSV") + .unwrap(); + + #[allow(clippy::format_collect)] + let expected = (0..limit) + .map(|n| format!("{n},{n}\n")) + .collect::() + .into_bytes(); + + let mut actual = vec![0; expected.len()]; + let mut index = 0; + while index < actual.len() { + let step = (1 + index % 10).min(actual.len() - index); + let buf = &mut actual[index..(index + step)]; + assert_eq!(cursor.read_exact(buf).await.unwrap(), step); + index += step; + } + + assert_eq!(cursor.read(&mut [0]).await.unwrap(), 0); // EOF + assert_eq!(cursor.decoded_bytes(), expected.len() as u64); + assert_eq!(actual, expected); +} diff --git a/tests/it/main.rs b/tests/it/main.rs index c35a6cc..5e0385d 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -58,6 +58,7 @@ mod chrono; mod compression; mod cursor_error; mod cursor_stats; +mod fetch_bytes; mod insert; mod inserter; mod ip;