diff --git a/src/arrow1/mod.rs b/src/arrow1/mod.rs index 0d27c0fe..fcc641bf 100644 --- a/src/arrow1/mod.rs +++ b/src/arrow1/mod.rs @@ -11,6 +11,9 @@ pub mod writer; #[cfg(feature = "writer")] pub mod writer_properties; +#[cfg(all(feature = "writer", feature = "async"))] +pub mod writer_async; + pub mod error; #[cfg(all(feature = "reader", feature = "async"))] diff --git a/src/arrow1/wasm.rs b/src/arrow1/wasm.rs index 5ad40113..792403ce 100644 --- a/src/arrow1/wasm.rs +++ b/src/arrow1/wasm.rs @@ -82,3 +82,24 @@ pub async fn read_parquet_stream( }); Ok(wasm_streams::ReadableStream::from_stream(stream).into_raw()) } + +#[wasm_bindgen(js_name = "transformParquetStream")] +#[cfg(all(feature = "writer", feature = "async"))] +pub fn transform_parquet_stream( + stream: wasm_streams::readable::sys::ReadableStream, + writer_properties: Option, +) -> WasmResult { + use futures::StreamExt; + let batches = wasm_streams::ReadableStream::from_raw(stream) + .into_stream() + .map(|maybe_chunk| { + let chunk = maybe_chunk.unwrap(); + let transformed: arrow_wasm::arrow1::RecordBatch = chunk.try_into().unwrap(); + transformed + }); + let output_stream = super::writer_async::transform_parquet_stream( + batches, + writer_properties.unwrap_or_default(), + ); + Ok(output_stream.unwrap()) +} diff --git a/src/arrow1/writer_async.rs b/src/arrow1/writer_async.rs new file mode 100644 index 00000000..82dd61eb --- /dev/null +++ b/src/arrow1/writer_async.rs @@ -0,0 +1,38 @@ +use crate::arrow1::error::Result; +use crate::common::stream::WrappedWritableStream; +use async_compat::CompatExt; +use futures::StreamExt; +use parquet::arrow::async_writer::AsyncArrowWriter; +use wasm_bindgen_futures::spawn_local; + +pub fn transform_parquet_stream( + batches: impl futures::Stream + 'static, + writer_properties: crate::arrow1::writer_properties::WriterProperties, +) -> Result { + let options = Some(writer_properties.into()); + // let encoding = writer_properties.get_encoding(); + + let (writable_stream, output_stream) = { + let raw_stream = wasm_streams::transform::sys::TransformStream::new(); + let raw_writable = raw_stream.writable(); + let inner_writer = wasm_streams::WritableStream::from_raw(raw_writable).into_async_write(); + let writable_stream = WrappedWritableStream { + stream: inner_writer, + }; + (writable_stream, raw_stream.readable()) + }; + spawn_local::<_>(async move { + let mut adapted_stream = batches.peekable(); + let mut pinned_stream = std::pin::pin!(adapted_stream); + let first_batch = pinned_stream.as_mut().peek().await.unwrap(); + let schema = first_batch.schema().into_inner(); + // Need to create an encoding for each column + let mut writer = + AsyncArrowWriter::try_new(writable_stream.compat(), schema, 1024, options).unwrap(); + while let Some(batch) = pinned_stream.next().await { + let _ = writer.write(&batch.into()).await; + } + let _ = writer.close().await; + }); + Ok(output_stream) +} diff --git a/src/arrow2/writer_async.rs b/src/arrow2/writer_async.rs index acab151d..703aa70f 100644 --- a/src/arrow2/writer_async.rs +++ b/src/arrow2/writer_async.rs @@ -1,38 +1,9 @@ use crate::arrow2::error::Result; +use crate::common::stream::WrappedWritableStream; use arrow2::io::parquet::write::FileSink; -use futures::{AsyncWrite, SinkExt, StreamExt}; +use futures::{SinkExt, StreamExt}; use wasm_bindgen_futures::spawn_local; -struct WrappedWritableStream<'writer> { - stream: wasm_streams::writable::IntoAsyncWrite<'writer>, -} - -impl<'writer> AsyncWrite for WrappedWritableStream<'writer> { - fn poll_write( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx) - } -} - -unsafe impl<'writer> Send for WrappedWritableStream<'writer> {} - pub fn transform_parquet_stream( batches: impl futures::Stream + 'static, writer_properties: crate::arrow2::writer_properties::WriterProperties, diff --git a/src/common/mod.rs b/src/common/mod.rs index 7781af36..ed1ce218 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -3,3 +3,6 @@ pub mod writer_properties; #[cfg(feature = "async")] pub mod fetch; + +#[cfg(feature = "async")] +pub mod stream; diff --git a/src/common/stream.rs b/src/common/stream.rs new file mode 100644 index 00000000..74474996 --- /dev/null +++ b/src/common/stream.rs @@ -0,0 +1,31 @@ +use futures::AsyncWrite; + +pub struct WrappedWritableStream<'writer> { + pub stream: wasm_streams::writable::IntoAsyncWrite<'writer>, +} + +impl<'writer> AsyncWrite for WrappedWritableStream<'writer> { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + AsyncWrite::poll_write(std::pin::Pin::new(&mut self.get_mut().stream), cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + AsyncWrite::poll_flush(std::pin::Pin::new(&mut self.get_mut().stream), cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + AsyncWrite::poll_close(std::pin::Pin::new(&mut self.get_mut().stream), cx) + } +} + +unsafe impl<'writer> Send for WrappedWritableStream<'writer> {} diff --git a/tests/js/arrow1.ts b/tests/js/arrow1.ts index d65fc663..8c336480 100644 --- a/tests/js/arrow1.ts +++ b/tests/js/arrow1.ts @@ -2,7 +2,7 @@ import * as test from "tape"; import * as wasm from "../../pkg/node/arrow1"; import { readFileSync } from "fs"; import { tableFromIPC, tableToIPC } from "apache-arrow"; -import { testArrowTablesEqual, readExpectedArrowData } from "./utils"; +import { testArrowTablesEqual, readExpectedArrowData, temporaryServer } from "./utils"; // Path from repo root const dataDir = "tests/data"; @@ -83,3 +83,22 @@ test("error produced trying to read file with arrayBuffer", (t) => { t.end(); }); + +test("read stream-write stream-read stream round trip (no writer properties provided)", async (t) => { + const server = await temporaryServer(); + const listeningPort = server.addresses()[0].port; + const rootUrl = `http://localhost:${listeningPort}`; + + const expectedTable = readExpectedArrowData(); + + const url = `${rootUrl}/1-partition-brotli.parquet`; + const originalStream = await wasm.readParquetStream(url); + + const stream = await wasm.transformParquetStream(originalStream); + const accumulatedBuffer = new Uint8Array(await new Response(stream).arrayBuffer()); + const roundtripTable = tableFromIPC(wasm.readParquet(accumulatedBuffer).intoIPC()); + + testArrowTablesEqual(t, expectedTable, roundtripTable); + await server.close(); + t.end(); +})