From 4a1200bdf8f29e2a67357dbf5c31c28505ace1a3 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Mon, 6 Jan 2025 15:40:27 -0800 Subject: [PATCH 01/14] Add nexus polling to WorkerClient --- core/src/worker/client.rs | 20 ++++++++++++++++++++ core/src/worker/client/mocks.rs | 7 +++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index ad3b7b6ae..527a4845b 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -104,6 +104,7 @@ pub(crate) trait WorkerClient: Sync + Send { task_queue: String, max_tasks_per_sec: Option, ) -> Result; + async fn poll_nexus_task(&self, task_queue: String) -> Result; async fn complete_workflow_task( &self, request: WorkflowTaskCompletion, @@ -201,6 +202,25 @@ impl WorkerClient for WorkerClientBag { .into_inner()) } + async fn poll_nexus_task(&self, task_queue: String) -> Result { + let request = PollNexusTaskQueueRequest { + namespace: self.namespace.clone(), + task_queue: Some(TaskQueue { + name: task_queue, + kind: TaskQueueKind::Normal as i32, + normal_name: "".to_string(), + }), + identity: self.identity.clone(), + worker_version_capabilities: self.worker_version_capabilities(), + }; + + Ok(self + .cloned_client() + .poll_nexus_task_queue(request) + .await? + .into_inner()) + } + async fn complete_workflow_task( &self, request: WorkflowTaskCompletion, diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index 4bd296816..645f21ec9 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -1,7 +1,6 @@ use super::*; use futures_util::Future; -use std::sync::Arc; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use temporal_client::SlotManager; pub(crate) static DEFAULT_WORKERS_REGISTRY: LazyLock> = @@ -61,6 +60,10 @@ mockall::mock! { -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; + fn poll_nexus_task<'a, 'b>(&self, task_queue: String) + -> impl Future> + Send + 'b + where 'a: 'b, Self: 'b; + fn complete_workflow_task<'a, 'b>( &self, request: WorkflowTaskCompletion, From 499f3e12fa1e065abb9a6d4f362649541a5d42b9 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Mon, 6 Jan 2025 15:57:18 -0800 Subject: [PATCH 02/14] Add slot kind --- core-api/src/worker.rs | 23 ++++++++++++++- core/src/pollers/poll_buffer.rs | 29 +++++++++++++++++-- core/src/worker/tuner.rs | 28 ++++++++++++++++-- core/src/worker/tuner/resource_based.rs | 28 ++++++++++++++++-- .../temporal/sdk/core/core_interface.proto | 6 ++++ 5 files changed, 106 insertions(+), 8 deletions(-) diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index 66c05b1f5..c1b4fc9a4 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -6,7 +6,7 @@ use std::{ time::Duration, }; use temporal_sdk_core_protos::coresdk::{ - ActivitySlotInfo, LocalActivitySlotInfo, WorkflowSlotInfo, + ActivitySlotInfo, LocalActivitySlotInfo, NexusSlotInfo, WorkflowSlotInfo, }; const MAX_CONCURRENT_WFT_POLLS_DEFAULT: usize = 5; @@ -263,6 +263,11 @@ pub trait WorkerTuner { &self, ) -> Arc + Send + Sync>; + /// Return a [SlotSupplier] for nexus tasks + fn nexus_task_slot_supplier( + &self, + ) -> Arc + Send + Sync>; + /// Core will call this at worker initialization time, allowing the implementation to hook up to /// metrics if any are configured. If not, it will not be called. fn attach_metrics(&self, metrics: TemporalMeter); @@ -364,6 +369,7 @@ pub enum SlotKindType { Workflow, Activity, LocalActivity, + Nexus, } #[derive(Debug, Copy, Clone)] @@ -372,11 +378,14 @@ pub struct WorkflowSlotKind {} pub struct ActivitySlotKind {} #[derive(Debug, Copy, Clone)] pub struct LocalActivitySlotKind {} +#[derive(Debug, Copy, Clone)] +pub struct NexusSlotKind {} pub enum SlotInfo<'a> { Workflow(&'a WorkflowSlotInfo), Activity(&'a ActivitySlotInfo), LocalActivity(&'a LocalActivitySlotInfo), + Nexus(&'a NexusSlotInfo), } pub trait SlotInfoTrait: prost::Message { @@ -397,6 +406,11 @@ impl SlotInfoTrait for LocalActivitySlotInfo { SlotInfo::LocalActivity(self) } } +impl SlotInfoTrait for NexusSlotInfo { + fn downcast(&self) -> SlotInfo { + SlotInfo::Nexus(self) + } +} pub trait SlotKind { type Info: SlotInfoTrait; @@ -424,3 +438,10 @@ impl SlotKind for LocalActivitySlotKind { SlotKindType::LocalActivity } } +impl SlotKind for NexusSlotKind { + type Info = NexusSlotInfo; + + fn kind() -> SlotKindType { + SlotKindType::Nexus + } +} diff --git a/core/src/pollers/poll_buffer.rs b/core/src/pollers/poll_buffer.rs index 32fe8674a..6b21ec50b 100644 --- a/core/src/pollers/poll_buffer.rs +++ b/core/src/pollers/poll_buffer.rs @@ -14,10 +14,12 @@ use std::{ }, time::Duration, }; -use temporal_sdk_core_api::worker::{ActivitySlotKind, SlotKind, WorkflowSlotKind}; +use temporal_sdk_core_api::worker::{ActivitySlotKind, NexusSlotKind, SlotKind, WorkflowSlotKind}; use temporal_sdk_core_protos::temporal::api::{ taskqueue::v1::TaskQueue, - workflowservice::v1::{PollActivityTaskQueueResponse, PollWorkflowTaskQueueResponse}, + workflowservice::v1::{ + PollActivityTaskQueueResponse, PollNexusTaskQueueResponse, PollWorkflowTaskQueueResponse, + }, }; use tokio::{ sync::{ @@ -297,6 +299,29 @@ pub(crate) fn new_activity_task_buffer( ) } +pub(crate) type PollNexusTaskBuffer = LongPollBuffer; +pub(crate) fn new_nexus_task_buffer( + client: Arc, + task_queue: String, + concurrent_pollers: usize, + semaphore: MeteredPermitDealer, + shutdown: CancellationToken, + num_pollers_handler: Option, +) -> PollNexusTaskBuffer { + LongPollBuffer::new( + move || { + let client = client.clone(); + let task_queue = task_queue.clone(); + async move { client.poll_nexus_task(task_queue).await } + }, + semaphore, + concurrent_pollers, + shutdown, + num_pollers_handler, + None:: BoxFuture<'static, ()>>, + ) +} + #[cfg(test)] #[derive(derive_more::Constructor)] pub(crate) struct MockPermittedPollBuffer { diff --git a/core/src/worker/tuner.rs b/core/src/worker/tuner.rs index 6b3506b29..1a34c7c1a 100644 --- a/core/src/worker/tuner.rs +++ b/core/src/worker/tuner.rs @@ -11,8 +11,8 @@ use std::sync::{Arc, OnceLock}; use temporal_sdk_core_api::{ telemetry::metrics::TemporalMeter, worker::{ - ActivitySlotKind, LocalActivitySlotKind, SlotKind, SlotSupplier, WorkerConfig, WorkerTuner, - WorkflowSlotKind, + ActivitySlotKind, LocalActivitySlotKind, NexusSlotKind, SlotKind, SlotSupplier, + WorkerConfig, WorkerTuner, WorkflowSlotKind, }, }; @@ -21,6 +21,7 @@ pub struct TunerHolder { wft_supplier: Arc + Send + Sync>, act_supplier: Arc + Send + Sync>, la_supplier: Arc + Send + Sync>, + nexus_supplier: Arc + Send + Sync>, metrics: OnceLock, } @@ -39,6 +40,9 @@ pub struct TunerHolderOptions { /// Options for local activity slots #[builder(default, setter(strip_option))] pub local_activity_slot_options: Option>, + /// Options for nexus slots + #[builder(default, setter(strip_option))] + pub nexus_slot_options: Option>, /// Options that will apply to all resource based slot suppliers. Must be set if any slot /// options are [SlotSupplierOptions::ResourceBased] #[builder(default, setter(strip_option))] @@ -165,6 +169,7 @@ pub struct TunerBuilder { Option + Send + Sync>>, local_activity_slot_supplier: Option + Send + Sync>>, + nexus_slot_supplier: Option + Send + Sync>>, } impl TunerBuilder { @@ -209,6 +214,15 @@ impl TunerBuilder { self } + /// Set a nexus slot supplier + pub fn nexus_slot_supplier( + &mut self, + supplier: Arc + Send + Sync>, + ) -> &mut Self { + self.nexus_slot_supplier = Some(supplier); + self + } + /// Build a [WorkerTuner] from the configured slot suppliers pub fn build(&mut self) -> TunerHolder { TunerHolder { @@ -224,6 +238,10 @@ impl TunerBuilder { .local_activity_slot_supplier .clone() .unwrap_or_else(|| Arc::new(FixedSizeSlotSupplier::new(100))), + nexus_supplier: self + .nexus_slot_supplier + .clone() + .unwrap_or_else(|| Arc::new(FixedSizeSlotSupplier::new(100))), metrics: OnceLock::new(), } } @@ -248,6 +266,12 @@ impl WorkerTuner for TunerHolder { self.la_supplier.clone() } + fn nexus_task_slot_supplier( + &self, + ) -> Arc + Send + Sync> { + self.nexus_supplier.clone() + } + fn attach_metrics(&self, m: TemporalMeter) { let _ = self.metrics.set(m); } diff --git a/core/src/worker/tuner/resource_based.rs b/core/src/worker/tuner/resource_based.rs index 63b51bc29..792411f4d 100644 --- a/core/src/worker/tuner/resource_based.rs +++ b/core/src/worker/tuner/resource_based.rs @@ -11,9 +11,9 @@ use std::{ use temporal_sdk_core_api::{ telemetry::metrics::{CoreMeter, GaugeF64, MetricAttributes, TemporalMeter}, worker::{ - ActivitySlotKind, LocalActivitySlotKind, SlotInfo, SlotInfoTrait, SlotKind, SlotKindType, - SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext, SlotSupplier, - SlotSupplierPermit, WorkerTuner, WorkflowSlotKind, + ActivitySlotKind, LocalActivitySlotKind, NexusSlotKind, SlotInfo, SlotInfoTrait, SlotKind, + SlotKindType, SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext, + SlotSupplier, SlotSupplierPermit, WorkerTuner, WorkflowSlotKind, }, }; use tokio::{sync::watch, task::JoinHandle}; @@ -30,6 +30,7 @@ pub struct ResourceBasedTuner { wf_opts: Option, act_opts: Option, la_opts: Option, + nexus_opts: Option, } impl ResourceBasedTuner { @@ -59,6 +60,7 @@ impl ResourceBasedTuner { wf_opts: None, act_opts: None, la_opts: None, + nexus_opts: None, } } @@ -79,6 +81,12 @@ impl ResourceBasedTuner { self.la_opts = Some(opts); self } + + /// Set nexus slot options + pub fn with_nexus_slots_options(&mut self, opts: ResourceSlotOptions) -> &mut Self { + self.nexus_opts = Some(opts); + self + } } const DEFAULT_WF_SLOT_OPTS: ResourceSlotOptions = ResourceSlotOptions { @@ -91,6 +99,13 @@ const DEFAULT_ACT_SLOT_OPTS: ResourceSlotOptions = ResourceSlotOptions { max_slots: 10_000, ramp_throttle: Duration::from_millis(50), }; +const DEFAULT_NEXUS_SLOT_OPTS: ResourceSlotOptions = ResourceSlotOptions { + min_slots: 1, + max_slots: 10_000, + // No ramp is chosen under the assumption that nexus tasks are unlikely to use many resources + // and would prefer lowest latency over protection against oversubscription. + ramp_throttle: Duration::from_millis(0), +}; /// Options for a specific slot type #[derive(Debug, Clone, Copy, derive_more::Constructor)] @@ -375,6 +390,13 @@ impl WorkerTuner for ResourceBas self.slots.as_kind(o) } + fn nexus_task_slot_supplier( + &self, + ) -> Arc + Send + Sync> { + let o = self.nexus_opts.unwrap_or(DEFAULT_NEXUS_SLOT_OPTS); + self.slots.as_kind(o) + } + fn attach_metrics(&self, metrics: TemporalMeter) { self.slots.attach_metrics(metrics); } diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/core_interface.proto b/sdk-core-protos/protos/local/temporal/sdk/core/core_interface.proto index 089658cc5..bfe7a0a37 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/core_interface.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/core_interface.proto @@ -45,3 +45,9 @@ message ActivitySlotInfo { message LocalActivitySlotInfo { string activity_type = 1; } + +// Info about nexus task slot usage +message NexusSlotInfo { + string service = 1; + string operation = 2; +} From 552cc879e1dcac751a69bd127729d8cb1e80228a Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Tue, 7 Jan 2025 10:56:08 -0800 Subject: [PATCH 03/14] Generic polling streams --- core/src/pollers/mod.rs | 165 +++++++++++++++++- core/src/worker/activities.rs | 46 ++--- .../activities/activity_task_poller_stream.rs | 78 --------- core/src/worker/workflow/mod.rs | 3 +- 4 files changed, 180 insertions(+), 112 deletions(-) delete mode 100644 core/src/worker/activities/activity_task_poller_stream.rs diff --git a/core/src/pollers/mod.rs b/core/src/pollers/mod.rs index 5e7be07f7..00a89ae99 100644 --- a/core/src/pollers/mod.rs +++ b/core/src/pollers/mod.rs @@ -8,21 +8,28 @@ pub use temporal_client::{ TlsConfig, WorkflowClientTrait, }; -use crate::abstractions::OwnedMeteredSemPermit; +use crate::{ + abstractions::{OwnedMeteredSemPermit, TrackedOwnedMeteredSemPermit}, + telemetry::metrics::MetricsContext, +}; +use anyhow::anyhow; +use futures_util::{stream, Stream}; +use std::{fmt::Debug, marker::PhantomData}; +use temporal_sdk_core_api::worker::{ActivitySlotKind, NexusSlotKind, SlotKind, WorkflowSlotKind}; use temporal_sdk_core_protos::temporal::api::workflowservice::v1::{ - PollActivityTaskQueueResponse, PollWorkflowTaskQueueResponse, + PollActivityTaskQueueResponse, PollNexusTaskQueueResponse, PollWorkflowTaskQueueResponse, }; +use tokio::select; +use tokio_util::sync::CancellationToken; #[cfg(test)] use futures_util::Future; #[cfg(test)] pub(crate) use poll_buffer::MockPermittedPollBuffer; -use temporal_sdk_core_api::worker::{ActivitySlotKind, WorkflowSlotKind}; pub(crate) type Result = std::result::Result; -/// A trait for things that poll the server. Hides complexity of concurrent polling or polling -/// on sticky/nonsticky queues simultaneously. +/// A trait for things that long poll the server. #[cfg_attr(test, mockall::automock)] #[cfg_attr(test, allow(unused))] #[async_trait::async_trait] @@ -45,6 +52,10 @@ pub(crate) type BoxedActPoller = BoxedPoller<( PollActivityTaskQueueResponse, OwnedMeteredSemPermit, )>; +pub(crate) type BoxedNexusPoller = BoxedPoller<( + PollNexusTaskQueueResponse, + OwnedMeteredSemPermit, +)>; #[async_trait::async_trait] impl Poller for Box + Send + Sync> @@ -85,3 +96,147 @@ mockall::mock! { where Self: 'a; } } + +#[derive(Debug)] +pub(crate) struct PermittedTqResp { + pub(crate) permit: OwnedMeteredSemPermit, + pub(crate) resp: T, +} + +#[derive(Debug)] +pub(crate) struct TrackedPermittedTqResp { + pub(crate) permit: TrackedOwnedMeteredSemPermit, + pub(crate) resp: T, +} + +// Trait for validatable task responses +pub(crate) trait ValidatableTask: + Debug + Default + PartialEq + Send + Sync + 'static +{ + type SlotKind: SlotKind; + + fn validate(&self) -> Result<(), anyhow::Error>; + fn task_name() -> &'static str; +} + +pub(crate) struct TaskPollerStream +where + P: Poller<(T, OwnedMeteredSemPermit)>, + T: ValidatableTask, +{ + poller: P, + metrics: MetricsContext, + metrics_no_task: fn(&MetricsContext), + shutdown_token: CancellationToken, + poller_was_shutdown: bool, + _phantom: PhantomData, +} + +impl TaskPollerStream +where + P: Poller<(T, OwnedMeteredSemPermit)>, + T: ValidatableTask, +{ + pub(crate) fn new( + poller: P, + metrics: MetricsContext, + metrics_no_task: fn(&MetricsContext), + shutdown_token: CancellationToken, + ) -> Self { + Self { + poller, + metrics, + metrics_no_task, + shutdown_token, + poller_was_shutdown: false, + _phantom: PhantomData, + } + } + + fn into_stream(self) -> impl Stream, tonic::Status>> { + stream::unfold(self, |mut state| async move { + loop { + let poll = async { + loop { + return match state.poller.poll().await { + Some(Ok((task, permit))) => { + if task == Default::default() { + // We get the default proto in the event that the long poll + // times out. + debug!("Poll {} task timeout", T::task_name()); + (state.metrics_no_task)(&state.metrics); + continue; + } + + if let Err(e) = task.validate() { + warn!( + "Received invalid {} task ({}): {:?}", + T::task_name(), + e, + &task + ); + return Some(Err(tonic::Status::invalid_argument( + e.to_string(), + ))); + } + + Some(Ok(PermittedTqResp { resp: task, permit })) + } + Some(Err(e)) => { + warn!(error=?e, "Error while polling for {} tasks", T::task_name()); + Some(Err(e)) + } + // If poller returns None, it's dead, thus we also return None to + // terminate this stream. + None => None, + }; + } + }; + if state.poller_was_shutdown { + return poll.await.map(|res| (res, state)); + } + select! { + biased; + + _ = state.shutdown_token.cancelled() => { + state.poller.notify_shutdown(); + state.poller_was_shutdown = true; + continue; + } + res = poll => { + return res.map(|res| (res, state)); + } + } + } + }) + } +} + +impl ValidatableTask for PollActivityTaskQueueResponse { + type SlotKind = ActivitySlotKind; + + fn validate(&self) -> Result<(), anyhow::Error> { + if self.task_token.is_empty() { + return Err(anyhow!("missing task token")); + } + Ok(()) + } + + fn task_name() -> &'static str { + "activity" + } +} + +pub(crate) fn new_activity_task_poller( + poller: BoxedActPoller, + metrics: MetricsContext, + shutdown_token: CancellationToken, +) -> impl Stream, tonic::Status>> { + TaskPollerStream::new( + poller, + metrics, + MetricsContext::act_poll_timeout, + shutdown_token, + ) + .into_stream() +} diff --git a/core/src/worker/activities.rs b/core/src/worker/activities.rs index 643dfb568..31402df9a 100644 --- a/core/src/worker/activities.rs +++ b/core/src/worker/activities.rs @@ -1,5 +1,4 @@ mod activity_heartbeat_manager; -mod activity_task_poller_stream; mod local_activities; pub(crate) use local_activities::{ @@ -9,17 +8,13 @@ pub(crate) use local_activities::{ use crate::{ abstractions::{ - ClosableMeteredPermitDealer, MeteredPermitDealer, OwnedMeteredSemPermit, - TrackedOwnedMeteredSemPermit, UsedMeteredSemPermit, + ClosableMeteredPermitDealer, MeteredPermitDealer, TrackedOwnedMeteredSemPermit, + UsedMeteredSemPermit, }, - pollers::BoxedActPoller, + pollers::{new_activity_task_poller, BoxedActPoller, PermittedTqResp, TrackedPermittedTqResp}, telemetry::metrics::{activity_type, eager, workflow_type, MetricsContext}, worker::{ - activities::{ - activity_heartbeat_manager::ActivityHeartbeatError, - activity_task_poller_stream::new_activity_task_poller, - }, - client::WorkerClient, + activities::activity_heartbeat_manager::ActivityHeartbeatError, client::WorkerClient, }, PollActivityError, TaskToken, }; @@ -157,7 +152,7 @@ pub(crate) struct WorkerActivityTasks { /// Holds activity tasks we have received in direct response to workflow task completion (a.k.a /// eager activities). Tasks received in this stream hold a "tracked" permit that is issued by /// the `eager_activities_semaphore`. - eager_activities_tx: UnboundedSender, + eager_activities_tx: UnboundedSender>, /// Ensures that no activities are in the middle of flushing their results to server while we /// try to shut down. completers_lock: tokio::sync::RwLock<()>, @@ -176,7 +171,7 @@ pub(crate) struct WorkerActivityTasks { #[derive(derive_more::From)] enum ActivityTaskSource { PendingCancel(PendingActivityCancel), - PendingStart(Result<(PermittedTqResp, bool), PollActivityError>), + PendingStart(Result<(PermittedTqResp, bool), PollActivityError>), } impl WorkerActivityTasks { @@ -245,11 +240,15 @@ impl WorkerActivityTasks { /// Merges the server poll and eager [ActivityTask] sources fn merge_start_task_sources( - non_poll_tasks_rx: UnboundedReceiver, - poller_stream: impl Stream>, + non_poll_tasks_rx: UnboundedReceiver>, + poller_stream: impl Stream< + Item = Result, tonic::Status>, + >, eager_activities_semaphore: Arc>, on_complete_token: CancellationToken, - ) -> impl Stream> { + ) -> impl Stream< + Item = Result<(PermittedTqResp, bool), PollActivityError>, + > { let non_poll_stream = stream::unfold( (non_poll_tasks_rx, eager_activities_semaphore), |(mut non_poll_tasks_rx, eager_activities_semaphore)| async move { @@ -662,7 +661,7 @@ where /// Allows for the handling of activities returned by WFT completions. pub(crate) struct ActivitiesFromWFTsHandle { sem: Arc>, - tx: UnboundedSender, + tx: UnboundedSender>, } impl ActivitiesFromWFTsHandle { @@ -675,7 +674,10 @@ impl ActivitiesFromWFTsHandle { /// Queue new activity tasks for dispatch received from non-polling sources (ex: eager returns /// from WFT completion) - pub(crate) fn add_tasks(&self, tasks: impl IntoIterator) { + pub(crate) fn add_tasks( + &self, + tasks: impl IntoIterator>, + ) { for t in tasks.into_iter() { // Technically we should be reporting `activity_task_received` here, but for simplicity // and time insensitivity, that metric is tracked in `about_to_issue_task`. @@ -684,18 +686,6 @@ impl ActivitiesFromWFTsHandle { } } -#[derive(Debug)] -pub(crate) struct PermittedTqResp { - pub(crate) permit: OwnedMeteredSemPermit, - pub(crate) resp: PollActivityTaskQueueResponse, -} - -#[derive(Debug)] -pub(crate) struct TrackedPermittedTqResp { - pub(crate) permit: TrackedOwnedMeteredSemPermit, - pub(crate) resp: PollActivityTaskQueueResponse, -} - fn worker_shutdown_failure() -> Failure { Failure { message: "Worker is shutting down and this activity did not complete in time".to_string(), diff --git a/core/src/worker/activities/activity_task_poller_stream.rs b/core/src/worker/activities/activity_task_poller_stream.rs deleted file mode 100644 index 1b068ac04..000000000 --- a/core/src/worker/activities/activity_task_poller_stream.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::{pollers::BoxedActPoller, worker::activities::PermittedTqResp, MetricsContext}; -use futures_util::{stream, Stream}; -use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollActivityTaskQueueResponse; -use tokio::select; -use tokio_util::sync::CancellationToken; - -struct StreamState { - poller: BoxedActPoller, - metrics: MetricsContext, - shutdown_token: CancellationToken, - poller_was_shutdown: bool, -} - -pub(crate) fn new_activity_task_poller( - poller: BoxedActPoller, - metrics: MetricsContext, - shutdown_token: CancellationToken, -) -> impl Stream> { - let state = StreamState { - poller, - metrics, - shutdown_token, - poller_was_shutdown: false, - }; - stream::unfold(state, |mut state| async move { - loop { - let poll = async { - loop { - return match state.poller.poll().await { - Some(Ok((resp, permit))) => { - if resp == PollActivityTaskQueueResponse::default() { - // We get the default proto in the event that the long poll times - // out. - debug!("Poll activity task timeout"); - state.metrics.act_poll_timeout(); - continue; - } - if let Some(reason) = validate_activity_task(&resp) { - warn!("Received invalid activity task ({}): {:?}", reason, &resp); - continue; - } - Some(Ok(PermittedTqResp { permit, resp })) - } - Some(Err(e)) => { - warn!(error=?e, "Error while polling for activity tasks"); - Some(Err(e)) - } - // If poller returns None, it's dead, thus we also return None to terminate - // this stream. - None => None, - }; - } - }; - if state.poller_was_shutdown { - return poll.await.map(|res| (res, state)); - } - select! { - biased; - - _ = state.shutdown_token.cancelled() => { - state.poller.notify_shutdown(); - state.poller_was_shutdown = true; - continue; - } - res = poll => { - return res.map(|res| (res, state)); - } - } - } - }) -} - -fn validate_activity_task(task: &PollActivityTaskQueueResponse) -> Option<&'static str> { - if task.task_token.is_empty() { - return Some("missing task token"); - } - None -} diff --git a/core/src/worker/workflow/mod.rs b/core/src/worker/workflow/mod.rs index 1dfd53157..25475cd1c 100644 --- a/core/src/worker/workflow/mod.rs +++ b/core/src/worker/workflow/mod.rs @@ -20,10 +20,11 @@ use crate::{ UsedMeteredSemPermit, }, internal_flags::InternalFlags, + pollers::TrackedPermittedTqResp, protosext::{legacy_query_failure, protocol_messages::IncomingProtocolMessage}, telemetry::{set_trace_subscriber_for_current_thread, TelemetryInstance, VecDisplayer}, worker::{ - activities::{ActivitiesFromWFTsHandle, LocalActivityManager, TrackedPermittedTqResp}, + activities::{ActivitiesFromWFTsHandle, LocalActivityManager}, client::{WorkerClient, WorkflowTaskCompletion}, workflow::{ history_update::HistoryPaginator, From 04b2fd9a7c0ea61c234b9c4aa114b3646daf4d8f Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Tue, 7 Jan 2025 15:15:25 -0800 Subject: [PATCH 04/14] Polling Nexus tasks from worker --- core-api/src/lib.rs | 13 ++- core-api/src/worker.rs | 4 + core/src/pollers/mod.rs | 45 +++++++- core/src/telemetry/metrics.rs | 19 +++- core/src/test_help/mod.rs | 9 +- core/src/worker/mod.rs | 77 ++++++++++--- core/src/worker/nexus.rs | 101 ++++++++++++++++++ .../local/temporal/sdk/core/nexus/nexus.proto | 13 +++ tests/integ_tests/workflow_tests/nexus.rs | 51 ++------- 9 files changed, 266 insertions(+), 66 deletions(-) create mode 100644 core/src/worker/nexus.rs diff --git a/core-api/src/lib.rs b/core-api/src/lib.rs index 7f9608b80..eb2930d95 100644 --- a/core-api/src/lib.rs +++ b/core-api/src/lib.rs @@ -9,9 +9,13 @@ use crate::{ }, worker::WorkerConfig, }; -use temporal_sdk_core_protos::coresdk::{ - activity_task::ActivityTask, workflow_activation::WorkflowActivation, - workflow_completion::WorkflowActivationCompletion, ActivityHeartbeat, ActivityTaskCompletion, +use temporal_sdk_core_protos::{ + coresdk::{ + activity_task::ActivityTask, workflow_activation::WorkflowActivation, + workflow_completion::WorkflowActivationCompletion, ActivityHeartbeat, + ActivityTaskCompletion, + }, + temporal::api::workflowservice::v1::PollNexusTaskQueueResponse, }; /// This trait is the primary way by which language specific SDKs interact with the core SDK. @@ -45,6 +49,9 @@ pub trait Worker: Send + Sync { /// Do not call poll concurrently. It handles polling the server concurrently internally. async fn poll_activity_task(&self) -> Result; + /// TODO: Keep or combine? + async fn poll_nexus_task(&self) -> Result; + /// Tell the worker that a workflow activation has completed. May (and should) be freely called /// concurrently. The future may take some time to resolve, as fetching more events might be /// necessary for completion to... complete - thus SDK implementers should make sure they do diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index c1b4fc9a4..a99b136c7 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -56,6 +56,10 @@ pub struct WorkerConfig { /// worker's task queue #[builder(default = "5")] pub max_concurrent_at_polls: usize, + /// Maximum number of concurrent poll nexus task requests we will perform at a time on this + /// worker's task queue + #[builder(default = "5")] + pub max_concurrent_nexus_polls: usize, /// If set to true this worker will only handle workflow tasks and local activities, it will not /// poll for activity tasks. #[builder(default = "false")] diff --git a/core/src/pollers/mod.rs b/core/src/pollers/mod.rs index 00a89ae99..4b2578b30 100644 --- a/core/src/pollers/mod.rs +++ b/core/src/pollers/mod.rs @@ -1,7 +1,7 @@ mod poll_buffer; pub(crate) use poll_buffer::{ - new_activity_task_buffer, new_workflow_task_buffer, WorkflowTaskPoller, + new_activity_task_buffer, new_nexus_task_buffer, new_workflow_task_buffer, WorkflowTaskPoller, }; pub use temporal_client::{ Client, ClientOptions, ClientOptionsBuilder, ClientTlsConfig, RetryClient, RetryConfig, @@ -12,7 +12,7 @@ use crate::{ abstractions::{OwnedMeteredSemPermit, TrackedOwnedMeteredSemPermit}, telemetry::metrics::MetricsContext, }; -use anyhow::anyhow; +use anyhow::{anyhow, bail}; use futures_util::{stream, Stream}; use std::{fmt::Debug, marker::PhantomData}; use temporal_sdk_core_api::worker::{ActivitySlotKind, NexusSlotKind, SlotKind, WorkflowSlotKind}; @@ -109,7 +109,6 @@ pub(crate) struct TrackedPermittedTqResp { pub(crate) resp: T, } -// Trait for validatable task responses pub(crate) trait ValidatableTask: Debug + Default + PartialEq + Send + Sync + 'static { @@ -240,3 +239,43 @@ pub(crate) fn new_activity_task_poller( ) .into_stream() } + +impl ValidatableTask for PollNexusTaskQueueResponse { + type SlotKind = NexusSlotKind; + + fn validate(&self) -> Result<(), anyhow::Error> { + if self.task_token.is_empty() { + bail!("missing task token"); + } else if self.request.is_none() { + bail!("missing request field"); + } else if self + .request + .as_ref() + .expect("just request exists") + .variant + .is_none() + { + bail!("missing request variant"); + } + Ok(()) + } + + fn task_name() -> &'static str { + "nexus" + } +} + +pub(crate) type NexusPollItem = Result, tonic::Status>; +pub(crate) fn new_nexus_task_poller( + poller: BoxedNexusPoller, + metrics: MetricsContext, + shutdown_token: CancellationToken, +) -> impl Stream { + TaskPollerStream::new( + poller, + metrics, + MetricsContext::nexus_poll_timeout, + shutdown_token, + ) + .into_stream() +} diff --git a/core/src/telemetry/metrics.rs b/core/src/telemetry/metrics.rs index b83488a0b..4caee53a4 100644 --- a/core/src/telemetry/metrics.rs +++ b/core/src/telemetry/metrics.rs @@ -48,6 +48,7 @@ struct Instruments { la_exec_latency: Arc, la_exec_succeeded_latency: Arc, la_total: Arc, + nexus_poll_no_task: Arc, worker_registered: Arc, num_pollers: Arc, task_slots_available: Arc, @@ -225,6 +226,11 @@ impl MetricsContext { self.instruments.la_total.add(1, &self.kvs); } + /// A nexus long poll timed out + pub(crate) fn nexus_poll_timeout(&self) { + self.instruments.nexus_poll_no_task.add(1, &self.kvs); + } + /// A worker was registered pub(crate) fn worker_registered(&self) { self.instruments.worker_registered.add(1, &self.kvs); @@ -386,6 +392,11 @@ impl Instruments { description: "Count of local activities executed".into(), unit: "".into(), }), + nexus_poll_no_task: meter.counter(MetricParameters { + name: "nexus_poll_no_task".into(), + description: "Count of nexus task queue poll timeouts (no new task)".into(), + unit: "".into(), + }), // name kept as worker start for compat with old sdk / what users expect worker_registered: meter.counter(MetricParameters { name: "worker_start".into(), @@ -452,6 +463,9 @@ pub(crate) fn workflow_sticky_poller() -> MetricKeyValue { pub(crate) fn activity_poller() -> MetricKeyValue { MetricKeyValue::new(KEY_POLLER_TYPE, "activity_task") } +pub(crate) fn nexus_poller() -> MetricKeyValue { + MetricKeyValue::new(KEY_POLLER_TYPE, "nexus_task") +} pub(crate) fn task_queue(tq: String) -> MetricKeyValue { MetricKeyValue::new(KEY_TASK_QUEUE, tq) } @@ -470,6 +484,9 @@ pub(crate) fn activity_worker_type() -> MetricKeyValue { pub(crate) fn local_activity_worker_type() -> MetricKeyValue { MetricKeyValue::new(KEY_WORKER_TYPE, "LocalActivityWorker") } +pub(crate) fn nexus_worker_type() -> MetricKeyValue { + MetricKeyValue::new(KEY_WORKER_TYPE, "NexusWorker") +} pub(crate) fn eager(is_eager: bool) -> MetricKeyValue { MetricKeyValue::new(KEY_EAGER, is_eager) } @@ -886,7 +903,7 @@ mod tests { a1.set(Arc::new(DummyCustomAttrs(1))).unwrap(); // Verify all metrics are created. This number will need to get updated any time a metric // is added. - let num_metrics = 30; + let num_metrics = 31; #[allow(clippy::needless_range_loop)] // Sorry clippy, this reads easier. for metric_num in 1..=num_metrics { let hole = assert_matches!(&events[metric_num], diff --git a/core/src/test_help/mod.rs b/core/src/test_help/mod.rs index 4f0f2a97f..039f5b603 100644 --- a/core/src/test_help/mod.rs +++ b/core/src/test_help/mod.rs @@ -49,8 +49,8 @@ use temporal_sdk_core_protos::{ protocol::v1::message, update, workflowservice::v1::{ - PollActivityTaskQueueResponse, PollWorkflowTaskQueueResponse, - RespondWorkflowTaskCompletedResponse, + PollActivityTaskQueueResponse, PollNexusTaskQueueResponse, + PollWorkflowTaskQueueResponse, RespondWorkflowTaskCompletedResponse, }, }, utilities::pack_any, @@ -170,6 +170,7 @@ pub(crate) fn mock_worker(mocks: MocksHolder) -> Worker { TaskPollers::Mocked { wft_stream: mocks.inputs.wft_stream, act_poller, + nexus_poller: mocks.inputs.nexus_poller, }, None, ) @@ -221,6 +222,7 @@ impl MocksHolder { pub(crate) struct MockWorkerInputs { pub(crate) wft_stream: BoxStream<'static, Result>, pub(crate) act_poller: Option>, + pub(crate) nexus_poller: Option>, pub(crate) config: WorkerConfig, } @@ -237,6 +239,7 @@ impl MockWorkerInputs { Self { wft_stream, act_poller: None, + nexus_poller: None, config: test_worker_cfg().build().unwrap(), } } @@ -268,6 +271,7 @@ impl MocksHolder { let mock_worker = MockWorkerInputs { wft_stream, act_poller: Some(mock_act_poller), + nexus_poller: None, config: test_worker_cfg().build().unwrap(), }; Self { @@ -290,6 +294,7 @@ impl MocksHolder { let mock_worker = MockWorkerInputs { wft_stream, act_poller: None, + nexus_poller: None, config: test_worker_cfg().build().unwrap(), }; Self { diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 37904d80d..8c3f91529 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -1,5 +1,6 @@ mod activities; pub(crate) mod client; +mod nexus; mod slot_provider; pub(crate) mod tuner; mod workflow; @@ -18,22 +19,25 @@ pub(crate) use activities::{ pub(crate) use workflow::{wft_poller::new_wft_poller, LEGACY_QUERY_ID}; use crate::{ - abstractions::{dbg_panic, MeteredPermitDealer}, + abstractions::{dbg_panic, MeteredPermitDealer, PermitDealerContextData}, errors::CompleteWfError, pollers::{ - new_activity_task_buffer, new_workflow_task_buffer, BoxedActPoller, WorkflowTaskPoller, + new_activity_task_buffer, new_nexus_task_buffer, new_workflow_task_buffer, BoxedActPoller, + BoxedNexusPoller, WorkflowTaskPoller, }, protosext::validate_activity_completion, telemetry::{ metrics::{ - activity_poller, activity_worker_type, workflow_poller, workflow_sticky_poller, - workflow_worker_type, MetricsContext, + activity_poller, activity_worker_type, local_activity_worker_type, nexus_poller, + nexus_worker_type, workflow_poller, workflow_sticky_poller, workflow_worker_type, + MetricsContext, }, TelemetryInstance, }, worker::{ activities::{LACompleteAction, LocalActivityManager, NextPendingLAAction}, client::WorkerClient, + nexus::NexusManager, workflow::{LAReqSink, LocalResolution, WorkflowBasics, Workflows}, }, ActivityHeartbeat, CompleteActivityError, PollActivityError, PollWfError, WorkerTrait, @@ -52,6 +56,7 @@ use std::{ time::Duration, }; use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics, WorkerKey}; +use temporal_sdk_core_api::errors::WorkerValidationError; use temporal_sdk_core_protos::{ coresdk::{ activity_result::activity_execution_result, @@ -63,7 +68,7 @@ use temporal_sdk_core_protos::{ temporal::api::{ enums::v1::TaskQueueKind, taskqueue::v1::{StickyExecutionAttributes, TaskQueue}, - workflowservice::v1::get_system_info_response, + workflowservice::v1::{get_system_info_response, PollNexusTaskQueueResponse}, }, TaskToken, }; @@ -71,10 +76,6 @@ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -use crate::{ - abstractions::PermitDealerContextData, telemetry::metrics::local_activity_worker_type, -}; -use temporal_sdk_core_api::errors::WorkerValidationError; #[cfg(test)] use { crate::{ @@ -97,6 +98,8 @@ pub struct Worker { at_task_mgr: Option, /// Manages local activities local_act_mgr: Arc, + /// Manages Nexus tasks + nexus_mgr: Option, /// Has shutdown been called? shutdown_token: CancellationToken, /// Will be called at the end of each activation completion @@ -148,6 +151,16 @@ impl WorkerTrait for Worker { } } + #[instrument(skip(self))] + async fn poll_nexus_task(&self) -> Result { + if let Some(nm) = self.nexus_mgr.as_ref() { + nm.next_nexus_task().await + } else { + self.shutdown_token.cancelled().await; + Err(PollActivityError::ShutDown) + } + } + async fn complete_workflow_activation( &self, completion: WorkflowActivationCompletion, @@ -316,7 +329,13 @@ impl Worker { ); let act_permits = act_slots.get_extant_count_rcv(); let (external_wft_tx, external_wft_rx) = unbounded_channel(); - let (wft_stream, act_poller) = match task_pollers { + let nexus_slots = MeteredPermitDealer::new( + tuner.nexus_task_slot_supplier(), + metrics.with_new_attrs([nexus_worker_type()]), + None, + slot_context_data.clone(), + ); + let (wft_stream, act_poller, nexus_poller) = match task_pollers { TaskPollers::Real => { let max_nonsticky_polls = if sticky_queue_name.is_some() { config.max_nonsticky_polls() @@ -386,17 +405,36 @@ impl Worker { wft_stream.right_stream() }; + // TODO: Use config option or always have instance depending on if we combine tasks + // polling or not + let nexus_poll_buffer = if false { + None + } else { + let np_metrics = metrics.with_new_attrs([nexus_poller()]); + Some(Box::new(new_nexus_task_buffer( + client.clone(), + config.task_queue.clone(), + config.max_concurrent_nexus_polls, + nexus_slots.clone(), + shutdown_token.child_token(), + Some(move |np| np_metrics.record_num_pollers(np)), + )) as BoxedNexusPoller) + }; + #[cfg(test)] let wft_stream = wft_stream.left_stream(); - (wft_stream, act_poll_buffer) + (wft_stream, act_poll_buffer, nexus_poll_buffer) } #[cfg(test)] TaskPollers::Mocked { wft_stream, act_poller, + nexus_poller, } => { let ap = act_poller .map(|ap| MockPermittedPollBuffer::new(Arc::new(act_slots.clone()), ap)); + let np = nexus_poller + .map(|np| MockPermittedPollBuffer::new(Arc::new(nexus_slots.clone()), np)); let wft_semaphore = wft_slots.clone(); let wfs = wft_stream.then(move |s| { let wft_semaphore = wft_semaphore.clone(); @@ -406,7 +444,11 @@ impl Worker { } }); let wfs = wfs.right_stream(); - (wfs, ap.map(|ap| Box::new(ap) as BoxedActPoller)) + ( + wfs, + ap.map(|ap| Box::new(ap) as BoxedActPoller), + np.map(|np| Box::new(np) as BoxedNexusPoller), + ) } }; @@ -441,6 +483,15 @@ impl Worker { info!("Activity polling is disabled for this worker"); }; let la_sink = LAReqSink::new(local_act_mgr.clone()); + + let nexus_mgr = nexus_poller.map(|np| { + NexusManager::new( + np, + metrics.with_new_attrs([nexus_worker_type()]), + shutdown_token.child_token(), + ) + }); + let provider = SlotProvider::new( config.namespace.clone(), config.task_queue.clone(), @@ -498,6 +549,7 @@ impl Worker { act_permits, la_permits, }), + nexus_mgr, } } @@ -807,6 +859,7 @@ pub(crate) enum TaskPollers { Mocked { wft_stream: BoxStream<'static, Result>, act_poller: Option>, + nexus_poller: Option>, }, } diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs new file mode 100644 index 000000000..e65e00fea --- /dev/null +++ b/core/src/worker/nexus.rs @@ -0,0 +1,101 @@ +use crate::{ + abstractions::UsedMeteredSemPermit, + pollers::{new_nexus_task_poller, BoxedNexusPoller, NexusPollItem}, + telemetry::metrics::MetricsContext, +}; +use futures_util::{stream::BoxStream, Stream, StreamExt}; +use std::collections::HashMap; +use temporal_sdk_core_api::{errors::PollActivityError, worker::NexusSlotKind}; +use temporal_sdk_core_protos::{ + coresdk::NexusSlotInfo, + temporal::api::{nexus::v1::request::Variant, workflowservice::v1::PollNexusTaskQueueResponse}, + TaskToken, +}; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; + +/// Centralizes all state related to received nexus tasks +pub(super) struct NexusManager { + task_stream: Mutex>>, + /// Token to notify when poll returned a shutdown error + poll_returned_shutdown_token: CancellationToken, +} + +impl NexusManager { + pub(super) fn new( + poller: BoxedNexusPoller, + metrics: MetricsContext, + shutdown_initiated_token: CancellationToken, + ) -> Self { + let source_stream = new_nexus_task_poller(poller, metrics, shutdown_initiated_token); + let task_stream = NexusTaskStream::new(source_stream); + Self { + task_stream: Mutex::new(task_stream.into_stream().boxed()), + poll_returned_shutdown_token: CancellationToken::new(), + } + } + + // TODO Different error or combine + /// Block until then next nexus task is received from server + pub(super) async fn next_nexus_task( + &self, + ) -> Result { + let mut sl = self.task_stream.lock().await; + sl.next().await.unwrap_or_else(|| { + self.poll_returned_shutdown_token.cancel(); + Err(PollActivityError::ShutDown) + }) + } +} + +struct NexusTaskStream { + source_stream: S, + outstanding_task_map: HashMap, +} + +struct NexusInFlightTask { + _permit: UsedMeteredSemPermit, +} + +impl NexusTaskStream +where + S: Stream, +{ + fn new(source: S) -> Self { + Self { + source_stream: source, + outstanding_task_map: HashMap::new(), + } + } + + fn into_stream( + mut self, + ) -> impl Stream> { + self.source_stream.map(move |t| match t { + Ok(t) => { + let (service, operation) = t + .resp + .request + .as_ref() + .and_then(|r| r.variant.as_ref()) + .map(|v| match v { + Variant::StartOperation(s) => { + (s.service.to_owned(), s.operation.to_owned()) + } + Variant::CancelOperation(c) => { + (c.service.to_owned(), c.operation.to_owned()) + } + }) + .unwrap_or_default(); + self.outstanding_task_map.insert( + TaskToken(t.resp.task_token.clone()), + NexusInFlightTask { + _permit: t.permit.into_used(NexusSlotInfo { service, operation }), + }, + ); + Ok(t.resp) + } + Err(e) => Err(PollActivityError::TonicError(e)), + }) + } +} diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto index 3789ecb89..6dd791c19 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto @@ -5,6 +5,7 @@ option ruby_package = "Temporalio::Internal::Bridge::Api::Nexus"; import "temporal/api/common/v1/message.proto"; import "temporal/api/failure/v1/message.proto"; +import "temporal/api/nexus/v1/message.proto"; import "temporal/sdk/core/common/common.proto"; // Used by core to resolve nexus operations. @@ -16,3 +17,15 @@ message NexusOperationResult { temporal.api.failure.v1.Failure timed_out = 4; } } + +// A response to a Nexus task +message NexusTaskResponse { + // The unique identifier for this task provided in the poll response + bytes task_token = 1; + oneof status { + // The handler completed. Note that the response kind must match the + // request kind (start or cancel). + temporal.api.nexus.v1.Response completed = 2; + temporal.api.nexus.v1.HandlerError error = 3; + } +} \ No newline at end of file diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index 72a31cf28..d596d60fb 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -10,7 +10,6 @@ use temporal_sdk_core_protos::{ }, temporal::api::{ common::v1::{callback, Callback}, - enums::v1::TaskQueueKind, failure::v1::failure::FailureInfo, nexus, nexus::v1::{ @@ -18,11 +17,7 @@ use temporal_sdk_core_protos::{ EndpointSpec, EndpointTarget, HandlerError, StartOperationResponse, }, operatorservice::v1::CreateNexusEndpointRequest, - taskqueue::v1::TaskQueue, - workflowservice::v1::{ - PollNexusTaskQueueRequest, RespondNexusTaskCompletedRequest, - RespondNexusTaskFailedRequest, - }, + workflowservice::v1::{RespondNexusTaskCompletedRequest, RespondNexusTaskFailedRequest}, }, }; use temporal_sdk_core_test_utils::{rand_6_chars, CoreWfStarter}; @@ -46,6 +41,7 @@ async fn nexus_basic( let mut starter = CoreWfStarter::new(wf_name); starter.worker_config.no_remote_activities(true); let mut worker = starter.worker().await; + let core_worker = starter.get_worker().await; let endpoint = mk_endpoint(&mut starter).await; @@ -70,19 +66,7 @@ async fn nexus_basic( let mut client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { - let nt = client - .poll_nexus_task_queue(PollNexusTaskQueueRequest { - namespace: client.namespace().to_owned(), - task_queue: Some(TaskQueue { - name: starter.get_task_queue().to_owned(), - kind: TaskQueueKind::Normal.into(), - normal_name: "".to_string(), - }), - ..Default::default() - }) - .await - .unwrap() - .into_inner(); + let nt = core_worker.poll_nexus_task().await.unwrap(); match outcome { Outcome::Succeed => { client @@ -180,6 +164,7 @@ async fn nexus_async( let mut starter = CoreWfStarter::new(wf_name); starter.worker_config.no_remote_activities(true); let mut worker = starter.worker().await; + let core_worker = starter.get_worker().await; let endpoint = mk_endpoint(&mut starter).await; let schedule_to_close_timeout = if outcome == Outcome::CancelAfterRecordedBeforeStarted { @@ -235,19 +220,7 @@ async fn nexus_async( let mut client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { - let nt = client - .poll_nexus_task_queue(PollNexusTaskQueueRequest { - namespace: client.namespace().to_owned(), - task_queue: Some(TaskQueue { - name: starter.get_task_queue().to_owned(), - kind: TaskQueueKind::Normal.into(), - normal_name: "".to_string(), - }), - ..Default::default() - }) - .await - .unwrap() - .into_inner(); + let nt = core_worker.poll_nexus_task().await.unwrap(); let start_req = assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::StartOperation(sr) => sr @@ -307,19 +280,7 @@ async fn nexus_async( .unwrap(); } if outcome == Outcome::Cancel { - let nt = client - .poll_nexus_task_queue(PollNexusTaskQueueRequest { - namespace: client.namespace().to_owned(), - task_queue: Some(TaskQueue { - name: starter.get_task_queue().to_owned(), - kind: TaskQueueKind::Normal.into(), - normal_name: "".to_string(), - }), - ..Default::default() - }) - .await - .unwrap() - .into_inner(); + let nt = core_worker.poll_nexus_task().await.unwrap(); assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::CancelOperation(_) From f8f8273c07201753a4f8a265dd73a3b43aef7877 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Wed, 8 Jan 2025 09:32:40 -0800 Subject: [PATCH 05/14] Completions through worker --- core-api/src/errors.rs | 15 +++ core-api/src/lib.rs | 15 ++- core/src/worker/client.rs | 45 +++++++ core/src/worker/client/mocks.rs | 14 ++ core/src/worker/mod.rs | 24 +++- core/src/worker/nexus.rs | 125 +++++++++++++++--- .../local/temporal/sdk/core/nexus/nexus.proto | 5 +- tests/integ_tests/workflow_tests/nexus.rs | 102 ++++++++------ 8 files changed, 276 insertions(+), 69 deletions(-) diff --git a/core-api/src/errors.rs b/core-api/src/errors.rs index 4c365f569..10bf64b23 100644 --- a/core-api/src/errors.rs +++ b/core-api/src/errors.rs @@ -67,6 +67,21 @@ pub enum CompleteActivityError { }, } +/// Errors thrown by [crate::Worker::complete_nexus_task] +#[derive(thiserror::Error, Debug)] +pub enum CompleteNexusError { + /// Lang SDK sent us a malformed nexus completion. This likely means a bug in the lang sdk. + #[error("Lang SDK sent us a malformed nexus completion: {reason}")] + MalformeNexusCompletion { + /// Reason the completion was malformed + reason: String, + }, + /// Nexus has not been enabled on this worker. If a user registers any Nexus handlers, the + /// TODO: xxx option must be set to true. + #[error("Nexus is not enabled on this worker")] + NexusNotEnabled, +} + /// Errors we can encounter during workflow processing which we may treat as either WFT failures /// or whole-workflow failures depending on user preference. #[derive(Clone, Debug, Eq, PartialEq, Hash)] diff --git a/core-api/src/lib.rs b/core-api/src/lib.rs index eb2930d95..e220a1f74 100644 --- a/core-api/src/lib.rs +++ b/core-api/src/lib.rs @@ -4,16 +4,16 @@ pub mod worker; use crate::{ errors::{ - CompleteActivityError, CompleteWfError, PollActivityError, PollWfError, + CompleteActivityError, CompleteNexusError, CompleteWfError, PollActivityError, PollWfError, WorkerValidationError, }, worker::WorkerConfig, }; use temporal_sdk_core_protos::{ coresdk::{ - activity_task::ActivityTask, workflow_activation::WorkflowActivation, - workflow_completion::WorkflowActivationCompletion, ActivityHeartbeat, - ActivityTaskCompletion, + activity_task::ActivityTask, nexus::NexusTaskCompletion, + workflow_activation::WorkflowActivation, workflow_completion::WorkflowActivationCompletion, + ActivityHeartbeat, ActivityTaskCompletion, }, temporal::api::workflowservice::v1::PollNexusTaskQueueResponse, }; @@ -68,6 +68,13 @@ pub trait Worker: Send + Sync { completion: ActivityTaskCompletion, ) -> Result<(), CompleteActivityError>; + /// Tell the worker that a nexus task has completed. May (and should) be freely called + /// concurrently. + async fn complete_nexus_task( + &self, + completion: NexusTaskCompletion, + ) -> Result<(), CompleteNexusError>; + /// Notify the Temporal service that an activity is still alive. Long running activities that /// take longer than `activity_heartbeat_timeout` to finish must call this function in order to /// report progress, otherwise the activity will timeout and a new attempt will be scheduled. diff --git a/core/src/worker/client.rs b/core/src/worker/client.rs index 527a4845b..c1b484c3e 100644 --- a/core/src/worker/client.rs +++ b/core/src/worker/client.rs @@ -14,6 +14,7 @@ use temporal_sdk_core_protos::{ }, enums::v1::{TaskQueueKind, WorkflowTaskFailedCause}, failure::v1::Failure, + nexus, protocol::v1::Message as ProtocolMessage, query::v1::WorkflowQueryResult, sdk::v1::WorkflowTaskCompletedMetadata, @@ -114,6 +115,11 @@ pub(crate) trait WorkerClient: Sync + Send { task_token: TaskToken, result: Option, ) -> Result; + async fn complete_nexus_task( + &self, + task_token: TaskToken, + response: nexus::v1::Response, + ) -> Result; async fn record_activity_heartbeat( &self, task_token: TaskToken, @@ -135,6 +141,11 @@ pub(crate) trait WorkerClient: Sync + Send { cause: WorkflowTaskFailedCause, failure: Option, ) -> Result; + async fn fail_nexus_task( + &self, + task_token: TaskToken, + error: nexus::v1::HandlerError, + ) -> Result; async fn get_workflow_execution_history( &self, workflow_id: String, @@ -280,6 +291,23 @@ impl WorkerClient for WorkerClientBag { .into_inner()) } + async fn complete_nexus_task( + &self, + task_token: TaskToken, + response: nexus::v1::Response, + ) -> Result { + Ok(self + .cloned_client() + .respond_nexus_task_completed(RespondNexusTaskCompletedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + response: Some(response), + }) + .await? + .into_inner()) + } + async fn record_activity_heartbeat( &self, task_token: TaskToken, @@ -358,6 +386,23 @@ impl WorkerClient for WorkerClientBag { .into_inner()) } + async fn fail_nexus_task( + &self, + task_token: TaskToken, + error: nexus::v1::HandlerError, + ) -> Result { + Ok(self + .cloned_client() + .respond_nexus_task_failed(RespondNexusTaskFailedRequest { + namespace: self.namespace.clone(), + identity: self.identity.clone(), + task_token: task_token.0, + error: Some(error), + }) + .await? + .into_inner()) + } + async fn get_workflow_execution_history( &self, workflow_id: String, diff --git a/core/src/worker/client/mocks.rs b/core/src/worker/client/mocks.rs index 645f21ec9..41fcc5a62 100644 --- a/core/src/worker/client/mocks.rs +++ b/core/src/worker/client/mocks.rs @@ -77,6 +77,13 @@ mockall::mock! { ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; + fn complete_nexus_task<'a, 'b>( + &self, + task_token: TaskToken, + response: nexus::v1::Response, + ) -> impl Future> + Send + 'b + where 'a: 'b, Self: 'b; + fn cancel_activity_task<'a, 'b>( &self, task_token: TaskToken, @@ -99,6 +106,13 @@ mockall::mock! { ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; + fn fail_nexus_task<'a, 'b>( + &self, + task_token: TaskToken, + error: nexus::v1::HandlerError, + ) -> impl Future> + Send + 'b + where 'a: 'b, Self: 'b; + fn record_activity_heartbeat<'a, 'b>( &self, task_token: TaskToken, diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 8c3f91529..6025354b2 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -56,7 +56,7 @@ use std::{ time::Duration, }; use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics, WorkerKey}; -use temporal_sdk_core_api::errors::WorkerValidationError; +use temporal_sdk_core_api::errors::{CompleteNexusError, WorkerValidationError}; use temporal_sdk_core_protos::{ coresdk::{ activity_result::activity_execution_result, @@ -76,6 +76,7 @@ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; +use temporal_sdk_core_protos::coresdk::nexus::NexusTaskCompletion; #[cfg(test)] use { crate::{ @@ -185,6 +186,27 @@ impl WorkerTrait for Worker { self.complete_activity(task_token, status).await } + async fn complete_nexus_task( + &self, + completion: NexusTaskCompletion, + ) -> Result<(), CompleteNexusError> { + let task_token = TaskToken(completion.task_token); + + let status = if let Some(s) = completion.status { + s + } else { + return Err(CompleteNexusError::MalformeNexusCompletion { + reason: "Nexus completion had empty status field".to_owned(), + }); + }; + + if let Some(nm) = self.nexus_mgr.as_ref() { + nm.complete_task(task_token, status, &*self.client).await + } else { + Err(CompleteNexusError::NexusNotEnabled) + } + } + fn record_activity_heartbeat(&self, details: ActivityHeartbeat) { self.record_heartbeat(details); } diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index e65e00fea..b472ba55b 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -2,13 +2,20 @@ use crate::{ abstractions::UsedMeteredSemPermit, pollers::{new_nexus_task_poller, BoxedNexusPoller, NexusPollItem}, telemetry::metrics::MetricsContext, + worker::client::WorkerClient, }; use futures_util::{stream::BoxStream, Stream, StreamExt}; -use std::collections::HashMap; -use temporal_sdk_core_api::{errors::PollActivityError, worker::NexusSlotKind}; +use std::{collections::HashMap, sync::Arc}; +use temporal_sdk_core_api::{ + errors::{CompleteNexusError, PollActivityError}, + worker::NexusSlotKind, +}; use temporal_sdk_core_protos::{ - coresdk::NexusSlotInfo, - temporal::api::{nexus::v1::request::Variant, workflowservice::v1::PollNexusTaskQueueResponse}, + coresdk::{nexus::nexus_task_completion, NexusSlotInfo}, + temporal::api::{ + nexus::v1::{request::Variant, response}, + workflowservice::v1::PollNexusTaskQueueResponse, + }, TaskToken, }; use tokio::sync::Mutex; @@ -19,6 +26,8 @@ pub(super) struct NexusManager { task_stream: Mutex>>, /// Token to notify when poll returned a shutdown error poll_returned_shutdown_token: CancellationToken, + /// Outstanding nexus tasks that have been issued to lang but not yet completed + outstanding_task_map: OutstandingTaskMap, } impl NexusManager { @@ -29,9 +38,11 @@ impl NexusManager { ) -> Self { let source_stream = new_nexus_task_poller(poller, metrics, shutdown_initiated_token); let task_stream = NexusTaskStream::new(source_stream); + let outstanding_task_map = task_stream.outstanding_task_map.clone(); Self { task_stream: Mutex::new(task_stream.into_stream().boxed()), poll_returned_shutdown_token: CancellationToken::new(), + outstanding_task_map, } } @@ -46,15 +57,68 @@ impl NexusManager { Err(PollActivityError::ShutDown) }) } + + pub(super) async fn complete_task( + &self, + tt: TaskToken, + status: nexus_task_completion::Status, + client: &dyn WorkerClient, + ) -> Result<(), CompleteNexusError> { + if let Some(task_info) = self.outstanding_task_map.lock().remove(&tt) { + let maybe_net_err = match status { + nexus_task_completion::Status::Completed(c) => { + // Server doesn't provide obvious errors for this validation, so it's done + // here to make life easier for lang implementors. + match &c.variant { + Some(response::Variant::StartOperation(_)) => { + if task_info.request_kind != RequestKind::Start { + return Err(CompleteNexusError::MalformeNexusCompletion { + reason: "Nexus request was StartOperation but response was not" + .to_string(), + }); + } + } + Some(response::Variant::CancelOperation(_)) => { + if task_info.request_kind != RequestKind::Cancel { + return Err(CompleteNexusError::MalformeNexusCompletion { + reason: + "Nexus request was CancelOperation but response was not" + .to_string(), + }); + } + } + None => { + return Err(CompleteNexusError::MalformeNexusCompletion { + reason: "Nexus completion must contain a status variant " + .to_string(), + }) + } + } + client.complete_nexus_task(tt, c).await.err() + } + nexus_task_completion::Status::Error(e) => { + client.fail_nexus_task(tt, e).await.err() + } + }; + if let Some(e) = maybe_net_err { + warn!( + error=?e, + "Network error while completing Nexus task", + ); + } + } else { + warn!( + "Attempted to complete nexus task {} but we were not tracking it", + &tt + ); + } + Ok(()) + } } struct NexusTaskStream { source_stream: S, - outstanding_task_map: HashMap, -} - -struct NexusInFlightTask { - _permit: UsedMeteredSemPermit, + outstanding_task_map: OutstandingTaskMap, } impl NexusTaskStream @@ -64,32 +128,37 @@ where fn new(source: S) -> Self { Self { source_stream: source, - outstanding_task_map: HashMap::new(), + outstanding_task_map: Arc::new(Default::default()), } } fn into_stream( - mut self, + self, ) -> impl Stream> { self.source_stream.map(move |t| match t { Ok(t) => { - let (service, operation) = t + let (service, operation, request_kind) = t .resp .request .as_ref() .and_then(|r| r.variant.as_ref()) .map(|v| match v { - Variant::StartOperation(s) => { - (s.service.to_owned(), s.operation.to_owned()) - } - Variant::CancelOperation(c) => { - (c.service.to_owned(), c.operation.to_owned()) - } + Variant::StartOperation(s) => ( + s.service.to_owned(), + s.operation.to_owned(), + RequestKind::Start, + ), + Variant::CancelOperation(c) => ( + c.service.to_owned(), + c.operation.to_owned(), + RequestKind::Cancel, + ), }) .unwrap_or_default(); - self.outstanding_task_map.insert( + self.outstanding_task_map.lock().insert( TaskToken(t.resp.task_token.clone()), NexusInFlightTask { + request_kind, _permit: t.permit.into_used(NexusSlotInfo { service, operation }), }, ); @@ -99,3 +168,21 @@ where }) } } + +type OutstandingTaskMap = Arc>>; + +struct NexusInFlightTask { + request_kind: RequestKind, + _permit: UsedMeteredSemPermit, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +enum RequestKind { + Start, + Cancel, +} +impl Default for RequestKind { + fn default() -> Self { + RequestKind::Start + } +} diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto index 6dd791c19..d417aaba9 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto @@ -19,13 +19,14 @@ message NexusOperationResult { } // A response to a Nexus task -message NexusTaskResponse { +message NexusTaskCompletion { // The unique identifier for this task provided in the poll response bytes task_token = 1; oneof status { - // The handler completed. Note that the response kind must match the + // The handler completed (successfully or not). Note that the response kind must match the // request kind (start or cancel). temporal.api.nexus.v1.Response completed = 2; + // The handler could not complete the request for some reason. temporal.api.nexus.v1.HandlerError error = 3; } } \ No newline at end of file diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index d596d60fb..853f6e078 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -1,11 +1,14 @@ use anyhow::bail; use assert_matches::assert_matches; use std::time::Duration; -use temporal_client::{WfClientExt, WorkflowClientTrait, WorkflowOptions, WorkflowService}; +use temporal_client::{WfClientExt, WorkflowClientTrait, WorkflowOptions}; use temporal_sdk::{CancellableFuture, NexusOperationOptions, WfContext, WfExitValue}; use temporal_sdk_core_protos::{ coresdk::{ - nexus::{nexus_operation_result, NexusOperationResult}, + nexus::{ + nexus_operation_result, nexus_task_completion, NexusOperationResult, + NexusTaskCompletion, + }, FromJsonPayloadExt, }, temporal::api::{ @@ -14,10 +17,10 @@ use temporal_sdk_core_protos::{ nexus, nexus::v1::{ endpoint_target, request, start_operation_response, workflow_event_link_from_nexus, - EndpointSpec, EndpointTarget, HandlerError, StartOperationResponse, + CancelOperationResponse, EndpointSpec, EndpointTarget, HandlerError, + StartOperationResponse, }, operatorservice::v1::CreateNexusEndpointRequest, - workflowservice::v1::{RespondNexusTaskCompletedRequest, RespondNexusTaskFailedRequest}, }, }; use temporal_sdk_core_test_utils::{rand_6_chars, CoreWfStarter}; @@ -64,44 +67,44 @@ async fn nexus_basic( }); starter.start_with_worker(wf_name, &mut worker).await; - let mut client = starter.get_client().await.get_client().clone(); + let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { let nt = core_worker.poll_nexus_task().await.unwrap(); match outcome { Outcome::Succeed => { - client - .respond_nexus_task_completed(RespondNexusTaskCompletedRequest { - namespace: client.namespace().to_owned(), + core_worker + .complete_nexus_task(NexusTaskCompletion { task_token: nt.task_token, - response: Some(nexus::v1::Response { - variant: Some(nexus::v1::response::Variant::StartOperation( - StartOperationResponse { - variant: Some(start_operation_response::Variant::SyncSuccess( - start_operation_response::Sync { - payload: Some("yay".into()), - }, - )), - }, - )), - }), - ..Default::default() + status: Some(nexus_task_completion::Status::Completed( + nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::StartOperation( + StartOperationResponse { + variant: Some( + start_operation_response::Variant::SyncSuccess( + start_operation_response::Sync { + payload: Some("yay".into()), + }, + ), + ), + }, + )), + }, + )), }) .await .unwrap(); } Outcome::Fail => { - client - .respond_nexus_task_failed(RespondNexusTaskFailedRequest { - namespace: client.namespace().to_owned(), + core_worker + .complete_nexus_task(NexusTaskCompletion { task_token: nt.task_token, - error: Some(HandlerError { + status: Some(nexus_task_completion::Status::Error(HandlerError { error_type: "BAD_REQUEST".to_string(), // bad req is non-retryable failure: Some(nexus::v1::Failure { message: "busted".to_string(), ..Default::default() }), - }), - identity: "whatever".to_string(), + })), }) .await .unwrap(); @@ -218,7 +221,7 @@ async fn nexus_async( let submitter = worker.get_submitter_handle(); starter.start_with_worker(wf_name, &mut worker).await; - let mut client = starter.get_client().await.get_client().clone(); + let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { let nt = core_worker.poll_nexus_task().await.unwrap(); let start_req = assert_matches!( @@ -258,23 +261,23 @@ async fn nexus_async( } if outcome != Outcome::CancelAfterRecordedBeforeStarted { // Do not say the operation started if we are trying to test this type of cancel - client - .respond_nexus_task_completed(RespondNexusTaskCompletedRequest { - namespace: client.namespace().to_owned(), + core_worker + .complete_nexus_task(NexusTaskCompletion { task_token: nt.task_token, - response: Some(nexus::v1::Response { - variant: Some(nexus::v1::response::Variant::StartOperation( - StartOperationResponse { - variant: Some(start_operation_response::Variant::AsyncSuccess( - start_operation_response::Async { - operation_id: "op-1".to_string(), - links: vec![], - }, - )), - }, - )), - }), - ..Default::default() + status: Some(nexus_task_completion::Status::Completed( + nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::StartOperation( + StartOperationResponse { + variant: Some(start_operation_response::Variant::AsyncSuccess( + start_operation_response::Async { + operation_id: "op-1".to_string(), + links: vec![], + }, + )), + }, + )), + }, + )), }) .await .unwrap(); @@ -289,6 +292,19 @@ async fn nexus_async( .cancel_workflow_execution(completer_id, None, "nexus cancel".to_string(), None) .await .unwrap(); + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: nt.task_token, + status: Some(nexus_task_completion::Status::Completed( + nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::CancelOperation( + CancelOperationResponse {}, + )), + }, + )), + }) + .await + .unwrap(); } }; From 157aa2b837101d3fd0bded0661bb4caa177e4886 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Wed, 8 Jan 2025 11:08:44 -0800 Subject: [PATCH 06/14] Quick refactoring to make workflow handles more available --- client/src/workflow_handle/mod.rs | 5 ++ test-utils/src/lib.rs | 73 +++++++++++-------- tests/integ_tests/update_tests.rs | 54 +++++++------- .../workflow_tests/local_activities.rs | 28 ++++--- 4 files changed, 91 insertions(+), 69 deletions(-) diff --git a/client/src/workflow_handle/mod.rs b/client/src/workflow_handle/mod.rs index fff49c9fd..495ec4239 100644 --- a/client/src/workflow_handle/mod.rs +++ b/client/src/workflow_handle/mod.rs @@ -109,6 +109,11 @@ where &self.info } + /// Get the client attached to this handle + pub fn client(&self) -> &CT { + &self.client + } + /// Await the result of the workflow execution pub async fn get_workflow_result( &self, diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index e9b7619b8..518e023e2 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -11,7 +11,7 @@ pub mod workflows; pub use temporal_sdk_core::replay::HistoryForReplay; use crate::stream::{Stream, TryStreamExt}; -use anyhow::{Context, Error}; +use anyhow::Context; use assert_matches::assert_matches; use futures_util::{future, stream, stream::FuturesUnordered, StreamExt}; use parking_lot::Mutex; @@ -22,8 +22,8 @@ use std::{ time::Duration, }; use temporal_client::{ - Client, ClientTlsConfig, RetryClient, TlsConfig, WorkflowClientTrait, WorkflowExecutionInfo, - WorkflowOptions, + Client, ClientTlsConfig, RetryClient, TlsConfig, WfClientExt, WorkflowClientTrait, + WorkflowExecutionInfo, WorkflowHandle, WorkflowOptions, }; use temporal_sdk::{ interceptors::{FailOnNondeterminismInterceptor, WorkerInterceptor}, @@ -53,6 +53,7 @@ use temporal_sdk_core_protos::{ QuerySuccess, ScheduleActivity, ScheduleLocalActivity, StartTimer, }, workflow_completion::WorkflowActivationCompletion, + FromPayloadsExt, }, temporal::api::{ common::v1::Payload, history::v1::History, @@ -248,13 +249,13 @@ impl CoreWfStarter { self.start_wf_with_id(self.task_queue_name.clone()).await } - /// Starts the workflow using the worker, returns run id. + /// Starts the workflow using the worker pub async fn start_with_worker( &self, wf_name: impl Into, worker: &mut TestWorker, - ) -> String { - worker + ) -> WorkflowHandle, Vec> { + let run_id = worker .submit_wf( self.task_queue_name.clone(), wf_name.into(), @@ -262,7 +263,12 @@ impl CoreWfStarter { self.workflow_options.clone(), ) .await + .unwrap(); + self.initted_worker + .get() .unwrap() + .client + .get_untyped_workflow_handle(&self.task_queue_name, run_id) } pub async fn eager_start_with_worker( @@ -301,31 +307,6 @@ impl CoreWfStarter { .run_id } - /// Fetch the history for the indicated workflow and replay it using the provided worker. - /// Can be used after completing workflows normally to ensure replay works as well. - pub async fn fetch_history_and_replay( - &mut self, - wf_id: impl Into, - run_id: impl Into, - worker: &mut Worker, - ) -> Result<(), anyhow::Error> { - let wf_id = wf_id.into(); - // Fetch history and replay it - let history = self - .get_client() - .await - .get_workflow_execution_history(wf_id.clone(), Some(run_id.into()), vec![]) - .await? - .history - .expect("history field must be populated"); - let with_id = HistoryForReplay::new(history, wf_id); - let replay_worker = init_core_replay_preloaded(worker.task_queue(), [with_id]); - worker.with_new_core_worker(replay_worker); - worker.set_worker_interceptor(FailOnNondeterminismInterceptor {}); - worker.run().await.unwrap(); - Ok(()) - } - pub fn get_task_queue(&self) -> &str { &self.task_queue_name } @@ -617,7 +598,7 @@ impl WorkerInterceptor for TestWorkerCompletionIceptor { n.on_shutdown(sdk_worker); } } - async fn on_workflow_activation(&self, a: &WorkflowActivation) -> Result<(), Error> { + async fn on_workflow_activation(&self, a: &WorkflowActivation) -> Result<(), anyhow::Error> { if let Some(n) = self.next.as_ref() { n.on_workflow_activation(a).await?; } @@ -837,6 +818,34 @@ where } } +#[async_trait::async_trait(?Send)] +pub trait WorkflowHandleExt { + async fn fetch_history_and_replay(&self, worker: &mut Worker) -> Result<(), anyhow::Error>; +} + +#[async_trait::async_trait(?Send)] +impl WorkflowHandleExt for WorkflowHandle, R> +where + R: FromPayloadsExt, +{ + async fn fetch_history_and_replay(&self, worker: &mut Worker) -> Result<(), anyhow::Error> { + let wf_id = self.info().workflow_id.clone(); + let run_id = self.info().run_id.clone(); + let history = self + .client() + .get_workflow_execution_history(wf_id.clone(), run_id, vec![]) + .await? + .history + .expect("history field must be populated"); + let with_id = HistoryForReplay::new(history, wf_id); + let replay_worker = init_core_replay_preloaded(worker.task_queue(), [with_id]); + worker.with_new_core_worker(replay_worker); + worker.set_worker_interceptor(FailOnNondeterminismInterceptor {}); + worker.run().await.unwrap(); + Ok(()) + } +} + /// Initiate shutdown, drain the pollers, and wait for shutdown to complete. pub async fn drain_pollers_and_shutdown(worker: &Arc) { worker.initiate_shutdown(); diff --git a/tests/integ_tests/update_tests.rs b/tests/integ_tests/update_tests.rs index a21e32feb..d413d0190 100644 --- a/tests/integ_tests/update_tests.rs +++ b/tests/integ_tests/update_tests.rs @@ -33,7 +33,7 @@ use temporal_sdk_core_protos::{ }; use temporal_sdk_core_test_utils::{ drain_pollers_and_shutdown, init_core_and_create_wf, init_core_replay_preloaded, - start_timer_cmd, CoreWfStarter, WorkerTestHelpers, + start_timer_cmd, CoreWfStarter, WorkerTestHelpers, WorkflowHandleExt, }; use tokio::{join, sync::Barrier}; use uuid::Uuid; @@ -672,7 +672,7 @@ async fn update_with_local_acts() { }, ); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { // make sure update has a chance to get registered @@ -707,8 +707,8 @@ async fn update_with_local_acts() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -730,7 +730,7 @@ async fn update_rejection_sdk() { Ok(().into()) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -751,8 +751,8 @@ async fn update_rejection_sdk() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -774,7 +774,7 @@ async fn update_fail_sdk() { Ok(().into()) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -795,8 +795,8 @@ async fn update_fail_sdk() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -822,7 +822,7 @@ async fn update_timer_sequence() { Ok(().into()) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -843,8 +843,8 @@ async fn update_timer_sequence() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -873,7 +873,7 @@ async fn task_failure_during_validation() { Ok(().into()) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -894,8 +894,8 @@ async fn task_failure_during_validation() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id.clone(), run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); // Verify we did not spam task failures. There should only be one. @@ -937,7 +937,7 @@ async fn task_failure_after_update() { Ok(().into()) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -958,8 +958,8 @@ async fn task_failure_after_update() { worker.run_until_done().await.unwrap(); }; join!(update, run); - starter - .fetch_history_and_replay(wf_id.clone(), run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -1002,7 +1002,8 @@ async fn worker_restarted_in_middle_of_update() { Ok(echo_me) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; + let wf_id = starter.get_task_queue().to_string(); let update = async { let res = client @@ -1039,7 +1040,7 @@ async fn worker_restarted_in_middle_of_update() { BARR.wait().await; // Poke the workflow off the sticky queue to get it to complete faster than WFT timeout client - .reset_sticky_task_queue(wf_id.clone(), run_id.clone()) + .reset_sticky_task_queue(wf_id.clone(), "".to_string()) .await .unwrap(); }; @@ -1053,8 +1054,8 @@ async fn worker_restarted_in_middle_of_update() { worker.run_until_done().await.unwrap(); }; join!(update, run, stopper); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -1108,7 +1109,8 @@ async fn update_after_empty_wft() { Ok(echo_me) }); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; + let wf_id = starter.get_task_queue().to_string(); let update = async { client @@ -1140,8 +1142,8 @@ async fn update_after_empty_wft() { worker.run_until_done().await.unwrap(); }; join!(update, runner); - starter - .fetch_history_and_replay(wf_id, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } diff --git a/tests/integ_tests/workflow_tests/local_activities.rs b/tests/integ_tests/workflow_tests/local_activities.rs index 3ee1e6503..dc90620ee 100644 --- a/tests/integ_tests/workflow_tests/local_activities.rs +++ b/tests/integ_tests/workflow_tests/local_activities.rs @@ -5,7 +5,7 @@ use std::{ sync::atomic::{AtomicU8, Ordering}, time::Duration, }; -use temporal_client::WorkflowOptions; +use temporal_client::{WfClientExt, WorkflowOptions}; use temporal_sdk::{ interceptors::WorkerInterceptor, ActContext, ActivityError, ActivityOptions, CancellableFuture, LocalActivityOptions, WfContext, WorkflowResult, @@ -23,6 +23,7 @@ use temporal_sdk_core_protos::{ }; use temporal_sdk_core_test_utils::{ history_from_proto_binary, replay_sdk_worker, workflows::la_problem_workflow, CoreWfStarter, + WorkflowHandleExt, }; use tokio_util::sync::CancellationToken; @@ -47,11 +48,10 @@ async fn one_local_activity() { worker.register_wf(wf_name.to_owned(), one_local_activity_wf); worker.register_activity("echo_activity", echo); - let run_id = starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; worker.run_until_done().await.unwrap(); - let tq = starter.get_task_queue().to_string(); - starter - .fetch_history_and_replay(tq, run_id, worker.inner_mut()) + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -190,8 +190,10 @@ async fn local_act_retry_timer_backoff() { .await .unwrap(); worker.run_until_done().await.unwrap(); - starter - .fetch_history_and_replay(wf_name, run_id, worker.inner_mut()) + let client = starter.get_client().await; + let handle = client.get_untyped_workflow_handle(wf_name, run_id); + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -586,8 +588,10 @@ async fn repro_nondeterminism_with_timer_bug() { .await .unwrap(); worker.run_until_done().await.unwrap(); - starter - .fetch_history_and_replay(wf_name, run_id, worker.inner_mut()) + let client = starter.get_client().await; + let handle = client.get_untyped_workflow_handle(wf_name, run_id); + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } @@ -737,8 +741,10 @@ async fn la_resolve_same_time_as_other_cancel() { .await .unwrap(); worker.run_until_done().await.unwrap(); - starter - .fetch_history_and_replay(wf_name, run_id, worker.inner_mut()) + let client = starter.get_client().await; + let handle = client.get_untyped_workflow_handle(wf_name, run_id); + handle + .fetch_history_and_replay(worker.inner_mut()) .await .unwrap(); } From 82a3005d974b41827741b6dc97f81e638c0ee209 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Wed, 8 Jan 2025 14:34:20 -0800 Subject: [PATCH 07/14] Combine poll errors --- core-api/src/errors.rs | 28 ++++++--------------- core-api/src/lib.rs | 8 +++--- core/src/core_tests/activity_tasks.rs | 4 +-- core/src/core_tests/local_activities.rs | 9 +++---- core/src/core_tests/mod.rs | 11 ++++---- core/src/core_tests/workers.rs | 20 +++++++-------- core/src/core_tests/workflow_tasks.rs | 10 ++++---- core/src/lib.rs | 2 +- core/src/test_help/mod.rs | 11 +++----- core/src/worker/activities.rs | 25 +++++++++--------- core/src/worker/mod.rs | 22 ++++++++-------- core/src/worker/nexus.rs | 16 +++++------- core/src/worker/workflow/mod.rs | 8 +++--- core/src/worker/workflow/workflow_stream.rs | 14 +++++------ sdk/src/lib.rs | 9 +++---- test-utils/src/lib.rs | 6 ++--- tests/integ_tests/workflow_tests.rs | 6 ++--- tests/integ_tests/workflow_tests/replay.rs | 13 ++++------ 18 files changed, 96 insertions(+), 126 deletions(-) diff --git a/core-api/src/errors.rs b/core-api/src/errors.rs index 10bf64b23..c26a201f5 100644 --- a/core-api/src/errors.rs +++ b/core-api/src/errors.rs @@ -13,30 +13,18 @@ pub enum WorkerValidationError { }, } -/// Errors thrown by [crate::Worker::poll_workflow_activation] +/// Errors thrown by [crate::Worker] polling methods #[derive(thiserror::Error, Debug)] -pub enum PollWfError { - /// [crate::Worker::shutdown] was called, and there are no more replay tasks to be handled. Lang - /// must call [crate::Worker::complete_workflow_activation] for any remaining tasks, and then - /// may exit. - #[error("Core is shut down and there are no more workflow replay tasks")] +pub enum PollError { + /// [crate::Worker::shutdown] was called, and there are no more tasks to be handled from this + /// poll function. Lang must call [crate::Worker::complete_workflow_activation], + /// [crate::Worker::complete_activity_task], or + /// [crate::Worker::complete_nexus_task] for any remaining tasks, and then may exit. + #[error("Core is shut down and there are no more tasks of this kind")] ShutDown, /// Unhandled error when calling the temporal server. Core will attempt to retry any non-fatal /// errors, so lang should consider this fatal. - #[error("Unhandled grpc error when workflow polling: {0:?}")] - TonicError(#[from] tonic::Status), -} - -/// Errors thrown by [crate::Worker::poll_activity_task] -#[derive(thiserror::Error, Debug)] -pub enum PollActivityError { - /// [crate::Worker::shutdown] was called, we will no longer fetch new activity tasks. Lang must - /// ensure it is finished with any workflow replay, see [PollWfError::ShutDown] - #[error("Core is shut down")] - ShutDown, - /// Unhandled error when calling the temporal server. Core will attempt to retry any non-fatal - /// errors, so lang should consider this fatal. - #[error("Unhandled grpc error when activity polling: {0:?}")] + #[error("Unhandled grpc error when polling: {0:?}")] TonicError(#[from] tonic::Status), } diff --git a/core-api/src/lib.rs b/core-api/src/lib.rs index e220a1f74..f82e90b0b 100644 --- a/core-api/src/lib.rs +++ b/core-api/src/lib.rs @@ -4,7 +4,7 @@ pub mod worker; use crate::{ errors::{ - CompleteActivityError, CompleteNexusError, CompleteWfError, PollActivityError, PollWfError, + CompleteActivityError, CompleteNexusError, CompleteWfError, PollError, WorkerValidationError, }, worker::WorkerConfig, @@ -40,17 +40,17 @@ pub trait Worker: Send + Sync { /// & job processing. /// /// Do not call poll concurrently. It handles polling the server concurrently internally. - async fn poll_workflow_activation(&self) -> Result; + async fn poll_workflow_activation(&self) -> Result; /// Ask the worker for some work, returning an [ActivityTask]. It is then the language SDK's /// responsibility to call the appropriate activity code with the provided inputs. Blocks /// indefinitely until such work is available or [Worker::shutdown] is called. /// /// Do not call poll concurrently. It handles polling the server concurrently internally. - async fn poll_activity_task(&self) -> Result; + async fn poll_activity_task(&self) -> Result; /// TODO: Keep or combine? - async fn poll_nexus_task(&self) -> Result; + async fn poll_nexus_task(&self) -> Result; /// Tell the worker that a workflow activation has completed. May (and should) be freely called /// concurrently. The future may take some time to resolve, as fetching more events might be diff --git a/core/src/core_tests/activity_tasks.rs b/core/src/core_tests/activity_tasks.rs index eb8d224aa..cd48bf590 100644 --- a/core/src/core_tests/activity_tasks.rs +++ b/core/src/core_tests/activity_tasks.rs @@ -25,7 +25,7 @@ use std::{ use temporal_client::WorkflowOptions; use temporal_sdk::{ActivityOptions, WfContext}; use temporal_sdk_core_api::{ - errors::{CompleteActivityError, PollActivityError}, + errors::{CompleteActivityError, PollError}, Worker as WorkerTrait, }; use temporal_sdk_core_protos::{ @@ -984,7 +984,7 @@ async fn activity_tasks_from_completion_reserve_slots() { core.initiate_shutdown(); // Even though this test requests eager activity tasks, none are returned in poll responses. let err = core.poll_activity_task().await.unwrap_err(); - assert_matches!(err, PollActivityError::ShutDown); + assert_matches!(err, PollError::ShutDown); }; // This wf poll should *not* set the flag that it wants tasks back since both slots are // occupied diff --git a/core/src/core_tests/local_activities.rs b/core/src/core_tests/local_activities.rs index b684e4ce7..b16eddd63 100644 --- a/core/src/core_tests/local_activities.rs +++ b/core/src/core_tests/local_activities.rs @@ -23,10 +23,7 @@ use temporal_client::WorkflowOptions; use temporal_sdk::{ ActContext, ActivityError, LocalActivityOptions, WfContext, WorkflowFunction, WorkflowResult, }; -use temporal_sdk_core_api::{ - errors::{PollActivityError, PollWfError}, - Worker, -}; +use temporal_sdk_core_api::{errors::PollError, Worker}; use temporal_sdk_core_protos::{ coresdk::{ activity_result::ActivityExecutionResult, @@ -1179,8 +1176,8 @@ async fn local_activities_can_be_delivered_during_shutdown() { }; let (wf_r, act_r) = join!(wf_poller, at_poller); - assert_matches!(wf_r.unwrap_err(), PollWfError::ShutDown); - assert_matches!(act_r.unwrap_err(), PollActivityError::ShutDown); + assert_matches!(wf_r.unwrap_err(), PollError::ShutDown); + assert_matches!(act_r.unwrap_err(), PollError::ShutDown); } #[tokio::test] diff --git a/core/src/core_tests/mod.rs b/core/src/core_tests/mod.rs index e4ad53735..3ea41bec2 100644 --- a/core/src/core_tests/mod.rs +++ b/core/src/core_tests/mod.rs @@ -10,14 +10,13 @@ mod workflow_cancels; mod workflow_tasks; use crate::{ - errors::{PollActivityError, PollWfError}, + errors::PollError, test_help::{build_mock_pollers, canned_histories, mock_worker, test_worker_cfg, MockPollCfg}, worker::client::mocks::{mock_manual_workflow_client, mock_workflow_client}, Worker, }; use futures_util::FutureExt; -use std::sync::LazyLock; -use std::time::Duration; +use std::{sync::LazyLock, time::Duration}; use temporal_sdk_core_api::Worker as WorkerTrait; use temporal_sdk_core_protos::coresdk::workflow_completion::WorkflowActivationCompletion; use tokio::{sync::Barrier, time::sleep}; @@ -40,7 +39,7 @@ async fn after_shutdown_server_is_not_polled() { worker.shutdown().await; assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); worker.finalize_shutdown().await; } @@ -86,11 +85,11 @@ async fn shutdown_interrupts_both_polls() { tokio::join! { async { assert_matches!(worker.poll_activity_task().await.unwrap_err(), - PollActivityError::ShutDown); + PollError::ShutDown); }, async { assert_matches!(worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown); + PollError::ShutDown); }, async { // Give polling a bit to get stuck, then shutdown diff --git a/core/src/core_tests/workers.rs b/core/src/core_tests/workers.rs index 00b70e184..44bc5f419 100644 --- a/core/src/core_tests/workers.rs +++ b/core/src/core_tests/workers.rs @@ -11,7 +11,7 @@ use crate::{ MockWorkerClient, }, }, - PollActivityError, PollWfError, + PollError, }; use futures_util::{stream, stream::StreamExt}; use std::{cell::RefCell, time::Duration}; @@ -53,7 +53,7 @@ async fn after_shutdown_of_worker_get_shutdown_err() { // Shutdown proceeds if the only outstanding activations are evictions assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }); } @@ -87,7 +87,7 @@ async fn shutdown_worker_can_complete_pending_activation() { // Shutdown proceeds if the only outstanding activations are evictions assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }); } @@ -120,7 +120,7 @@ async fn worker_shutdown_during_poll_doesnt_deadlock() { let _ = tx.send(true); }; let (pollres, _) = tokio::join!(pollfut, shutdownfut); - assert_matches!(pollres.unwrap_err(), PollWfError::ShutDown); + assert_matches!(pollres.unwrap_err(), PollError::ShutDown); worker.finalize_shutdown().await; } @@ -153,11 +153,11 @@ async fn can_shutdown_local_act_only_worker_when_act_polling() { // We need to see workflow poll return shutdown before activity poll will assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); assert_matches!( worker.poll_activity_task().await.unwrap_err(), - PollActivityError::ShutDown + PollError::ShutDown ); } ); @@ -187,7 +187,7 @@ async fn complete_with_task_not_found_during_shutdown() { // This will return shutdown once the completion goes through assert_matches!( core.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }; let complete_fut = async { @@ -284,12 +284,12 @@ async fn worker_can_shutdown_after_never_polling_ok(#[values(true, false)] poll_ // Must continue polling until polls return shutdown. if poll_workflow { let res = core.poll_workflow_activation().await.unwrap_err(); - if !matches!(res, PollWfError::ShutDown) { + if !matches!(res, PollError::ShutDown) { continue; } } let res = core.poll_activity_task().await.unwrap_err(); - if !matches!(res, PollActivityError::ShutDown) { + if !matches!(res, PollError::ShutDown) { continue; } core.finalize_shutdown().await; @@ -356,7 +356,7 @@ async fn worker_shutdown_api(#[case] use_cache: bool, #[case] api_success: bool) // Shutdown proceeds if the only outstanding activations are evictions assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }); } diff --git a/core/src/core_tests/workflow_tasks.rs b/core/src/core_tests/workflow_tasks.rs index b2f937571..42fa25b4f 100644 --- a/core/src/core_tests/workflow_tasks.rs +++ b/core/src/core_tests/workflow_tasks.rs @@ -31,7 +31,7 @@ use std::{ use temporal_client::WorkflowOptions; use temporal_sdk::{ActivityOptions, CancellableFuture, TimerOptions, WfContext}; use temporal_sdk_core_api::{ - errors::PollWfError, + errors::PollError, worker::{ SlotMarkUsedContext, SlotReleaseContext, SlotReservationContext, SlotSupplier, SlotSupplierPermit, WorkflowSlotKind, @@ -2024,7 +2024,7 @@ async fn autocompletes_wft_no_work() { // work assert_matches!( core.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); core.shutdown().await; @@ -2590,7 +2590,7 @@ async fn _do_post_terminal_commands_test( .unwrap(); let act = core.poll_workflow_activation().await; - assert_matches!(act.unwrap_err(), PollWfError::ShutDown); + assert_matches!(act.unwrap_err(), PollError::ShutDown); core.shutdown().await; } @@ -2792,7 +2792,7 @@ async fn poller_wont_run_ahead_of_task_slots() { // This should end up getting shut down after the other routine finishes tasks assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }; // Wait for a bit concurrently with above, verify no extra tasks got taken, shutdown @@ -3047,7 +3047,7 @@ async fn slot_provider_cant_hand_out_more_permits_than_cache_size() { // This should end up getting shut down after the other routine finishes tasks assert_matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); }; // Wait for a bit concurrently with above, verify no extra tasks got taken, shutdown diff --git a/core/src/lib.rs b/core/src/lib.rs index 93880f995..8d1f71464 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -60,7 +60,7 @@ use futures_util::Stream; use std::sync::Arc; use temporal_client::{ConfiguredClient, TemporalServiceClientWithMetrics}; use temporal_sdk_core_api::{ - errors::{CompleteActivityError, PollActivityError, PollWfError}, + errors::{CompleteActivityError, PollError}, telemetry::TelemetryOptions, Worker as WorkerTrait, }; diff --git a/core/src/test_help/mod.rs b/core/src/test_help/mod.rs index 039f5b603..7a8c811c3 100644 --- a/core/src/test_help/mod.rs +++ b/core/src/test_help/mod.rs @@ -31,10 +31,7 @@ use std::{ time::Duration, }; use temporal_sdk::interceptors::FailOnNondeterminismInterceptor; -use temporal_sdk_core_api::{ - errors::{PollActivityError, PollWfError}, - Worker as WorkerTrait, -}; +use temporal_sdk_core_api::{errors::PollError, Worker as WorkerTrait}; use temporal_sdk_core_protos::{ coresdk::{ workflow_activation::{workflow_activation_job, WorkflowActivation}, @@ -1058,13 +1055,13 @@ impl WorkerExt for Worker { async { assert_matches!( self.poll_activity_task().await.unwrap_err(), - PollActivityError::ShutDown + PollError::ShutDown ); }, async { assert_matches!( self.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); } ); @@ -1075,7 +1072,7 @@ impl WorkerExt for Worker { self.initiate_shutdown(); assert_matches!( self.poll_activity_task().await.unwrap_err(), - PollActivityError::ShutDown + PollError::ShutDown ); self.shutdown().await; } diff --git a/core/src/worker/activities.rs b/core/src/worker/activities.rs index 31402df9a..28c4b6207 100644 --- a/core/src/worker/activities.rs +++ b/core/src/worker/activities.rs @@ -16,7 +16,7 @@ use crate::{ worker::{ activities::activity_heartbeat_manager::ActivityHeartbeatError, client::WorkerClient, }, - PollActivityError, TaskToken, + PollError, TaskToken, }; use activity_heartbeat_manager::ActivityHeartbeatManager; use dashmap::DashMap; @@ -142,7 +142,7 @@ pub(crate) struct WorkerActivityTasks { heartbeat_manager: ActivityHeartbeatManager, /// Combined stream for any ActivityTask producing source (polls, eager activities, /// cancellations) - activity_task_stream: Mutex>>, + activity_task_stream: Mutex>>, /// Activities that have been issued to lang but not yet completed outstanding_activity_tasks: OutstandingActMap, /// Ensures we don't exceed this worker's maximum concurrent activity limit for activities. This @@ -171,7 +171,7 @@ pub(crate) struct WorkerActivityTasks { #[derive(derive_more::From)] enum ActivityTaskSource { PendingCancel(PendingActivityCancel), - PendingStart(Result<(PermittedTqResp, bool), PollActivityError>), + PendingStart(Result<(PermittedTqResp, bool), PollError>), } impl WorkerActivityTasks { @@ -246,9 +246,8 @@ impl WorkerActivityTasks { >, eager_activities_semaphore: Arc>, on_complete_token: CancellationToken, - ) -> impl Stream< - Item = Result<(PermittedTqResp, bool), PollActivityError>, - > { + ) -> impl Stream, bool), PollError>> + { let non_poll_stream = stream::unfold( (non_poll_tasks_rx, eager_activities_semaphore), |(mut non_poll_tasks_rx, eager_activities_semaphore)| async move { @@ -301,13 +300,13 @@ impl WorkerActivityTasks { /// /// Polls the various task sources (server polls, eager activities, cancellations) while /// respecting the provided rate limits and allowed concurrency. Returns - /// [PollActivityError::ShutDown] after shutdown is completed and all tasks sources are + /// [PollError::ShutDown] after shutdown is completed and all tasks sources are /// depleted. - pub(crate) async fn poll(&self) -> Result { + pub(crate) async fn poll(&self) -> Result { let mut poller_stream = self.activity_task_stream.lock().await; poller_stream.next().await.unwrap_or_else(|| { self.poll_returned_shutdown_token.cancel(); - Err(PollActivityError::ShutDown) + Err(PollError::ShutDown) }) } @@ -484,7 +483,7 @@ where /// cancels_stream ------------------------------+--- activity_task_stream /// eager_activities_rx ---+--- starts_stream ---| /// server_poll_stream ---| - fn streamify(self) -> impl Stream> { + fn streamify(self) -> impl Stream> { let outstanding_tasks_clone = self.outstanding_tasks.clone(); let should_issue_immediate_cancel = Arc::new(AtomicBool::new(false)); let should_issue_immediate_cancel_clone = should_issue_immediate_cancel.clone(); @@ -784,7 +783,7 @@ mod tests { ) .await; atm.initiate_shutdown(); - assert_matches!(atm.poll().await.unwrap_err(), PollActivityError::ShutDown); + assert_matches!(atm.poll().await.unwrap_err(), PollError::ShutDown); atm.shutdown().await; } @@ -872,7 +871,7 @@ mod tests { } atm.initiate_shutdown(); - assert_matches!(atm.poll().await.unwrap_err(), PollActivityError::ShutDown); + assert_matches!(atm.poll().await.unwrap_err(), PollError::ShutDown); atm.shutdown().await; } @@ -957,7 +956,7 @@ mod tests { .await; atm.initiate_shutdown(); - assert_matches!(atm.poll().await.unwrap_err(), PollActivityError::ShutDown); + assert_matches!(atm.poll().await.unwrap_err(), PollError::ShutDown); atm.shutdown().await; } } diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 6025354b2..954aedaad 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -40,7 +40,7 @@ use crate::{ nexus::NexusManager, workflow::{LAReqSink, LocalResolution, WorkflowBasics, Workflows}, }, - ActivityHeartbeat, CompleteActivityError, PollActivityError, PollWfError, WorkerTrait, + ActivityHeartbeat, CompleteActivityError, PollError, WorkerTrait, }; use activities::WorkerActivityTasks; use futures_util::{stream, StreamExt}; @@ -135,12 +135,12 @@ impl WorkerTrait for Worker { Ok(()) } - async fn poll_workflow_activation(&self) -> Result { + async fn poll_workflow_activation(&self) -> Result { self.next_workflow_activation().await } #[instrument(skip(self))] - async fn poll_activity_task(&self) -> Result { + async fn poll_activity_task(&self) -> Result { loop { match self.activity_poll().await.transpose() { Some(r) => break r, @@ -153,12 +153,12 @@ impl WorkerTrait for Worker { } #[instrument(skip(self))] - async fn poll_nexus_task(&self) -> Result { + async fn poll_nexus_task(&self) -> Result { if let Some(nm) = self.nexus_mgr.as_ref() { nm.next_nexus_task().await } else { self.shutdown_token.cancelled().await; - Err(PollActivityError::ShutDown) + Err(PollError::ShutDown) } } @@ -661,12 +661,12 @@ impl Worker { /// /// Returns `Ok(None)` in the event of a poll timeout or if the polling loop should otherwise /// be restarted - async fn activity_poll(&self) -> Result, PollActivityError> { + async fn activity_poll(&self) -> Result, PollError> { let local_activities_complete = self.local_activities_complete.load(Ordering::Relaxed); let non_local_activities_complete = self.non_local_activities_complete.load(Ordering::Relaxed); if local_activities_complete && non_local_activities_complete { - return Err(PollActivityError::ShutDown); + return Err(PollError::ShutDown); } let act_mgr_poll = async { if non_local_activities_complete { @@ -676,7 +676,7 @@ impl Worker { if let Some(ref act_mgr) = self.at_task_mgr { let res = act_mgr.poll().await; if let Err(err) = res.as_ref() { - if matches!(err, PollActivityError::ShutDown) { + if matches!(err, PollError::ShutDown) { self.non_local_activities_complete .store(true, Ordering::Relaxed); return Ok(None); @@ -718,7 +718,7 @@ impl Worker { }; // Since we consider network errors (at this level) fatal, we want to start shutdown if one // is encountered - if matches!(r, Err(PollActivityError::TonicError(_))) { + if matches!(r, Err(PollError::TonicError(_))) { self.initiate_shutdown(); } r @@ -761,7 +761,7 @@ impl Worker { } #[instrument(skip(self), fields(run_id, workflow_id, task_queue=%self.config.task_queue))] - pub(crate) async fn next_workflow_activation(&self) -> Result { + pub(crate) async fn next_workflow_activation(&self) -> Result { let r = self.workflows.next_workflow_activation().await; // In the event workflows are shutdown or erroring, begin shutdown of everything else. Once // they are shut down, tell the local activity manager that, so that it can know to cancel @@ -769,7 +769,7 @@ impl Worker { if let Err(ref e) = r { // This is covering the situation where WFT pollers dying is the reason for shutdown self.initiate_shutdown(); - if matches!(e, PollWfError::ShutDown) { + if matches!(e, PollError::ShutDown) { self.local_act_mgr.workflows_have_shutdown(); } } diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index b472ba55b..6d89f7459 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -7,7 +7,7 @@ use crate::{ use futures_util::{stream::BoxStream, Stream, StreamExt}; use std::{collections::HashMap, sync::Arc}; use temporal_sdk_core_api::{ - errors::{CompleteNexusError, PollActivityError}, + errors::{CompleteNexusError, PollError}, worker::NexusSlotKind, }; use temporal_sdk_core_protos::{ @@ -23,7 +23,7 @@ use tokio_util::sync::CancellationToken; /// Centralizes all state related to received nexus tasks pub(super) struct NexusManager { - task_stream: Mutex>>, + task_stream: Mutex>>, /// Token to notify when poll returned a shutdown error poll_returned_shutdown_token: CancellationToken, /// Outstanding nexus tasks that have been issued to lang but not yet completed @@ -48,13 +48,11 @@ impl NexusManager { // TODO Different error or combine /// Block until then next nexus task is received from server - pub(super) async fn next_nexus_task( - &self, - ) -> Result { + pub(super) async fn next_nexus_task(&self) -> Result { let mut sl = self.task_stream.lock().await; sl.next().await.unwrap_or_else(|| { self.poll_returned_shutdown_token.cancel(); - Err(PollActivityError::ShutDown) + Err(PollError::ShutDown) }) } @@ -132,9 +130,7 @@ where } } - fn into_stream( - self, - ) -> impl Stream> { + fn into_stream(self) -> impl Stream> { self.source_stream.map(move |t| match t { Ok(t) => { let (service, operation, request_kind) = t @@ -164,7 +160,7 @@ where ); Ok(t.resp) } - Err(e) => Err(PollActivityError::TonicError(e)), + Err(e) => Err(PollError::TonicError(e)), }) } } diff --git a/core/src/worker/workflow/mod.rs b/core/src/worker/workflow/mod.rs index 25475cd1c..61d0df6a4 100644 --- a/core/src/worker/workflow/mod.rs +++ b/core/src/worker/workflow/mod.rs @@ -56,7 +56,7 @@ use std::{ time::{Duration, Instant}, }; use temporal_sdk_core_api::{ - errors::{CompleteWfError, PollWfError}, + errors::{CompleteWfError, PollError}, worker::{ActivitySlotKind, WorkerConfig, WorkflowSlotKind}, }; use temporal_sdk_core_protos::{ @@ -101,7 +101,7 @@ const WFT_HEARTBEAT_TIMEOUT_FRACTION: f32 = 0.8; const MAX_EAGER_ACTIVITY_RESERVATIONS_PER_WORKFLOW_TASK: usize = 3; type Result = result::Result; -type BoxedActivationStream = BoxStream<'static, Result>; +type BoxedActivationStream = BoxStream<'static, Result>; type InternalFlagsRef = Rc>; /// Centralizes all state related to workflows and workflow tasks @@ -249,7 +249,7 @@ impl Workflows { } } - pub(super) async fn next_workflow_activation(&self) -> Result { + pub(super) async fn next_workflow_activation(&self) -> Result { self.ever_polled.store(true, atomic::Ordering::Release); loop { let al = { @@ -258,7 +258,7 @@ impl Workflows { if let Some(beginner) = beginner.take() { let _ = beginner.send(()); } - stream.next().await.unwrap_or(Err(PollWfError::ShutDown))? + stream.next().await.unwrap_or(Err(PollError::ShutDown))? }; match al { ActivationOrAuto::LangActivation(mut act) diff --git a/core/src/worker/workflow/workflow_stream.rs b/core/src/worker/workflow/workflow_stream.rs index 50afaa1c5..4ac9bf413 100644 --- a/core/src/worker/workflow/workflow_stream.rs +++ b/core/src/worker/workflow/workflow_stream.rs @@ -10,7 +10,7 @@ use crate::{ }; use futures_util::{stream, stream::PollNext, Stream, StreamExt}; use std::{collections::VecDeque, fmt::Debug, future, sync::Arc}; -use temporal_sdk_core_api::errors::PollWfError; +use temporal_sdk_core_api::errors::PollError; use temporal_sdk_core_protos::coresdk::workflow_activation::remove_from_cache::EvictionReason; use tokio_util::sync::CancellationToken; use tracing::{Level, Span}; @@ -64,7 +64,7 @@ impl WFStream { wft_stream: impl Stream> + Send + 'static, local_rx: impl Stream + Send + 'static, local_activity_request_sink: impl LocalActivityRequestSink, - ) -> impl Stream> { + ) -> impl Stream> { let all_inputs = stream::select_with_strategy( local_rx.map(Into::into), wft_stream @@ -82,7 +82,7 @@ impl WFStream { all_inputs: impl Stream, basics: WorkflowBasics, local_activity_request_sink: impl LocalActivityRequestSink, - ) -> impl Stream> { + ) -> impl Stream> { let mut state = WFStream { buffered_polls_need_cache_slot: Default::default(), runs: RunCache::new( @@ -165,7 +165,7 @@ impl WFStream { None } WFStreamInput::PollerError(e) => { - return Err(PollWfError::TonicError(e)); + return Err(PollError::TonicError(e)); } }; @@ -174,7 +174,7 @@ impl WFStream { if state.shutdown_done() { info!("Workflow shutdown is done"); - return Err(PollWfError::ShutDown); + return Err(PollError::ShutDown); } Ok(WFStreamOutput { @@ -184,7 +184,7 @@ impl WFStream { }) .inspect(|o| { if let Some(e) = o.as_ref().err() { - if !matches!(e, PollWfError::ShutDown) { + if !matches!(e, PollError::ShutDown) { error!( "Workflow processing encountered fatal error and must shut down {:?}", e @@ -193,7 +193,7 @@ impl WFStream { } }) // Stop the stream once we have shut down - .take_while(|o| future::ready(!matches!(o, Err(PollWfError::ShutDown)))) + .take_while(|o| future::ready(!matches!(o, Err(PollError::ShutDown)))) } /// Instantiate or update run machines with a new WFT diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index b529b439a..22598420d 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -80,10 +80,7 @@ use std::{ }; use temporal_client::ClientOptionsBuilder; use temporal_sdk_core::Url; -use temporal_sdk_core_api::{ - errors::{PollActivityError, PollWfError}, - Worker as CoreWorker, -}; +use temporal_sdk_core_api::{errors::PollError, Worker as CoreWorker}; use temporal_sdk_core_protos::{ coresdk::{ activity_result::{ActivityExecutionResult, ActivityResolution}, @@ -284,7 +281,7 @@ impl Worker { async { loop { let activation = match common.worker.poll_workflow_activation().await { - Err(PollWfError::ShutDown) => { + Err(PollError::ShutDown) => { break; } o => o?, @@ -319,7 +316,7 @@ impl Worker { if !act_half.activity_fns.is_empty() { loop { let activity = common.worker.poll_activity_task().await; - if matches!(activity, Err(PollActivityError::ShutDown)) { + if matches!(activity, Err(PollError::ShutDown)) { break; } act_half.activity_task_handler( diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 518e023e2..cc01f014c 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -38,7 +38,7 @@ use temporal_sdk_core::{ ClientOptions, ClientOptionsBuilder, CoreRuntime, WorkerConfigBuilder, }; use temporal_sdk_core_api::{ - errors::{PollActivityError, PollWfError}, + errors::PollError, telemetry::{ metrics::CoreMeter, Logger, OtelCollectorOptionsBuilder, PrometheusExporterOptionsBuilder, TelemetryOptions, TelemetryOptionsBuilder, @@ -853,13 +853,13 @@ pub async fn drain_pollers_and_shutdown(worker: &Arc) { async { assert!(matches!( worker.poll_activity_task().await.unwrap_err(), - PollActivityError::ShutDown + PollError::ShutDown )); }, async { assert!(matches!( worker.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown, + PollError::ShutDown, )); } ); diff --git a/tests/integ_tests/workflow_tests.rs b/tests/integ_tests/workflow_tests.rs index ae4d9e237..31aeed9b5 100644 --- a/tests/integ_tests/workflow_tests.rs +++ b/tests/integ_tests/workflow_tests.rs @@ -30,7 +30,7 @@ use temporal_sdk::{ WorkflowResult, }; use temporal_sdk_core::{replay::HistoryForReplay, CoreRuntime}; -use temporal_sdk_core_api::errors::{PollWfError, WorkflowErrorType}; +use temporal_sdk_core_api::errors::{PollError, WorkflowErrorType}; use temporal_sdk_core_protos::{ coresdk::{ activity_result::ActivityExecutionResult, @@ -144,14 +144,14 @@ async fn shutdown_aborts_actively_blocked_poll() { }); assert_matches!( core.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); handle.await.unwrap(); // Ensure double-shutdown doesn't explode core.shutdown().await; assert_matches!( core.poll_workflow_activation().await.unwrap_err(), - PollWfError::ShutDown + PollError::ShutDown ); } diff --git a/tests/integ_tests/workflow_tests/replay.rs b/tests/integ_tests/workflow_tests/replay.rs index b48893922..4a58d8542 100644 --- a/tests/integ_tests/workflow_tests/replay.rs +++ b/tests/integ_tests/workflow_tests/replay.rs @@ -4,7 +4,7 @@ use parking_lot::Mutex; use std::{collections::HashSet, sync::Arc, time::Duration}; use temporal_sdk::{interceptors::WorkerInterceptor, WfContext, Worker, WorkflowFunction}; use temporal_sdk_core::replay::{HistoryFeeder, HistoryForReplay}; -use temporal_sdk_core_api::errors::{PollActivityError, PollWfError}; +use temporal_sdk_core_api::errors::PollError; use temporal_sdk_core_protos::{ coresdk::{ workflow_activation::remove_from_cache::EvictionReason, @@ -50,10 +50,7 @@ async fn timer_workflow_replay() { let task = core.poll_workflow_activation().await.unwrap(); // Verify that an in-progress poll is interrupted by completion finishing processing history let act_poll_fut = async { - assert_matches!( - core.poll_activity_task().await, - Err(PollActivityError::ShutDown) - ); + assert_matches!(core.poll_activity_task().await, Err(PollError::ShutDown)); }; let poll_fut = async { let evict_task = core @@ -66,7 +63,7 @@ async fn timer_workflow_replay() { .unwrap(); assert_matches!( core.poll_workflow_activation().await, - Err(PollWfError::ShutDown) + Err(PollError::ShutDown) ); }; let complete_fut = async { @@ -77,7 +74,7 @@ async fn timer_workflow_replay() { // Subsequent polls should still return shutdown assert_matches!( core.poll_workflow_activation().await, - Err(PollWfError::ShutDown) + Err(PollError::ShutDown) ); core.shutdown().await; @@ -117,7 +114,7 @@ async fn workflow_nondeterministic_replay() { core.shutdown().await; assert_matches!( core.poll_workflow_activation().await, - Err(PollWfError::ShutDown) + Err(PollError::ShutDown) ); } From b9375c71f464ea62eb5e3a820966857db5c736a9 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Wed, 8 Jan 2025 16:39:35 -0800 Subject: [PATCH 08/14] Shutdown --- core/src/worker/mod.rs | 6 +- core/src/worker/nexus.rs | 102 ++++++++++------ tests/integ_tests/workflow_tests/nexus.rs | 138 ++++++++++++++++++---- 3 files changed, 182 insertions(+), 64 deletions(-) diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 954aedaad..e0ac03da7 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -607,13 +607,17 @@ impl Worker { if let Some(acts) = self.at_task_mgr.as_ref() { acts.shutdown().await; } + // Wait for nexus tasks to finish + if let Some(nm) = self.nexus_mgr.as_ref() { + nm.shutdown().await; + } // Wait for all permits to be released, but don't totally hang real-world shutdown. tokio::select! { _ = async { self.all_permits_tracker.lock().await.all_done().await } => {}, _ = tokio::time::sleep(Duration::from_secs(1)) => { dbg_panic!("Waiting for all slot permits to release took too long!"); } - }; + } } /// Finish shutting down by consuming the background pollers and freeing all resources diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index 6d89f7459..51738d706 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -4,8 +4,15 @@ use crate::{ telemetry::metrics::MetricsContext, worker::client::WorkerClient, }; -use futures_util::{stream::BoxStream, Stream, StreamExt}; -use std::{collections::HashMap, sync::Arc}; +use futures_util::{stream, stream::BoxStream, Stream, StreamExt}; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; use temporal_sdk_core_api::{ errors::{CompleteNexusError, PollError}, worker::NexusSlotKind, @@ -28,6 +35,7 @@ pub(super) struct NexusManager { poll_returned_shutdown_token: CancellationToken, /// Outstanding nexus tasks that have been issued to lang but not yet completed outstanding_task_map: OutstandingTaskMap, + ever_polled: AtomicBool, } impl NexusManager { @@ -43,17 +51,21 @@ impl NexusManager { task_stream: Mutex::new(task_stream.into_stream().boxed()), poll_returned_shutdown_token: CancellationToken::new(), outstanding_task_map, + ever_polled: AtomicBool::new(false), } } - // TODO Different error or combine /// Block until then next nexus task is received from server pub(super) async fn next_nexus_task(&self) -> Result { + self.ever_polled.store(true, Ordering::Relaxed); let mut sl = self.task_stream.lock().await; - sl.next().await.unwrap_or_else(|| { + let r = sl.next().await.unwrap_or_else(|| Err(PollError::ShutDown)); + // This can't happen in the or_else closure because ShutDown is typically returned by the + // stream directly, before it terminates. + if let Err(PollError::ShutDown) = &r { self.poll_returned_shutdown_token.cancel(); - Err(PollError::ShutDown) - }) + } + r } pub(super) async fn complete_task( @@ -112,6 +124,13 @@ impl NexusManager { } Ok(()) } + + pub(super) async fn shutdown(&self) { + if !self.ever_polled.load(Ordering::Relaxed) { + return; + } + self.poll_returned_shutdown_token.cancelled().await; + } } struct NexusTaskStream { @@ -131,37 +150,46 @@ where } fn into_stream(self) -> impl Stream> { - self.source_stream.map(move |t| match t { - Ok(t) => { - let (service, operation, request_kind) = t - .resp - .request - .as_ref() - .and_then(|r| r.variant.as_ref()) - .map(|v| match v { - Variant::StartOperation(s) => ( - s.service.to_owned(), - s.operation.to_owned(), - RequestKind::Start, - ), - Variant::CancelOperation(c) => ( - c.service.to_owned(), - c.operation.to_owned(), - RequestKind::Cancel, - ), - }) - .unwrap_or_default(); - self.outstanding_task_map.lock().insert( - TaskToken(t.resp.task_token.clone()), - NexusInFlightTask { - request_kind, - _permit: t.permit.into_used(NexusSlotInfo { service, operation }), - }, - ); - Ok(t.resp) - } - Err(e) => Err(PollError::TonicError(e)), - }) + let outstanding_task_clone = self.outstanding_task_map.clone(); + self.source_stream + .map(move |t| match t { + Ok(t) => { + let (service, operation, request_kind) = t + .resp + .request + .as_ref() + .and_then(|r| r.variant.as_ref()) + .map(|v| match v { + Variant::StartOperation(s) => ( + s.service.to_owned(), + s.operation.to_owned(), + RequestKind::Start, + ), + Variant::CancelOperation(c) => ( + c.service.to_owned(), + c.operation.to_owned(), + RequestKind::Cancel, + ), + }) + .unwrap_or_default(); + self.outstanding_task_map.lock().insert( + TaskToken(t.resp.task_token.clone()), + NexusInFlightTask { + request_kind, + _permit: t.permit.into_used(NexusSlotInfo { service, operation }), + }, + ); + Ok(t.resp) + } + Err(e) => Err(PollError::TonicError(e)), + }) + .chain(stream::once(async move { + while !outstanding_task_clone.lock().is_empty() { + // todo no spin + tokio::time::sleep(Duration::from_millis(10)).await; + } + Err(PollError::ShutDown) + })) } } diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index 853f6e078..2ec0a6ad7 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -3,6 +3,7 @@ use assert_matches::assert_matches; use std::time::Duration; use temporal_client::{WfClientExt, WorkflowClientTrait, WorkflowOptions}; use temporal_sdk::{CancellableFuture, NexusOperationOptions, WfContext, WfExitValue}; +use temporal_sdk_core_api::errors::PollError; use temporal_sdk_core_protos::{ coresdk::{ nexus::{ @@ -24,7 +25,7 @@ use temporal_sdk_core_protos::{ }, }; use temporal_sdk_core_test_utils::{rand_6_chars, CoreWfStarter}; -use tokio::join; +use tokio::{join, sync::mpsc}; #[derive(Debug, PartialEq, Eq, Clone, Copy)] enum Outcome { @@ -94,7 +95,12 @@ async fn nexus_basic( .await .unwrap(); } - Outcome::Fail => { + Outcome::Fail | Outcome::Timeout => { + if outcome == Outcome::Timeout { + // We have to complete the nexus task so Core can shut down, but make sure we + // don't do it until after it times out. + tokio::time::sleep(Duration::from_millis(3100)).await; + } core_worker .complete_nexus_task(NexusTaskCompletion { task_token: nt.task_token, @@ -109,9 +115,12 @@ async fn nexus_basic( .await .unwrap(); } - Outcome::Timeout => {} - Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => unimplemented!(), + Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => unreachable!(), } + assert_matches!( + core_worker.poll_nexus_task().await, + Err(PollError::ShutDown) + ); }; join!(nexus_task_handle, async { @@ -147,7 +156,7 @@ async fn nexus_basic( assert_eq!(f.message, "nexus operation completed unsuccessfully"); assert_eq!(f.cause.unwrap().message, "operation timed out"); } - Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => unimplemented!(), + Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => unreachable!(), } } @@ -219,7 +228,7 @@ async fn nexus_async( }, ); let submitter = worker.get_submitter_handle(); - starter.start_with_worker(wf_name, &mut worker).await; + let handle = starter.start_with_worker(wf_name, &mut worker).await; let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { @@ -259,29 +268,34 @@ async fn nexus_async( .await .unwrap(); } - if outcome != Outcome::CancelAfterRecordedBeforeStarted { - // Do not say the operation started if we are trying to test this type of cancel - core_worker - .complete_nexus_task(NexusTaskCompletion { - task_token: nt.task_token, - status: Some(nexus_task_completion::Status::Completed( - nexus::v1::Response { - variant: Some(nexus::v1::response::Variant::StartOperation( - StartOperationResponse { - variant: Some(start_operation_response::Variant::AsyncSuccess( - start_operation_response::Async { - operation_id: "op-1".to_string(), - links: vec![], - }, - )), - }, - )), - }, - )), - }) + if outcome == Outcome::CancelAfterRecordedBeforeStarted { + // Do not say the operation started until after it's had a chance to already be + // cancelled in this case + handle + .get_workflow_result(Default::default()) .await .unwrap(); } + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: nt.task_token, + status: Some(nexus_task_completion::Status::Completed( + nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::StartOperation( + StartOperationResponse { + variant: Some(start_operation_response::Variant::AsyncSuccess( + start_operation_response::Async { + operation_id: "op-1".to_string(), + links: vec![], + }, + )), + }, + )), + }, + )), + }) + .await + .unwrap(); if outcome == Outcome::Cancel { let nt = core_worker.poll_nexus_task().await.unwrap(); assert_matches!( @@ -306,6 +320,10 @@ async fn nexus_async( .await .unwrap(); } + assert_matches!( + core_worker.poll_nexus_task().await, + Err(PollError::ShutDown) + ); }; join!(nexus_task_handle, async { @@ -393,6 +411,74 @@ async fn nexus_cancel_before_start() { worker.run_until_done().await.unwrap(); } +#[tokio::test] +async fn nexus_must_complete_task_to_shutdown() { + let wf_name = "nexus_must_complete_task_to_shutdown"; + let mut starter = CoreWfStarter::new(wf_name); + starter.worker_config.no_remote_activities(true); + let mut worker = starter.worker().await; + let core_worker = starter.get_worker().await; + + let endpoint = mk_endpoint(&mut starter).await; + + worker.register_wf(wf_name.to_owned(), move |ctx: WfContext| { + let endpoint = endpoint.clone(); + async move { + let started = ctx.start_nexus_operation(NexusOperationOptions { + endpoint: endpoint.clone(), + service: "svc".to_string(), + operation: "op".to_string(), + ..Default::default() + }); + // Workflow completes right away, only having scheduled the operation. We need a timer + // to make sure the nexus task actually gets scheduled. + ctx.timer(Duration::from_millis(1)).await; + // started.await.unwrap(); + Ok(().into()) + } + }); + let handle = starter.start_with_worker(wf_name, &mut worker).await; + let (complete_order_tx, mut complete_order_rx) = mpsc::unbounded_channel(); + + let task_handle = async { + // Should get the nexus task first + let nt = core_worker.poll_nexus_task().await.unwrap(); + // The workflow will complete + handle + .get_workflow_result(Default::default()) + .await + .unwrap(); + // Complete the task + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: nt.task_token, + status: Some(nexus_task_completion::Status::Error(HandlerError { + error_type: "BAD_REQUEST".to_string(), // bad req is non-retryable + failure: Some(nexus::v1::Failure { + message: "busted".to_string(), + ..Default::default() + }), + })), + }) + .await + .unwrap(); + dbg!("Completed nexus task"); + complete_order_tx.send("t").unwrap(); + assert_matches!( + core_worker.poll_nexus_task().await, + Err(PollError::ShutDown) + ); + }; + + join!(task_handle, async { + worker.run_until_done().await.unwrap(); + complete_order_tx.send("w").unwrap(); + }); + + // The first thing to finish needs to have been the nexus task completion + assert_eq!(complete_order_rx.recv().await.unwrap(), "t"); +} + async fn mk_endpoint(starter: &mut CoreWfStarter) -> String { let client = starter.get_client().await; let endpoint = format!("mycoolendpoint-{}", rand_6_chars()); From 786950337bcdcac58e738ddff0f4e16fc31d151d Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Wed, 8 Jan 2025 17:28:02 -0800 Subject: [PATCH 09/14] Introduce NexusTask --- core-api/src/lib.rs | 23 +++++++++------ core/src/worker/mod.rs | 6 ++-- core/src/worker/nexus.rs | 18 ++++++------ .../local/temporal/sdk/core/nexus/nexus.proto | 28 +++++++++++++++++++ sdk-core-protos/src/lib.rs | 12 ++++++++ tests/integ_tests/workflow_tests/nexus.rs | 8 +++--- 6 files changed, 71 insertions(+), 24 deletions(-) diff --git a/core-api/src/lib.rs b/core-api/src/lib.rs index f82e90b0b..c66418d7f 100644 --- a/core-api/src/lib.rs +++ b/core-api/src/lib.rs @@ -9,13 +9,12 @@ use crate::{ }, worker::WorkerConfig, }; -use temporal_sdk_core_protos::{ - coresdk::{ - activity_task::ActivityTask, nexus::NexusTaskCompletion, - workflow_activation::WorkflowActivation, workflow_completion::WorkflowActivationCompletion, - ActivityHeartbeat, ActivityTaskCompletion, - }, - temporal::api::workflowservice::v1::PollNexusTaskQueueResponse, +use temporal_sdk_core_protos::coresdk::{ + activity_task::ActivityTask, + nexus::{NexusTask, NexusTaskCompletion}, + workflow_activation::WorkflowActivation, + workflow_completion::WorkflowActivationCompletion, + ActivityHeartbeat, ActivityTaskCompletion, }; /// This trait is the primary way by which language specific SDKs interact with the core SDK. @@ -49,8 +48,14 @@ pub trait Worker: Send + Sync { /// Do not call poll concurrently. It handles polling the server concurrently internally. async fn poll_activity_task(&self) -> Result; - /// TODO: Keep or combine? - async fn poll_nexus_task(&self) -> Result; + /// Ask the worker for some nexus related work. It is then the language SDK's + /// responsibility to call the appropriate nexus operation handler code with the provided + /// inputs. Blocks indefinitely until such work is available or [Worker::shutdown] is called. + /// + /// All tasks must be responded to for shutdown to complete. + /// + /// Do not call poll concurrently. It handles polling the server concurrently internally. + async fn poll_nexus_task(&self) -> Result; /// Tell the worker that a workflow activation has completed. May (and should) be freely called /// concurrently. The future may take some time to resolve, as fetching more events might be diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index e0ac03da7..bd87dc646 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -68,7 +68,7 @@ use temporal_sdk_core_protos::{ temporal::api::{ enums::v1::TaskQueueKind, taskqueue::v1::{StickyExecutionAttributes, TaskQueue}, - workflowservice::v1::{get_system_info_response, PollNexusTaskQueueResponse}, + workflowservice::v1::get_system_info_response, }, TaskToken, }; @@ -76,7 +76,7 @@ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -use temporal_sdk_core_protos::coresdk::nexus::NexusTaskCompletion; +use temporal_sdk_core_protos::coresdk::nexus::{NexusTask, NexusTaskCompletion}; #[cfg(test)] use { crate::{ @@ -153,7 +153,7 @@ impl WorkerTrait for Worker { } #[instrument(skip(self))] - async fn poll_nexus_task(&self) -> Result { + async fn poll_nexus_task(&self) -> Result { if let Some(nm) = self.nexus_mgr.as_ref() { nm.next_nexus_task().await } else { diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index 51738d706..b305e0bfb 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -18,11 +18,11 @@ use temporal_sdk_core_api::{ worker::NexusSlotKind, }; use temporal_sdk_core_protos::{ - coresdk::{nexus::nexus_task_completion, NexusSlotInfo}, - temporal::api::{ - nexus::v1::{request::Variant, response}, - workflowservice::v1::PollNexusTaskQueueResponse, + coresdk::{ + nexus::{nexus_task, nexus_task_completion, NexusTask}, + NexusSlotInfo, }, + temporal::api::nexus::v1::{request::Variant, response}, TaskToken, }; use tokio::sync::Mutex; @@ -30,7 +30,7 @@ use tokio_util::sync::CancellationToken; /// Centralizes all state related to received nexus tasks pub(super) struct NexusManager { - task_stream: Mutex>>, + task_stream: Mutex>>, /// Token to notify when poll returned a shutdown error poll_returned_shutdown_token: CancellationToken, /// Outstanding nexus tasks that have been issued to lang but not yet completed @@ -56,7 +56,7 @@ impl NexusManager { } /// Block until then next nexus task is received from server - pub(super) async fn next_nexus_task(&self) -> Result { + pub(super) async fn next_nexus_task(&self) -> Result { self.ever_polled.store(true, Ordering::Relaxed); let mut sl = self.task_stream.lock().await; let r = sl.next().await.unwrap_or_else(|| Err(PollError::ShutDown)); @@ -149,7 +149,7 @@ where } } - fn into_stream(self) -> impl Stream> { + fn into_stream(self) -> impl Stream> { let outstanding_task_clone = self.outstanding_task_map.clone(); self.source_stream .map(move |t| match t { @@ -179,7 +179,9 @@ where _permit: t.permit.into_used(NexusSlotInfo { service, operation }), }, ); - Ok(t.resp) + Ok(NexusTask { + variant: Some(nexus_task::Variant::Task(t.resp)), + }) } Err(e) => Err(PollError::TonicError(e)), }) diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto index d417aaba9..a0b203a8f 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto @@ -6,6 +6,7 @@ option ruby_package = "Temporalio::Internal::Bridge::Api::Nexus"; import "temporal/api/common/v1/message.proto"; import "temporal/api/failure/v1/message.proto"; import "temporal/api/nexus/v1/message.proto"; +import "temporal/api/workflowservice/v1/request_response.proto"; import "temporal/sdk/core/common/common.proto"; // Used by core to resolve nexus operations. @@ -29,4 +30,31 @@ message NexusTaskCompletion { // The handler could not complete the request for some reason. temporal.api.nexus.v1.HandlerError error = 3; } +} + +message NexusTask { + oneof variant { + // A nexus task from server + temporal.api.workflowservice.v1.PollNexusTaskQueueResponse task = 1; + // A request by Core to notify an in-progress operation handler that it should cancel. This + // is distinct from a `CancelOperationRequest` from the server, which results from the user + // requesting the cancellation of an operation. Handling this variant should result in + // something like cancelling a cancellation token given to the user's operation handler. + // + // EX: Core knows the nexus operation has timed out, and it does not make sense for the + // user's operation handler to continue doing work. + CancelNexusTask cancel_task = 2; + } +} + +message CancelNexusTask { + // The task token from the PollNexusTaskQueueResponse + bytes task_token = 1; + // Why Core is asking for this operation to be cancelled + NexusTaskCancelReason reason = 2; +} + +enum NexusTaskCancelReason { + // The nexus task is known to have timed out + TIMED_OUT = 0; } \ No newline at end of file diff --git a/sdk-core-protos/src/lib.rs b/sdk-core-protos/src/lib.rs index 0ca48a52e..fce7dfc3c 100644 --- a/sdk-core-protos/src/lib.rs +++ b/sdk-core-protos/src/lib.rs @@ -757,7 +757,19 @@ pub mod coresdk { } pub mod nexus { + use crate::temporal::api::workflowservice::v1::PollNexusTaskQueueResponse; + tonic::include_proto!("coresdk.nexus"); + + impl NexusTask { + /// Unwrap the inner server-delivered nexus task if that's what this is, else panic. + pub fn unwrap_task(self) -> PollNexusTaskQueueResponse { + if let Some(nexus_task::Variant::Task(t)) = self.variant { + return t; + } + panic!("Nexus task did not contain a server task"); + } + } } pub mod workflow_commands { diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index 2ec0a6ad7..13eba90f4 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -70,7 +70,7 @@ async fn nexus_basic( let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { - let nt = core_worker.poll_nexus_task().await.unwrap(); + let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); match outcome { Outcome::Succeed => { core_worker @@ -232,7 +232,7 @@ async fn nexus_async( let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { - let nt = core_worker.poll_nexus_task().await.unwrap(); + let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); let start_req = assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::StartOperation(sr) => sr @@ -297,7 +297,7 @@ async fn nexus_async( .await .unwrap(); if outcome == Outcome::Cancel { - let nt = core_worker.poll_nexus_task().await.unwrap(); + let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::CancelOperation(_) @@ -442,7 +442,7 @@ async fn nexus_must_complete_task_to_shutdown() { let task_handle = async { // Should get the nexus task first - let nt = core_worker.poll_nexus_task().await.unwrap(); + let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); // The workflow will complete handle .get_workflow_result(Default::default()) From fa0c9a6ced0f9a757f4cb48ef4fc17f8a7897e1d Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Thu, 9 Jan 2025 14:37:55 -0800 Subject: [PATCH 10/14] Send cancels for timed out tasks (and req server 1.26) --- core/Cargo.toml | 2 +- core/src/worker/mod.rs | 34 ++- core/src/worker/nexus.rs | 219 +++++++++++++----- .../local/temporal/sdk/core/nexus/nexus.proto | 11 + sdk-core-protos/src/history_info.rs | 7 +- sdk-core-protos/src/lib.rs | 61 +++++ tests/integ_tests/workflow_tests/nexus.rs | 72 +++--- 7 files changed, 305 insertions(+), 101 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index bcd623662..77cae804b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -44,7 +44,7 @@ lru = "0.12" mockall = "0.13" opentelemetry = { workspace = true, features = ["metrics"], optional = true } opentelemetry_sdk = { version = "0.26", features = ["rt-tokio", "metrics"], optional = true } -opentelemetry-otlp = { version = "0.26", features = ["tokio", "metrics", "tls", "http-proto", "reqwest-client",], optional = true } +opentelemetry-otlp = { version = "0.26", features = ["tokio", "metrics", "tls", "http-proto", "reqwest-client", ], optional = true } opentelemetry-prometheus = { git = "https://github.com/open-telemetry/opentelemetry-rust.git", rev = "e911383", optional = true } parking_lot = { version = "0.12", features = ["send_guard"] } pid = "4.0" diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index bd87dc646..005600d09 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -76,7 +76,10 @@ use tokio::sync::{mpsc::unbounded_channel, watch}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; -use temporal_sdk_core_protos::coresdk::nexus::{NexusTask, NexusTaskCompletion}; +use temporal_sdk_core_protos::coresdk::nexus::{ + nexus_task_completion, NexusTask, NexusTaskCompletion, +}; + #[cfg(test)] use { crate::{ @@ -84,7 +87,9 @@ use { protosext::ValidPollWFTQResponse, }, futures_util::stream::BoxStream, - temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollActivityTaskQueueResponse, + temporal_sdk_core_protos::temporal::api::workflowservice::v1::{ + PollActivityTaskQueueResponse, PollNexusTaskQueueResponse, + }, }; /// A worker polls on a certain task queue @@ -190,8 +195,6 @@ impl WorkerTrait for Worker { &self, completion: NexusTaskCompletion, ) -> Result<(), CompleteNexusError> { - let task_token = TaskToken(completion.task_token); - let status = if let Some(s) = completion.status { s } else { @@ -200,11 +203,8 @@ impl WorkerTrait for Worker { }); }; - if let Some(nm) = self.nexus_mgr.as_ref() { - nm.complete_task(task_token, status, &*self.client).await - } else { - Err(CompleteNexusError::NexusNotEnabled) - } + self.complete_nexus_task(TaskToken(completion.task_token), status) + .await } fn record_activity_heartbeat(&self, details: ActivityHeartbeat) { @@ -799,6 +799,22 @@ impl Worker { Ok(()) } + #[instrument( + skip(self, tt, status), + fields(task_token=%&tt, status=%&status, task_queue=%self.config.task_queue) + )] + async fn complete_nexus_task( + &self, + tt: TaskToken, + status: nexus_task_completion::Status, + ) -> Result<(), CompleteNexusError> { + if let Some(nm) = self.nexus_mgr.as_ref() { + nm.complete_task(tt, status, &*self.client).await + } else { + Err(CompleteNexusError::NexusNotEnabled) + } + } + /// Request a workflow eviction pub(crate) fn request_wf_eviction( &self, diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index b305e0bfb..f04574d00 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -4,7 +4,12 @@ use crate::{ telemetry::metrics::MetricsContext, worker::client::WorkerClient, }; -use futures_util::{stream, stream::BoxStream, Stream, StreamExt}; +use anyhow::anyhow; +use futures_util::{ + stream, + stream::{BoxStream, PollNext}, + Stream, StreamExt, +}; use std::{ collections::HashMap, sync::{ @@ -19,15 +24,23 @@ use temporal_sdk_core_api::{ }; use temporal_sdk_core_protos::{ coresdk::{ - nexus::{nexus_task, nexus_task_completion, NexusTask}, + nexus::{ + nexus_task, nexus_task_completion, CancelNexusTask, NexusTask, NexusTaskCancelReason, + }, NexusSlotInfo, }, temporal::api::nexus::v1::{request::Variant, response}, TaskToken, }; -use tokio::sync::Mutex; +use tokio::{ + sync::{mpsc::UnboundedSender, Mutex, Notify}, + task::JoinHandle, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_util::sync::CancellationToken; +static REQUEST_TIMEOUT_HEADER: &str = "Request-Timeout"; + /// Centralizes all state related to received nexus tasks pub(super) struct NexusManager { task_stream: Mutex>>, @@ -35,6 +48,8 @@ pub(super) struct NexusManager { poll_returned_shutdown_token: CancellationToken, /// Outstanding nexus tasks that have been issued to lang but not yet completed outstanding_task_map: OutstandingTaskMap, + /// Notified every time a task in the map is completed + task_completed_notify: Arc, ever_polled: AtomicBool, } @@ -45,12 +60,23 @@ impl NexusManager { shutdown_initiated_token: CancellationToken, ) -> Self { let source_stream = new_nexus_task_poller(poller, metrics, shutdown_initiated_token); - let task_stream = NexusTaskStream::new(source_stream); + let (cancels_tx, cancels_rx) = tokio::sync::mpsc::unbounded_channel(); + let task_stream_input = stream::select_with_strategy( + UnboundedReceiverStream::new(cancels_rx).map(TaskStreamInput::from), + source_stream + .map(TaskStreamInput::from) + .chain(stream::once(async move { TaskStreamInput::SourceComplete })), + |_: &mut ()| PollNext::Left, + ); + let task_completed_notify = Arc::new(Notify::new()); + let task_stream = + NexusTaskStream::new(task_stream_input, cancels_tx, task_completed_notify.clone()); let outstanding_task_map = task_stream.outstanding_task_map.clone(); Self { task_stream: Mutex::new(task_stream.into_stream().boxed()), poll_returned_shutdown_token: CancellationToken::new(), outstanding_task_map, + task_completed_notify, ever_polled: AtomicBool::new(false), } } @@ -74,7 +100,9 @@ impl NexusManager { status: nexus_task_completion::Status, client: &dyn WorkerClient, ) -> Result<(), CompleteNexusError> { - if let Some(task_info) = self.outstanding_task_map.lock().remove(&tt) { + let removed = self.outstanding_task_map.lock().remove(&tt); + if let Some(task_info) = removed { + task_info.timeout_task.inspect(|jh| jh.abort()); let maybe_net_err = match status { nexus_task_completion::Status::Completed(c) => { // Server doesn't provide obvious errors for this validation, so it's done @@ -83,7 +111,7 @@ impl NexusManager { Some(response::Variant::StartOperation(_)) => { if task_info.request_kind != RequestKind::Start { return Err(CompleteNexusError::MalformeNexusCompletion { - reason: "Nexus request was StartOperation but response was not" + reason: "Nexus response was StartOperation but request was not" .to_string(), }); } @@ -92,7 +120,7 @@ impl NexusManager { if task_info.request_kind != RequestKind::Cancel { return Err(CompleteNexusError::MalformeNexusCompletion { reason: - "Nexus request was CancelOperation but response was not" + "Nexus response was CancelOperation but request was not" .to_string(), }); } @@ -106,15 +134,21 @@ impl NexusManager { } client.complete_nexus_task(tt, c).await.err() } + nexus_task_completion::Status::AckCancel(_) => None, nexus_task_completion::Status::Error(e) => { client.fail_nexus_task(tt, e).await.err() } }; + + self.task_completed_notify.notify_waiters(); + if let Some(e) = maybe_net_err { - warn!( - error=?e, - "Network error while completing Nexus task", - ); + if e.code() == tonic::Code::NotFound { + warn!(details=?e, "Nexus task not found on completion. This \ + may happen if the operation has already been cancelled but completed anyway."); + } else { + warn!(error=?e, "Network error while completing Nexus task"); + } } } else { warn!( @@ -136,62 +170,110 @@ impl NexusManager { struct NexusTaskStream { source_stream: S, outstanding_task_map: OutstandingTaskMap, + cancels_tx: UnboundedSender, + task_completed_notify: Arc, } impl NexusTaskStream where - S: Stream, + S: Stream, { - fn new(source: S) -> Self { + fn new( + source: S, + cancels_tx: UnboundedSender, + task_completed_notify: Arc, + ) -> Self { Self { source_stream: source, outstanding_task_map: Arc::new(Default::default()), + cancels_tx, + task_completed_notify, } } fn into_stream(self) -> impl Stream> { let outstanding_task_clone = self.outstanding_task_map.clone(); + let source_done = CancellationToken::new(); + let source_done_clone = source_done.clone(); self.source_stream - .map(move |t| match t { - Ok(t) => { - let (service, operation, request_kind) = t - .resp - .request - .as_ref() - .and_then(|r| r.variant.as_ref()) - .map(|v| match v { - Variant::StartOperation(s) => ( - s.service.to_owned(), - s.operation.to_owned(), - RequestKind::Start, - ), - Variant::CancelOperation(c) => ( - c.service.to_owned(), - c.operation.to_owned(), - RequestKind::Cancel, - ), - }) - .unwrap_or_default(); - self.outstanding_task_map.lock().insert( - TaskToken(t.resp.task_token.clone()), - NexusInFlightTask { - request_kind, - _permit: t.permit.into_used(NexusSlotInfo { service, operation }), - }, - ); - Ok(NexusTask { - variant: Some(nexus_task::Variant::Task(t.resp)), - }) - } - Err(e) => Err(PollError::TonicError(e)), + .filter_map(move |t| { + let res = match t { + TaskStreamInput::Poll(Ok(t)) => { + let tt = TaskToken(t.resp.task_token.clone()); + let mut timeout_task = None; + if let Some(timeout_str) = t + .resp + .request + .as_ref() + .and_then(|r| r.header.get(REQUEST_TIMEOUT_HEADER)) + { + if let Ok(timeout_dur) = parse_request_timeout(timeout_str) { + let tt_clone = tt.clone(); + let cancels_tx = self.cancels_tx.clone(); + timeout_task = Some(tokio::task::spawn(async move { + tokio::time::sleep(timeout_dur).await; + debug!( + task_token=%tt_clone, + "Timing out nexus task due to elapsed local timeout timer" + ); + let _ = cancels_tx.send(CancelNexusTask { + task_token: tt_clone.0, + reason: NexusTaskCancelReason::TimedOut.into(), + }); + })); + } else { + // TODO: Auto-respond as bad request + } + } + + let (service, operation, request_kind) = t + .resp + .request + .as_ref() + .and_then(|r| r.variant.as_ref()) + .map(|v| match v { + Variant::StartOperation(s) => ( + s.service.to_owned(), + s.operation.to_owned(), + RequestKind::Start, + ), + Variant::CancelOperation(c) => ( + c.service.to_owned(), + c.operation.to_owned(), + RequestKind::Cancel, + ), + }) + .unwrap_or_default(); + self.outstanding_task_map.lock().insert( + tt, + NexusInFlightTask { + request_kind, + timeout_task, + _permit: t.permit.into_used(NexusSlotInfo { service, operation }), + }, + ); + Some(Ok(NexusTask { + variant: Some(nexus_task::Variant::Task(t.resp)), + })) + } + TaskStreamInput::Cancel(c) => Some(Ok(NexusTask { + variant: Some(nexus_task::Variant::CancelTask(c)), + })), + TaskStreamInput::SourceComplete => { + source_done.cancel(); + None + } + TaskStreamInput::Poll(Err(e)) => Some(Err(PollError::TonicError(e))), + }; + async move { res } }) - .chain(stream::once(async move { + .take_until(async move { + source_done_clone.cancelled().await; while !outstanding_task_clone.lock().is_empty() { - // todo no spin - tokio::time::sleep(Duration::from_millis(10)).await; + self.task_completed_notify.notified().await; } - Err(PollError::ShutDown) - })) + }) + .chain(stream::once(async move { Err(PollError::ShutDown) })) } } @@ -199,16 +281,45 @@ type OutstandingTaskMap = Arc>, _permit: UsedMeteredSemPermit, } -#[derive(Eq, PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Copy, Clone, Default)] enum RequestKind { + #[default] Start, Cancel, } -impl Default for RequestKind { - fn default() -> Self { - RequestKind::Start + +#[derive(derive_more::From)] +enum TaskStreamInput { + Poll(NexusPollItem), + Cancel(CancelNexusTask), + SourceComplete, +} + +fn parse_request_timeout(timeout: &str) -> Result { + let timeout = timeout.trim(); + let (value, unit) = timeout.split_at( + timeout + .find(|c: char| !c.is_ascii_digit() && c != '.') + .unwrap_or(timeout.len()), + ); + + match unit { + "m" => value + .parse::() + .map(|v| Duration::from_secs_f64(60.0 * v)) + .map_err(Into::into), + "s" => value + .parse::() + .map(Duration::from_secs_f64) + .map_err(Into::into), + "ms" => value + .parse::() + .map(Duration::from_millis) + .map_err(Into::into), + _ => Err(anyhow!("Invalid timeout format")), } } diff --git a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto index a0b203a8f..20c042198 100644 --- a/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto +++ b/sdk-core-protos/protos/local/temporal/sdk/core/nexus/nexus.proto @@ -29,6 +29,11 @@ message NexusTaskCompletion { temporal.api.nexus.v1.Response completed = 2; // The handler could not complete the request for some reason. temporal.api.nexus.v1.HandlerError error = 3; + // The lang SDK acknowledges that it is responding to a `CancelNexusTask` and thus the + // response is irrelevant. This is not the only way to respond to a cancel, the other + // variants can still be used, but this variant should be used when the handler was aborted + // by cancellation. + bool ack_cancel = 4; } } @@ -41,6 +46,10 @@ message NexusTask { // requesting the cancellation of an operation. Handling this variant should result in // something like cancelling a cancellation token given to the user's operation handler. // + // These do not count as a separate task for the purposes of completing all issued tasks, + // but rather count as a sort of modification to the already-issued task which is being + // cancelled. + // // EX: Core knows the nexus operation has timed out, and it does not make sense for the // user's operation handler to continue doing work. CancelNexusTask cancel_task = 2; @@ -57,4 +66,6 @@ message CancelNexusTask { enum NexusTaskCancelReason { // The nexus task is known to have timed out TIMED_OUT = 0; + // The worker is shutting down + WORKER_SHUTDOWN = 1; } \ No newline at end of file diff --git a/sdk-core-protos/src/history_info.rs b/sdk-core-protos/src/history_info.rs index 0b69e78eb..cd849bf2c 100644 --- a/sdk-core-protos/src/history_info.rs +++ b/sdk-core-protos/src/history_info.rs @@ -56,10 +56,9 @@ impl HistoryInfo { let next_event = history.peek(); if event.event_type == EventType::WorkflowTaskStarted as i32 { - let next_is_completed = next_event.map_or(false, |ne| { - ne.event_type == EventType::WorkflowTaskCompleted as i32 - }); - let next_is_failed_or_timeout_or_term = next_event.map_or(false, |ne| { + let next_is_completed = next_event + .is_some_and(|ne| ne.event_type == EventType::WorkflowTaskCompleted as i32); + let next_is_failed_or_timeout_or_term = next_event.is_some_and(|ne| { matches!( ne.event_type(), EventType::WorkflowTaskFailed diff --git a/sdk-core-protos/src/lib.rs b/sdk-core-protos/src/lib.rs index fce7dfc3c..4e65a8791 100644 --- a/sdk-core-protos/src/lib.rs +++ b/sdk-core-protos/src/lib.rs @@ -758,6 +758,7 @@ pub mod coresdk { pub mod nexus { use crate::temporal::api::workflowservice::v1::PollNexusTaskQueueResponse; + use std::fmt::{Display, Formatter}; tonic::include_proto!("coresdk.nexus"); @@ -769,6 +770,33 @@ pub mod coresdk { } panic!("Nexus task did not contain a server task"); } + + /// Get the task token + pub fn task_token(&self) -> &[u8] { + match &self.variant { + Some(nexus_task::Variant::Task(t)) => t.task_token.as_slice(), + Some(nexus_task::Variant::CancelTask(c)) => c.task_token.as_slice(), + None => panic!("Nexus task did not contain a task token"), + } + } + } + + impl Display for nexus_task_completion::Status { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "NexusTaskCompletion(")?; + match self { + nexus_task_completion::Status::Completed(c) => { + write!(f, "{c}") + } + nexus_task_completion::Status::Error(e) => { + write!(f, "{e}") + } + nexus_task_completion::Status::AckCancel(_) => { + write!(f, "AckCancel") + } + }?; + write!(f, ")") + } } } @@ -2368,10 +2396,43 @@ pub mod temporal { enums::v1::EventType, }; use anyhow::{anyhow, bail}; + use std::fmt::{Display, Formatter}; use tonic::transport::Uri; tonic::include_proto!("temporal.api.nexus.v1"); + impl Display for Response { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "NexusResponse(",)?; + match &self.variant { + None => {} + Some(v) => { + write!(f, "{v}")?; + } + } + write!(f, ")") + } + } + + impl Display for response::Variant { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + response::Variant::StartOperation(_) => { + write!(f, "StartOperation") + } + response::Variant::CancelOperation(_) => { + write!(f, "CancelOperation") + } + } + } + } + + impl Display for HandlerError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "HandlerError") + } + } + static SCHEME_PREFIX: &str = "temporal://"; /// Attempt to parse a nexus lint into a workflow event link diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index 13eba90f4..efb132e46 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -7,8 +7,8 @@ use temporal_sdk_core_api::errors::PollError; use temporal_sdk_core_protos::{ coresdk::{ nexus::{ - nexus_operation_result, nexus_task_completion, NexusOperationResult, - NexusTaskCompletion, + nexus_operation_result, nexus_task, nexus_task_completion, NexusOperationResult, + NexusTaskCancelReason, NexusTaskCompletion, }, FromJsonPayloadExt, }, @@ -97,9 +97,13 @@ async fn nexus_basic( } Outcome::Fail | Outcome::Timeout => { if outcome == Outcome::Timeout { - // We have to complete the nexus task so Core can shut down, but make sure we - // don't do it until after it times out. - tokio::time::sleep(Duration::from_millis(3100)).await; + // Wait for the timeout task cancel to get sent + dbg!("Waiting"); + let timeout_t = core_worker.poll_nexus_task().await.unwrap(); + let cancel = assert_matches!(timeout_t.variant, + Some(nexus_task::Variant::CancelTask(ct)) => ct); + assert_eq!(cancel.reason, NexusTaskCancelReason::TimedOut as i32); + dbg!("Done waiting!"); } core_worker .complete_nexus_task(NexusTaskCompletion { @@ -180,8 +184,7 @@ async fn nexus_async( let endpoint = mk_endpoint(&mut starter).await; let schedule_to_close_timeout = if outcome == Outcome::CancelAfterRecordedBeforeStarted { - // There is some internal timer on the server that won't record cancel in this case until - // after some elapsed period, so, don't time out first then. + // If we set this, it'll time out before we can cancel it. None } else { Some(Duration::from_secs(5)) @@ -219,7 +222,7 @@ async fn nexus_async( move |ctx: WfContext| async move { match outcome { Outcome::Succeed => Ok("completed async".into()), - Outcome::Cancel => { + Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted => { ctx.cancelled().await; Ok(WfExitValue::Cancelled) } @@ -228,20 +231,32 @@ async fn nexus_async( }, ); let submitter = worker.get_submitter_handle(); - let handle = starter.start_with_worker(wf_name, &mut worker).await; + starter.start_with_worker(wf_name, &mut worker).await; let client = starter.get_client().await.get_client().clone(); let nexus_task_handle = async { - let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); + let mut nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); let start_req = assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::StartOperation(sr) => sr ); let completer_id = format!("completer-{}", rand_6_chars()); - if !matches!( - outcome, - Outcome::Timeout | Outcome::CancelAfterRecordedBeforeStarted - ) { + if !matches!(outcome, Outcome::Timeout) { + if outcome == Outcome::CancelAfterRecordedBeforeStarted { + // Server does not permit cancels to happen in this state. So, we wait for one timeout + // to happen, then say the operation started, after which it will be cancelled. + let ntt = core_worker.poll_nexus_task().await.unwrap(); + assert_matches!(ntt.variant, Some(nexus_task::Variant::CancelTask(_))); + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: ntt.task_token().to_vec(), + status: Some(nexus_task_completion::Status::AckCancel(true)), + }) + .await + .unwrap(); + // Get the next start request + nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); + } // Start the workflow which will act like the nexus handler and complete the async // operation submitter @@ -268,14 +283,6 @@ async fn nexus_async( .await .unwrap(); } - if outcome == Outcome::CancelAfterRecordedBeforeStarted { - // Do not say the operation started until after it's had a chance to already be - // cancelled in this case - handle - .get_workflow_result(Default::default()) - .await - .unwrap(); - } core_worker .complete_nexus_task(NexusTaskCompletion { task_token: nt.task_token, @@ -296,8 +303,12 @@ async fn nexus_async( }) .await .unwrap(); - if outcome == Outcome::Cancel { - let nt = core_worker.poll_nexus_task().await.unwrap().unwrap_task(); + if matches!( + outcome, + Outcome::Cancel | Outcome::CancelAfterRecordedBeforeStarted + ) { + let nt = core_worker.poll_nexus_task().await.unwrap(); + let nt = nt.unwrap_task(); assert_matches!( nt.request.unwrap().variant.unwrap(), request::Variant::CancelOperation(_) @@ -358,12 +369,7 @@ async fn nexus_async( Some(nexus_operation_result::Status::Cancelled(f)) => f ); assert_eq!(f.message, "nexus operation completed unsuccessfully"); - let msg = if outcome == Outcome::CancelAfterRecordedBeforeStarted { - "operation canceled before it was started" - } else { - "operation canceled" - }; - assert_eq!(f.cause.unwrap().message, msg); + assert_eq!(f.cause.unwrap().message, "operation canceled"); } Outcome::Timeout => { let f = assert_matches!( @@ -424,12 +430,13 @@ async fn nexus_must_complete_task_to_shutdown() { worker.register_wf(wf_name.to_owned(), move |ctx: WfContext| { let endpoint = endpoint.clone(); async move { - let started = ctx.start_nexus_operation(NexusOperationOptions { + // We just need to create the command, not await it. + drop(ctx.start_nexus_operation(NexusOperationOptions { endpoint: endpoint.clone(), service: "svc".to_string(), operation: "op".to_string(), ..Default::default() - }); + })); // Workflow completes right away, only having scheduled the operation. We need a timer // to make sure the nexus task actually gets scheduled. ctx.timer(Duration::from_millis(1)).await; @@ -462,7 +469,6 @@ async fn nexus_must_complete_task_to_shutdown() { }) .await .unwrap(); - dbg!("Completed nexus task"); complete_order_tx.send("t").unwrap(); assert_matches!( core_worker.poll_nexus_task().await, From b0bf7c77f4ad7b7590342433fcb62582ee3415d0 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 10 Jan 2025 10:49:08 -0800 Subject: [PATCH 11/14] Rust 1.84 --- .github/workflows/per-pr.yml | 8 ++++---- .../no_handle_conversions_require_into_fail.stderr | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/per-pr.yml b/.github/workflows/per-pr.yml index 4e134aada..d351baddc 100644 --- a/.github/workflows/per-pr.yml +++ b/.github/workflows/per-pr.yml @@ -21,7 +21,7 @@ jobs: submodules: recursive - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.80.0 + toolchain: 1.84.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -43,7 +43,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.80.0 + toolchain: 1.84.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -74,7 +74,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.80.0 + toolchain: 1.84.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: @@ -99,7 +99,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: - toolchain: 1.80.0 + toolchain: 1.84.0 - name: Install protoc uses: arduino/setup-protoc@v3 with: diff --git a/fsm/rustfsm_procmacro/tests/trybuild/no_handle_conversions_require_into_fail.stderr b/fsm/rustfsm_procmacro/tests/trybuild/no_handle_conversions_require_into_fail.stderr index 63fd3267f..35e804e86 100644 --- a/fsm/rustfsm_procmacro/tests/trybuild/no_handle_conversions_require_into_fail.stderr +++ b/fsm/rustfsm_procmacro/tests/trybuild/no_handle_conversions_require_into_fail.stderr @@ -2,7 +2,7 @@ error[E0277]: the trait bound `One: From` is not satisfied --> tests/trybuild/no_handle_conversions_require_into_fail.rs:11:5 | 11 | Two --(B)--> One; - | ^^^ the trait `From` is not implemented for `One`, which is required by `Two: Into` + | ^^^ the trait `From` is not implemented for `One` | = note: required for `Two` to implement `Into` note: required by a bound in `TransitionResult::::from` From 6209201af91151c8ba7ddb58f5deac91131c6c9c Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 10 Jan 2025 11:08:41 -0800 Subject: [PATCH 12/14] Add in grace period support --- core-api/src/worker.rs | 4 +- core/src/worker/mod.rs | 1 + core/src/worker/nexus.rs | 39 +++++++++++++++--- tests/integ_tests/workflow_tests/nexus.rs | 48 ++++++++++++++++------- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index a99b136c7..92adb8835 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -118,8 +118,8 @@ pub struct WorkerConfig { #[builder(default = "5")] pub fetching_concurrency: usize, - /// If set, core will issue cancels for all outstanding activities after shutdown has been - /// initiated and this amount of time has elapsed. + /// If set, core will issue cancels for all outstanding activities and nexus operations after + /// shutdown has been initiated and this amount of time has elapsed. #[builder(default)] pub graceful_shutdown_period: Option, diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 005600d09..7e8b4137f 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -510,6 +510,7 @@ impl Worker { NexusManager::new( np, metrics.with_new_attrs([nexus_worker_type()]), + config.graceful_shutdown_period, shutdown_token.child_token(), ) }); diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index f04574d00..f41e18805 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -33,6 +33,7 @@ use temporal_sdk_core_protos::{ TaskToken, }; use tokio::{ + join, sync::{mpsc::UnboundedSender, Mutex, Notify}, task::JoinHandle, }; @@ -57,6 +58,7 @@ impl NexusManager { pub(super) fn new( poller: BoxedNexusPoller, metrics: MetricsContext, + graceful_shutdown: Option, shutdown_initiated_token: CancellationToken, ) -> Self { let source_stream = new_nexus_task_poller(poller, metrics, shutdown_initiated_token); @@ -69,8 +71,12 @@ impl NexusManager { |_: &mut ()| PollNext::Left, ); let task_completed_notify = Arc::new(Notify::new()); - let task_stream = - NexusTaskStream::new(task_stream_input, cancels_tx, task_completed_notify.clone()); + let task_stream = NexusTaskStream::new( + task_stream_input, + cancels_tx, + task_completed_notify.clone(), + graceful_shutdown, + ); let outstanding_task_map = task_stream.outstanding_task_map.clone(); Self { task_stream: Mutex::new(task_stream.into_stream().boxed()), @@ -172,6 +178,7 @@ struct NexusTaskStream { outstanding_task_map: OutstandingTaskMap, cancels_tx: UnboundedSender, task_completed_notify: Arc, + grace_period: Option, } impl NexusTaskStream @@ -182,12 +189,14 @@ where source: S, cancels_tx: UnboundedSender, task_completed_notify: Arc, + grace_period: Option, ) -> Self { Self { source_stream: source, outstanding_task_map: Arc::new(Default::default()), cancels_tx, task_completed_notify, + grace_period, } } @@ -195,6 +204,7 @@ where let outstanding_task_clone = self.outstanding_task_map.clone(); let source_done = CancellationToken::new(); let source_done_clone = source_done.clone(); + let cancels_tx_clone = self.cancels_tx.clone(); self.source_stream .filter_map(move |t| { let res = match t { @@ -269,9 +279,28 @@ where }) .take_until(async move { source_done_clone.cancelled().await; - while !outstanding_task_clone.lock().is_empty() { - self.task_completed_notify.notified().await; - } + let (grace_killer, stop_grace) = futures_util::future::abortable(async { + if let Some(gp) = self.grace_period { + tokio::time::sleep(gp).await; + for (tt, _) in outstanding_task_clone.lock().iter() { + let _ = cancels_tx_clone.send(CancelNexusTask { + task_token: tt.0.clone(), + reason: NexusTaskCancelReason::WorkerShutdown.into(), + }); + } + } + }); + join!( + async { + while !outstanding_task_clone.lock().is_empty() { + self.task_completed_notify.notified().await; + } + // If we were waiting for the grace period but everything already finished, + // we don't need to keep waiting. + stop_grace.abort(); + }, + grace_killer + ) }) .chain(stream::once(async move { Err(PollError::ShutDown) })) } diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index efb132e46..5dbef9723 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -417,11 +417,17 @@ async fn nexus_cancel_before_start() { worker.run_until_done().await.unwrap(); } +#[rstest::rstest] #[tokio::test] -async fn nexus_must_complete_task_to_shutdown() { +async fn nexus_must_complete_task_to_shutdown(#[values(true, false)] use_grace_period: bool) { let wf_name = "nexus_must_complete_task_to_shutdown"; let mut starter = CoreWfStarter::new(wf_name); starter.worker_config.no_remote_activities(true); + if use_grace_period { + starter + .worker_config + .graceful_shutdown_period(Duration::from_millis(500)); + } let mut worker = starter.worker().await; let core_worker = starter.get_worker().await; @@ -455,20 +461,32 @@ async fn nexus_must_complete_task_to_shutdown() { .get_workflow_result(Default::default()) .await .unwrap(); - // Complete the task - core_worker - .complete_nexus_task(NexusTaskCompletion { - task_token: nt.task_token, - status: Some(nexus_task_completion::Status::Error(HandlerError { - error_type: "BAD_REQUEST".to_string(), // bad req is non-retryable - failure: Some(nexus::v1::Failure { - message: "busted".to_string(), - ..Default::default() - }), - })), - }) - .await - .unwrap(); + if use_grace_period { + // Wait for cancel to be sent + let nt = core_worker.poll_nexus_task().await.unwrap(); + assert_matches!(nt.variant, Some(nexus_task::Variant::CancelTask(_))); + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: nt.task_token().to_vec(), + status: Some(nexus_task_completion::Status::AckCancel(true)), + }) + .await + .unwrap(); + } else { + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token: nt.task_token, + status: Some(nexus_task_completion::Status::Error(HandlerError { + error_type: "BAD_REQUEST".to_string(), // bad req is non-retryable + failure: Some(nexus::v1::Failure { + message: "busted".to_string(), + ..Default::default() + }), + })), + }) + .await + .unwrap(); + } complete_order_tx.send("t").unwrap(); assert_matches!( core_worker.poll_nexus_task().await, From 65b20b8c43f8d353aef20d9dd1dc1a6f9b423ffc Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 10 Jan 2025 12:27:43 -0800 Subject: [PATCH 13/14] Add more metrics --- core-api/src/worker.rs | 6 + core/src/telemetry/metrics.rs | 64 ++++++- core/src/worker/mod.rs | 2 +- core/src/worker/nexus.rs | 68 ++++++- core/src/worker/tuner.rs | 3 + sdk-core-protos/src/lib.rs | 35 +++- tests/integ_tests/metrics_tests.rs | 210 +++++++++++++++++++++- tests/integ_tests/workflow_tests/nexus.rs | 41 +---- tests/main.rs | 36 +++- 9 files changed, 409 insertions(+), 56 deletions(-) diff --git a/core-api/src/worker.rs b/core-api/src/worker.rs index 92adb8835..b35749b1b 100644 --- a/core-api/src/worker.rs +++ b/core-api/src/worker.rs @@ -157,6 +157,12 @@ pub struct WorkerConfig { /// Mutually exclusive with `tuner` #[builder(setter(into, strip_option), default)] pub max_outstanding_local_activities: Option, + /// The maximum number of nexus tasks that will ever be given to this worker + /// concurrently + /// + /// Mutually exclusive with `tuner` + #[builder(setter(into, strip_option), default)] + pub max_outstanding_nexus_tasks: Option, } impl WorkerConfig { diff --git a/core/src/telemetry/metrics.rs b/core/src/telemetry/metrics.rs index 4caee53a4..447da9c9a 100644 --- a/core/src/telemetry/metrics.rs +++ b/core/src/telemetry/metrics.rs @@ -49,6 +49,10 @@ struct Instruments { la_exec_succeeded_latency: Arc, la_total: Arc, nexus_poll_no_task: Arc, + nexus_task_schedule_to_start_latency: Arc, + nexus_task_e2e_latency: Arc, + nexus_task_execution_latency: Arc, + nexus_task_execution_failed: Arc, worker_registered: Arc, num_pollers: Arc, task_slots_available: Arc, @@ -231,6 +235,34 @@ impl MetricsContext { self.instruments.nexus_poll_no_task.add(1, &self.kvs); } + /// Record nexus task schedule to start time + pub(crate) fn nexus_task_sched_to_start_latency(&self, dur: Duration) { + self.instruments + .nexus_task_schedule_to_start_latency + .record(dur, &self.kvs); + } + + /// Record nexus task end-to-end time + pub(crate) fn nexus_task_e2e_latency(&self, dur: Duration) { + self.instruments + .nexus_task_e2e_latency + .record(dur, &self.kvs); + } + + /// Record nexus task execution time + pub(crate) fn nexus_task_execution_latency(&self, dur: Duration) { + self.instruments + .nexus_task_execution_latency + .record(dur, &self.kvs); + } + + /// Record a nexus task execution failure + pub(crate) fn nexus_task_execution_failed(&self) { + self.instruments + .nexus_task_execution_failed + .add(1, &self.kvs); + } + /// A worker was registered pub(crate) fn worker_registered(&self) { self.instruments.worker_registered.add(1, &self.kvs); @@ -397,6 +429,26 @@ impl Instruments { description: "Count of nexus task queue poll timeouts (no new task)".into(), unit: "".into(), }), + nexus_task_schedule_to_start_latency: meter.histogram_duration(MetricParameters { + name: "nexus_task_schedule_to_start_latency".into(), + unit: "duration".into(), + description: "Histogram of nexus task schedule-to-start latencies".into(), + }), + nexus_task_e2e_latency: meter.histogram_duration(MetricParameters { + name: "nexus_task_endtoend_latency".into(), + unit: "duration".into(), + description: "Histogram of nexus task end-to-end latencies".into(), + }), + nexus_task_execution_latency: meter.histogram_duration(MetricParameters { + name: "nexus_task_execution_latency".into(), + unit: "duration".into(), + description: "Histogram of nexus task execution latencies".into(), + }), + nexus_task_execution_failed: meter.counter(MetricParameters { + name: "nexus_task_execution_failed".into(), + description: "Count of nexus task execution failures".into(), + unit: "".into(), + }), // name kept as worker start for compat with old sdk / what users expect worker_registered: meter.counter(MetricParameters { name: "worker_start".into(), @@ -493,12 +545,18 @@ pub(crate) fn eager(is_eager: bool) -> MetricKeyValue { pub(crate) enum FailureReason { Nondeterminism, Workflow, + Timeout, + NexusOperation(String), + NexusHandlerError(String), } impl Display for FailureReason { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let str = match self { - FailureReason::Nondeterminism => "NonDeterminismError", - FailureReason::Workflow => "WorkflowError", + FailureReason::Nondeterminism => "NonDeterminismError".to_owned(), + FailureReason::Workflow => "WorkflowError".to_owned(), + FailureReason::Timeout => "timeout".to_owned(), + FailureReason::NexusOperation(op) => format!("operation_{}", op), + FailureReason::NexusHandlerError(op) => format!("handler_error_{}", op), }; write!(f, "{}", str) } @@ -903,7 +961,7 @@ mod tests { a1.set(Arc::new(DummyCustomAttrs(1))).unwrap(); // Verify all metrics are created. This number will need to get updated any time a metric // is added. - let num_metrics = 31; + let num_metrics = 35; #[allow(clippy::needless_range_loop)] // Sorry clippy, this reads easier. for metric_num in 1..=num_metrics { let hole = assert_matches!(&events[metric_num], diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 7e8b4137f..7cb928dde 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -509,7 +509,7 @@ impl Worker { let nexus_mgr = nexus_poller.map(|np| { NexusManager::new( np, - metrics.with_new_attrs([nexus_worker_type()]), + metrics.clone(), config.graceful_shutdown_period, shutdown_token.child_token(), ) diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index f41e18805..a85cb2287 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -1,7 +1,10 @@ use crate::{ abstractions::UsedMeteredSemPermit, pollers::{new_nexus_task_poller, BoxedNexusPoller, NexusPollItem}, - telemetry::metrics::MetricsContext, + telemetry::{ + metrics, + metrics::{FailureReason, MetricsContext}, + }, worker::client::WorkerClient, }; use anyhow::anyhow; @@ -16,7 +19,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, - time::Duration, + time::{Duration, Instant, SystemTime}, }; use temporal_sdk_core_api::{ errors::{CompleteNexusError, PollError}, @@ -29,7 +32,7 @@ use temporal_sdk_core_protos::{ }, NexusSlotInfo, }, - temporal::api::nexus::v1::{request::Variant, response}, + temporal::api::nexus::v1::{request::Variant, response, start_operation_response}, TaskToken, }; use tokio::{ @@ -51,7 +54,9 @@ pub(super) struct NexusManager { outstanding_task_map: OutstandingTaskMap, /// Notified every time a task in the map is completed task_completed_notify: Arc, + ever_polled: AtomicBool, + metrics: MetricsContext, } impl NexusManager { @@ -61,7 +66,8 @@ impl NexusManager { graceful_shutdown: Option, shutdown_initiated_token: CancellationToken, ) -> Self { - let source_stream = new_nexus_task_poller(poller, metrics, shutdown_initiated_token); + let source_stream = + new_nexus_task_poller(poller, metrics.clone(), shutdown_initiated_token); let (cancels_tx, cancels_rx) = tokio::sync::mpsc::unbounded_channel(); let task_stream_input = stream::select_with_strategy( UnboundedReceiverStream::new(cancels_rx).map(TaskStreamInput::from), @@ -76,6 +82,7 @@ impl NexusManager { cancels_tx, task_completed_notify.clone(), graceful_shutdown, + metrics.clone(), ); let outstanding_task_map = task_stream.outstanding_task_map.clone(); Self { @@ -84,6 +91,7 @@ impl NexusManager { outstanding_task_map, task_completed_notify, ever_polled: AtomicBool::new(false), + metrics, } } @@ -108,13 +116,24 @@ impl NexusManager { ) -> Result<(), CompleteNexusError> { let removed = self.outstanding_task_map.lock().remove(&tt); if let Some(task_info) = removed { + self.metrics + .nexus_task_execution_latency(task_info.start_time.elapsed()); task_info.timeout_task.inspect(|jh| jh.abort()); - let maybe_net_err = match status { + let (did_send, maybe_net_err) = match status { nexus_task_completion::Status::Completed(c) => { // Server doesn't provide obvious errors for this validation, so it's done // here to make life easier for lang implementors. match &c.variant { - Some(response::Variant::StartOperation(_)) => { + Some(response::Variant::StartOperation(so)) => { + if let Some(start_operation_response::Variant::OperationError(oe)) = + so.variant.as_ref() + { + self.metrics + .with_new_attrs([metrics::failure_reason( + FailureReason::NexusOperation(oe.operation_state.clone()), + )]) + .nexus_task_execution_failed(); + }; if task_info.request_kind != RequestKind::Start { return Err(CompleteNexusError::MalformeNexusCompletion { reason: "Nexus response was StartOperation but request was not" @@ -138,11 +157,21 @@ impl NexusManager { }) } } - client.complete_nexus_task(tt, c).await.err() + (true, client.complete_nexus_task(tt, c).await.err()) + } + nexus_task_completion::Status::AckCancel(_) => { + self.metrics + .with_new_attrs([metrics::failure_reason(FailureReason::Timeout)]) + .nexus_task_execution_failed(); + (false, None) } - nexus_task_completion::Status::AckCancel(_) => None, nexus_task_completion::Status::Error(e) => { - client.fail_nexus_task(tt, e).await.err() + self.metrics + .with_new_attrs([metrics::failure_reason( + FailureReason::NexusHandlerError(e.error_type.clone()), + )]) + .nexus_task_execution_failed(); + (true, client.fail_nexus_task(tt, e).await.err()) } }; @@ -155,6 +184,11 @@ impl NexusManager { } else { warn!(error=?e, "Network error while completing Nexus task"); } + } else if did_send { + // Record e2e latency if we sent replied to server without an RPC error + if let Some(elapsed) = task_info.scheduled_time.and_then(|t| t.elapsed().ok()) { + self.metrics.nexus_task_e2e_latency(elapsed); + } } } else { warn!( @@ -179,6 +213,7 @@ struct NexusTaskStream { cancels_tx: UnboundedSender, task_completed_notify: Arc, grace_period: Option, + metrics: MetricsContext, } impl NexusTaskStream @@ -190,6 +225,7 @@ where cancels_tx: UnboundedSender, task_completed_notify: Arc, grace_period: Option, + metrics: MetricsContext, ) -> Self { Self { source_stream: source, @@ -197,6 +233,7 @@ where cancels_tx, task_completed_notify, grace_period, + metrics, } } @@ -209,6 +246,10 @@ where .filter_map(move |t| { let res = match t { TaskStreamInput::Poll(Ok(t)) => { + if let Some(dur) = t.resp.sched_to_start() { + self.metrics.nexus_task_sched_to_start_latency(dur); + }; + let tt = TaskToken(t.resp.task_token.clone()); let mut timeout_task = None; if let Some(timeout_str) = t @@ -259,6 +300,13 @@ where NexusInFlightTask { request_kind, timeout_task, + scheduled_time: t + .resp + .request + .as_ref() + .and_then(|r| r.scheduled_time) + .and_then(|t| t.try_into().ok()), + start_time: Instant::now(), _permit: t.permit.into_used(NexusSlotInfo { service, operation }), }, ); @@ -311,6 +359,8 @@ type OutstandingTaskMap = Arc>, + scheduled_time: Option, + start_time: Instant, _permit: UsedMeteredSemPermit, } diff --git a/core/src/worker/tuner.rs b/core/src/worker/tuner.rs index 1a34c7c1a..c6159ae52 100644 --- a/core/src/worker/tuner.rs +++ b/core/src/worker/tuner.rs @@ -184,6 +184,9 @@ impl TunerBuilder { if let Some(m) = cfg.max_outstanding_local_activities { builder.local_activity_slot_supplier(Arc::new(FixedSizeSlotSupplier::new(m))); } + if let Some(m) = cfg.max_outstanding_nexus_tasks { + builder.nexus_slot_supplier(Arc::new(FixedSizeSlotSupplier::new(m))); + } builder } diff --git a/sdk-core-protos/src/lib.rs b/sdk-core-protos/src/lib.rs index 4e65a8791..96d85050c 100644 --- a/sdk-core-protos/src/lib.rs +++ b/sdk-core-protos/src/lib.rs @@ -2527,10 +2527,8 @@ pub mod temporal { if let Some((sch, st)) = self.$sched_field.clone().zip(self.started_time.clone()) { - let sch: Result = sch.try_into(); - let st: Result = st.try_into(); - if let (Ok(sch), Ok(st)) = (sch, st) { - return st.duration_since(sch).ok(); + if let Some(value) = elapsed_between_prost_times(sch, st) { + return value; } } None @@ -2538,6 +2536,18 @@ pub mod temporal { }; } + fn elapsed_between_prost_times( + from: prost_wkt_types::Timestamp, + to: prost_wkt_types::Timestamp, + ) -> Option> { + let from: Result = from.try_into(); + let to: Result = to.try_into(); + if let (Ok(from), Ok(to)) = (from, to) { + return Some(to.duration_since(from).ok()); + } + None + } + impl PollWorkflowTaskQueueResponse { sched_to_start_impl!(scheduled_time); } @@ -2584,6 +2594,23 @@ pub mod temporal { sched_to_start_impl!(current_attempt_scheduled_time); } + impl PollNexusTaskQueueResponse { + pub fn sched_to_start(&self) -> Option { + if let Some((sch, st)) = self + .request + .as_ref() + .and_then(|r| r.scheduled_time) + .clone() + .zip(SystemTime::now().try_into().ok()) + { + if let Some(value) = elapsed_between_prost_times(sch, st) { + return value; + } + } + None + } + } + impl QueryWorkflowResponse { /// Unwrap a successful response as vec of payloads pub fn unwrap(self) -> Vec { diff --git a/tests/integ_tests/metrics_tests.rs b/tests/integ_tests/metrics_tests.rs index 993ca0186..e3838b6ff 100644 --- a/tests/integ_tests/metrics_tests.rs +++ b/tests/integ_tests/metrics_tests.rs @@ -1,3 +1,4 @@ +use crate::integ_tests::mk_nexus_endpoint; use anyhow::anyhow; use assert_matches::assert_matches; use std::{ @@ -12,7 +13,8 @@ use temporal_client::{ WorkflowClientTrait, WorkflowOptions, WorkflowService, REQUEST_LATENCY_HISTOGRAM_NAME, }; use temporal_sdk::{ - ActContext, ActivityError, ActivityOptions, CancellableFuture, LocalActivityOptions, WfContext, + ActContext, ActivityError, ActivityOptions, CancellableFuture, LocalActivityOptions, + NexusOperationOptions, WfContext, }; use temporal_sdk_core::{ init_worker, @@ -20,6 +22,7 @@ use temporal_sdk_core::{ CoreRuntime, TokioRuntimeBuilder, }; use temporal_sdk_core_api::{ + errors::PollError, telemetry::{ metrics::{CoreMeter, MetricAttributes, MetricParameters}, HistogramBucketOverrides, OtelCollectorOptionsBuilder, OtlpProtocol, @@ -32,6 +35,7 @@ use temporal_sdk_core_api::{ use temporal_sdk_core_protos::{ coresdk::{ activity_result::ActivityExecutionResult, + nexus::{nexus_task, nexus_task_completion, NexusTaskCompletion}, workflow_activation::{workflow_activation_job, WorkflowActivationJob}, workflow_commands::{ workflow_command, CancelWorkflowExecution, CompleteWorkflowExecution, @@ -45,6 +49,11 @@ use temporal_sdk_core_protos::{ common::v1::RetryPolicy, enums::v1::WorkflowIdReusePolicy, failure::v1::Failure, + nexus, + nexus::v1::{ + request::Variant, start_operation_response, HandlerError, StartOperationResponse, + UnsuccessfulOperationError, + }, query::v1::WorkflowQuery, workflowservice::v1::{DescribeNamespaceRequest, ListNamespacesRequest}, }, @@ -181,6 +190,7 @@ async fn one_slot_worker_reports_available_slot() { .max_outstanding_local_activities(1_usize) // Need to use two for WFTs because there are a minimum of 2 pollers b/c of sticky polling .max_outstanding_workflow_tasks(2_usize) + .max_outstanding_nexus_tasks(1_usize) .max_concurrent_wft_polls(1_usize) .build() .unwrap(); @@ -258,6 +268,10 @@ async fn one_slot_worker_reports_available_slot() { act_task_barr.wait().await; }; + let nexus_polling = async { + let _ = worker.poll_nexus_task().await; + }; + let testing = async { // Wait just a beat for the poller to initiate tokio::time::sleep(Duration::from_millis(50)).await; @@ -277,6 +291,11 @@ async fn one_slot_worker_reports_available_slot() { service_name=\"temporal-core-sdk\",task_queue=\"one_slot_worker_tq\",\ worker_type=\"LocalActivityWorker\"}} 1" ))); + assert!(body.contains(&format!( + "temporal_worker_task_slots_available{{namespace=\"{NAMESPACE}\",\ + service_name=\"temporal-core-sdk\",task_queue=\"one_slot_worker_tq\",\ + worker_type=\"NexusWorker\"}} 1" + ))); // Start a workflow so that a task will get delivered client @@ -329,6 +348,11 @@ async fn one_slot_worker_reports_available_slot() { service_name=\"temporal-core-sdk\",task_queue=\"one_slot_worker_tq\",\ worker_type=\"LocalActivityWorker\"}} 0" ))); + assert!(body.contains(&format!( + "temporal_worker_task_slots_used{{namespace=\"{NAMESPACE}\",\ + service_name=\"temporal-core-sdk\",task_queue=\"one_slot_worker_tq\",\ + worker_type=\"NexusWorker\"}} 0" + ))); // Now we allow the complete to proceed. Once it goes through, there should be 2 WFT slot // open but 0 activity slots @@ -387,8 +411,9 @@ async fn one_slot_worker_reports_available_slot() { service_name=\"temporal-core-sdk\",task_queue=\"one_slot_worker_tq\",\ worker_type=\"LocalActivityWorker\"}} 1" ))); + worker.initiate_shutdown(); }; - join!(wf_polling, act_polling, testing); + join!(wf_polling, act_polling, nexus_polling, testing); } #[rstest::rstest] @@ -878,6 +903,187 @@ async fn activity_metrics() { ))); } +#[tokio::test] +async fn nexus_metrics() { + let (telemopts, addr, _aborter) = prom_metrics(None); + let rt = CoreRuntime::new_assume_tokio(telemopts).unwrap(); + let wf_name = "nexus_metrics"; + let mut starter = CoreWfStarter::new_with_runtime(wf_name, rt); + starter.worker_config.no_remote_activities(true); + let task_queue = starter.get_task_queue().to_owned(); + let mut worker = starter.worker().await; + let core_worker = starter.get_worker().await; + let endpoint = mk_nexus_endpoint(&mut starter).await; + + worker.register_wf(wf_name.to_string(), move |ctx: WfContext| { + let partial_op = NexusOperationOptions { + endpoint: endpoint.clone(), + service: "mysvc".to_string(), + operation: "myop".to_string(), + ..Default::default() + }; + async move { + join!( + async { + ctx.start_nexus_operation(partial_op.clone()) + .await + .unwrap() + .result() + .await + }, + async { + ctx.start_nexus_operation(NexusOperationOptions { + input: Some("fail".into()), + ..partial_op.clone() + }) + .await + .unwrap() + .result() + .await + }, + async { + ctx.start_nexus_operation(NexusOperationOptions { + input: Some("handler-fail".into()), + ..partial_op.clone() + }) + .await + .unwrap() + .result() + .await + }, + async { + ctx.start_nexus_operation(NexusOperationOptions { + input: Some("timeout".into()), + schedule_to_close_timeout: Some(Duration::from_secs(2)), + ..partial_op.clone() + }) + .await + .unwrap() + .result() + .await + } + ); + Ok(().into()) + } + }); + + starter.start_with_worker(wf_name, &mut worker).await; + + let nexus_polling = async { + for _ in 0..5 { + let nt = core_worker.poll_nexus_task().await.unwrap(); + let task_token = nt.task_token().to_vec(); + let status = if matches!(nt.variant, Some(nexus_task::Variant::CancelTask(_))) { + nexus_task_completion::Status::AckCancel(true) + } else { + let nt = nt.unwrap_task(); + match nt.request.unwrap().variant.unwrap() { + Variant::StartOperation(s) => match s.payload { + Some(p) if p.data.is_empty() => { + nexus_task_completion::Status::Completed(nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::StartOperation( + StartOperationResponse { + variant: Some( + start_operation_response::Variant::SyncSuccess( + start_operation_response::Sync { + payload: Some("yay".into()), + }, + ), + ), + }, + )), + }) + } + Some(p) if p == "fail".into() => { + nexus_task_completion::Status::Completed(nexus::v1::Response { + variant: Some(nexus::v1::response::Variant::StartOperation( + StartOperationResponse { + variant: Some( + start_operation_response::Variant::OperationError( + UnsuccessfulOperationError { + operation_state: "failed".to_string(), + failure: Some(nexus::v1::Failure { + message: "fail".to_string(), + ..Default::default() + }), + }, + ), + ), + }, + )), + }) + } + Some(p) if p == "handler-fail".into() => { + nexus_task_completion::Status::Error(HandlerError { + error_type: "BAD_REQUEST".to_string(), + failure: Some(nexus::v1::Failure { + message: "handler-fail".to_string(), + ..Default::default() + }), + }) + } + Some(p) if p == "timeout".into() => { + // Don't do anything, will wait for timeout task + continue; + } + _ => unreachable!(), + }, + _ => unreachable!(), + } + }; + core_worker + .complete_nexus_task(NexusTaskCompletion { + task_token, + status: Some(status), + }) + .await + .unwrap(); + } + // Gotta get shutdown poll + assert_matches!( + core_worker.poll_nexus_task().await, + Err(PollError::ShutDown) + ); + }; + + join!(nexus_polling, async { + worker.run_until_done().await.unwrap() + }); + + let body = get_text(format!("http://{addr}/metrics")).await; + assert!(body.contains(&format!( + "temporal_nexus_task_execution_failed{{failure_reason=\"handler_error_BAD_REQUEST\",\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 1" + ))); + assert!(body.contains(&format!( + "temporal_nexus_task_execution_failed{{failure_reason=\"timeout\",\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 1" + ))); + assert!(body.contains(&format!( + "temporal_nexus_task_execution_failed{{failure_reason=\"operation_failed\",\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 1" + ))); + assert!(body.contains(&format!( + "temporal_nexus_task_schedule_to_start_latency_count{{\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 4" + ))); + assert!(body.contains(&format!( + "temporal_nexus_task_execution_latency_count{{\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 4" + ))); + // Only 3 actually finished - the timed-out one will not have an e2e latency + assert!(body.contains(&format!( + "temporal_nexus_task_endtoend_latency_count{{\ + namespace=\"{NAMESPACE}\",service_name=\"temporal-core-sdk\",\ + task_queue=\"{task_queue}\"}} 3" + ))); +} + #[tokio::test] async fn evict_on_complete_does_not_count_as_forced_eviction() { let (telemopts, addr, _aborter) = prom_metrics(None); diff --git a/tests/integ_tests/workflow_tests/nexus.rs b/tests/integ_tests/workflow_tests/nexus.rs index 5dbef9723..366be343b 100644 --- a/tests/integ_tests/workflow_tests/nexus.rs +++ b/tests/integ_tests/workflow_tests/nexus.rs @@ -1,3 +1,4 @@ +use crate::integ_tests::mk_nexus_endpoint; use anyhow::bail; use assert_matches::assert_matches; use std::time::Duration; @@ -17,11 +18,9 @@ use temporal_sdk_core_protos::{ failure::v1::failure::FailureInfo, nexus, nexus::v1::{ - endpoint_target, request, start_operation_response, workflow_event_link_from_nexus, - CancelOperationResponse, EndpointSpec, EndpointTarget, HandlerError, - StartOperationResponse, + request, start_operation_response, workflow_event_link_from_nexus, + CancelOperationResponse, HandlerError, StartOperationResponse, }, - operatorservice::v1::CreateNexusEndpointRequest, }, }; use temporal_sdk_core_test_utils::{rand_6_chars, CoreWfStarter}; @@ -47,7 +46,7 @@ async fn nexus_basic( let mut worker = starter.worker().await; let core_worker = starter.get_worker().await; - let endpoint = mk_endpoint(&mut starter).await; + let endpoint = mk_nexus_endpoint(&mut starter).await; worker.register_wf(wf_name.to_owned(), move |ctx: WfContext| { let endpoint = endpoint.clone(); @@ -98,12 +97,10 @@ async fn nexus_basic( Outcome::Fail | Outcome::Timeout => { if outcome == Outcome::Timeout { // Wait for the timeout task cancel to get sent - dbg!("Waiting"); let timeout_t = core_worker.poll_nexus_task().await.unwrap(); let cancel = assert_matches!(timeout_t.variant, Some(nexus_task::Variant::CancelTask(ct)) => ct); assert_eq!(cancel.reason, NexusTaskCancelReason::TimedOut as i32); - dbg!("Done waiting!"); } core_worker .complete_nexus_task(NexusTaskCompletion { @@ -182,7 +179,7 @@ async fn nexus_async( let mut worker = starter.worker().await; let core_worker = starter.get_worker().await; - let endpoint = mk_endpoint(&mut starter).await; + let endpoint = mk_nexus_endpoint(&mut starter).await; let schedule_to_close_timeout = if outcome == Outcome::CancelAfterRecordedBeforeStarted { // If we set this, it'll time out before we can cancel it. None @@ -388,7 +385,7 @@ async fn nexus_cancel_before_start() { starter.worker_config.no_remote_activities(true); let mut worker = starter.worker().await; - let endpoint = mk_endpoint(&mut starter).await; + let endpoint = mk_nexus_endpoint(&mut starter).await; worker.register_wf(wf_name.to_owned(), move |ctx: WfContext| { let endpoint = endpoint.clone(); @@ -431,7 +428,7 @@ async fn nexus_must_complete_task_to_shutdown(#[values(true, false)] use_grace_p let mut worker = starter.worker().await; let core_worker = starter.get_worker().await; - let endpoint = mk_endpoint(&mut starter).await; + let endpoint = mk_nexus_endpoint(&mut starter).await; worker.register_wf(wf_name.to_owned(), move |ctx: WfContext| { let endpoint = endpoint.clone(); @@ -502,27 +499,3 @@ async fn nexus_must_complete_task_to_shutdown(#[values(true, false)] use_grace_p // The first thing to finish needs to have been the nexus task completion assert_eq!(complete_order_rx.recv().await.unwrap(), "t"); } - -async fn mk_endpoint(starter: &mut CoreWfStarter) -> String { - let client = starter.get_client().await; - let endpoint = format!("mycoolendpoint-{}", rand_6_chars()); - let mut op_client = client.get_client().inner().operator_svc().clone(); - op_client - .create_nexus_endpoint(CreateNexusEndpointRequest { - spec: Some(EndpointSpec { - name: endpoint.to_owned(), - description: None, - target: Some(EndpointTarget { - variant: Some(endpoint_target::Variant::Worker(endpoint_target::Worker { - namespace: client.namespace().to_owned(), - task_queue: starter.get_task_queue().to_owned(), - })), - }), - }), - }) - .await - .unwrap(); - // Endpoint creation can (as of server 1.25.2 at least) return before they are actually usable. - tokio::time::sleep(Duration::from_millis(800)).await; - endpoint -} diff --git a/tests/main.rs b/tests/main.rs index 69288cd09..7e7850e52 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -19,15 +19,21 @@ mod integ_tests { mod worker_tests; mod workflow_tests; - use std::{env, str::FromStr}; + use std::{env, str::FromStr, time::Duration}; use temporal_client::WorkflowService; use temporal_sdk_core::{ init_worker, ClientOptionsBuilder, ClientTlsConfig, CoreRuntime, TlsConfig, WorkflowClientTrait, }; use temporal_sdk_core_api::worker::WorkerConfigBuilder; - use temporal_sdk_core_protos::temporal::api::workflowservice::v1::ListNamespacesRequest; - use temporal_sdk_core_test_utils::{get_integ_server_options, get_integ_telem_options}; + use temporal_sdk_core_protos::temporal::api::{ + nexus::v1::{endpoint_target, EndpointSpec, EndpointTarget}, + operatorservice::v1::CreateNexusEndpointRequest, + workflowservice::v1::ListNamespacesRequest, + }; + use temporal_sdk_core_test_utils::{ + get_integ_server_options, get_integ_telem_options, rand_6_chars, CoreWfStarter, + }; use url::Url; // Create a worker like a bridge would (unwraps aside) @@ -103,4 +109,28 @@ mod integ_tests { .await .unwrap(); } + + pub(crate) async fn mk_nexus_endpoint(starter: &mut CoreWfStarter) -> String { + let client = starter.get_client().await; + let endpoint = format!("mycoolendpoint-{}", rand_6_chars()); + let mut op_client = client.get_client().inner().operator_svc().clone(); + op_client + .create_nexus_endpoint(CreateNexusEndpointRequest { + spec: Some(EndpointSpec { + name: endpoint.to_owned(), + description: None, + target: Some(EndpointTarget { + variant: Some(endpoint_target::Variant::Worker(endpoint_target::Worker { + namespace: client.namespace().to_owned(), + task_queue: starter.get_task_queue().to_owned(), + })), + }), + }), + }) + .await + .unwrap(); + // Endpoint creation can (as of server 1.25.2 at least) return before they are actually usable. + tokio::time::sleep(Duration::from_millis(800)).await; + endpoint + } } From 1f204cbef7ae7f0308a48af43cbf7893ac093012 Mon Sep 17 00:00:00 2001 From: Spencer Judge Date: Fri, 10 Jan 2025 16:57:33 -0800 Subject: [PATCH 14/14] Fix dumb name --- core-api/src/errors.rs | 2 +- core/src/worker/mod.rs | 2 +- core/src/worker/nexus.rs | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core-api/src/errors.rs b/core-api/src/errors.rs index c26a201f5..b4a8e3641 100644 --- a/core-api/src/errors.rs +++ b/core-api/src/errors.rs @@ -60,7 +60,7 @@ pub enum CompleteActivityError { pub enum CompleteNexusError { /// Lang SDK sent us a malformed nexus completion. This likely means a bug in the lang sdk. #[error("Lang SDK sent us a malformed nexus completion: {reason}")] - MalformeNexusCompletion { + MalformedNexusCompletion { /// Reason the completion was malformed reason: String, }, diff --git a/core/src/worker/mod.rs b/core/src/worker/mod.rs index 7cb928dde..aa4f755bc 100644 --- a/core/src/worker/mod.rs +++ b/core/src/worker/mod.rs @@ -198,7 +198,7 @@ impl WorkerTrait for Worker { let status = if let Some(s) = completion.status { s } else { - return Err(CompleteNexusError::MalformeNexusCompletion { + return Err(CompleteNexusError::MalformedNexusCompletion { reason: "Nexus completion had empty status field".to_owned(), }); }; diff --git a/core/src/worker/nexus.rs b/core/src/worker/nexus.rs index a85cb2287..0e110e73e 100644 --- a/core/src/worker/nexus.rs +++ b/core/src/worker/nexus.rs @@ -135,7 +135,7 @@ impl NexusManager { .nexus_task_execution_failed(); }; if task_info.request_kind != RequestKind::Start { - return Err(CompleteNexusError::MalformeNexusCompletion { + return Err(CompleteNexusError::MalformedNexusCompletion { reason: "Nexus response was StartOperation but request was not" .to_string(), }); @@ -143,7 +143,7 @@ impl NexusManager { } Some(response::Variant::CancelOperation(_)) => { if task_info.request_kind != RequestKind::Cancel { - return Err(CompleteNexusError::MalformeNexusCompletion { + return Err(CompleteNexusError::MalformedNexusCompletion { reason: "Nexus response was CancelOperation but request was not" .to_string(), @@ -151,7 +151,7 @@ impl NexusManager { } } None => { - return Err(CompleteNexusError::MalformeNexusCompletion { + return Err(CompleteNexusError::MalformedNexusCompletion { reason: "Nexus completion must contain a status variant " .to_string(), })