Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor nn: helpers and enum #592

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/operators/nn.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ mod core;
mod implementations;
mod functional;
mod common;
mod helpers;

use orion::operators::nn::common::{AUTO_PAD, POOLING_TYPE};
use orion::operators::nn::common::{AUTO_PAD, MODE, PADDING_MODE, POOLING_TYPE};

use orion::operators::nn::core::NNTrait;

Expand Down
14 changes: 14 additions & 0 deletions src/operators/nn/common.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,17 @@ enum POOLING_TYPE {
LPPOOL,
MAX,
}

#[derive(Copy, Drop)]
enum MODE {
NEAREST,
LINEAR,
CUBIC,
}

#[derive(Copy, Drop)]
enum PADDING_MODE {
ZEROS,
BORDER,
REFLECTION,
}
8 changes: 4 additions & 4 deletions src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use orion::operators::tensor::core::Tensor;
use orion::operators::nn::AUTO_PAD;
use orion::operators::nn::{AUTO_PAD, MODE, PADDING_MODE};

/// Trait
///
Expand Down Expand Up @@ -1087,7 +1087,7 @@ trait NNTrait<T> {
X: @Tensor<T>,
W: @Tensor<T>,
B: Option<@Tensor<T>>,
auto_pad: Option<orion::operators::nn::functional::conv_transpose::AUTO_PAD>,
auto_pad: Option<AUTO_PAD>,
dilations: Option<Span<usize>>,
group: Option<usize>,
kernel_shape: Option<Span<usize>>,
Expand Down Expand Up @@ -1302,8 +1302,8 @@ trait NNTrait<T> {
X: @Tensor<T>,
grid: @Tensor<T>,
align_corner: Option<usize>,
mode: Option<orion::operators::nn::functional::grid_sample::MODE>,
padding_mode: Option<orion::operators::nn::functional::grid_sample::PADDING_MODE>,
mode: Option<MODE>,
padding_mode: Option<PADDING_MODE>,
) -> Tensor<T>;
///
/// # NNTrait::max_pool
Expand Down
52 changes: 10 additions & 42 deletions src/operators/nn/functional/col2im.cairo
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use orion::numbers::NumberTrait;
use orion::operators::tensor::core::{stride};
use orion::operators::tensor::core::{stride, unravel_index};
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor,};
use orion::operators::vec::{NullableVec, NullableVecImpl};
use orion::operators::nn::helpers::{is_out, prod};

fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Add<T>, +Mul<T>,>(
fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Add<T>, +MulEq<T>,>(
data: @Tensor<T>,
image_shape: Span<usize>,
block_shape: Span<usize>,
Expand Down Expand Up @@ -53,7 +54,7 @@ fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Ad
},
};

let bl = prod(block_shape, 0);
let bl = prod(block_shape);
let C = *(*data).shape.at(1) / bl;

let mut new_shape: Array<i32> = array![
Expand Down Expand Up @@ -158,15 +159,15 @@ fn col2im_naive_implementation<
let mut data_im = NullableVecImpl::new();
data_im.set(*image_shape.at(0) * *stride_img.at(0) - 1, NumberTrait::zero());

let kernel_size = prod(kernel_shape, 0);
let col_size = prod(dim_col, 0);
let kernel_size = prod(kernel_shape);
let col_size = prod(dim_col);
let mut c_col = 0;
while c_col != kernel_size {
let offset = get_indices(c_col, kernel_shape).span();
let offset = unravel_index(c_col, kernel_shape);

let mut col = 0;
while col != col_size {
let ind_col = get_indices(col, dim_col).span();
let ind_col = unravel_index(col, dim_col);
let mut ind_im: Array<usize> = array![];
let mut i = 0;
while i != n_dims {
Expand Down Expand Up @@ -218,7 +219,7 @@ fn col2im_shape_check<T, +TensorTrait<T>, +Copy<T>, +Drop<T>,>(
) {
let n_input_plane = *(*X).shape.at(0);

let kernel_size = prod(kernel_shape, 0);
let kernel_size = prod(kernel_shape);

assert(n_input_plane % kernel_size == 0, 'wrong input dimension');

Expand All @@ -240,7 +241,7 @@ fn col2im_shape_check<T, +TensorTrait<T>, +Copy<T>, +Drop<T>,>(
i += 1;
};

let block_size = prod(n_blocks.span(), 0);
let block_size = prod(n_blocks.span());

assert(input_length == block_size, 'input_length != block_size');
}
Expand All @@ -267,36 +268,3 @@ fn get_indices(index: usize, shape: Span<usize>,) -> Array<usize> {

new_res
}

fn is_out(ind: Span<usize>, shape: Span<usize>,) -> bool {
let mut n = 0;
let is_out = loop {
if n == ind.len() {
break false;
}
let s = *shape.at(n);
let i = *ind.at(n);
if i < 0 {
break true;
}
if i >= s {
break true;
}
n += 1;
};

is_out
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}
163 changes: 3 additions & 160 deletions src/operators/nn/functional/conv.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use orion::numbers::{U32IntoI32, I32IntoU32, I32Div, I32Number};
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor,};
use orion::operators::vec::{NullableVec, NullableVecImpl};
use orion::operators::tensor::core::{stride};

use orion::operators::nn::helpers::{cartesian, arange, max_in_tensor, min_in_tensor, dot};
use orion::operators::nn::AUTO_PAD;


Expand Down Expand Up @@ -230,7 +230,8 @@ fn conv<
}

// group == 1
if *dilations.at(0) != 1 || min(dilations.clone()) != max(dilations.clone()) {
if *dilations.at(0) != 1
|| min_in_tensor(dilations.clone()) != min_in_tensor(dilations.clone()) {
// computation of the dilated kernel
let nd = dilations.len();
let mut new_kernel_shape: Array<usize> = array![];
Expand Down Expand Up @@ -1213,161 +1214,3 @@ fn r_index_check(r_index: Span<i32>, shape_out: Span<usize>) -> bool {
flag
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}

fn min(mut a: Span<usize>) -> usize {
assert(a.len() > 0, 'span cannot be empty');

let mut min = *a.at(0);
loop {
match a.pop_front() {
Option::Some(v) => { if *v < min {
min = *v;
}; },
Option::None => { break min; }
};
}
}

fn max(mut a: Span<usize>) -> usize {
assert(a.len() > 0, 'span cannot be empty');

let mut max = *a.at(0);
loop {
match a.pop_front() {
Option::Some(v) => { if *v > max {
max = *v;
}; },
Option::None => { break max; }
};
}
}

fn arange(start: usize, end: usize, step: usize) -> Span<usize> {
assert((end - start) % step == 0, 'incompatible step value');

let mut arr: Array<usize> = array![];
let mut i = start;
while i < end {
arr.append(i);
i += step;
};

arr.span()
}


fn cartesian(mut arrays: Span<Span<usize>>,) -> Span<Span<usize>> {
let mut n = 1;
let mut i = arrays.len() - 1;
loop {
n = n * (*(arrays.at(i))).len();
if i == 0 {
break;
}
i -= 1;
};

let mut i = 0;
let mut size_arrays: Array<usize> = array![];
while i != arrays.len() {
size_arrays.append((*(arrays.at(i))).len());
i += 1;
};

let size_arrays = size_arrays.span();
let mut output_arrays = array![];
let mut m = n;

let mut i = 0;
while i != arrays.len() {
m = m / (*(arrays.at(i))).len();
let mut out = repeat(*(arrays.at(i)), m);
out = repeat_2(out, size_arrays, i);

output_arrays.append(out);
i += 1;
};

let output_arrays = output_arrays.span();

let mut i = 0;
let mut ret = ArrayTrait::new();
while i != n {
let mut j = 0;
let mut x: Array<usize> = array![];
while j != arrays.len() {
x.append(*(output_arrays.at(j)).at(i));
j += 1;
};

ret.append(x.span());
i += 1;
};

ret.span()
}

fn repeat_2(mut array: Array<usize>, size_array: Span<usize>, index: usize) -> Array<usize> {
let mut size = array.len();
let mut i = 0;
while i != index {
let mut j = 1;
while j != *size_array.at(index - 1 - i) {
let mut k = 0;
while k != size {
array.append(*array.at(k));
k += 1;
};

j += 1;
};

size = size * *size_array.at(index - 1 - i);
i += 1;
};

array
}

fn repeat(array: Span<usize>, m: usize,) -> Array<usize> {
let mut out: Array<usize> = array![];
let mut j = 0;
while j != array.len() {
let mut k = 0;
while k != m {
out.append(*array.at(j));
k += 1;
};

j += 1;
};

out
}

fn dot<
T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +Add<T>, +TensorTrait<T>, +AddEq<T>, +Mul<T>,
>(
a: Span<T>, b: Span<T>
) -> T {
let mut i = 0;
let mut sum = NumberTrait::zero();
while i != a.len() {
sum = sum + *a.at(i) * *b.at(i);
i += 1;
};

sum
}
Loading
Loading