Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vadixidav committed Aug 29, 2024
1 parent 30a004d commit 9665475
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#![no_std]
// #![no_std]

#[cfg(feature = "nalgebra")]
mod tonalgebra;
Expand Down
32 changes: 17 additions & 15 deletions src/tonalgebra/ndarray_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use nalgebra::Dyn;
/// ```
/// use nshare::IntoNalgebra;
///
/// let arr = ndarray::arr1(&[0.1, 0.2, 0.3, 0.4]);
/// let arr = ndarray::array![0.1, 0.2, 0.3, 0.4];
/// let m = arr.view().into_nalgebra();
/// assert!(m.iter().eq(&[0.1, 0.2, 0.3, 0.4]));
/// assert_eq!(m.shape(), (4, 1));
Expand All @@ -35,7 +35,7 @@ where
/// ```
/// use nshare::IntoNalgebra;
///
/// let mut arr = ndarray::arr1(&[0.1, 0.2, 0.3, 0.4]);
/// let mut arr = ndarray::array![0.1, 0.2, 0.3, 0.4];
/// let m = arr.view_mut().into_nalgebra();
/// assert!(m.iter().eq(&[0.1, 0.2, 0.3, 0.4]));
/// assert_eq!(m.shape(), (4, 1));
Expand Down Expand Up @@ -63,7 +63,7 @@ where
/// ```
/// use nshare::IntoNalgebra;
///
/// let arr = ndarray::arr1(&[0.1, 0.2, 0.3, 0.4]);
/// let arr = ndarray::array![0.1, 0.2, 0.3, 0.4];
/// let m = arr.into_nalgebra();
/// assert!(m.iter().eq(&[0.1, 0.2, 0.3, 0.4]));
/// assert_eq!(m.shape(), (4, 1));
Expand All @@ -87,12 +87,12 @@ where
/// ```
/// use nshare::IntoNalgebra;
///
/// let arr = ndarray::arr2(&[
/// let arr = ndarray::array![
/// [0.1, 0.2, 0.3, 0.4],
/// [0.5, 0.6, 0.7, 0.8],
/// [1.1, 1.2, 1.3, 1.4],
/// [1.5, 1.6, 1.7, 1.8],
/// ]);
/// ];
/// let m = arr.view().into_nalgebra();
/// assert!(m.row(1).iter().eq(&[0.5, 0.6, 0.7, 0.8]));
/// assert_eq!(m.shape(), (4, 4));
Expand All @@ -107,9 +107,10 @@ where
let nrows = Dyn(self.nrows());
let ncols = Dyn(self.ncols());
let ptr = self.as_ptr();
let stride_row: usize = TryFrom::try_from(self.strides()[0]).expect("Negative row stride");
let stride_col: usize =
TryFrom::try_from(self.strides()[1]).expect("Negative column stride");
let stride_row: usize = TryFrom::try_from(self.strides()[0])
.expect("can only convert positive row stride to nalgebra");
let stride_col: usize = TryFrom::try_from(self.strides()[1])
.expect("can only convert positive col stride to nalgebra");
let storage = unsafe {
nalgebra::ViewStorage::from_raw_parts(
ptr,
Expand All @@ -124,12 +125,12 @@ where
/// ```
/// use nshare::IntoNalgebra;
///
/// let mut arr = ndarray::arr2(&[
/// let mut arr = ndarray::array![
/// [0.1, 0.2, 0.3, 0.4],
/// [0.5, 0.6, 0.7, 0.8],
/// [1.1, 1.2, 1.3, 1.4],
/// [1.5, 1.6, 1.7, 1.8],
/// ]);
/// ];
/// let m = arr.view_mut().into_nalgebra();
/// assert!(m.row(1).iter().eq(&[0.5, 0.6, 0.7, 0.8]));
/// assert_eq!(m.shape(), (4, 4));
Expand All @@ -143,9 +144,10 @@ where
fn into_nalgebra(mut self) -> Self::Out {
let nrows = Dyn(self.nrows());
let ncols = Dyn(self.ncols());
let stride_row: usize = TryFrom::try_from(self.strides()[0]).expect("Negative row stride");
let stride_col: usize =
TryFrom::try_from(self.strides()[1]).expect("Negative column stride");
let stride_row: usize = TryFrom::try_from(self.strides()[0])
.expect("can only convert positive row stride to nalgebra");
let stride_col: usize = TryFrom::try_from(self.strides()[1])
.expect("can only convert positive col stride to nalgebra");
let ptr = self.as_mut_ptr();
let storage = unsafe {
nalgebra::ViewStorageMut::from_raw_parts(
Expand All @@ -161,12 +163,12 @@ where
/// ```
/// use nshare::IntoNalgebra;
///
/// let mut arr = ndarray::arr2(&[
/// let mut arr = ndarray::array![
/// [0.1, 0.2, 0.3, 0.4],
/// [0.5, 0.6, 0.7, 0.8],
/// [1.1, 1.2, 1.3, 1.4],
/// [1.5, 1.6, 1.7, 1.8],
/// ]);
/// ];
/// let m = arr.clone().into_nalgebra();
/// assert!(m.row(1).iter().eq(&[0.5, 0.6, 0.7, 0.8]));
/// assert_eq!(m.shape(), (4, 4));
Expand Down
16 changes: 16 additions & 0 deletions tests/nalgebra.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use nshare::IntoNalgebra;

#[test]
fn single_row_ndarray_to_nalgebra() {
let mut arr = ndarray::array![[0.1, 0.2, 0.3, 0.4]];
let m = arr.view_mut().into_nalgebra();
assert!(m.row(0).iter().eq(&[0.1, 0.2, 0.3, 0.4]));
assert_eq!(m.shape(), (1, 4));
assert!(arr
.view_mut()
.reversed_axes()
.into_nalgebra()
.column(0)
.iter()
.eq(&[0.1, 0.2, 0.3, 0.4]));
}

0 comments on commit 9665475

Please sign in to comment.