Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Outer product implementation #690

Closed
wants to merge 9 commits into from
124 changes: 124 additions & 0 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::numeric_util;
use crate::{LinalgScalar, Zip};

use std::any::TypeId;
use std::mem::MaybeUninit;

#[cfg(feature = "blas")]
use std::cmp;
Expand Down Expand Up @@ -823,3 +824,126 @@ mod blas_tests {
assert!(blas_column_major_2d::<f32, _>(&m));
}
}

#[allow(dead_code)]
fn general_outer_to_dyn<Sa, Sb, F, T>(
a: &ArrayBase<Sa, IxDyn>,
b: &ArrayBase<Sb, IxDyn>,
mut f: F,
) -> ArrayD<T>
where
T: Copy,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
F: FnMut(T, T) -> T,
{
//Iterators on the shapes, compelted by 1s
let a_shape_iter = a.shape().iter().chain([1].iter().cycle());
let b_shape_iter = b.shape().iter().chain([1].iter().cycle());

let res_ndim = std::cmp::max(a.ndim(), b.ndim());
let res_dim: Vec<Ix> = a_shape_iter
.zip(b_shape_iter)
.take(res_ndim)
.map(|(x, y)| x * y)
.collect();

let mut res: ArrayD<MaybeUninit<T>> = ArrayBase::maybe_uninit(res_dim);
let res_chunks = res.exact_chunks_mut(b.shape());
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
Zip::from(b)
.apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem)))
});
unsafe { res.assume_init() }
}

#[allow(dead_code, clippy::type_repetition_in_bounds)]
fn kron_to_dyn<Sa, Sb, T>(a: &ArrayBase<Sa, IxDyn>, b: &ArrayBase<Sb, IxDyn>) -> Array<T, IxDyn>
where
T: Copy,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
{
general_outer_to_dyn(a, b, std::ops::Mul::mul)
}

#[allow(dead_code)]
fn general_outer_same_size<Sa, I, Sb, F, T>(
a: &ArrayBase<Sa, I>,
b: &ArrayBase<Sb, I>,
mut f: F,
) -> Array<T, I>
where
T: Copy,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
I: Dimension,
F: FnMut(T, T) -> T,
{
let mut res_dim = a.raw_dim();
let mut res_dim_view = res_dim.as_array_view_mut();
res_dim_view *= &b.raw_dim().as_array_view();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously I kind of love this, that we can use array methods on dimensions 🙂.

The unsafe code looks good, we only need to prove that all elements are assigned to, and they will be if the exact chunks don't leave any uneven remainder. And that looks good to me, B's shape evenly divides the result's shape, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, we also have to check the product. We need to use a saturating multiplication, otherwise we can overflow, and then the above doesn't hold?


let mut res: Array<MaybeUninit<T>, I> = ArrayBase::maybe_uninit(res_dim);
let res_chunks = res.exact_chunks_mut(b.raw_dim());
Zip::from(res_chunks).and(a).apply(|res_chunk, &a_elem| {
Zip::from(b)
.apply_assign_into(res_chunk, |&b_elem| MaybeUninit::new(f(a_elem, b_elem)))
});
unsafe { res.assume_init() }
}

#[allow(dead_code, clippy::type_repetition_in_bounds)]
fn kron_same_size<Sa, I, Sb, T>(a: &ArrayBase<Sa, I>, b: &ArrayBase<Sb, I>) -> Array<T, I>
where
T: Copy,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
I: Dimension,
T: crate::ScalarOperand + std::ops::Mul<Output = T>,
{
general_outer_same_size(a, b, std::ops::Mul::mul)
}

#[cfg(test)]
mod kron_test {
use super::*;

#[test]
fn test_same_size() {
let a = array![
[[1, 2, 3], [4, 5, 6]],
[[17, 42, 69], [0, -1, 1]],
[[1337, 1, 0], [-1337, -1, 0]]
];
let b = array![
[[55, 66, 77], [88, 99, 1010]],
[[42, 42, 0], [1, -3, 10]],
[[110, 0, 7], [523, 21, -12]]
];
let res1 = kron_same_size(&a, &b);
let res2 = kron_to_dyn(&a.clone().into_dyn(), &b.clone().into_dyn());
assert_eq!(res1.clone().into_dyn(), res2);
for a0 in 0..a.len_of(Axis(0)) {
for a1 in 0..a.len_of(Axis(1)) {
for a2 in 0..a.len_of(Axis(2)) {
for b0 in 0..b.len_of(Axis(0)) {
for b1 in 0..b.len_of(Axis(1)) {
for b2 in 0..b.len_of(Axis(2)) {
assert_eq!(
res2[[
b.shape()[0] * a0 + b0,
b.shape()[1] * a1 + b1,
b.shape()[2] * a2 + b2
]],
a[[a0, a1, a2]] * b[[b0, b1, b2]]
)
}
}
}
}
}
}
}
}