Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BinEinsum to suite-unit #1630

Merged
merged 9 commits into from
Jan 31, 2025
Prev Previous commit
Next Next commit
force 1 iter axis for tflite
LouisChourakiSonos committed Jan 30, 2025
commit c563972a34d9932b4d5e765e4eabd98f140acacc
5 changes: 3 additions & 2 deletions test-rt/suite-unit/src/bin_einsum.rs
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ use tract_num_traits::{One, Zero};
pub struct BinEinsumProblemParams {
pub force_unique_non_trivial_m_n: bool,
pub no_trivial_axes: bool,
pub force_max_one_iter_axis: bool,
}

#[derive(Clone)]
@@ -42,8 +43,8 @@ impl Arbitrary for BinEinsumProblem {
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
let m_n_axes_range = if params.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize };
let trivial_axes_range = if params.no_trivial_axes { 0..1usize } else { 0..2usize };

(m_n_axes_range.clone(), m_n_axes_range, 0..3usize, trivial_axes_range.clone(), trivial_axes_range)
let iter_axes_range = if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize };
(m_n_axes_range.clone(), m_n_axes_range, iter_axes_range, trivial_axes_range.clone(), trivial_axes_range)
.prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| {
let m_axes: String = ('a'..).take(m_axes).collect();
let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect();
6 changes: 3 additions & 3 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ fn mk_suite() -> infra::TestSuite {
compatible_conv_q,
);

let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()};
let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, force_max_one_iter_axis: true, ..BinEinsumProblemParams::default()};
unit.get_sub_mut("bin_einsum").add_arbitrary::<BinEinsumProblem>("proptest", einsum_params.clone());
infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}
@@ -150,9 +150,9 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
return true;
}
}

if t[0] == "bin_einsum" && t[1] == "proptest" { return true; }

let [section, _unit] = t else { return false };
["deconv", "q_flavours", "q_binary", "q_elmwise"].contains(&&**section)
}