Skip to content

Commit

Permalink
properly handle dimension limits
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisChourakiSonos committed Jan 30, 2025
1 parent 1c36cee commit 3561f3e
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 36 deletions.
83 changes: 72 additions & 11 deletions test-rt/suite-unit/src/bin_einsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,23 @@ use tract_ndarray::{ArrayD, Axis, Dimension};
use tract_core::ops::einsum::EinSum;
use tract_num_traits::{One, Zero};

#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
pub struct BinEinsumProblemParams {
pub force_unique_non_trivial_m_n: bool,
pub no_trivial_axes: bool,
pub force_max_one_iter_axis: bool,
pub max_dims: usize,
}

impl Default for BinEinsumProblemParams {
fn default() -> BinEinsumProblemParams {
BinEinsumProblemParams {
force_unique_non_trivial_m_n: false,
no_trivial_axes: false,
force_max_one_iter_axis: false,
max_dims: 8,
}
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -41,22 +53,71 @@ impl Arbitrary for BinEinsumProblem {
type Strategy = BoxedStrategy<BinEinsumProblem>;

fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
let m_n_axes_range = if params.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize };
let trivial_axes_range = if params.no_trivial_axes { 0..1usize } else { 0..2usize };
let iter_axes_range = if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize };
(m_n_axes_range.clone(), m_n_axes_range, iter_axes_range, trivial_axes_range.clone(), trivial_axes_range)
.prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| {
let m_axes: String = ('a'..).take(m_axes).collect();
let supp_m_n_axes_range =
if params.force_unique_non_trivial_m_n { 0..1usize } else { 0..2usize };
assert!(params.max_dims >= 3);
let remaining = params.max_dims - 3; // At least 1 m, and n

supp_m_n_axes_range
.clone()
.prop_flat_map(move |supp_m_axes| {
let remaining = remaining - supp_m_axes;
let n_axes_range = if remaining < supp_m_n_axes_range.end {
0..(remaining + 1)
} else {
supp_m_n_axes_range.clone()
};
let iter_axes_range =
if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize };
n_axes_range.prop_flat_map(move |supp_n_axes| {
let remaining = remaining - supp_n_axes;
let iter_axes_range = if remaining < iter_axes_range.end {
0..(remaining + 1)
} else {
iter_axes_range.clone()
};
iter_axes_range.clone().prop_flat_map(move |iter_axes| {
let remaining = remaining - iter_axes;
let trivial_m_n_axes_range =
if params.no_trivial_axes { 0..1usize } else { 0..2usize };
let trivial_m_axes_range = if remaining < trivial_m_n_axes_range.end {
0..(remaining + 1)
} else {
trivial_m_n_axes_range.clone()
};
trivial_m_axes_range.clone().prop_flat_map(move |trivial_m_axes| {
let remaining = remaining - trivial_m_axes;
let trivial_n_axes_range = if remaining < trivial_m_n_axes_range.end {
0..(remaining + 1)
} else {
trivial_m_n_axes_range.clone()
};
trivial_n_axes_range.clone().prop_flat_map(move |trivial_n_axes| {
Just((
supp_m_axes,
supp_n_axes,
iter_axes,
trivial_m_axes,
trivial_n_axes,
))
})
})
})
})
})
.prop_map(|(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| {
dbg!(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes);
let m_axes: String = ('b'..).take(supp_m_axes).collect();
let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect();
let trivial_n_axes: String = ('p'..).take(trivial_n_axes).collect();
let n_axes: String = ('g'..).take(n_axes).collect();
let n_axes: String = ('h'..).take(supp_n_axes).collect();
let iter_axes: String = ('w'..).take(iter_axes).collect();
let a_axes: Vec<char> =
(m_axes.clone() + &trivial_m_axes + &iter_axes + "k").chars().collect();
(m_axes.clone() + "a" + &trivial_m_axes + &iter_axes + "k").chars().collect();
let b_axes: Vec<char> =
(n_axes.clone() + &trivial_n_axes + &iter_axes + "k").chars().collect();
(n_axes.clone() + "g" + &trivial_n_axes + &iter_axes + "k").chars().collect();
let c_axes: Vec<char> =
(m_axes + &n_axes + &trivial_m_axes + &trivial_n_axes + &iter_axes)
(m_axes + &n_axes + "ag" + &trivial_m_axes + &trivial_n_axes + &iter_axes)
.chars()
.collect();
(Just(a_axes), Just(b_axes), Just(c_axes))
Expand Down
10 changes: 7 additions & 3 deletions test-rt/test-f16/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#[path="suite.rs"]
#[path = "suite.rs"]
mod suite;

fn main() {
suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::SuperApproximate");
suite::suite().test_runtime(
"tests",
"suite::suite()",
"runtime()",
"Approximation::SuperApproximate",
);
}

7 changes: 1 addition & 6 deletions test-rt/test-f16/suite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use infra::Test;
use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams};
use suite_unit::conv_q::{QConvProblem, QConvProblemParams};

pub fn suite() -> &'static infra::TestSuite {
Expand All @@ -20,9 +19,7 @@ fn mk_suite() -> infra::TestSuite {
QConvProblemParams::default(),
compatible_conv_q,
);
let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()};
unit.get_sub_mut("bin_einsum").add_arbitrary::<BinEinsumProblem>("proptest", einsum_params.clone());


infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}

Expand All @@ -32,8 +29,6 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
return true;
}
}

if t[0] == "bin_einsum" && t[1] == "proptest" { return true; }

let [section, _unit] = t else { return false };
["q_flavours"].contains(&&**section)
Expand Down
10 changes: 7 additions & 3 deletions test-rt/test-nnef-cycle/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#[path="suite.rs"]
#[path = "suite.rs"]
mod suite;

fn main() {
suite::suite().test_runtime("nnef_cycle", "suite::suite()", "runtime()", "Approximation::Approximate");
suite::suite().test_runtime(
"nnef_cycle",
"suite::suite()",
"runtime()",
"Approximation::Approximate",
);
}

5 changes: 0 additions & 5 deletions test-rt/test-nnef-cycle/suite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use infra::Test;
use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams};
use suite_unit::conv_q::{QConvProblem, QConvProblemParams};

pub fn suite() -> &'static infra::TestSuite {
Expand All @@ -22,9 +21,6 @@ fn mk_suite() -> infra::TestSuite {
compatible_conv_q,
);

let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()};
unit.get_sub_mut("bin_einsum").add_arbitrary::<BinEinsumProblem>("proptest", einsum_params.clone());

infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}

Expand Down Expand Up @@ -88,7 +84,6 @@ fn ignore_unit(t: &[String], tc: &dyn Test) -> bool {
if let Some(qcp) = tc.downcast_ref::<QConvProblem>() {
return !compatible_conv_q(qcp);
}
if t[0] == "bin_einsum" && t[1] == "proptest" { return true; }
false
}

Expand Down
10 changes: 7 additions & 3 deletions test-rt/test-tflite/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#[path="suite.rs"]
#[path = "suite.rs"]
mod suite;

fn main() {
suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::Approximate");
suite::suite().test_runtime(
"tests",
"suite::suite()",
"runtime()",
"Approximation::Approximate",
);
}

15 changes: 10 additions & 5 deletions test-rt/test-tflite/suite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ fn mk_suite() -> infra::TestSuite {
compatible_conv_q,
);

let einsum_params = BinEinsumProblemParams { force_unique_non_trivial_m_n: true, no_trivial_axes: true, ..BinEinsumProblemParams::default()};
unit.get_sub_mut("bin_einsum").add_arbitrary::<BinEinsumProblem>("proptest", einsum_params.clone());
let einsum_params = BinEinsumProblemParams { max_dims: 4, ..BinEinsumProblemParams::default() };
unit.get_sub_mut("bin_einsum")
.add_arbitrary::<BinEinsumProblem>("proptest", einsum_params.clone());
infra::TestSuite::default().with("onnx", onnx).with("unit", unit)
}

Expand Down Expand Up @@ -104,7 +105,8 @@ fn ignore_onnx(t: &[String]) -> bool {
test_thresholdrelu
",
);
let excluded = patterns("
let excluded = patterns(
"
test_slice_start_out_of_bounds
test_Conv1d_groups
test_Conv2d_groups
Expand All @@ -122,7 +124,8 @@ fn ignore_onnx(t: &[String]) -> bool {
pool_2d_same_lower
test_cosh.*
test_sinh.*
");
",
);
!included.iter().any(|pat| pat.is_match(name)) || excluded.iter().any(|pat| pat.is_match(name))
}

Expand Down Expand Up @@ -151,7 +154,9 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool {
}
}

if t[0] == "bin_einsum" && t[1] == "proptest" { return true; }
if t[0] == "bin_einsum" && t[1] == "proptest" {
return true;
}

let [section, _unit] = t else { return false };
["deconv", "q_flavours", "q_binary", "q_elmwise"].contains(&&**section)
Expand Down

0 comments on commit 3561f3e

Please sign in to comment.