diff --git a/Cargo.lock b/Cargo.lock index ce3dbba..e032965 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" +[[package]] +name = "atomic" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" +dependencies = [ + "bytemuck", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -41,6 +50,12 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" + [[package]] name = "cast" version = "0.3.0" @@ -406,6 +421,7 @@ dependencies = [ name = "poc-kokkos-rs" version = "0.1.0" dependencies = [ + "atomic", "cfg-if", "criterion", "cxx", diff --git a/Cargo.toml b/Cargo.toml index a8185cd..6e87c64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ rayon = ["dep:rayon"] cxx = "*" cfg-if = "*" rayon = {version = "*", optional=true} +atomic = "0.6.0" [dev-dependencies] criterion = { version = "*", features = ["html_reports"] } diff --git a/src/view/mod.rs b/src/view/mod.rs index 5ebd148..a70c6e9 100644 --- a/src/view/mod.rs +++ b/src/view/mod.rs @@ -44,6 +44,8 @@ pub struct ViewBase<'a, const N: usize, T> { pub stride: [usize; N], } +#[cfg(not(any(feature = "rayon", feature = "threads")))] +// ~~~~~~~~ Constructors impl<'a, const N: usize, T> ViewBase<'a, N, T> where T: Default + Clone, // fair assumption imo @@ -122,6 +124,27 @@ where stride: self.stride, } } +} + +impl<'a, const N: usize, T> ViewBase<'a, N, T> { + // ~~~~~~~~ Uniform writing interface across all features + + #[inline(always)] + #[cfg(not(any(feature = "rayon", feature = "threads")))] + /// Serial writing interface. Uses mutable indexing implementation. + pub fn set(&mut self, index: [usize; N], val: T) { + self[index] = val; + } + + #[inline(always)] + #[cfg(any(feature = "rayon", feature = "threads"))] + /// Thread-safe writing interface. Uses non-mutable indexing and + /// immutability of atomic type methods. + pub fn set(&self, index: [usize; N], val: T) { + self[index].store(val, atomic::Ordering::Relaxed); + } + + // ~~~~~~~~ Convenience pub fn raw_val<'b>(self) -> Result, ViewError<'b>> { if let DataType::Owned(v) = self.data { @@ -143,10 +166,8 @@ where } } -impl<'a, const N: usize, T> Index<[usize; N]> for ViewBase<'a, N, T> -where - T: Default + Clone, -{ +/// Read-only access is always implemented. +impl<'a, const N: usize, T> Index<[usize; N]> for ViewBase<'a, N, T> { type Output = T; fn index(&self, index: [usize; N]) -> &Self::Output { @@ -168,10 +189,10 @@ where } } -impl<'a, const N: usize, T> IndexMut<[usize; N]> for ViewBase<'a, N, T> -where - T: Default + Clone, -{ +#[cfg(not(any(feature = "rayon", feature = "threads")))] +/// Read-write access is implemented using [IndexMut] trait when no parallel +/// features are enabled. +impl<'a, const N: usize, T> IndexMut<[usize; N]> for ViewBase<'a, N, T> { fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output { let flat_idx: usize = self.flat_idx(index); match &mut self.data { diff --git a/src/view/parameters.rs b/src/view/parameters.rs index d61a38e..ff5dfe0 100644 --- a/src/view/parameters.rs +++ b/src/view/parameters.rs @@ -13,20 +13,29 @@ //! - Memory traits? //! +#[cfg(any(feature = "rayon", feature = "threads"))] +use atomic::Atomic; + /// Maximum possible depth (i.e. number of dimensions) for a view. pub const MAX_VIEW_DEPTH: usize = 8; +#[cfg(not(any(feature = "rayon", feature = "threads")))] +pub type InnerDataType = T; + +#[cfg(any(feature = "rayon", feature = "threads"))] +pub type InnerDataType = Atomic; + #[derive(Debug)] /// Enum used to identify the type of data the view is holding. See variants for more /// information. The policy used to implement the [PartialEq] trait is based on Kokkos' /// [`equal` algorithm](https://kokkos.github.io/kokkos-core-wiki/API/algorithms/std-algorithms/all/StdEqual.html). pub enum DataType<'a, T> { /// The view owns the data. - Owned(Vec), + Owned(Vec>), /// The view borrows the data and can only read it. - Borrowed(&'a [T]), + Borrowed(&'a [InnerDataType]), /// The view borrows the data and can both read and modify it. - MutBorrowed(&'a mut [T]), + MutBorrowed(&'a mut [InnerDataType]), } /// Kokkos implements equality check by comparing the pointers, i.e.