Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rough implementation of task poll callbacks #7107

Closed
wants to merge 12 commits into from
111 changes: 111 additions & 0 deletions tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ pub struct Builder {
/// To run before each task is spawned.
pub(super) before_spawn: Option<TaskCallback>,

/// To run before each poll
#[cfg(tokio_unstable)]
pub(super) before_poll: Option<TaskCallback>,

/// To run after each poll
#[cfg(tokio_unstable)]
pub(super) after_poll: Option<TaskCallback>,

/// To run after each task is terminated.
pub(super) after_termination: Option<TaskCallback>,

Expand Down Expand Up @@ -306,6 +314,11 @@ impl Builder {
before_spawn: None,
after_termination: None,

#[cfg(tokio_unstable)]
before_poll: None,
#[cfg(tokio_unstable)]
after_poll: None,

keep_alive: None,

// Defaults for these values depend on the scheduler kind, so we get them
Expand Down Expand Up @@ -743,6 +756,92 @@ impl Builder {
self
}

/// Executes function `f` just before a task is polled
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
/// invoked immediately.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
/// features][unstable] for details.
///
/// [unstable]: crate#unstable-features
///
/// # Examples
///
/// ```
/// # use std::sync::{atomic::AtomicUsize, Arc};
/// # use tokio::task::yield_now;
/// # pub fn main() {
/// let poll_start_counter = Arc::new(AtomicUsize::new(0));
/// let poll_start = poll_start_counter.clone();
/// let rt = tokio::runtime::Builder::new_multi_thread()
/// .enable_all()
/// .on_before_task_poll(move |meta| {
/// println!("task {} is about to be polled", meta.id())
/// })
/// .build()
/// .unwrap();
/// let task = rt.spawn(async {
/// yield_now().await;
/// });
/// let _ = rt.block_on(task);
///
/// # }
/// ```
#[cfg(tokio_unstable)]
pub fn on_before_task_poll<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.before_poll = Some(std::sync::Arc::new(f));
self
}

/// Executes function `f` just after a task is polled
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
/// invoked immediately.
///
/// **Note**: This is an [unstable API][unstable]. The public API of this type
/// may break in 1.x releases. See [the documentation on unstable
/// features][unstable] for details.
///
/// [unstable]: crate#unstable-features
///
/// # Examples
///
/// ```
/// # use std::sync::{atomic::AtomicUsize, Arc};
/// # use tokio::task::yield_now;
/// # pub fn main() {
/// let poll_stop_counter = Arc::new(AtomicUsize::new(0));
/// let poll_stop = poll_stop_counter.clone();
/// let rt = tokio::runtime::Builder::new_multi_thread()
/// .enable_all()
/// .on_after_task_poll(move |meta| {
/// println!("task {} completed polling", meta.id());
/// })
/// .build()
/// .unwrap();
/// let task = rt.spawn(async {
/// yield_now().await;
/// });
/// let _ = rt.block_on(task);
///
/// # }
/// ```
#[cfg(tokio_unstable)]
pub fn on_after_task_poll<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.after_poll = Some(std::sync::Arc::new(f));
self
}

/// Executes function `f` just after a task is terminated.
///
/// `f` is called within the Tokio context, so functions like
Expand Down Expand Up @@ -1410,6 +1509,10 @@ impl Builder {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
#[cfg(tokio_unstable)]
before_poll: self.before_poll.clone(),
#[cfg(tokio_unstable)]
after_poll: self.after_poll.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
Expand Down Expand Up @@ -1560,6 +1663,10 @@ cfg_rt_multi_thread! {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
#[cfg(tokio_unstable)]
before_poll: self.before_poll.clone(),
#[cfg(tokio_unstable)]
after_poll: self.after_poll.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
Expand Down Expand Up @@ -1610,6 +1717,10 @@ cfg_rt_multi_thread! {
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
#[cfg(tokio_unstable)]
before_poll: self.before_poll.clone(),
#[cfg(tokio_unstable)]
after_poll: self.after_poll.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down
8 changes: 8 additions & 0 deletions tokio/src/runtime/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ pub(crate) struct Config {
/// To run after each task is terminated.
pub(crate) after_termination: Option<TaskCallback>,

/// To run before each poll
#[cfg(tokio_unstable)]
pub(crate) before_poll: Option<TaskCallback>,

/// To run after each poll
#[cfg(tokio_unstable)]
pub(crate) after_poll: Option<TaskCallback>,

/// The multi-threaded scheduler includes a per-worker LIFO slot used to
/// store the last scheduled task. This can improve certain usage patterns,
/// especially message passing between tasks. However, this LIFO slot is not
Expand Down
13 changes: 13 additions & 0 deletions tokio/src/runtime/scheduler/current_thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ impl CurrentThread {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
#[cfg(tokio_unstable)]
before_poll_callback: config.before_poll.clone(),
#[cfg(tokio_unstable)]
after_poll_callback: config.after_poll.clone(),
},
shared: Shared {
inject: Inject::new(),
Expand Down Expand Up @@ -766,8 +770,17 @@ impl CoreGuard<'_> {

let task = context.handle.shared.owned.assert_owner(task);

#[cfg(tokio_unstable)]
let task_id = task.task_id();

let (c, ()) = context.run_task(core, || {
#[cfg(tokio_unstable)]
context.handle.task_hooks.poll_start_callback(task_id);

task.run();

#[cfg(tokio_unstable)]
context.handle.task_hooks.poll_stop_callback(task_id);
});

core = c;
Expand Down
30 changes: 26 additions & 4 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,7 @@ pub(super) fn create(

let remotes_len = remotes.len();
let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
task_hooks: TaskHooks::from_config(&config),
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
Expand Down Expand Up @@ -574,6 +571,9 @@ impl Context {
}

fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult {
#[cfg(tokio_unstable)]
let task_id = task.task_id();

let task = self.worker.handle.shared.owned.assert_owner(task);

// Make sure the worker is not in the **searching** state. This enables
Expand All @@ -593,7 +593,19 @@ impl Context {

// Run the task
coop::budget(|| {
// Unlike the poll time above, poll start callback is attached to the task id,
// so it is tightly associated with the actual poll invocation.
#[cfg(tokio_unstable)]
self.worker.handle.task_hooks.poll_start_callback(task_id);

task.run();

// <we pause here for infinity>
// thread.sleep(60 seconds)

#[cfg(tokio_unstable)]
self.worker.handle.task_hooks.poll_stop_callback(task_id);

let mut lifo_polls = 0;

// As long as there is budget remaining and a task exists in the
Expand Down Expand Up @@ -656,7 +668,17 @@ impl Context {
// Run the LIFO task, then loop
*self.core.borrow_mut() = Some(core);
let task = self.worker.handle.shared.owned.assert_owner(task);

#[cfg(tokio_unstable)]
let task_id = task.task_id();

#[cfg(tokio_unstable)]
self.worker.handle.task_hooks.poll_start_callback(task_id);

task.run();

#[cfg(tokio_unstable)]
self.worker.handle.task_hooks.poll_stop_callback(task_id);
}
})
}
Expand Down
5 changes: 1 addition & 4 deletions tokio/src/runtime/scheduler/multi_thread_alt/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ pub(super) fn create(
let (inject, inject_synced) = inject::Shared::new();

let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
task_hooks: TaskHooks::from_config(&config),
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
Expand Down
32 changes: 23 additions & 9 deletions tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,13 @@ pub(crate) struct LocalNotified<S: 'static> {
_not_send: PhantomData<*const ()>,
}

impl<S> LocalNotified<S> {
#[cfg(tokio_unstable)]
pub(crate) fn task_id(&self) -> Id {
self.task.id()
}
}

/// A task that is not owned by any `OwnedTasks`. Used for blocking tasks.
/// This type holds two ref-counts.
pub(crate) struct UnownedTask<S: 'static> {
Expand Down Expand Up @@ -386,6 +393,16 @@ impl<S: 'static> Task<S> {
self.raw.header_ptr()
}

/// Returns a [task ID] that uniquely identifies this task relative to other
/// currently spawned tasks.
///
/// [task ID]: crate::task::Id
#[cfg(tokio_unstable)]
pub(crate) fn id(&self) -> crate::task::Id {
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}

cfg_taskdump! {
/// Notify the task for task dumping.
///
Expand All @@ -400,22 +417,19 @@ impl<S: 'static> Task<S> {
}
}

/// Returns a [task ID] that uniquely identifies this task relative to other
/// currently spawned tasks.
///
/// [task ID]: crate::task::Id
#[cfg(tokio_unstable)]
pub(crate) fn id(&self) -> crate::task::Id {
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}
}
}

impl<S: 'static> Notified<S> {
fn header(&self) -> &Header {
self.0.header()
}

#[cfg(tokio_unstable)]
#[allow(dead_code)]
pub(crate) fn task_id(&self) -> crate::task::Id {
self.0.id()
}
}

impl<S: 'static> Notified<S> {
Expand Down
40 changes: 40 additions & 0 deletions tokio/src/runtime/task_hooks.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
use std::marker::PhantomData;

use super::Config;

impl TaskHooks {
pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) {
if let Some(f) = self.task_spawn_callback.as_ref() {
f(meta)
}
}

#[allow(dead_code)]
pub(crate) fn from_config(config: &Config) -> Self {
Self {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
#[cfg(tokio_unstable)]
before_poll_callback: config.before_poll.clone(),
#[cfg(tokio_unstable)]
after_poll_callback: config.after_poll.clone(),
}
}

#[cfg(tokio_unstable)]
#[inline]
pub(crate) fn poll_start_callback(&self, id: super::task::Id) {
if let Some(poll_start) = &self.before_poll_callback {
(poll_start)(&TaskMeta {
id,
_phantom: std::marker::PhantomData,
})
}
}

#[cfg(tokio_unstable)]
#[inline]
pub(crate) fn poll_stop_callback(&self, id: super::task::Id) {
if let Some(poll_stop) = &self.after_poll_callback {
(poll_stop)(&TaskMeta {
id,
_phantom: std::marker::PhantomData,
})
}
}
}

#[derive(Clone)]
pub(crate) struct TaskHooks {
pub(crate) task_spawn_callback: Option<TaskCallback>,
pub(crate) task_terminate_callback: Option<TaskCallback>,
#[cfg(tokio_unstable)]
pub(crate) before_poll_callback: Option<TaskCallback>,
#[cfg(tokio_unstable)]
pub(crate) after_poll_callback: Option<TaskCallback>,
}

/// Task metadata supplied to user-provided hooks for task events.
Expand Down
Loading
Loading