Skip to content

Commit

Permalink
fix col2im and conv_transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Apr 22, 2024
1 parent f131f78 commit f0e7287
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 222 deletions.
33 changes: 0 additions & 33 deletions src/operators/nn/functional/col2im.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -268,36 +268,3 @@ fn get_indices(index: usize, shape: Span<usize>,) -> Array<usize> {

new_res
}

fn is_out(ind: Span<usize>, shape: Span<usize>,) -> bool {
let mut n = 0;
let is_out = loop {
if n == ind.len() {
break false;
}
let s = *shape.at(n);
let i = *ind.at(n);
if i < 0 {
break true;
}
if i >= s {
break true;
}
n += 1;
};

is_out
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}
189 changes: 0 additions & 189 deletions src/operators/nn/functional/conv_transpose.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -486,192 +486,3 @@ fn conv_transpose<

TensorTrait::new(shape.span(), final.span())
}

fn get_image<T, +Drop<T>, +Copy<T>>(self: @Tensor<T>, row: usize) -> Span<T> {
assert((*self).shape.len() == 2, 'Expected a 2D tensor');

let row_length = *self.shape[1];
let start = row * row_length;

(*self).data.slice(start, row_length)
}

fn col2im_naive_implementation<
T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Add<T>,
>(
data: @Tensor<T>,
image_shape: Span<usize>,
kernel_shape: Span<usize>,
dilations: Span<usize>,
pads: Span<usize>,
strides: Span<usize>,
) -> NullableVec<T> {
let n_dims = pads.len() / 2;

col2im_shape_check(data, image_shape, kernel_shape, dilations, pads, strides);

let mut dim_col: Array<usize> = array![];
let mut i = 0;
while i != n_dims {
dim_col
.append(
(*image_shape.at(i)
+ (*pads.at(i) + *pads.at(i + n_dims))
- (*dilations.at(i) * (*kernel_shape.at(i) - 1) + 1))
/ *strides.at(i)
+ 1
);

i += 1;
};

let dim_col = dim_col.span();

let stride_img = stride(image_shape);

let mut data_im = NullableVecImpl::new();
data_im.set(*image_shape.at(0) * *stride_img.at(0) - 1, NumberTrait::zero());

let kernel_size = prod(kernel_shape, 0);
let col_size = prod(dim_col, 0);
let mut c_col = 0;
while c_col != kernel_size {
let offset = get_indices(c_col, kernel_shape).span();

let mut col = 0;
while col != col_size {
let ind_col = get_indices(col, dim_col).span();
let mut ind_im: Array<usize> = array![];
let mut i = 0;
while i != n_dims {
if (*ind_col.at(i) * *strides.at(i) + *offset.at(i) * *dilations.at(i)) < *pads
.at(i) {
let neg_index = *pads.at(i)
- (*ind_col.at(i) * *strides.at(i) + *offset.at(i) * *dilations.at(i));
ind_im.append(*image_shape.at(i) + neg_index);
} else {
ind_im
.append(
*ind_col.at(i) * *strides.at(i)
+ *offset.at(i) * *dilations.at(i)
- *pads.at(i)
);
}

i += 1;
};

let ind_im = ind_im.span();
if !is_out(ind_im, image_shape) {
let mut index = 0;
let mut i = 0;
while i != image_shape.len() {
index += *stride_img.at(i) * *ind_im.at(i);
i += 1;
};

data_im.set(index, data_im.at(index) + *(*data).data.at(c_col * col_size + col));
}

col += 1;
};

c_col += 1;
};

data_im
}

fn col2im_shape_check<T, +TensorTrait<T>, +Copy<T>, +Drop<T>,>(
X: @Tensor<T>,
output_shape: Span<usize>,
kernel_shape: Span<usize>,
dilations: Span<usize>,
pads: Span<usize>,
strides: Span<usize>,
) {
let n_input_plane = *(*X).shape.at(0);

let kernel_size = prod(kernel_shape, 0);

assert(n_input_plane % kernel_size == 0, 'wrong input dimension');

let input_length = *(*X).shape.at(1);
let n_dims = output_shape.len();
let mut n_blocks: Array<usize> = array![];

let mut i = 0;
while i != n_dims {
n_blocks
.append(
(*output_shape.at(i)
+ (*pads.at(i) + *pads.at(i + n_dims))
- *dilations.at(i) * (*kernel_shape.at(i) - 1)
- 1)
/ *strides.at(i)
+ 1
);
i += 1;
};

let block_size = prod(n_blocks.span(), 0);

assert(input_length == block_size, 'input_length != block_size');
}


fn get_indices(index: usize, shape: Span<usize>,) -> Array<usize> {
let mut i = index;
let mut res: Array<usize> = array![];
let mut k = shape.len() - 1;
while k != 0 {
let m = i % *shape.at(k);
res.append(m);
i -= m;
i /= *shape.at(k);
k -= 1;
};

let mut new_res: Array<usize> = array![];
new_res.append(i);
let mut i = shape.len() - 1;
while i != 0 {
new_res.append(*res.at(i - 1));
i -= 1;
};

new_res
}

fn is_out(ind: Span<usize>, shape: Span<usize>,) -> bool {
let mut n = 0;
let is_out = loop {
if n == ind.len() {
break false;
}
let s = *shape.at(n);
let i = *ind.at(n);
if i < 0 {
break true;
}
if i >= s {
break true;
}
n += 1;
};

is_out
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}

0 comments on commit f0e7287

Please sign in to comment.