diff --git a/sys/kern/src/arch/arm_m.rs b/sys/kern/src/arch/arm_m.rs index cb2928d71..1f5cd19c2 100644 --- a/sys/kern/src/arch/arm_m.rs +++ b/sys/kern/src/arch/arm_m.rs @@ -519,7 +519,7 @@ pub fn apply_memory_protection(task: &task::Task) { } } -pub fn start_first_task(tick_divisor: u32, task: &mut task::Task) -> ! { +pub fn start_first_task(tick_divisor: u32, task: &task::Task) -> ! { // Enable faults and set fault/exception priorities to reasonable settings. // Our goal here is to keep the kernel non-preemptive, which means the // kernel entry points (SVCall, PendSV, SysTick, interrupt handlers) must be @@ -654,7 +654,7 @@ pub fn start_first_task(tick_divisor: u32, task: &mut task::Task) -> ! { mpu.ctrl.write(ENABLE | PRIVDEFENA); } - CURRENT_TASK_PTR.store(task, Ordering::Relaxed); + CURRENT_TASK_PTR.store(task as *const _ as *mut _, Ordering::Relaxed); extern "C" { // Exposed by the linker script. @@ -901,9 +901,9 @@ cfg_if::cfg_if! { /// This records a pointer that aliases `task`. As long as you don't read that /// pointer while you have access to `task`, and as long as the `task` being /// stored is actually in the task table, you'll be okay. -pub unsafe fn set_current_task(task: &mut task::Task) { - CURRENT_TASK_PTR.store(task, Ordering::Relaxed); - crate::profiling::event_context_switch(task as *mut _ as usize); +pub unsafe fn set_current_task(task: &task::Task) { + CURRENT_TASK_PTR.store(task as *const _ as *mut _, Ordering::Relaxed); + crate::profiling::event_context_switch(task as *const _ as usize); } /// Reads the tick counter. @@ -1079,7 +1079,6 @@ unsafe extern "C" fn pendsv_entry() { with_task_table(|tasks| { let next = task::select(current, tasks); - let next = &mut tasks[next]; apply_memory_protection(next); // Safety: next comes from the task table and we don't use it again // until next kernel entry, so we meet set_current_task's requirements. @@ -1474,16 +1473,15 @@ unsafe extern "C" fn handle_fault(task: *mut task::Task) { // switch to a task to run. with_task_table(|tasks| { let next = match task::force_fault(tasks, idx, fault) { - task::NextTask::Specific(i) => i, + task::NextTask::Specific(i) => &tasks[i], task::NextTask::Other => task::select(idx, tasks), - task::NextTask::Same => idx, + task::NextTask::Same => &tasks[idx], }; - if next == idx { + if core::ptr::eq(next as *const _, task as *const _) { panic!("attempt to return to Task #{idx} after fault"); } - let next = &mut tasks[next]; apply_memory_protection(next); // Safety: next comes from the task table and we don't use it again // until next kernel entry, so we meet set_current_task's requirements. @@ -1688,16 +1686,15 @@ unsafe extern "C" fn handle_fault( // fault!) with_task_table(|tasks| { let next = match task::force_fault(tasks, idx, fault) { - task::NextTask::Specific(i) => i, + task::NextTask::Specific(i) => &tasks[i], task::NextTask::Other => task::select(idx, tasks), - task::NextTask::Same => idx, + task::NextTask::Same => &tasks[idx], }; - if next == idx { + if core::ptr::eq(next as *const _, task as *const _) { panic!("attempt to return to Task #{idx} after fault"); } - let next = &mut tasks[next]; apply_memory_protection(next); // Safety: this leaks a pointer aliasing next into static scope, but // we're not going to read it back until the next kernel entry, so we diff --git a/sys/kern/src/startup.rs b/sys/kern/src/startup.rs index 4bc1edd8f..be46a405b 100644 --- a/sys/kern/src/startup.rs +++ b/sys/kern/src/startup.rs @@ -82,15 +82,11 @@ pub unsafe fn start_kernel(tick_divisor: u32) -> ! { // Great! Pick our first task. We'll act like we're scheduling after the // last task, which will cause a scan from 0 on. - let first_task_index = - crate::task::select(task_table.len() - 1, task_table); + let first_task = crate::task::select(task_table.len() - 1, task_table); - crate::arch::apply_memory_protection(&task_table[first_task_index]); + crate::arch::apply_memory_protection(first_task); TASK_TABLE_IN_USE.store(false, Ordering::Release); - crate::arch::start_first_task( - tick_divisor, - &mut task_table[first_task_index], - ) + crate::arch::start_first_task(tick_divisor, first_task) } /// Runs `body` with a reference to the task table. diff --git a/sys/kern/src/syscalls.rs b/sys/kern/src/syscalls.rs index d97c08f1e..3fd501a6a 100644 --- a/sys/kern/src/syscalls.rs +++ b/sys/kern/src/syscalls.rs @@ -99,14 +99,14 @@ pub unsafe extern "C" fn syscall_entry(nr: u32, task: *mut Task) { NextTask::Specific(i) => { // Safety: this is a valid task from the tasks table, meeting // switch_to's requirements. - unsafe { switch_to(&mut tasks[i]) } + unsafe { switch_to(&tasks[i]) } } NextTask::Other => { let next = task::select(idx, tasks); // Safety: this is a valid task from the tasks table, meeting // switch_to's requirements. - unsafe { switch_to(&mut tasks[next]) } + unsafe { switch_to(next) } } } }); @@ -322,7 +322,7 @@ fn recv(tasks: &mut [Task], caller: usize) -> Result { let mut last = caller; // keep track of scan position. // Is anyone blocked waiting to send to us? - while let Some(sender) = task::priority_scan(last, tasks, |t| { + while let Some((sender, _)) = task::priority_scan(last, tasks, |t| { t.state().is_sending_to(caller_id) }) { // Oh hello sender! @@ -667,7 +667,7 @@ fn borrow_lease( /// To avoid causing problems, ensure that `task` is a member of the task table, /// with memory protection generated by the build system, and that your access /// to `task` goes out of scope before next kernel entry. -unsafe fn switch_to(task: &mut Task) { +unsafe fn switch_to(task: &Task) { arch::apply_memory_protection(task); // Safety: our contract above is sufficient to ensure that this is safe. unsafe { diff --git a/sys/kern/src/task.rs b/sys/kern/src/task.rs index b34392038..31095b4b7 100644 --- a/sys/kern/src/task.rs +++ b/sys/kern/src/task.rs @@ -820,47 +820,55 @@ pub fn check_task_id_against_table( /// Selects a new task to run after `previous`. Tries to be fair, kind of. /// /// If no tasks are runnable, the kernel panics. -pub fn select(previous: usize, tasks: &[Task]) -> usize { - priority_scan(previous, tasks, |t| t.is_runnable()) - .expect("no tasks runnable") +pub fn select(previous: usize, tasks: &[Task]) -> &Task { + match priority_scan(previous, tasks, |t| t.is_runnable()) { + Some((_index, task)) => task, + None => panic!(), + } } +/// Scans the task table to find a prioritized candidate. +/// /// Scans `tasks` for the next task, after `previous`, that satisfies `pred`. If /// more than one task satisfies `pred`, returns the most important one. If /// multiple tasks with the same priority satisfy `pred`, prefers the first one -/// in order after `previous`, mod `tasks.len()`. +/// in order after `previous`, mod `tasks.len()`. Finally, if no tasks satisfy +/// `pred`, returns `None` /// /// Whew. /// /// This is generally the right way to search a task table, and is used to /// implement (among other bits) the scheduler. /// -/// # Panics -/// -/// If `previous` is not a valid index in `tasks`. +/// On success, the return value is the task's index in the task table, and a +/// direct reference to the task. pub fn priority_scan( previous: usize, tasks: &[Task], pred: impl Fn(&Task) -> bool, -) -> Option { - uassert!(previous < tasks.len()); - let search_order = (previous + 1..tasks.len()).chain(0..previous + 1); - let mut choice = None; - for i in search_order { - if !pred(&tasks[i]) { +) -> Option<(usize, &Task)> { + let mut pos = previous; + let mut choice: Option<(usize, &Task)> = None; + for _step_no in 0..tasks.len() { + pos = pos.wrapping_add(1); + if pos >= tasks.len() { + pos = 0; + } + let t = &tasks[pos]; + if !pred(t) { continue; } - if let Some((_, prio)) = choice { - if !tasks[i].priority.is_more_important_than(prio) { + if let Some((_, best_task)) = choice { + if !t.priority.is_more_important_than(best_task.priority) { continue; } } - choice = Some((i, tasks[i].priority)); + choice = Some((pos, t)); } - choice.map(|(idx, _)| idx) + choice } /// Puts a task into a forced fault condition. diff --git a/sys/kerncore/src/lib.rs b/sys/kerncore/src/lib.rs index 8c7cabb12..f7657b22e 100644 --- a/sys/kerncore/src/lib.rs +++ b/sys/kerncore/src/lib.rs @@ -9,6 +9,8 @@ #![cfg_attr(not(test), no_std)] #![forbid(clippy::wildcard_imports)] +use core::cmp::Ordering; + /// Describes types that act as "slices" (in the very abstract sense) referenced /// by tasks in syscalls. /// @@ -84,6 +86,23 @@ pub trait MemoryRegion { fn end_addr(&self) -> usize; } +/// Compares a memory region to an address for use in binary-searching a region +/// table. +/// +/// This will return `Equal` if the address falls within the region, `Greater` +/// if the address is lower, `Less` if the address is higher. i.e. it returns +/// the status of the region relative to the address, not vice versa. +#[inline(always)] +fn region_compare(region: &impl MemoryRegion, addr: usize) -> Ordering { + if addr < region.base_addr() { + Ordering::Greater + } else if addr >= region.end_addr() { + Ordering::Less + } else { + Ordering::Equal + } +} + impl MemoryRegion for &T { #[inline(always)] fn contains(&self, addr: usize) -> bool { @@ -159,35 +178,53 @@ where // Per the function's preconditions, the region table is sorted in ascending // order of base address, and the regions within it do not overlap. This - // lets us use a one-pass algorithm. + // lets us use a binary search followed by a short scan let mut scan_addr = slice.base_addr(); let end_addr = slice.end_addr(); - for region in table { - if region.contains(scan_addr) { - // Make sure it's permissible! - if !region_ok(region) { - // bail to the fail handling code at the end. - break; - } - - if end_addr <= region.end_addr() { - // We've exhausted the slice in this region, we don't have - // to continue processing. - return true; - } + let Ok(index) = + table.binary_search_by(|reg| region_compare(reg, scan_addr)) + else { + // No region contained the start address. + return false; + }; + + // Perform fast checks on the initial region. In practical testing this + // provides a ~1% performance improvement over only using the loop below. + let first_region = &table[index]; + if !region_ok(first_region) { + return false; + } + // Advance to the end of the first region + scan_addr = first_region.end_addr(); + if scan_addr >= end_addr { + // That was easy + return true; + } - // Continue scanning at the end of this region. - scan_addr = region.end_addr(); - } else if region.base_addr() > scan_addr { - // We've passed our target address without finding regions that - // work! + // Scan adjacent regions. + for region in &table[index + 1..] { + if !region.contains(scan_addr) { + // We've hit a hole without finishing our scan. + break; + } + // Make sure the region is permissible! + if !region_ok(region) { + // bail to the fail handling code at the end. break; } + + if end_addr <= region.end_addr() { + // This region contains the end of our slice! We made it! + return true; + } + + // Continue scanning at the end of this region. + scan_addr = region.end_addr(); } - // We reach this point by exhausting the region table, or finding a - // region at a higher address than the slice. + // We reach this point by exhausting the region table without reaching the + // end of the slice, or hitting a hole. false }