diff --git a/ferrunix-core/src/cycle_detection.rs b/ferrunix-core/src/cycle_detection.rs index cfd875a..ed19f59 100644 --- a/ferrunix-core/src/cycle_detection.rs +++ b/ferrunix-core/src/cycle_detection.rs @@ -1,7 +1,6 @@ //! Implementation of a cycle detection algorithm for our dependency resolution algorithm. use std::any::TypeId; -use std::sync::atomic::{AtomicBool, Ordering}; use crate::dependency_builder::{self, DepBuilder}; use crate::types::{ @@ -15,16 +14,45 @@ pub enum ValidationError { /// A cycle between dependencies has been detected. Cycle, /// Dependencies are missing. - Missing(Vec), + Missing, } impl std::fmt::Display for ValidationError { #[allow(clippy::use_debug)] fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Cycle => write!(fmt, "cycle detected:"), + Self::Cycle => write!(fmt, "cycle detected!"), + Self::Missing => write!(fmt, "dependencies missing!"), + } + } +} + +impl std::error::Error for ValidationError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +/// Detailed validation errors. +#[derive(Debug, Clone, PartialEq, Hash)] +#[non_exhaustive] +pub enum FullValidationError { + /// A cycle between dependencies has been detected. + Cycle(Option), + /// Dependencies are missing. + Missing(Vec), +} + +impl std::fmt::Display for FullValidationError { + #[allow(clippy::use_debug)] + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Cycle(ref node) => match node { + Some(node) => write!(fmt, "cycle detected at {node}"), + None => write!(fmt, "cycle detected!"), + }, Self::Missing(ref all_missing) => { - writeln!(fmt, "dependencies missing!")?; + writeln!(fmt, "dependencies missing:")?; for missing in all_missing { writeln!( @@ -44,7 +72,7 @@ impl std::fmt::Display for ValidationError { } } -impl std::error::Error for ValidationError { +impl std::error::Error for FullValidationError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } @@ -80,8 +108,6 @@ pub(crate) struct DependencyValidator { /// The visitor callbacks. Those are necessary because we only want to register each type once /// we have collected them all. visitor: NonAsyncRwLock>, - /// Whether we have already visited all visitors. - visitor_visited: AtomicBool, /// Context for visitors. context: NonAsyncRwLock, } @@ -91,7 +117,6 @@ impl DependencyValidator { pub(crate) fn new() -> Self { Self { visitor: NonAsyncRwLock::new(HashMap::new()), - visitor_visited: AtomicBool::new(false), context: NonAsyncRwLock::new(VisitorContext::new()), } } @@ -113,8 +138,16 @@ impl DependencyValidator { index }); - self.visitor_visited.store(false, Ordering::Release); - self.visitor.write().insert(TypeId::of::(), visitor); + { + let mut visitors = self.visitor.write(); + visitors.insert(TypeId::of::(), visitor); + drop(visitors); + } + { + let mut context = self.context.write(); + context.reset(); + drop(context); + } } /// Register a new singleton, without any dependencies. @@ -192,8 +225,16 @@ impl DependencyValidator { current }); - self.visitor_visited.store(false, Ordering::Release); - self.visitor.write().insert(TypeId::of::(), visitor); + { + let mut visitors = self.visitor.write(); + visitors.insert(TypeId::of::(), visitor); + drop(visitors); + } + { + let mut context = self.context.write(); + context.reset(); + drop(context); + } } /// Register a new singleton, with dependencies specified via `Deps`. @@ -210,65 +251,104 @@ impl DependencyValidator { /// Walk the dependency graph and validate that all types can be constructed, all dependencies /// are fulfillable and there are no cycles in the graph. pub(crate) fn validate_all(&self) -> Result<(), ValidationError> { - // This **must** be a separate `if`, otherwise the lock is held also in the `else`. - // if let Some(cache) = &*self.validation_cache.read() { - // // Validation is cached. - // { - // let missing = self.missing_cache.read(); - // if missing.len() > 0 { - // let mut vec = Vec::with_capacity(missing.len()); - // for (_, ty) in missing.iter() { - // vec.push(ty.clone()); - // } - // return Err(ValidationError::Missing(vec)); - // } - // } - - // // EARLY RETURN ABSOLUTELY REQUIRED! - // return match cache { - // Ok(_) => Ok(()), - // Err(_err) => Err(ValidationError::Cycle), - // }; - // } - - // Validation is **not** cached. - - // if self - // .visitor_visited - // .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) - // .is_ok() - // { - let mut context = self.context.write(); - - // Make sure we have all types registered. - { - let visitor = self.visitor.read(); - for (_type_id, cb) in visitor.iter() { - // To avoid a dead lock due to other visitors needing to be called, we pass in the - // visitors hashmap. - (cb.0)(self, &visitor, &mut context); - } + let read_context = self.context.read(); + if Self::validate_context(&read_context)? { + // Validation result is still cached. + return Ok(()); } + // No validation result is cached, drop the read lock and acquire an exclusive lock to + // update the cached validation result. + drop(read_context); + let mut write_context = self.context.write(); + if Self::validate_context(&write_context)? { + // Context was updated by another thread while we waited for the exclusive write lock + // to be acquired. + return Ok(()); + } + + // Validation did not run, we need to run it. + self.calculate_validation(&mut write_context); + + // Throws an error if our dependency graph is invalid. + Self::validate_context(&write_context)?; + + Ok(()) + } + + /// Walk the dependency graph and validate that all types can be constructed, all dependencies + /// are fulfillable and there are no cycles in the graph. + pub(crate) fn validate_all_full(&self) -> Result<(), FullValidationError> { + let mut context = VisitorContext::new(); + self.calculate_validation(&mut context); + + // Evaluate whether we want to make this available via an option? It takes ages to + // calculate! + // let tarjan = petgraph::algo::tarjan_scc(&context.graph); + // dbg!(&tarjan); + if !context.missing.is_empty() { let mut vec = Vec::with_capacity(context.missing.len()); - for (_, ty) in &context.missing { + context.missing.iter().for_each(|(_, ty)| { vec.push(ty.clone()); + }); + return Err(FullValidationError::Missing(vec)); + } + + if let Some(cached) = &context.validation_cache { + return match cached { + Ok(_) => Ok(()), + Err(err) => { + let index = err.node_id(); + let node_name = context.graph.node_weight(index); + return Err(FullValidationError::Cycle( + node_name.map(|el| (*el).to_owned()), + )); + } + }; + } + + unreachable!("this is a bug") + } + + /// Inspect `context`, and return a [`ValidationError`] if there are errors in the dependency + /// graph. + /// + /// Returns `Ok(true)` if the validation result is cached. + /// Returns `Ok(false)` if the validation result is outdated and needs to be recalculated. + fn validate_context( + context: &VisitorContext, + ) -> Result { + if !context.missing.is_empty() { + return Err(ValidationError::Missing); + } + + if let Some(cached) = &context.validation_cache { + return match cached { + Ok(_) => Ok(true), + Err(_) => Err(ValidationError::Cycle), + }; + } + + Ok(false) + } + + /// Visit all visitors in `self.visitor`, and create the new dependency graph. + fn calculate_validation(&self, context: &mut VisitorContext) { + { + // Keep the lock as short as possible. + let visitor = self.visitor.read(); + for (_type_id, cb) in visitor.iter() { + // To avoid a dead lock due to other visitors needing to be called, we pass in the + // visitors hashmap. + (cb.0)(self, &visitor, context); } - return Err(ValidationError::Missing(vec)); } + // We only calculate whether we have let mut space = petgraph::algo::DfsSpace::new(&context.graph); context.validation_cache = Some(petgraph::algo::toposort(&context.graph, Some(&mut space))); - - let ret = match context.validation_cache { - Some(Ok(_)) => Ok(()), - Some(Err(_)) => Err(ValidationError::Cycle), - _ => unreachable!("it's written above"), - }; - - ret } /// Validate whether the type `T` is constructible. diff --git a/ferrunix-core/src/lib.rs b/ferrunix-core/src/lib.rs index 226ede9..b7731b8 100644 --- a/ferrunix-core/src/lib.rs +++ b/ferrunix-core/src/lib.rs @@ -18,7 +18,6 @@ pub mod cycle_detection; pub mod dependencies; pub mod dependency_builder; pub mod error; -pub mod lazy_locked_cache; pub mod object_builder; pub mod registration; pub mod registry; diff --git a/ferrunix-core/src/registry.rs b/ferrunix-core/src/registry.rs index ebc3f5c..43ea58b 100644 --- a/ferrunix-core/src/registry.rs +++ b/ferrunix-core/src/registry.rs @@ -4,7 +4,7 @@ use std::any::TypeId; use std::marker::PhantomData; -use crate::cycle_detection::{DependencyValidator, ValidationError}; +use crate::cycle_detection::{DependencyValidator, FullValidationError, ValidationError}; use crate::dependency_builder::DepBuilder; use crate::object_builder::Object; use crate::types::{ @@ -352,12 +352,27 @@ impl Registry { /// Nontheless, it's recommended to call this before using the [`Registry`]. /// /// # Errors - /// Returns a [`ValidationError`] when the dependency graph is missing dependencies or has cycles. + /// Returns a [`ValidationError`] when the dependency graph is missing dependencies or + /// has cycles. #[cfg_attr(feature = "tracing", tracing::instrument)] pub fn validate_all(&self) -> Result<(), ValidationError> { self.validator.validate_all() } + /// Check whether all registered types have the required dependencies and returns a + /// detailed error about what's missing or where a cycle was detected. + /// + /// This is a potentially expensive call since it needs to go through the + /// entire dependency tree for each registered type. + /// + /// # Errors + /// Returns a [`ValidationError`] when the dependency graph is missing dependencies or + /// has cycles. + #[cfg_attr(feature = "tracing", tracing::instrument)] + pub fn validate_all_full(&self) -> Result<(), FullValidationError> { + self.validator.validate_all_full() + } + /// Check whether the type `T` is registered in this registry, and all /// dependencies of the type `T` are also registered. /// diff --git a/ferrunix/tests/it/cycle_test.rs b/ferrunix/tests/it/cycle_test.rs index 1971e53..818694b 100644 --- a/ferrunix/tests/it/cycle_test.rs +++ b/ferrunix/tests/it/cycle_test.rs @@ -74,6 +74,7 @@ fn detect_cycle() { assert!(registry.validate::().is_err()); assert!(registry.validate_all().is_err()); + assert!(registry.validate_all_full().is_err()); } #[test] @@ -107,6 +108,7 @@ fn detect_missing() { assert!(registry.validate::().is_err()); assert!(registry.validate_all().is_err()); + assert!(registry.validate_all_full().is_err()); } #[test] @@ -136,4 +138,5 @@ fn all_fine() { registry.validate::().unwrap(); registry.validate_all().unwrap(); + registry.validate_all_full().unwrap(); } diff --git a/ferrunix/tests/it/stress.rs b/ferrunix/tests/it/stress.rs index d36920e..3659be2 100644 --- a/ferrunix/tests/it/stress.rs +++ b/ferrunix/tests/it/stress.rs @@ -25,7 +25,7 @@ macro_rules! make_type { #[derive(Debug, Default)] pub(super) struct $base { $( - pub(super) [<_ $deps>]: $deps, + pub(super) [<_ $deps>]: Box<$deps>, )* } @@ -39,7 +39,7 @@ macro_rules! make_type { .transient(|($( [<_ $deps>], )*)| $base {$( - [<_ $deps>]: [<_ $deps>].get(), + [<_ $deps>]: Box::new([<_ $deps>].get()), )*}); } } @@ -56,9 +56,14 @@ macro_rules! make_many_types { make_type!(TypeZero, Dep0); make_type!(Dep0, Dep1); make_type!(Dep1, Dep2); - make_type!(Dep2, Dep3); + make_type!(Dep2, Dep3, TypeNoDeps0, TypeNoDeps1); make_type!(Dep3, Dep4); - make_type!(Dep4); + make_type!(Dep4, Dep5, Config); + make_type!(Dep5, Dep6, Config); + make_type!(Dep6, Dep7); + make_type!(Dep7, Dep8); + make_type!(Dep8, Dep9); + make_type!(Dep9); make_type!(TypeNoDeps0); make_type!(TypeNoDeps1); @@ -115,6 +120,11 @@ macro_rules! register_all_types { $modname::Dep2::register(&$reg); $modname::Dep3::register(&$reg); $modname::Dep4::register(&$reg); + $modname::Dep5::register(&$reg); + $modname::Dep6::register(&$reg); + $modname::Dep7::register(&$reg); + $modname::Dep8::register(&$reg); + $modname::Dep9::register(&$reg); $modname::TypeNoDeps0::register(&$reg); $modname::TypeNoDeps1::register(&$reg); @@ -143,7 +153,10 @@ macro_rules! register_all_types { $modname::TypeSingleDep9::register(&$reg); $modname::TypeSingleDep10::register(&$reg); - $reg.validate_all().unwrap(); + // Error ignored, because it might fail when some other thread is in + // between adding types. + #[allow(clippy::let_underscore_must_use)] + let _ = $reg.validate_all(); }; } @@ -154,6 +167,9 @@ make_many_types!(manytypes3); make_many_types!(manytypes4); make_many_types!(manytypes5); make_many_types!(manytypes6); +make_many_types!(manytypes7); +make_many_types!(manytypes8); +make_many_types!(manytypes9); #[test] fn stress_registration() { @@ -163,42 +179,30 @@ fn stress_registration() { let registry = Arc::clone(®istry); std::thread::spawn(move || { register_all_types!(manytypes0, registry); - }) - }; - let handle1 = { - let registry = Arc::clone(®istry); - std::thread::spawn(move || { register_all_types!(manytypes1, registry); - }) - }; - let handle2 = { - let registry = Arc::clone(®istry); - std::thread::spawn(move || { register_all_types!(manytypes2, registry); }) }; - let handle3 = { + let handle1 = { let registry = Arc::clone(®istry); std::thread::spawn(move || { register_all_types!(manytypes3, registry); - }) - }; - let handle4 = { - let registry = Arc::clone(®istry); - std::thread::spawn(move || { register_all_types!(manytypes4, registry); + register_all_types!(manytypes5, registry); }) }; - let handle5 = { + let handle2 = { let registry = Arc::clone(®istry); std::thread::spawn(move || { - register_all_types!(manytypes5, registry); + register_all_types!(manytypes6, registry); + register_all_types!(manytypes7, registry); + register_all_types!(manytypes8, registry); }) }; - let handle6 = { + let handle3 = { let registry = Arc::clone(®istry); std::thread::spawn(move || { - register_all_types!(manytypes6, registry); + register_all_types!(manytypes9, registry); }) }; @@ -206,10 +210,8 @@ fn stress_registration() { handle1.join().unwrap(); handle2.join().unwrap(); handle3.join().unwrap(); - handle4.join().unwrap(); - handle5.join().unwrap(); - handle6.join().unwrap(); + registry.validate_all_full().unwrap(); registry.validate_all().unwrap(); // println!("{}", registry.dotgraph().unwrap()); }