Skip to content

Commit

Permalink
Use oneshot instead of notify
Browse files Browse the repository at this point in the history
  • Loading branch information
rdettai committed Aug 30, 2024
1 parent fbbdc0f commit 34dd452
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 43 deletions.
2 changes: 1 addition & 1 deletion quickwit/quickwit-common/src/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub trait EventSubscriber<E>: Send + Sync + 'static {
impl<E, F> EventSubscriber<E> for F
where
E: Event,
F: Fn(E) + Send + Sync + 'static,
F: FnMut(E) + Send + Sync + 'static,
{
async fn handle_event(&mut self, event: E) {
(self)(event);
Expand Down
97 changes: 55 additions & 42 deletions quickwit/quickwit-ingest/src/ingest_v2/workbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use std::collections::{BTreeMap, HashMap, HashSet};
use std::sync::{Arc, Mutex};

use quickwit_common::pubsub::{EventBroker, EventSubscriptionHandle};
use quickwit_common::pubsub::{EventBroker, EventSubscriber, EventSubscriptionHandle};
use quickwit_common::rate_limited_error;
use quickwit_proto::control_plane::{
GetOrCreateOpenShardsFailure, GetOrCreateOpenShardsFailureReason,
Expand All @@ -32,7 +32,8 @@ use quickwit_proto::ingest::router::{
};
use quickwit_proto::ingest::{IngestV2Error, RateLimitingCause};
use quickwit_proto::types::{NodeId, Position, ShardId, SubrequestId};
use tokio::sync::Notify;
use tokio::sync::oneshot;
use tonic::async_trait;
use tracing::warn;

use super::router::PersistRequestSummary;
Expand All @@ -43,58 +44,70 @@ struct PublishState {
already_published: HashMap<ShardId, Position>,
}

struct PublishTrackerSubscription {
state: Arc<Mutex<PublishState>>,
publish_complete_sender: Option<oneshot::Sender<()>>,
}

#[async_trait]
impl EventSubscriber<ShardPositionsUpdate> for PublishTrackerSubscription {
async fn handle_event(&mut self, update: ShardPositionsUpdate) {
let mut state_handle = self.state.lock().unwrap();
for (updated_shard_id, updated_position) in &update.updated_shard_positions {
if let Some(shard_position) = state_handle.awaiting_publish.get(updated_shard_id) {
if updated_position >= shard_position {
state_handle.awaiting_publish.remove(updated_shard_id);
if state_handle.awaiting_publish.is_empty() {
if let Some(sender) = self.publish_complete_sender.take() {
let _ = sender.send(());
} else {
break;
}
}
}
} else {
// Save this position update in case the publish update
// event arrived before the shard persist response. We
// might build a state that tracks irrelevant shards for
// the duration of the query but that should be fine.
state_handle
.already_published
.insert(updated_shard_id.clone(), updated_position.clone());
}
}
}
}

/// A helper for tracking the progress of the publish events when running in
/// `wait_for` commit mode.
///
/// Registers a set of shard positions and listens to [`ShardPositionsUpdate`]
/// events to assert when all the persisted events have been published. To make
/// sure that no events are missed:
/// - the tracker should be created before the persist requests are sent
/// - the `shard_persisted` method should for all successful persist subrequests
/// - `track_persisted_position()` should be called for all successful persist subrequests
struct PublishTracker {
state: Arc<Mutex<PublishState>>,
publish_complete: Arc<Notify>,
publish_complete: oneshot::Receiver<()>,
_publish_listen_handle: EventSubscriptionHandle,
}

impl PublishTracker {
fn new(event_tracker: EventBroker) -> Self {
let state = Arc::new(Mutex::new(PublishState::default()));
let state_clone = state.clone();
let publish_complete = Arc::new(Notify::new());
let publish_complete_notifier = publish_complete.clone();
let _publish_listen_handle =
event_tracker.subscribe(move |update: ShardPositionsUpdate| {
let mut state_handle = state_clone.lock().unwrap();
for (updated_shard_id, updated_position) in &update.updated_shard_positions {
if let Some(shard_position) =
state_handle.awaiting_publish.get(updated_shard_id)
{
if updated_position >= shard_position {
state_handle.awaiting_publish.remove(updated_shard_id);
if state_handle.awaiting_publish.is_empty() {
publish_complete_notifier.notify_one();
}
}
} else {
// Save this position update in case the publish update
// event arrived before the shard persist response. We
// might build a state that tracks irrelevant shards for
// the duration of the query but that should be fine.
state_handle
.already_published
.insert(updated_shard_id.clone(), updated_position.clone());
}
}
});
let (publish_complete_sender, publish_complete) = oneshot::channel();
let subscription = PublishTrackerSubscription {
state: state.clone(),
publish_complete_sender: Some(publish_complete_sender),
};
Self {
state,
_publish_listen_handle,
_publish_listen_handle: event_tracker.subscribe(subscription),
publish_complete,
}
}

fn shard_persisted(&self, shard_id: ShardId, new_position: Position) {
fn track_persisted_position(&self, shard_id: ShardId, new_position: Position) {
let mut state_handle = self.state.lock().unwrap();
match state_handle.already_published.get(&shard_id) {
Some(already_published_position) if new_position <= *already_published_position => {
Expand All @@ -115,7 +128,7 @@ impl PublishTracker {
if self.state.lock().unwrap().awaiting_publish.is_empty() {
return;
}
self.publish_complete.notified().await;
let _ = self.publish_complete.await;
}
}

Expand Down Expand Up @@ -246,10 +259,10 @@ impl IngestWorkbench {
);
return;
};
if let Some(publish_tracker) = &mut self.publish_tracker {
if let Some(publish_tracker) = &self.publish_tracker {
if let Some(position) = &persist_success.replication_position_inclusive {
publish_tracker
.shard_persisted(persist_success.shard_id().clone(), position.clone());
.track_persisted_position(persist_success.shard_id().clone(), position.clone());
}
}
self.num_successes += 1;
Expand Down Expand Up @@ -493,9 +506,9 @@ mod tests {
let shard_id_3 = ShardId::from("test-shard-3");
let shard_id_4 = ShardId::from("test-shard-3");

tracker.shard_persisted(shard_id_1.clone(), Position::offset(42usize));
tracker.shard_persisted(shard_id_2.clone(), Position::offset(42usize));
tracker.shard_persisted(shard_id_3.clone(), Position::offset(42usize));
tracker.track_persisted_position(shard_id_1.clone(), Position::offset(42usize));
tracker.track_persisted_position(shard_id_2.clone(), Position::offset(42usize));
tracker.track_persisted_position(shard_id_3.clone(), Position::offset(42usize));

event_broker.publish(ShardPositionsUpdate {
source_uid: SourceUid {
Expand Down Expand Up @@ -524,7 +537,7 @@ mod tests {
});

// persist response received after the publish event
tracker.shard_persisted(shard_id_4.clone(), Position::offset(42usize));
tracker.track_persisted_position(shard_id_4.clone(), Position::offset(42usize));

tokio::time::timeout(Duration::from_millis(200), tracker.wait_publish_complete())
.await
Expand All @@ -538,8 +551,8 @@ mod tests {
let tracker = PublishTracker::new(event_broker.clone());
let shard_id_1 = ShardId::from("test-shard-1");
let position = Position::offset(42usize);
tracker.shard_persisted(shard_id_1.clone(), position.clone());
tracker.shard_persisted(ShardId::from("test-shard-2"), position.clone());
tracker.track_persisted_position(shard_id_1.clone(), position.clone());
tracker.track_persisted_position(ShardId::from("test-shard-2"), position.clone());

event_broker.publish(ShardPositionsUpdate {
source_uid: SourceUid {
Expand Down

0 comments on commit 34dd452

Please sign in to comment.