Skip to content

Commit

Permalink
Cleaned up slicing syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 3, 2024
1 parent dfd40ed commit bb9552e
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 93 deletions.
14 changes: 14 additions & 0 deletions crates/luminal_metal/src/tests/fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,20 @@ fn test_movement() {
assert_exact(&c.data(), &d_c.as_vec());
}

#[test]
fn test_slice_add() {
let mut cx = Graph::new();
let a = cx.tensor().set(random_array::<256>());
let mut b = (a.slice(0..64) + a.slice(64..128) + a.slice(128..192) + a.slice(192..256))
.realize::<R1<64>>()
.expand::<R2<4, 64>, _>()
.retrieve();

cx.compile(MetalCompiler::<f16>::default(), &mut b);
cx.execute();
cx.display();
}

#[test]
fn test_conv2d() {
let mut cx = Graph::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/luminal_metal/src/tests/fp32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ fn test_conv2d() {
);
cx.execute();

assert_close_precision(
assert_close(
&out1.data(),
&[
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
Expand Down
2 changes: 1 addition & 1 deletion crates/luminal_nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ mod tests {

exp_out1.retrieve();

let model: Conv2D<CH_IN, CH_OUT, KERNELX, KERNELY> = Conv2D::initialize(&mut cx);
let model = Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200,
Expand Down
171 changes: 88 additions & 83 deletions src/shape/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,94 @@ fn get_end_bound<D: Into<Expression> + Copy, S: Into<Expression>>(
}
}

fn dim_to_size(r: Expression) -> usize {
r.to_usize().unwrap_or(i32::MAX as usize)
}

pub trait RangeToDim<D: Dimension> {
type Dimension: Dimension;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression);
}

impl<D: Dimension> RangeToDim<D> for RangeFrom<usize> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeTo<usize> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeToInclusive<usize> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for Range<usize> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeFrom<Expression> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeTo<Expression> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeToInclusive<Expression> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for Range<Expression> {
type Dimension = Dyn<'-'>;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(
get_start_bound(self.start_bound()),
get_end_bound(self.end_bound(), size),
)
}
}
impl<D: Dimension> RangeToDim<D> for RangeFull {
type Dimension = D;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
(0.into(), size.into())
}
}
impl<D: Dimension, R: RangeToDim<D>> RangeToDim<D> for (R,) {
type Dimension = R::Dimension;
fn bounds(&self, size: impl Into<Expression>) -> (Expression, Expression) {
self.0.bounds(size)
}
}

pub trait SliceOfShape<S: Shape> {
Expand All @@ -68,34 +122,21 @@ impl SliceOfShape<R0> for () {
}
}

impl<A: Dimension, R: RangeBounds<Expression> + RangeToDim<A>> SliceOfShape<(A,)> for (R,) {
impl<A: Dimension, R: RangeToDim<A>> SliceOfShape<(A,)> for R {
type OutputShape = (R::Dimension,);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![(
get_start_bound(self.0.start_bound()),
get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())),
)]
vec![self.bounds(A::const_size())]
}
}

impl<
A: Dimension,
B: Dimension,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
> SliceOfShape<(A, B)> for (R1, R2)
impl<A: Dimension, B: Dimension, R1: RangeToDim<A>, R2: RangeToDim<B>> SliceOfShape<(A, B)>
for (R1, R2)
{
type OutputShape = (R1::Dimension, R2::Dimension);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![
(
get_start_bound(self.0.start_bound()),
get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())),
),
(
get_start_bound(self.1.start_bound()),
get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())),
),
self.0.bounds(A::const_size()),
self.1.bounds(B::const_size()),
]
}
}
Expand All @@ -104,26 +145,17 @@ impl<
A: Dimension,
B: Dimension,
C: Dimension,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
R1: RangeToDim<A>,
R2: RangeToDim<B>,
R3: RangeToDim<C>,
> SliceOfShape<(A, B, C)> for (R1, R2, R3)
{
type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![
(
get_start_bound(self.0.start_bound()),
get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())),
),
(
get_start_bound(self.1.start_bound()),
get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())),
),
(
get_start_bound(self.2.start_bound()),
get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())),
),
self.0.bounds(A::const_size()),
self.1.bounds(B::const_size()),
self.2.bounds(C::const_size()),
]
}
}
Expand All @@ -133,31 +165,19 @@ impl<
B: Dimension,
C: Dimension,
D: Dimension,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
R4: RangeBounds<Expression> + RangeToDim<C>,
R1: RangeToDim<A>,
R2: RangeToDim<B>,
R3: RangeToDim<C>,
R4: RangeToDim<C>,
> SliceOfShape<(A, B, C, D)> for (R1, R2, R3, R4)
{
type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension, R4::Dimension);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![
(
get_start_bound(self.0.start_bound()),
get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())),
),
(
get_start_bound(self.1.start_bound()),
get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())),
),
(
get_start_bound(self.2.start_bound()),
get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())),
),
(
get_start_bound(self.3.start_bound()),
get_end_bound(self.3.end_bound(), dim_to_size(D::const_size())),
),
self.0.bounds(A::const_size()),
self.1.bounds(B::const_size()),
self.2.bounds(C::const_size()),
self.3.bounds(D::const_size()),
]
}
}
Expand All @@ -168,11 +188,11 @@ impl<
C: Dimension,
D: Dimension,
E: Dimension,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
R4: RangeBounds<Expression> + RangeToDim<C>,
R5: RangeBounds<Expression> + RangeToDim<C>,
R1: RangeToDim<A>,
R2: RangeToDim<B>,
R3: RangeToDim<C>,
R4: RangeToDim<C>,
R5: RangeToDim<C>,
> SliceOfShape<(A, B, C, D, E)> for (R1, R2, R3, R4, R5)
{
type OutputShape = (
Expand All @@ -184,26 +204,11 @@ impl<
);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![
(
get_start_bound(self.0.start_bound()),
get_end_bound(self.0.end_bound(), dim_to_size(A::const_size())),
),
(
get_start_bound(self.1.start_bound()),
get_end_bound(self.1.end_bound(), dim_to_size(B::const_size())),
),
(
get_start_bound(self.2.start_bound()),
get_end_bound(self.2.end_bound(), dim_to_size(C::const_size())),
),
(
get_start_bound(self.3.start_bound()),
get_end_bound(self.3.end_bound(), dim_to_size(D::const_size())),
),
(
get_start_bound(self.4.start_bound()),
get_end_bound(self.4.end_bound(), dim_to_size(E::const_size())),
),
self.0.bounds(A::const_size()),
self.1.bounds(B::const_size()),
self.2.bounds(C::const_size()),
self.3.bounds(D::const_size()),
self.4.bounds(E::const_size()),
]
}
}
24 changes: 16 additions & 8 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,22 @@ pub fn assert_exact<T: PartialEq + Debug>(a_vec: &[T], b_vec: &[T]) {
}
}

pub fn random_array<const N: usize>() -> [f32; N] {
let mut rng = thread_rng();
random_array_rng(&mut rng)
}

pub fn random_array_rng<const N: usize, R: Rng>(rng: &mut R) -> [f32; N] {
let mut arr = [0.; N];
for i in &mut arr {
*i = rng.gen_range(-0.5..0.5);
}
arr
}

pub fn random_vec(n: usize) -> Vec<f32> {
let mut rng = thread_rng();
(0..n).map(|_| rng.gen_range(-0.5..0.5)).collect()
random_vec_rng(n, &mut rng)
}

pub fn random_vec_rng<R: Rng>(n: usize, rng: &mut R) -> Vec<f32> {
Expand All @@ -127,13 +140,8 @@ macro_rules! test_imports {
Axis as LAxis, Const as LConst, *,
},
tests::{
assert_close,
assert_close_precision,
assert_exact,
// harness::{test_compilers_close, test_compilers_exact},
random_vec,
random_vec_rng,
test_graphs,
assert_close, assert_close_precision, assert_exact, random_array, random_array_rng,
random_vec, random_vec_rng, test_graphs,
},
};
};
Expand Down

0 comments on commit bb9552e

Please sign in to comment.