Skip to content

Commit

Permalink
Merge pull request #523 from robertknight/refactor-conv-error-tests
Browse files Browse the repository at this point in the history
Refactor tests for invalid Conv inputs, fix panic if group count is zero
  • Loading branch information
robertknight authored Jan 8, 2025
2 parents 4bf7eb8 + 6947c38 commit 6323411
Showing 1 changed file with 83 additions and 40 deletions.
123 changes: 83 additions & 40 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ where
));
}

if groups == 0 {
return Err(OpError::InvalidValue("Group count must be > 0"));
}

let out_channels_per_group = out_c / groups;
let in_channels_per_group = in_c / groups;

Expand All @@ -201,7 +205,7 @@ where
));
}

if groups == 0 || in_c % groups != 0 || out_c % groups != 0 {
if in_c % groups != 0 || out_c % groups != 0 {
return Err(OpError::IncompatibleInputShapes(
"Input channels and output channels must be divisible by group count",
));
Expand Down Expand Up @@ -1182,51 +1186,90 @@ mod tests {
}

#[test]
fn test_conv_input_too_small() {
fn test_conv_invalid() {
let mut rng = XorShiftRng::new(1234);
let input = Tensor::rand(&[1, 1, 2, 2], &mut rng);
let kernel = Tensor::rand(&[1, 1, 3, 3], &mut rng);

let pool = new_pool();
let result = conv(
&pool,
input.view(),
kernel.view(),
None,
[0; 4].into(),
1, /* groups */
&[1, 1], /* stride */
&[1, 1], /* dilations */
);

assert_eq!(
result.err(),
Some(OpError::InvalidValue("Input too small for kernel size"))
);
}
struct Case<'a> {
input: Tensor<f32>,
kernel: Tensor<f32>,
strides: &'a [usize],
groups: usize,
dilations: &'a [usize],
expected: OpError,
}

#[test]
fn test_conv_zero_stride() {
let mut rng = XorShiftRng::new(1234);
let input = Tensor::rand(&[1, 1, 2, 2], &mut rng);
let kernel = Tensor::rand(&[1, 1, 2, 2], &mut rng);
let cases = [
// Input too small
Case {
input: Tensor::rand(&[1, 1, 2, 2], &mut rng),
kernel: Tensor::rand(&[1, 1, 3, 3], &mut rng),
strides: &[1, 1],
dilations: &[1, 1],
groups: 1,
expected: OpError::InvalidValue("Input too small for kernel size"),
},
// Zero stride
Case {
input: Tensor::rand(&[1, 1, 2, 2], &mut rng),
kernel: Tensor::rand(&[1, 1, 2, 2], &mut rng),
strides: &[0, 0],
dilations: &[1, 1],
groups: 1,
expected: OpError::InvalidValue("Strides must be > 0"),
},
// Unsupported stride count
Case {
input: Tensor::rand(&[1, 1, 2, 2], &mut rng),
kernel: Tensor::rand(&[1, 1, 2, 2], &mut rng),
strides: &[1, 1, 1],
dilations: &[1, 1],
groups: 1,
expected: OpError::InvalidValue("expected 2 stride values"),
},
// Unsupported dilation count
Case {
input: Tensor::rand(&[1, 1, 2, 2], &mut rng),
kernel: Tensor::rand(&[1, 1, 2, 2], &mut rng),
strides: &[1, 1],
dilations: &[1, 1, 1],
groups: 1,
expected: OpError::InvalidValue("expected 2 dilation values"),
},
// Zero groups
Case {
input: Tensor::rand(&[1, 1, 2, 2], &mut rng),
kernel: Tensor::rand(&[1, 1, 2, 2], &mut rng),
strides: &[1, 1],
dilations: &[1, 1],
groups: 0,
expected: OpError::InvalidValue("Group count must be > 0"),
},
];

let pool = new_pool();
let result = conv(
&pool,
input.view(),
kernel.view(),
None,
[0; 4].into(),
1, /* groups */
&[0, 0], /* stride */
&[1, 1], /* dilations */
);

assert_eq!(
result.err(),
Some(OpError::InvalidValue("Strides must be > 0"))
);
for Case {
input,
kernel,
strides,
dilations,
groups,
expected,
} in cases
{
let result = conv(
&pool,
input.view(),
kernel.view(),
None,
[0; 4].into(),
groups,
strides,
dilations,
);

assert_eq!(result.err(), Some(expected));
}
}

#[test]
Expand Down

0 comments on commit 6323411

Please sign in to comment.