Skip to content

Commit

Permalink
Ensure all permits are released before shutdown resolves (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sushisource authored Nov 18, 2024
1 parent 4925755 commit 9879b55
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 25 deletions.
6 changes: 5 additions & 1 deletion core/src/abstractions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ where
}
}

pub(crate) fn get_extant_count_rcv(&self) -> watch::Receiver<usize> {
self.extant_permits.1.clone()
}

fn build_owned(&self, res: SlotSupplierPermit) -> OwnedMeteredSemPermit<SK> {
self.unused_claimants.fetch_add(1, Ordering::Release);
self.extant_permits.0.send_modify(|ep| *ep += 1);
Expand Down Expand Up @@ -331,7 +335,7 @@ impl<SK: SlotKind> Drop for OwnedMeteredSemPermit<SK> {
if let Some(uc) = self.unused_claimants.take() {
uc.fetch_sub(1, Ordering::Release);
}
(self.release_fn)(&self.release_ctx)
(self.release_fn)(&self.release_ctx);
}
}
impl<SK: SlotKind> Debug for OwnedMeteredSemPermit<SK> {
Expand Down
32 changes: 14 additions & 18 deletions core/src/worker/activities/local_activities.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
use crate::{
abstractions::{
dbg_panic, MeteredPermitDealer, OwnedMeteredSemPermit, PermitDealerContextData,
UsedMeteredSemPermit,
},
abstractions::{dbg_panic, MeteredPermitDealer, OwnedMeteredSemPermit, UsedMeteredSemPermit},
protosext::ValidScheduleLA,
retry_logic::RetryPolicyExt,
telemetry::metrics::{activity_type, local_activity_worker_type, workflow_type},
telemetry::metrics::{activity_type, workflow_type},
worker::workflow::HeartbeatTimeoutMsg,
MetricsContext, TaskToken,
};
Expand All @@ -17,11 +14,10 @@ use std::{
collections::{hash_map::Entry, HashMap},
fmt::{Debug, Formatter},
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant, SystemTime},
};
use temporal_sdk_core_api::worker::{LocalActivitySlotKind, SlotSupplier};
use temporal_sdk_core_api::worker::LocalActivitySlotKind;
use temporal_sdk_core_protos::{
coresdk::{
activity_result::{Cancellation, Failure as ActFail, Success},
Expand Down Expand Up @@ -214,26 +210,19 @@ impl LAMData {

impl LocalActivityManager {
pub(crate) fn new(
slot_supplier: Arc<dyn SlotSupplier<SlotKind = LocalActivitySlotKind> + Send + Sync>,
namespace: String,
permit_dealer: MeteredPermitDealer<LocalActivitySlotKind>,
heartbeat_timeout_tx: UnboundedSender<HeartbeatTimeoutMsg>,
metrics_context: MetricsContext,
context_data: Arc<PermitDealerContextData>,
) -> Self {
let (act_req_tx, act_req_rx) = unbounded_channel();
let (cancels_req_tx, cancels_req_rx) = unbounded_channel();
let shutdown_complete_tok = CancellationToken::new();
let semaphore = MeteredPermitDealer::new(
slot_supplier,
metrics_context.with_new_attrs([local_activity_worker_type()]),
None,
context_data,
);
Self {
namespace,
rcvs: tokio::sync::Mutex::new(RcvChans::new(
act_req_rx,
semaphore,
permit_dealer,
cancels_req_rx,
shutdown_complete_tok.clone(),
)),
Expand All @@ -255,15 +244,20 @@ impl LocalActivityManager {
#[cfg(test)]
fn test(max_concurrent: usize) -> Self {
use crate::worker::tuner::FixedSizeSlotSupplier;
use std::sync::Arc;

let ss = Arc::new(FixedSizeSlotSupplier::new(max_concurrent));
let (hb_tx, _hb_rx) = unbounded_channel();
Self::new(
ss,
"fake_ns".to_string(),
MeteredPermitDealer::new(
ss,
MetricsContext::no_op(),
None,
Arc::new(Default::default()),
),
hb_tx,
MetricsContext::no_op(),
Arc::new(Default::default()),
)
}

Expand Down Expand Up @@ -740,6 +734,8 @@ impl LocalActivityManager {
while !self.set_shutdown_complete_if_ready(&mut self.dat.lock()) {
self.complete_notify.notified().await;
}
// This makes sure we drop any permits that might be held inside the stream
self.rcvs.lock().await.inner = stream::empty().boxed();
}

/// Try to close the activity stream as soon as worker shutdown is initiated. This is required
Expand Down
47 changes: 43 additions & 4 deletions core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use std::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics, WorkerKey};
use temporal_sdk_core_protos::{
Expand All @@ -66,11 +67,13 @@ use temporal_sdk_core_protos::{
},
TaskToken,
};
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::{mpsc::unbounded_channel, watch};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;

use crate::abstractions::PermitDealerContextData;
use crate::{
abstractions::PermitDealerContextData, telemetry::metrics::local_activity_worker_type,
};
use temporal_sdk_core_api::errors::WorkerValidationError;
#[cfg(test)]
use {
Expand Down Expand Up @@ -103,6 +106,22 @@ pub struct Worker {
non_local_activities_complete: Arc<AtomicBool>,
/// Set when local activities are complete and should stop being polled
local_activities_complete: Arc<AtomicBool>,
/// Used to track all permits have been released
all_permits_tracker: tokio::sync::Mutex<AllPermitsTracker>,
}

struct AllPermitsTracker {
wft_permits: watch::Receiver<usize>,
act_permits: watch::Receiver<usize>,
la_permits: watch::Receiver<usize>,
}

impl AllPermitsTracker {
async fn all_done(&mut self) {
let _ = self.wft_permits.wait_for(|x| *x == 0).await;
let _ = self.act_permits.wait_for(|x| *x == 0).await;
let _ = self.la_permits.wait_for(|x| *x == 0).await;
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -288,12 +307,14 @@ impl Worker {
},
slot_context_data.clone(),
);
let wft_permits = wft_slots.get_extant_count_rcv();
let act_slots = MeteredPermitDealer::new(
tuner.activity_task_slot_supplier(),
metrics.with_new_attrs([activity_worker_type()]),
None,
slot_context_data.clone(),
);
let act_permits = act_slots.get_extant_count_rcv();
let (external_wft_tx, external_wft_rx) = unbounded_channel();
let (wft_stream, act_poller) = match task_pollers {
TaskPollers::Real => {
Expand Down Expand Up @@ -390,12 +411,18 @@ impl Worker {
};

let (hb_tx, hb_rx) = unbounded_channel();
let local_act_mgr = Arc::new(LocalActivityManager::new(
let la_pemit_dealer = MeteredPermitDealer::new(
tuner.local_activity_slot_supplier(),
metrics.with_new_attrs([local_activity_worker_type()]),
None,
slot_context_data,
);
let la_permits = la_pemit_dealer.get_extant_count_rcv();
let local_act_mgr = Arc::new(LocalActivityManager::new(
config.namespace.clone(),
la_pemit_dealer,
hb_tx,
metrics.clone(),
slot_context_data,
));
let at_task_mgr = act_poller.map(|ap| {
WorkerActivityTasks::new(
Expand Down Expand Up @@ -463,6 +490,11 @@ impl Worker {
// Non-local activities are already complete if configured not to poll for them.
non_local_activities_complete: Arc::new(AtomicBool::new(!poll_on_non_local_activities)),
local_activities_complete: Default::default(),
all_permits_tracker: tokio::sync::Mutex::new(AllPermitsTracker {
wft_permits,
act_permits,
la_permits,
}),
}
}

Expand All @@ -484,6 +516,13 @@ impl Worker {
if let Some(acts) = self.at_task_mgr.as_ref() {
acts.shutdown().await;
}
// Wait for all permits to be released, but don't totally hang real-world shutdown.
tokio::select! {
_ = async { self.all_permits_tracker.lock().await.all_done().await } => {},
_ = tokio::time::sleep(Duration::from_secs(1)) => {
dbg_panic!("Waiting for all slot permits to release took too long!");
}
};
}

/// Finish shutting down by consuming the background pollers and freeing all resources
Expand Down
4 changes: 2 additions & 2 deletions core/src/worker/workflow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ impl Workflows {
local_activity_request_sink,
);

// However, we want to avoid plowing ahead until we've been asked to poll at least
// once. This supports activity-only workers.
// However, we want to avoid plowing ahead until we've been asked to poll at
// least once. This supports activity-only workers.
let do_poll = tokio::select! {
sp = start_polling_rx => {
sp.is_ok()
Expand Down

0 comments on commit 9879b55

Please sign in to comment.