diff --git a/ferrunix-core/Cargo.toml b/ferrunix-core/Cargo.toml index c29e255..6b604ec 100644 --- a/ferrunix-core/Cargo.toml +++ b/ferrunix-core/Cargo.toml @@ -17,11 +17,14 @@ categories.workspace = true workspace = true [features] -default = ["multithread"] +default = ["multithread", "tokio"] multithread = ["once_cell/parking_lot"] +tokio = ["dep:tokio", "dep:futures"] [dependencies] once_cell = { version = "1.11" } parking_lot = "0.12" thiserror = "1" inventory = "0.3.1" +tokio = { version = "1.5", default-features = false, features = ["sync"], optional = true } +futures = { version = "0.3.11", optional = true } diff --git a/ferrunix-core/src/registry.rs b/ferrunix-core/src/registry.rs index bdedd75..175cc6b 100644 --- a/ferrunix-core/src/registry.rs +++ b/ferrunix-core/src/registry.rs @@ -20,10 +20,22 @@ enum Object { Singleton(BoxedSingletonGetter, SingletonCell), } +/// All possible "objects" that can be held by the registry. +#[cfg(feature = "tokio")] +enum AsyncObject { + AsyncTransient(crate::types::AsyncBoxedCtor), + AsyncSingleton( + crate::types::AsyncBoxedSingletonGetter, + Ref, + ), +} + /// Registry for all types that can be constructed or otherwise injected. pub struct Registry { objects: RwLock>, validation: RwLock>, + #[cfg(feature = "tokio")] + objects_async: crate::types::AsyncRwLock>, } impl Registry { @@ -40,6 +52,8 @@ impl Registry { Self { objects: RwLock::new(HashMap::new()), validation: RwLock::new(HashMap::new()), + #[cfg(feature = "tokio")] + objects_async: crate::types::AsyncRwLock::new(HashMap::new()), } } @@ -69,6 +83,7 @@ impl Registry { where T: Registerable, { + // TODO: Move construction out of locked region. self.objects.write().insert( TypeId::of::(), Object::Transient(Box::new(move |_| -> Option { @@ -76,6 +91,43 @@ impl Registry { Some(Box::new(obj)) })), ); + // TODO: Move construction out of locked region. + self.validation + .write() + .insert(TypeId::of::(), Box::new(|_| true)); + } + + /// Register a new transient object, without dependencies. + /// + /// To register a type with dependencies, use the builder returned from + /// [`Registry::with_deps`]. + /// + /// # Parameters + /// * `ctor`: A constructor function returning the newly constructed `T`. + /// This constructor will be called for every `T` that is requested. + #[cfg(feature = "tokio")] + pub async fn transient_async(&self, ctor: F) + where + T: Registerable + Clone, + F: std::future::Future + Send + Sync + 'static, + { + use futures::future::FutureExt; + let sharable_ctor = ctor.shared(); + let boxed: crate::types::AsyncBoxedCtor = Box::new(move |_| { + let cloned_ctor = sharable_ctor.clone(); + let fut = async move { + let obj = cloned_ctor.await; + Option::::Some(Box::new(obj)) + }; + Box::pin(fut) + }); + + // TODO: Move construction out of locked region. + self.objects_async + .write() + .await + .insert(TypeId::of::(), AsyncObject::AsyncTransient(boxed)); + // TODO: Move construction out of locked region. self.validation .write() .insert(TypeId::of::(), Box::new(|_| true)); @@ -100,15 +152,69 @@ impl Registry { Some(Ref::clone(rc)) }, ); + + // TODO: Move construction out of locked region. self.objects.write().insert( TypeId::of::(), Object::Singleton(getter, OnceCell::new()), ); + // TODO: Move construction out of locked region. self.validation .write() .insert(TypeId::of::(), Box::new(|_| true)); } + /// Register a new singleton object, without dependencies. + /// + /// To register a type with dependencies, use the builder returned from + /// [`Registry::with_deps`]. + /// + /// # Parameters + /// * `ctor`: A constructor function returning the newly constructed `T`. + /// This constructor will be called once, lazily, when the first + /// instance of `T` is requested. + #[cfg(feature = "tokio")] + pub async fn singleton_async(&self, ctor: F) + where + T: Registerable + Clone, + F: std::future::Future + Send + Sync + 'static, + { + use futures::future::FutureExt; + let sharable_ctor = ctor.shared(); + let getter: crate::types::AsyncBoxedSingletonGetter = Box::new( + move |_this: &Self, + cell: &Ref| { + let cloned_ctor = sharable_ctor.clone(); + let cell = Ref::clone(cell); + let fut = async move { + let rc = cell + .get_or_init(move || async move { + let obj = cloned_ctor.await; + Ref::new(obj) as RefAny + }) + .await; + Option::::Some(Ref::clone(rc)) + }; + Box::pin(fut) + }, + ); + + let singleton = AsyncObject::AsyncSingleton( + getter, + Ref::new(crate::types::AsyncSingletonCell::new()), + ); + { + let mut lock = self.objects_async.write().await; + lock.insert(TypeId::of::(), singleton); + } + + let validator: Validator = Box::new(|_| true); + { + let mut lock = self.validation.write(); + lock.insert(TypeId::of::(), validator); + } + } + /// Retrieves a newly constructed `T` from this registry. /// /// Returns `None` if `T` wasn't registered or failed to construct. @@ -116,10 +222,32 @@ impl Registry { where T: Registerable, { - if let Some(Object::Transient(ctor)) = - self.objects.read().get(&TypeId::of::()) - { + let lock = self.objects.read(); + if let Some(Object::Transient(ctor)) = lock.get(&TypeId::of::()) { let boxed = (ctor)(self)?; + drop(lock); + if let Ok(obj) = boxed.downcast::() { + return Some(*obj); + } + } + + None + } + + /// Retrieves a newly constructed `T` from this registry. + /// + /// Returns `None` if `T` wasn't registered or failed to construct. + #[cfg(feature = "tokio")] + pub async fn get_transient_async(&self) -> Option + where + T: Registerable, + { + let lock = self.objects_async.read().await; + if let Some(AsyncObject::AsyncTransient(ctor)) = + lock.get(&TypeId::of::()) + { + let boxed = (ctor)(self).await?; + drop(lock); if let Ok(obj) = boxed.downcast::() { return Some(*obj); } @@ -136,10 +264,35 @@ impl Registry { where T: Registerable, { + let lock = self.objects.read(); if let Some(Object::Singleton(getter, cell)) = - self.objects.read().get(&TypeId::of::()) + lock.get(&TypeId::of::()) { let singleton = (getter)(self, cell)?; + drop(lock); + if let Ok(obj) = singleton.downcast::() { + return Some(obj); + } + } + + None + } + + /// Retrieves the singleton `T` from this registry. + /// + /// Returns `None` if `T` wasn't registered or failed to construct. The + /// singleton is a ref-counted pointer object (either `Arc` or `Rc`). + #[cfg(feature = "tokio")] + pub async fn get_singleton_async(&self) -> Option> + where + T: Registerable, + { + let lock = self.objects_async.read().await; + if let Some(AsyncObject::AsyncSingleton(getter, cell)) = + lock.get(&TypeId::of::()) + { + let singleton = (getter)(self, cell).await?; + drop(lock); if let Ok(obj) = singleton.downcast::() { return Some(obj); } @@ -280,8 +433,7 @@ where /// For single dependencies, the destructured tuple needs to end with a /// comma: `(dep,)`. pub fn transient(&self, ctor: fn(Deps) -> T) { - self.registry.objects.write().insert( - TypeId::of::(), + let transient = Object::Transient(Box::new(move |this| -> Option { #[allow(clippy::option_if_let_else)] match Deps::build( @@ -292,23 +444,71 @@ where Some(obj) => Some(Box::new(obj)), None => None, } - })), - ); - self.registry.validation.write().insert( - TypeId::of::(), - Box::new(|registry: &Registry| { - let type_ids = - Deps::as_typeids(dependency_builder::private::SealToken); - type_ids.iter().all(|el| { - if let Some(validator) = registry.validation.read().get(el) - { - return (validator)(registry); - } + })); + { + let mut lock = self.registry.objects.write(); + lock.insert(TypeId::of::(), transient); + } - false - }) - }), - ); + let validator: Validator = Box::new(|registry: &Registry| { + let type_ids = + Deps::as_typeids(dependency_builder::private::SealToken); + type_ids.iter().all(|el| { + if let Some(validator) = registry.validation.read().get(el) { + return (validator)(registry); + } + + false + }) + }); + + { + let mut lock = self.registry.validation.write(); + lock.insert(TypeId::of::(), validator); + } + } + + #[cfg(feature = "tokio")] + pub async fn transient_async( + &self, + ctor: fn( + Deps, + ) + -> Box + Send + Sync>, + ) { + let transient = + Object::Transient(Box::new(move |this| -> Option { + #[allow(clippy::option_if_let_else)] + match Deps::build( + this, + ctor, + dependency_builder::private::SealToken, + ) { + Some(obj) => Some(Box::new(obj)), + None => None, + } + })); + { + let mut lock = self.registry.objects.write(); + lock.insert(TypeId::of::(), transient); + } + + let validator: Validator = Box::new(|registry: &Registry| { + let type_ids = + Deps::as_typeids(dependency_builder::private::SealToken); + type_ids.iter().all(|el| { + if let Some(validator) = registry.validation.read().get(el) { + return (validator)(registry); + } + + false + }) + }); + + { + let mut lock = self.registry.validation.write(); + lock.insert(TypeId::of::(), validator); + } } /// Register a new singleton object, with dependencies specified in @@ -355,25 +555,27 @@ where } }, ); - self.registry.objects.write().insert( - TypeId::of::(), - Object::Singleton(getter, OnceCell::new()), - ); - self.registry.validation.write().insert( - TypeId::of::(), - Box::new(|registry: &Registry| { - let type_ids = - Deps::as_typeids(dependency_builder::private::SealToken); - type_ids.iter().all(|el| { - if let Some(validator) = registry.validation.read().get(el) - { - return (validator)(registry); - } + let singleton = Object::Singleton(getter, OnceCell::new()); + { + let mut lock = self.registry.objects.write(); + lock.insert(TypeId::of::(), singleton); + } - false - }) - }), - ); + let validator: Validator = Box::new(|registry: &Registry| { + let type_ids = + Deps::as_typeids(dependency_builder::private::SealToken); + type_ids.iter().all(|el| { + if let Some(validator) = registry.validation.read().get(el) { + return (validator)(registry); + } + + false + }) + }); + { + let mut lock = self.registry.validation.write(); + lock.insert(TypeId::of::(), validator); + } } } diff --git a/ferrunix-core/src/types.rs b/ferrunix-core/src/types.rs index cb314b4..ce2bf1e 100644 --- a/ferrunix-core/src/types.rs +++ b/ferrunix-core/src/types.rs @@ -5,6 +5,11 @@ dead_code )] +#[cfg(all(feature = "tokio", not(feature = "multithread")))] +compile_error!( + "the `tokio` feature can only be enabled if `multithread` is also enabled." +); + /// Types that are enabled when the `multithread` feature is set. #[cfg(feature = "multithread")] mod sync { @@ -38,6 +43,37 @@ mod sync { Box Option + Send + Sync>; pub(crate) type Validator = Box bool + Send + Sync>; + #[cfg(feature = "tokio")] + mod tokio_ext { + use super::*; + use std::future::Future; + + // `RwLock` types. + pub(crate) type AsyncRwLock = ::tokio::sync::RwLock; + + pub(crate) type AsyncBoxedCtor = Box< + dyn Fn( + &Registry, + ) -> std::pin::Pin< + Box> + Send + Sync>, + > + Send + + Sync, + >; + pub(crate) type AsyncBoxedSingletonGetter = Box< + dyn Fn( + &Registry, + &Ref, + ) -> std::pin::Pin< + Box> + Send + Sync>, + > + Send + + Sync, + >; + pub(crate) type AsyncSingletonCell = ::tokio::sync::OnceCell; + } + + #[cfg(feature = "tokio")] + pub(crate) use tokio_ext::*; + /// A generic reference type that's used as the default type for types with /// the singleton lifetime. /// diff --git a/ferrunix/Cargo.toml b/ferrunix/Cargo.toml index 0c501e0..5d36043 100644 --- a/ferrunix/Cargo.toml +++ b/ferrunix/Cargo.toml @@ -17,9 +17,10 @@ categories.workspace = true workspace = true [features] -default = ["multithread", "derive"] +default = ["multithread", "derive", "tokio"] multithread = ["ferrunix-core/multithread"] derive = ["dep:ferrunix-macros"] +tokio = ["ferrunix-core/tokio"] [dependencies] ferrunix-core = { path = "../ferrunix-core", default-features = false, version = "=0.2.0" }