Skip to content

Commit

Permalink
run block decompression from executor (#2386)
Browse files Browse the repository at this point in the history
* run block decompression from executor

* add a wrapper with is_closed to oneshot channel

* add cancelation test to Executor::spawn_blocking
  • Loading branch information
trinity-1686a authored May 8, 2024
1 parent 2b76335 commit 8cd7ddc
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 9 deletions.
107 changes: 106 additions & 1 deletion src/core/executor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "quickwit")]
use futures_util::{future::Either, FutureExt};
use rayon::{ThreadPool, ThreadPoolBuilder};

use crate::TantivyError;
Expand Down Expand Up @@ -91,11 +93,84 @@ impl Executor {
}
}
}

/// Spawn a task on the pool, returning a future completing on task success.
///
/// If the task panic, returns `Err(())`.
#[cfg(feature = "quickwit")]
pub fn spawn_blocking<T: Send + 'static>(
&self,
cpu_intensive_task: impl FnOnce() -> T + Send + 'static,
) -> impl std::future::Future<Output = Result<T, ()>> {
match self {
Executor::SingleThread => Either::Left(std::future::ready(Ok(cpu_intensive_task()))),
Executor::ThreadPool(pool) => {
let (sender, receiver) = oneshot_with_sentinel::channel();
pool.spawn(|| {
if sender.is_closed() {
return;
}
let task_result = cpu_intensive_task();
let _ = sender.send(task_result);
});

let res = receiver.map(|res| res.map_err(|_| ()));
Either::Right(res)
}
}
}
}

#[cfg(feature = "quickwit")]
mod oneshot_with_sentinel {
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
// TODO get ride of this if oneshot ever gains a is_closed()

pub struct SenderWithSentinel<T> {
tx: oneshot::Sender<T>,
guard: Arc<()>,
}

pub struct ReceiverWithSentinel<T> {
rx: oneshot::Receiver<T>,
_guard: Arc<()>,
}

pub fn channel<T>() -> (SenderWithSentinel<T>, ReceiverWithSentinel<T>) {
let (tx, rx) = oneshot::channel();
let guard = Arc::new(());
(
SenderWithSentinel {
tx,
guard: guard.clone(),
},
ReceiverWithSentinel { rx, _guard: guard },
)
}

impl<T> SenderWithSentinel<T> {
pub fn send(self, message: T) -> Result<(), oneshot::SendError<T>> {
self.tx.send(message)
}

pub fn is_closed(&self) -> bool {
Arc::strong_count(&self.guard) == 1
}
}

impl<T> std::future::Future for ReceiverWithSentinel<T> {
type Output = Result<T, oneshot::RecvError>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.rx).poll(ctx)
}
}
}

#[cfg(test)]
mod tests {

use super::Executor;

#[test]
Expand Down Expand Up @@ -147,4 +222,34 @@ mod tests {
assert_eq!(result[i], i * 2);
}
}

#[cfg(feature = "quickwit")]
#[test]
fn test_cancel_cpu_intensive_tasks() {
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;

let counter: Arc<AtomicU64> = Default::default();
let mut futures = Vec::new();
let executor = Executor::multi_thread(3, "search-test").unwrap();
for _ in 0..1_000 {
let counter_clone = counter.clone();
let fut = executor.spawn_blocking(move || {
std::thread::sleep(Duration::from_millis(4));
counter_clone.fetch_add(1, Ordering::SeqCst)
});
futures.push(fut);
}
std::thread::sleep(Duration::from_millis(5));
// The first few num_cores tasks should run, but the other should get cancelled.
drop(futures);
while Arc::strong_count(&counter) > 1 {
std::thread::sleep(Duration::from_millis(10));
}
// with ideal timing, we expect the result to always be 6, but as long as we run some, and
// cancelled most, the test is a success
assert!(counter.load(Ordering::SeqCst) > 0);
assert!(counter.load(Ordering::SeqCst) < 50);
}
}
3 changes: 2 additions & 1 deletion src/core/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ impl Searcher {
&self,
doc_address: DocAddress,
) -> crate::Result<D> {
let executor = self.inner.index.search_executor();
let store_reader = &self.inner.store_readers[doc_address.segment_ord as usize];
store_reader.get_async(doc_address.doc_id).await
store_reader.get_async(doc_address.doc_id, executor).await
}

/// Access the schema associated with the index of this searcher.
Expand Down
32 changes: 25 additions & 7 deletions src/store/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use crate::schema::document::{BinaryDocumentDeserializer, DocumentDeserialize};
use crate::space_usage::StoreSpaceUsage;
use crate::store::index::Checkpoint;
use crate::DocId;
#[cfg(feature = "quickwit")]
use crate::Executor;

pub(crate) const DOCSTORE_CACHE_CAPACITY: usize = 100;

Expand Down Expand Up @@ -341,7 +343,11 @@ impl StoreReader {
/// In most cases use [`get_async`](Self::get_async)
///
/// Loads and decompresses a block asynchronously.
async fn read_block_async(&self, checkpoint: &Checkpoint) -> io::Result<Block> {
async fn read_block_async(
&self,
checkpoint: &Checkpoint,
executor: &Executor,
) -> io::Result<Block> {
let cache_key = checkpoint.byte_range.start;
if let Some(block) = self.cache.get_from_cache(checkpoint.byte_range.start) {
return Ok(block);
Expand All @@ -353,8 +359,12 @@ impl StoreReader {
.read_bytes_async()
.await?;

let decompressed_block =
OwnedBytes::new(self.decompressor.decompress(compressed_block.as_ref())?);
let decompressor = self.decompressor;
let maybe_decompressed_block = executor
.spawn_blocking(move || decompressor.decompress(compressed_block.as_ref()))
.await
.expect("decompression panicked");
let decompressed_block = OwnedBytes::new(maybe_decompressed_block?);

self.cache
.put_into_cache(cache_key, decompressed_block.clone());
Expand All @@ -363,15 +373,23 @@ impl StoreReader {
}

/// Reads raw bytes of a given document asynchronously.
pub async fn get_document_bytes_async(&self, doc_id: DocId) -> crate::Result<OwnedBytes> {
pub async fn get_document_bytes_async(
&self,
doc_id: DocId,
executor: &Executor,
) -> crate::Result<OwnedBytes> {
let checkpoint = self.block_checkpoint(doc_id)?;
let block = self.read_block_async(&checkpoint).await?;
let block = self.read_block_async(&checkpoint, executor).await?;
Self::get_document_bytes_from_block(block, doc_id, &checkpoint)
}

/// Fetches a document asynchronously. Async version of [`get`](Self::get).
pub async fn get_async<D: DocumentDeserialize>(&self, doc_id: DocId) -> crate::Result<D> {
let mut doc_bytes = self.get_document_bytes_async(doc_id).await?;
pub async fn get_async<D: DocumentDeserialize>(
&self,
doc_id: DocId,
executor: &Executor,
) -> crate::Result<D> {
let mut doc_bytes = self.get_document_bytes_async(doc_id, executor).await?;

let deserializer = BinaryDocumentDeserializer::from_reader(&mut doc_bytes)
.map_err(crate::TantivyError::from)?;
Expand Down

0 comments on commit 8cd7ddc

Please sign in to comment.