From 006f45375b0a2cc1168bf2f0d63675baa10b7be4 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 16 Jan 2025 16:03:20 +0100 Subject: [PATCH 1/9] first proptest implem for matmul --- core/src/ops/einsum/as_matmul.rs | 6 +- core/src/ops/einsum/optimize.rs | 38 +++- test-rt/suite-unit/src/bin_einsum.rs | 313 +++++++++++++++++++++++++++ test-rt/suite-unit/src/lib.rs | 2 + 4 files changed, 354 insertions(+), 5 deletions(-) create mode 100644 test-rt/suite-unit/src/bin_einsum.rs diff --git a/core/src/ops/einsum/as_matmul.rs b/core/src/ops/einsum/as_matmul.rs index 0d9300d904..2ca8e7ae0b 100644 --- a/core/src/ops/einsum/as_matmul.rs +++ b/core/src/ops/einsum/as_matmul.rs @@ -37,9 +37,9 @@ fn einsum_rules( let op = match ensure_mkn_axes(op, model, node).context("Figuring out m, k and n axes")? { AxesOrPatch::Annotated(op) => op, AxesOrPatch::Patch(p) => return Ok(Some(p)), - AxesOrPatch::NotAMatMul(axis) => { - bail!("{} is not a matmul because of axis {}", op.axes, axis.repr) - } + AxesOrPatch::NotAMatMul(axes) => { + bail!("{} is not a matmul because of axis {}", op.axes, axes.iter().map(|a| a.repr).join(", ") ) + } }; let prefix: String = op .axes diff --git a/core/src/ops/einsum/optimize.rs b/core/src/ops/einsum/optimize.rs index b8182cafc2..edfe32d340 100644 --- a/core/src/ops/einsum/optimize.rs +++ b/core/src/ops/einsum/optimize.rs @@ -18,7 +18,7 @@ use crate::ops::nn::{Reduce, Reducer}; pub enum AxesOrPatch<'a> { Annotated(EinSumAnnotatedAsMatMul<'a>), Patch(TypedModelPatch), - NotAMatMul(&'a Axis), + NotAMatMul(Vec<&'a Axis>), } pub struct EinSumAnnotatedAsMatMul<'a> { @@ -69,6 +69,7 @@ pub(crate) fn optimize( { return Ok(None); } + let annotated = match ensure_mkn_axes(op, model, node)? { AxesOrPatch::Annotated(op) => op, AxesOrPatch::Patch(p) => return Ok(Some(p)), @@ -114,6 +115,22 @@ pub(crate) fn ensure_mkn_axes<'a>( let Some(k_axis) = k_axis else { return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?)); }; + + let non_trivial_m_axes = op.axes.iter_all_axes().filter(|a| { + a.inputs[0].len() == 1 + && a.outputs[0].len() == 1 + && a.inputs[1].len() == 0 + && !input_shapes[0][a.inputs[0][0]].is_one() + }).collect_vec(); + //let mut m_axes_pos = non_trivial_m_axes.iter().map(|axis| { axis.inputs[0][0] }).collect_vec(); + //m_axes_pos.sort(); + // + //let consecutive_m_axes = m_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); + // + if non_trivial_m_axes.len() > 1 { + return Ok(AxesOrPatch::NotAMatMul(non_trivial_m_axes)) + } + let m_axis = op .axes .iter_all_axes() @@ -126,6 +143,22 @@ pub(crate) fn ensure_mkn_axes<'a>( let Some(m_axis) = m_axis else { return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, false, &[k_axis])?)); }; + + let non_trivial_n_axes = op.axes.iter_all_axes().filter(|a| { + a.inputs[1].len() == 1 + && a.outputs[0].len() == 1 + && a.inputs[0].len() == 0 + && !input_shapes[1][a.inputs[1][0]].is_one() + }).collect_vec(); + + //let mut n_axes_pos = non_trivial_n_axes.iter().map(|axis| { axis.inputs[1][0] }).collect_vec(); + //n_axes_pos.sort(); +// + //let consecutive_n_axes = n_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); + if non_trivial_n_axes.len() > 1 { + return Ok(AxesOrPatch::NotAMatMul(non_trivial_n_axes)) + } + let n_axis = op .axes .iter_all_axes() @@ -133,6 +166,7 @@ pub(crate) fn ensure_mkn_axes<'a>( (a.inputs[0].len() == 0 || input_shapes[0][a.inputs[0][0]].is_one()) && a.inputs[1].len() == 1 && a.outputs[0].len() == 1 + && *a != m_axis }) .max_by_key(|a| output_shape[a.outputs[0][0]].as_i64().unwrap_or(i64::MAX)); let Some(n_axis) = n_axis else { @@ -152,7 +186,7 @@ pub(crate) fn ensure_mkn_axes<'a>( axis.inputs[1].first().map(|pos| &input_shapes[1][*pos]).unwrap_or(&one) != &one; let in_out = axis.outputs[0].first().map(|pos| &output_shape[*pos]).unwrap_or(&one) != &one; if (in_left ^ in_right) && !in_out { - return Ok(AxesOrPatch::NotAMatMul(axis)); + return Ok(AxesOrPatch::NotAMatMul(vec![axis])); } } let m = input_shapes[0][m_axis.inputs[0][0]].clone(); diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs new file mode 100644 index 0000000000..a7e2719c9c --- /dev/null +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -0,0 +1,313 @@ +use std::{fmt, ops::Mul}; + +use infra::{Test, TestResult, TestSuite}; +use proptest::prelude::*; +use proptest::strategy::BoxedStrategy; +use tract_core::internal::*; +use tract_ndarray::{ArrayD, Axis, Dimension}; + +use tract_core::ops::einsum::EinSum; +use tract_num_traits::{One, Zero}; + +#[derive(Debug, Clone, Default)] +pub struct BinEinsumProblemParams { + pub no_iter_axes: bool, +} + +#[derive(Clone)] +pub struct BinEinsumProblem { + expr: AxesMapping, + a: Tensor, + b: Tensor, + a_constant: bool, + b_constant: bool, + unicast_add_constant: Option, +} + +impl std::fmt::Debug for BinEinsumProblem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{} A:{:?} B:{:?} a_constant:{:?} b_constant:{:?} unicast_add_constant:{:?}", + self.expr, self.a, self.b, self.a_constant, self.b_constant, self.unicast_add_constant + ) + } +} + +impl Arbitrary for BinEinsumProblem { + type Parameters = BinEinsumProblemParams; + type Strategy = BoxedStrategy; + + fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { + (1..2usize, 1..2usize, 0..3usize, 0..2usize, 0..2usize) + .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { + let m_axes: String = ('a'..).take(m_axes).collect(); + let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); + let trivial_n_axes: String = ('p'..).take(trivial_n_axes).collect(); + let n_axes: String = ('g'..).take(n_axes).collect(); + let iter_axes: String = ('w'..).take(iter_axes).collect(); + let a_axes: Vec = (m_axes.clone() + &trivial_m_axes + &iter_axes + "k").chars().collect(); + let b_axes: Vec = (n_axes.clone() + &trivial_n_axes + &iter_axes + "k").chars().collect(); + let c_axes: Vec = (m_axes + &n_axes + &trivial_m_axes + &trivial_n_axes + &iter_axes).chars().collect(); + (Just(a_axes), Just(b_axes), Just(c_axes)) + }) + .prop_flat_map(|(a, b, c)| (a.prop_shuffle(), b.prop_shuffle(), c.prop_shuffle())) + .prop_map(|(a, b, c)| { + let a: String = a.into_iter().collect(); + let b: String = b.into_iter().collect(); + let c: String = c.into_iter().collect(); + let expr: AxesMapping = format!("{a},{b}->{c}").parse().unwrap(); + eprintln!("{expr}"); + expr + }) + .prop_flat_map(|expr| { + let dims = expr.iter_all_axes().count(); + (Just(expr), proptest::collection::vec(1..4usize, dims..=dims)) + }) + .prop_flat_map(|(expr, axis_dims)| { + let shape_a: TVec = expr + .axes(InOut::In(0)) + .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) + .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .collect(); + let shape_b: TVec = expr + .axes(InOut::In(1)) + .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) + .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .collect(); + let shape_output: TVec = expr + .axes(InOut::Out(0)) + .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) + .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .collect(); + let unicast_add_constant = proptest::option::of(tensor(&shape_output)); + (Just(expr), tensor(&shape_a), tensor(&shape_b), 0..3usize, unicast_add_constant) + }) + .prop_map(|(expr, a, b, a_b_constant, unicast_add_constant)| { + let a_constant = (a_b_constant & 0x1) != 0; + let b_constant = (a_b_constant & 0x2) != 0; + BinEinsumProblem { expr, a, b, a_constant, b_constant, unicast_add_constant } + }) + .boxed() + } +} + +pub fn tensor(shape: &[usize]) -> BoxedStrategy { + let len = shape.iter().product::(); + let shape: Vec = shape.into(); + proptest::collection::vec((-10i8..=10i8).prop_map(|i| i as f32), len..=len) + .prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor()) + .boxed() +} + +impl BinEinsumProblem { + fn tract(&self) -> TractResult { + let mut model = TypedModel::default(); + let a = if self.a_constant { + model.add_const("a", self.a.clone())? + } else { + model.add_source("a", TypedFact::shape_and_dt_of(&self.a))? + }; + let b = if self.b_constant { + model.add_const("b", self.b.clone())? + } else { + model.add_source("b", TypedFact::shape_and_dt_of(&self.b))? + }; + + let output = model.wire_node( + "einsum", + EinSum { axes: self.expr.clone(), operating_dt: f32::datum_type(), q_params: None }, + &[a, b], + )?; + + //if let Some(c) = &self.unicast_add_constant { + // let c = model.add_const("c", c.clone())?; + // output = model.wire_node("add", tract_core::ops::math::add(), &[output[0], c])?; + //} + + model.set_output_outlets(&output)?; + + //let test = model.node_by_name("einsum")?.op.as_op().downcast_ref::().unwrap(); + + model = model.into_decluttered()?; + //let test1 = model.node_by_name("einsum")?.op.as_op().downcast_ref::().unwrap(); + //dbg!(&test1.axes); + Ok(model) + } + + fn output_shape(&self) -> TVec { + self.expr + .axes(InOut::Out(0)) + .map(|axis| { + let dim_in_a = axis.inputs[0].get(0).map(|pos| self.a.shape()[*pos]).unwrap_or(1); + let dim_in_b = axis.inputs[1].get(0).map(|pos| self.b.shape()[*pos]).unwrap_or(1); + dim_in_a.max(dim_in_b) + }) + .collect() + } + + fn reference>(&self) -> ArrayD { + let output_shape = self.output_shape(); + + let a = self.a.cast_to::().unwrap(); + let b = self.b.cast_to::().unwrap(); + + let a = a.to_array_view::().unwrap(); + let b = b.to_array_view::().unwrap(); + + let k_axes: TVec<_> = self + .expr + .iter_all_axes() + .filter(|axis| { + axis.outputs[0].len() == 0 && axis.inputs[0].len() == 1 && axis.inputs[1].len() == 1 + }) + .collect(); + + let summing_shape: TVec = k_axes + .iter() + .map(|axis| { + let dim_in_a = axis.inputs[0].get(0).map(|pos| self.a.shape()[*pos]).unwrap_or(1); + let dim_in_b = axis.inputs[1].get(0).map(|pos| self.b.shape()[*pos]).unwrap_or(1); + dim_in_a.max(dim_in_b) + }) + .collect(); + + let output = tract_ndarray::ArrayD::::from_shape_fn(&*output_shape, |coords| { + let coords = coords.as_array_view(); + let mut a = a.clone(); + let mut b = b.clone(); + for (axis, x) in self.expr.axes(InOut::Out(0)).zip(coords.iter()) { + if let Some(pos) = axis.inputs[0].get(0) { + a.collapse_axis(Axis(*pos), if a.shape()[*pos] > 1 { *x } else { 0 }); + } + + if let Some(pos) = axis.inputs[1].get(0) { + b.collapse_axis(Axis(*pos), if b.shape()[*pos] > 1 { *x } else { 0 }); + } + } + + let mut sum: Acc = Acc::zero(); + for sum_coords in tract_ndarray::indices(&*summing_shape) { + let mut a = a.clone(); + let mut b = b.clone(); + + let sum_coords = sum_coords.as_array_view(); + for (axis, x) in k_axes.iter().zip(sum_coords) { + a.collapse_axis(Axis(axis.inputs[0][0]), *x); + b.collapse_axis(Axis(axis.inputs[1][0]), *x); + } + + let product = *a.iter().next().unwrap() * *b.iter().next().unwrap(); + sum = sum + product; + } + sum + }); + output + } +} + +impl Test for BinEinsumProblem { + fn run_with_approx( + &self, + _suite: &str, + id: &str, + runtime: &dyn Runtime, + approx: Approximation, + ) -> TestResult { + let reference = self.reference::().into_tensor(); + //dbg!(&reference); + let mut model = self.tract()?; + + model.properties.insert("tract-rt-test.id".to_string(), rctensor0(id.to_string())); + let mut inputs = tvec![]; + if !self.a_constant { + inputs.push(self.a.clone().into()); + } + if !self.b_constant { + inputs.push(self.b.clone().into()); + } + let mut output = runtime.prepare(model)?.run(inputs)?; + let output = output.remove(0).into_tensor(); + output.close_enough(&reference, approx) + } +} + +pub fn suite() -> TractResult { + let mut suite = TestSuite::default(); + + suite.add_arbitrary::("proptest", BinEinsumProblemParams::default()); + + suite.add( + "unicast_0", + BinEinsumProblem { + expr: "ak,gk->ag".parse().unwrap(), + a: Tensor::zero::(&[1, 2]).unwrap(), + b: Tensor::zero::(&[1, 2]).unwrap(), + a_constant: false, + b_constant: false, + unicast_add_constant: Some(Tensor::zero::(&[1, 1]).unwrap()), + }, + ); + + suite.add( + "unicast_1", + BinEinsumProblem { + expr: "ak,gk->ag".parse().unwrap(), + a: Tensor::zero::(&[2, 1]).unwrap(), + b: Tensor::zero::(&[2, 1]).unwrap(), + a_constant: false, + b_constant: false, + unicast_add_constant: Some(tensor2(&[[0f32, 0.], [0., 1.]])), + }, + ); + + suite.add( + "unicast_2", + BinEinsumProblem { + expr: "abk,gk->abg".parse().unwrap(), + a: Tensor::zero::(&[2, 2, 1]).unwrap(), + b: Tensor::zero::(&[1, 1]).unwrap(), + a_constant: false, + b_constant: false, + unicast_add_constant: Some(tensor3(&[[[0f32], [0.]], [[0.], [1.]]])), + }, + ); + + suite.add( + "trivial_0", + BinEinsumProblem { + expr: "ak,gk->ag".parse().unwrap(), + a: tensor2(&[[1f32]]), + b: tensor2(&[[0f32], [1f32]]), + a_constant: false, + b_constant: false, + unicast_add_constant: None, + }, + ); + + suite.add( + "trivial_1", + BinEinsumProblem { + expr: "akb,gk->gba".parse().unwrap(), + a: tensor3(&[[[0f32], [0f32]]]), + b: tensor2(&[[0f32, 0f32]]), + a_constant: true, + b_constant: false, + unicast_add_constant: None, + }, + ); + + // TODO: fix ensure_mkn() to handle multiple n axes + //suite.add( + // "multiple_n_axes", + // BinEinsumProblem { + // expr: "kwa,gkwh->gahw".parse().unwrap(), + // a: Tensor::zero::(&[1, 2, 1]).unwrap(), + // b: Tensor::zero::(&[2, 1, 2, 2]).unwrap(), + // a_constant: false, + // b_constant: false, + // unicast_add_constant: None, + // } + //); + Ok(suite) +} diff --git a/test-rt/suite-unit/src/lib.rs b/test-rt/suite-unit/src/lib.rs index 88cf1a3d6f..b8ddcaf36d 100644 --- a/test-rt/suite-unit/src/lib.rs +++ b/test-rt/suite-unit/src/lib.rs @@ -6,6 +6,7 @@ use tract_core::ops::cnn::*; use tract_core::ops::nn::*; use tract_ndarray::*; +pub mod bin_einsum; pub mod conv_f32; pub mod conv_q; pub mod deconv; @@ -18,6 +19,7 @@ pub mod slice; pub fn suite() -> TractResult { let mut suite: TestSuite = Default::default(); + suite.add("bin_einsum", bin_einsum::suite()?); suite.add("conv_f32", conv_f32::suite()?); suite.add("conv_q", conv_q::suite()?); suite.add("deconv", deconv::suite()?); From c95a346c7d4f936f37a7d61252b2aef8cce7d0c2 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 16:03:31 +0100 Subject: [PATCH 2/9] ignore some configurations for metal conv and einsum tests --- core/src/ops/einsum/optimize.rs | 23 ++++++------ test-rt/suite-unit/src/bin_einsum.rs | 56 ++++++++++++++++++---------- test-rt/suite-unit/src/conv_f32.rs | 19 ++++++++-- test-rt/suite-unit/src/conv_q.rs | 2 +- test-rt/suite-unit/src/deconv.rs | 3 +- test-rt/suite-unit/src/q_binary.rs | 2 +- test-rt/test-metal/suite.rs | 20 +++++++++- 7 files changed, 85 insertions(+), 40 deletions(-) diff --git a/core/src/ops/einsum/optimize.rs b/core/src/ops/einsum/optimize.rs index edfe32d340..cccb3dbf33 100644 --- a/core/src/ops/einsum/optimize.rs +++ b/core/src/ops/einsum/optimize.rs @@ -122,12 +122,13 @@ pub(crate) fn ensure_mkn_axes<'a>( && a.inputs[1].len() == 0 && !input_shapes[0][a.inputs[0][0]].is_one() }).collect_vec(); - //let mut m_axes_pos = non_trivial_m_axes.iter().map(|axis| { axis.inputs[0][0] }).collect_vec(); - //m_axes_pos.sort(); - // - //let consecutive_m_axes = m_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); - // - if non_trivial_m_axes.len() > 1 { + + let mut m_axes_pos = non_trivial_m_axes.iter().map(|axis| { axis.inputs[0][0] }).collect_vec(); + m_axes_pos.sort(); + + let consecutive_m_axes = m_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); + + if non_trivial_m_axes.len() > 1 && !consecutive_m_axes { return Ok(AxesOrPatch::NotAMatMul(non_trivial_m_axes)) } @@ -151,11 +152,11 @@ pub(crate) fn ensure_mkn_axes<'a>( && !input_shapes[1][a.inputs[1][0]].is_one() }).collect_vec(); - //let mut n_axes_pos = non_trivial_n_axes.iter().map(|axis| { axis.inputs[1][0] }).collect_vec(); - //n_axes_pos.sort(); -// - //let consecutive_n_axes = n_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); - if non_trivial_n_axes.len() > 1 { + let mut n_axes_pos = non_trivial_n_axes.iter().map(|axis| { axis.inputs[1][0] }).collect_vec(); + n_axes_pos.sort(); + + let consecutive_n_axes = n_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); + if non_trivial_n_axes.len() > 1 && !consecutive_n_axes { return Ok(AxesOrPatch::NotAMatMul(non_trivial_n_axes)) } diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index a7e2719c9c..e8b5ecf7c5 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -11,7 +11,7 @@ use tract_num_traits::{One, Zero}; #[derive(Debug, Clone, Default)] pub struct BinEinsumProblemParams { - pub no_iter_axes: bool, + pub force_unique_non_trivial_m_n: bool, } #[derive(Clone)] @@ -38,17 +38,23 @@ impl Arbitrary for BinEinsumProblem { type Parameters = BinEinsumProblemParams; type Strategy = BoxedStrategy; - fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { - (1..2usize, 1..2usize, 0..3usize, 0..2usize, 0..2usize) + fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { + let m_n_axes_range = if args.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; + (m_n_axes_range.clone(), m_n_axes_range, 0..3usize, 0..2usize, 0..2usize) .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { let m_axes: String = ('a'..).take(m_axes).collect(); let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); let trivial_n_axes: String = ('p'..).take(trivial_n_axes).collect(); let n_axes: String = ('g'..).take(n_axes).collect(); let iter_axes: String = ('w'..).take(iter_axes).collect(); - let a_axes: Vec = (m_axes.clone() + &trivial_m_axes + &iter_axes + "k").chars().collect(); - let b_axes: Vec = (n_axes.clone() + &trivial_n_axes + &iter_axes + "k").chars().collect(); - let c_axes: Vec = (m_axes + &n_axes + &trivial_m_axes + &trivial_n_axes + &iter_axes).chars().collect(); + let a_axes: Vec = + (m_axes.clone() + &trivial_m_axes + &iter_axes + "k").chars().collect(); + let b_axes: Vec = + (n_axes.clone() + &trivial_n_axes + &iter_axes + "k").chars().collect(); + let c_axes: Vec = + (m_axes + &n_axes + &trivial_m_axes + &trivial_n_axes + &iter_axes) + .chars() + .collect(); (Just(a_axes), Just(b_axes), Just(c_axes)) }) .prop_flat_map(|(a, b, c)| (a.prop_shuffle(), b.prop_shuffle(), c.prop_shuffle())) @@ -67,18 +73,30 @@ impl Arbitrary for BinEinsumProblem { .prop_flat_map(|(expr, axis_dims)| { let shape_a: TVec = expr .axes(InOut::In(0)) - .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) - .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .map(|axis| { + expr.iter_all_axes() + .position(|x| (x == axis) && !('m'..='v').contains(&axis.repr)) + .map(|dim| axis_dims[dim]) + .unwrap_or(1) + }) .collect(); let shape_b: TVec = expr .axes(InOut::In(1)) - .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) - .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .map(|axis| { + expr.iter_all_axes() + .position(|x| (x == axis) && !('m'..='v').contains(&axis.repr)) + .map(|dim| axis_dims[dim]) + .unwrap_or(1) + }) .collect(); let shape_output: TVec = expr .axes(InOut::Out(0)) - .map(|axis| {if !('m'..='v').contains(&axis.repr) { expr.iter_all_axes().position(|x| x == axis).unwrap() } else {1000} }) - .map(|dim| {if dim != 1000 {axis_dims[dim]} else { 1 } }) + .map(|axis| { + expr.iter_all_axes() + .position(|x| (x == axis) && !('m'..='v').contains(&axis.repr)) + .map(|dim| axis_dims[dim]) + .unwrap_or(1) + }) .collect(); let unicast_add_constant = proptest::option::of(tensor(&shape_output)); (Just(expr), tensor(&shape_a), tensor(&shape_b), 0..3usize, unicast_add_constant) @@ -139,8 +157,8 @@ impl BinEinsumProblem { self.expr .axes(InOut::Out(0)) .map(|axis| { - let dim_in_a = axis.inputs[0].get(0).map(|pos| self.a.shape()[*pos]).unwrap_or(1); - let dim_in_b = axis.inputs[1].get(0).map(|pos| self.b.shape()[*pos]).unwrap_or(1); + let dim_in_a = axis.inputs[0].first().map(|pos| self.a.shape()[*pos]).unwrap_or(1); + let dim_in_b = axis.inputs[1].first().map(|pos| self.b.shape()[*pos]).unwrap_or(1); dim_in_a.max(dim_in_b) }) .collect() @@ -159,15 +177,15 @@ impl BinEinsumProblem { .expr .iter_all_axes() .filter(|axis| { - axis.outputs[0].len() == 0 && axis.inputs[0].len() == 1 && axis.inputs[1].len() == 1 + axis.outputs[0].is_empty() && axis.inputs[0].len() == 1 && axis.inputs[1].len() == 1 }) .collect(); let summing_shape: TVec = k_axes .iter() .map(|axis| { - let dim_in_a = axis.inputs[0].get(0).map(|pos| self.a.shape()[*pos]).unwrap_or(1); - let dim_in_b = axis.inputs[1].get(0).map(|pos| self.b.shape()[*pos]).unwrap_or(1); + let dim_in_a = axis.inputs[0].first().map(|pos| self.a.shape()[*pos]).unwrap_or(1); + let dim_in_b = axis.inputs[1].first().map(|pos| self.b.shape()[*pos]).unwrap_or(1); dim_in_a.max(dim_in_b) }) .collect(); @@ -177,11 +195,11 @@ impl BinEinsumProblem { let mut a = a.clone(); let mut b = b.clone(); for (axis, x) in self.expr.axes(InOut::Out(0)).zip(coords.iter()) { - if let Some(pos) = axis.inputs[0].get(0) { + if let Some(pos) = axis.inputs[0].first() { a.collapse_axis(Axis(*pos), if a.shape()[*pos] > 1 { *x } else { 0 }); } - if let Some(pos) = axis.inputs[1].get(0) { + if let Some(pos) = axis.inputs[1].first() { b.collapse_axis(Axis(*pos), if b.shape()[*pos] > 1 { *x } else { 0 }); } } diff --git a/test-rt/suite-unit/src/conv_f32.rs b/test-rt/suite-unit/src/conv_f32.rs index b050b89a24..dcd713b052 100644 --- a/test-rt/suite-unit/src/conv_f32.rs +++ b/test-rt/suite-unit/src/conv_f32.rs @@ -10,6 +10,7 @@ pub struct ConvProblemParams { pub no_group: bool, pub no_arbitrary_grouping: bool, pub geo_rank: Option>, + pub no_batch: bool, } #[derive(Debug, Clone)] @@ -200,19 +201,29 @@ impl Arbitrary for ConvProblem { type Parameters = ConvProblemParams; type Strategy = BoxedStrategy; fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + let batch_range = if params.no_batch { 1usize..=1 } else { 1usize..=3 }; let geo_rank = params.geo_rank.unwrap_or(1..4); ( data_format(), kernel_format(), prop_oneof![Just(PaddingSpec::Valid), Just(PaddingSpec::SameUpper)], - 1usize..=3, + batch_range, 1usize..=4, 1usize..=4, 1usize..=(if params.no_group { 1 } else { 3 }), geo_rank.prop_flat_map(shapes), ) .prop_flat_map( - move |(df, kf, pad, n, mut ci0, mut co0, group, (mut ker_shape, data_shape))| { + move |( + df, + kf, + pad, + batch, + mut ci0, + mut co0, + group, + (mut ker_shape, data_shape), + )| { // FIXME in HWIO order, only regular and depthwise are supported if params.no_arbitrary_grouping && group > 1 { ci0 = 1; @@ -221,7 +232,7 @@ impl Arbitrary for ConvProblem { if kf == KernelFormat::HWIO && group > 1 { ci0 = 1; } - let shape_in = df.from_n_c_hw(n, ci0 * group, data_shape).unwrap(); + let shape_in = df.from_n_c_hw(batch, ci0 * group, data_shape).unwrap(); let data_in = tensor(&*shape_in.shape); match kf { KernelFormat::HWIO => { @@ -268,7 +279,7 @@ impl Test for ConvProblem { approx: Approximation, ) -> TestResult { let reference = self.reference().into_tensor(); - dbg!(&reference); + // dbg!(&reference); let mut model = self.tract()?; // dbg!(&model); model.declutter()?; diff --git a/test-rt/suite-unit/src/conv_q.rs b/test-rt/suite-unit/src/conv_q.rs index e31088eada..50c2ee701b 100644 --- a/test-rt/suite-unit/src/conv_q.rs +++ b/test-rt/suite-unit/src/conv_q.rs @@ -1230,7 +1230,7 @@ pub fn suite() -> TractResult { co: 2, kernel_format: OIHW, group: 2, - kernel: tensor4(&[[[[1i8]]],[[[0i8]]]]), + kernel: tensor4(&[[[[1i8]]], [[[0i8]]]]), bias: None, data: Tensor::zero::(&[2, 4, 4, 2]).unwrap(), qp, diff --git a/test-rt/suite-unit/src/deconv.rs b/test-rt/suite-unit/src/deconv.rs index 369ecadf30..2f085a5457 100644 --- a/test-rt/suite-unit/src/deconv.rs +++ b/test-rt/suite-unit/src/deconv.rs @@ -145,8 +145,7 @@ impl DeconvProblem { self.kernel_format.input_channels(self.kernel.shape(), self.group).into_owned(), self.kernel_format.output_channels(self.kernel.shape(), self.group).into_owned(), ); - let op = - Deconv::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group); + let op = Deconv::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group); Ok(op) } diff --git a/test-rt/suite-unit/src/q_binary.rs b/test-rt/suite-unit/src/q_binary.rs index f59bdc8e95..4e0c5f3a34 100644 --- a/test-rt/suite-unit/src/q_binary.rs +++ b/test-rt/suite-unit/src/q_binary.rs @@ -268,7 +268,7 @@ pub fn suite() -> TractResult { c_dt: qu8_dt(1, 0.5), }, ); - + suite.add( "bug_aligned_dt_0", QBinaryOpProblem { diff --git a/test-rt/test-metal/suite.rs b/test-rt/test-metal/suite.rs index edd9c4ab95..c7a3b18a72 100644 --- a/test-rt/test-metal/suite.rs +++ b/test-rt/test-metal/suite.rs @@ -1,4 +1,6 @@ use infra::Test; +use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams}; +use suite_unit::conv_f32::{ConvProblem, ConvProblemParams}; pub fn suite() -> &'static infra::TestSuite { lazy_static::lazy_static! { @@ -13,13 +15,27 @@ fn mk_suite() -> infra::TestSuite { onnx.ignore(&ignore_onnx); let mut unit = suite_unit::suite().unwrap().clone(); + unit.get_sub_mut("bin_einsum").add_arbitrary::( + "proptest", + BinEinsumProblemParams { + force_unique_non_trivial_m_n: true, + ..BinEinsumProblemParams::default() + }, + ); + unit.get_sub_mut("conv_f32").add_arbitrary::( + "proptest", + ConvProblemParams { no_batch: true, ..ConvProblemParams::default() }, + ); + unit.ignore_case(&ignore_unit); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } -fn ignore_unit(_t: &[String], _case: &dyn Test) -> bool { - false +fn ignore_unit(t: &[String], case: &dyn Test) -> bool { + case.is::() + || case.is::() + || (t[0] == "conv_f32" && t[1] == "bug_metal_0") } fn ignore_onnx(t: &[String]) -> bool { From da8a3c1aaa30a558edb83cad53352a4477abdde2 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 16:52:41 +0100 Subject: [PATCH 3/9] Removed broken edge cases ma agement handled by arbitrary --- core/src/ops/einsum/optimize.rs | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/core/src/ops/einsum/optimize.rs b/core/src/ops/einsum/optimize.rs index cccb3dbf33..a57f4ed455 100644 --- a/core/src/ops/einsum/optimize.rs +++ b/core/src/ops/einsum/optimize.rs @@ -115,22 +115,6 @@ pub(crate) fn ensure_mkn_axes<'a>( let Some(k_axis) = k_axis else { return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?)); }; - - let non_trivial_m_axes = op.axes.iter_all_axes().filter(|a| { - a.inputs[0].len() == 1 - && a.outputs[0].len() == 1 - && a.inputs[1].len() == 0 - && !input_shapes[0][a.inputs[0][0]].is_one() - }).collect_vec(); - - let mut m_axes_pos = non_trivial_m_axes.iter().map(|axis| { axis.inputs[0][0] }).collect_vec(); - m_axes_pos.sort(); - - let consecutive_m_axes = m_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); - - if non_trivial_m_axes.len() > 1 && !consecutive_m_axes { - return Ok(AxesOrPatch::NotAMatMul(non_trivial_m_axes)) - } let m_axis = op .axes @@ -145,21 +129,6 @@ pub(crate) fn ensure_mkn_axes<'a>( return Ok(AxesOrPatch::Patch(inject_m_or_n_axis(op, model, node, false, &[k_axis])?)); }; - let non_trivial_n_axes = op.axes.iter_all_axes().filter(|a| { - a.inputs[1].len() == 1 - && a.outputs[0].len() == 1 - && a.inputs[0].len() == 0 - && !input_shapes[1][a.inputs[1][0]].is_one() - }).collect_vec(); - - let mut n_axes_pos = non_trivial_n_axes.iter().map(|axis| { axis.inputs[1][0] }).collect_vec(); - n_axes_pos.sort(); - - let consecutive_n_axes = n_axes_pos.windows(2).all(|window| { (window[1] - window[0]) == 1}); - if non_trivial_n_axes.len() > 1 && !consecutive_n_axes { - return Ok(AxesOrPatch::NotAMatMul(non_trivial_n_axes)) - } - let n_axis = op .axes .iter_all_axes() From c13ab41c77852fc1272c02016a2e6b98e3dff8ea Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Wed, 29 Jan 2025 16:54:59 +0100 Subject: [PATCH 4/9] Limit inputs to 8 for Nnef tests --- test-rt/suite-unit/src/bin_einsum.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index e8b5ecf7c5..28c231ab3e 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -40,7 +40,7 @@ impl Arbitrary for BinEinsumProblem { fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { let m_n_axes_range = if args.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; - (m_n_axes_range.clone(), m_n_axes_range, 0..3usize, 0..2usize, 0..2usize) + (m_n_axes_range.clone(), m_n_axes_range, 0..2usize, 0..2usize, 0..2usize) .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { let m_axes: String = ('a'..).take(m_axes).collect(); let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); From b0c6c019529667f8862a5f4a88f95cbb378bd78d Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 30 Jan 2025 10:20:13 +0100 Subject: [PATCH 5/9] no trivial exes for nnef and tf tests --- test-rt/suite-unit/src/bin_einsum.rs | 9 ++++++--- test-rt/test-nnef-cycle/suite.rs | 5 +++++ test-rt/test-tflite/suite.rs | 7 +++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index 28c231ab3e..69501fc879 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -12,6 +12,7 @@ use tract_num_traits::{One, Zero}; #[derive(Debug, Clone, Default)] pub struct BinEinsumProblemParams { pub force_unique_non_trivial_m_n: bool, + pub no_trivial_axes: bool, } #[derive(Clone)] @@ -38,9 +39,11 @@ impl Arbitrary for BinEinsumProblem { type Parameters = BinEinsumProblemParams; type Strategy = BoxedStrategy; - fn arbitrary_with(args: Self::Parameters) -> Self::Strategy { - let m_n_axes_range = if args.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; - (m_n_axes_range.clone(), m_n_axes_range, 0..2usize, 0..2usize, 0..2usize) + fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { + let m_n_axes_range = if params.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; + let trivial_axes_range = if params.no_trivial_axes { 0..1usize } else { 0..2usize }; + + (m_n_axes_range.clone(), m_n_axes_range, 0..3usize, trivial_axes_range.clone(), trivial_axes_range) .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { let m_axes: String = ('a'..).take(m_axes).collect(); let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); diff --git a/test-rt/test-nnef-cycle/suite.rs b/test-rt/test-nnef-cycle/suite.rs index 8cf8d3366b..7c79474b81 100644 --- a/test-rt/test-nnef-cycle/suite.rs +++ b/test-rt/test-nnef-cycle/suite.rs @@ -1,4 +1,5 @@ use infra::Test; +use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams}; use suite_unit::conv_q::{QConvProblem, QConvProblemParams}; pub fn suite() -> &'static infra::TestSuite { @@ -21,6 +22,9 @@ fn mk_suite() -> infra::TestSuite { compatible_conv_q, ); + let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()}; + unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); + infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -84,6 +88,7 @@ fn ignore_unit(t: &[String], tc: &dyn Test) -> bool { if let Some(qcp) = tc.downcast_ref::() { return !compatible_conv_q(qcp); } + if t[0] == "bin_einsum" && t[1] == "proptest" { return true; } false } diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 28058db688..fc31242ec1 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -1,5 +1,6 @@ use infra::Test; use regex::Regex; +use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams}; use suite_unit::conv_f32::{ConvProblem, ConvProblemParams}; use suite_unit::conv_q::{QConvProblem, QConvProblemParams}; use tract_core::internal::*; @@ -27,6 +28,9 @@ fn mk_suite() -> infra::TestSuite { QConvProblemParams { conv: cv, tflite_rules: true, ..QConvProblemParams::default() }, compatible_conv_q, ); + + let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()}; + unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -146,6 +150,9 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool { return true; } } + + if t[0] == "bin_einsum" && t[1] == "proptest" { return true; } + let [section, _unit] = t else { return false }; ["deconv", "q_flavours", "q_binary", "q_elmwise"].contains(&&**section) } From c563972a34d9932b4d5e765e4eabd98f140acacc Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 30 Jan 2025 12:55:18 +0100 Subject: [PATCH 6/9] force 1 iter axis for tflite --- test-rt/suite-unit/src/bin_einsum.rs | 5 +++-- test-rt/test-tflite/suite.rs | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index 69501fc879..3813dfbcd4 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -13,6 +13,7 @@ use tract_num_traits::{One, Zero}; pub struct BinEinsumProblemParams { pub force_unique_non_trivial_m_n: bool, pub no_trivial_axes: bool, + pub force_max_one_iter_axis: bool, } #[derive(Clone)] @@ -42,8 +43,8 @@ impl Arbitrary for BinEinsumProblem { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { let m_n_axes_range = if params.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; let trivial_axes_range = if params.no_trivial_axes { 0..1usize } else { 0..2usize }; - - (m_n_axes_range.clone(), m_n_axes_range, 0..3usize, trivial_axes_range.clone(), trivial_axes_range) + let iter_axes_range = if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize }; + (m_n_axes_range.clone(), m_n_axes_range, iter_axes_range, trivial_axes_range.clone(), trivial_axes_range) .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { let m_axes: String = ('a'..).take(m_axes).collect(); let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index fc31242ec1..c6cc78d9dd 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -29,7 +29,7 @@ fn mk_suite() -> infra::TestSuite { compatible_conv_q, ); - let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()}; + let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, force_max_one_iter_axis: true, ..BinEinsumProblemParams::default()}; unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -150,9 +150,9 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool { return true; } } - + if t[0] == "bin_einsum" && t[1] == "proptest" { return true; } - + let [section, _unit] = t else { return false }; ["deconv", "q_flavours", "q_binary", "q_elmwise"].contains(&&**section) } From 28fccf2a78a6f8c506928ade38aea72a456ea7b0 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 30 Jan 2025 13:58:46 +0100 Subject: [PATCH 7/9] force dim <= 4 for tflite --- test-rt/test-tflite/suite.rs | 2 +- tflite/src/rewriter.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index c6cc78d9dd..5f54552b4b 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -29,7 +29,7 @@ fn mk_suite() -> infra::TestSuite { compatible_conv_q, ); - let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, force_max_one_iter_axis: true, ..BinEinsumProblemParams::default()}; + let einsum_params = BinEinsumProblemParams { force_unique_non_trivial_m_n: true, no_trivial_axes: true, ..BinEinsumProblemParams::default()}; unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } diff --git a/tflite/src/rewriter.rs b/tflite/src/rewriter.rs index 9d6aa409cc..f2ade9fd67 100644 --- a/tflite/src/rewriter.rs +++ b/tflite/src/rewriter.rs @@ -43,6 +43,8 @@ fn trivial_axes_around_matmul( let trivial_axes = (0..rank - 2) .filter(|axis| facts[0].shape[*axis].is_one() && facts[1].shape[*axis].is_one()) .collect_vec(); + + ensure!(!trivial_axes.is_empty(), "Found Einsum with 4 > axes and no trivial axes"); let mut patch = TypedModelPatch::default(); let mut wire = patch.taps(model, &node.inputs)?; for axis in trivial_axes.iter().rev() { From d1af7c7f21aad60f0cc1fad92433cb7721e8bda9 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 30 Jan 2025 14:48:24 +0100 Subject: [PATCH 8/9] properly handle dimension limits --- test-rt/suite-unit/src/bin_einsum.rs | 83 ++++++++++++++++++++++++---- test-rt/test-f16/build.rs | 10 +++- test-rt/test-f16/suite.rs | 2 + test-rt/test-nnef-cycle/build.rs | 10 +++- test-rt/test-nnef-cycle/suite.rs | 5 -- test-rt/test-tflite/build.rs | 10 +++- test-rt/test-tflite/suite.rs | 15 +++-- 7 files changed, 105 insertions(+), 30 deletions(-) diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index 3813dfbcd4..507c2a919b 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -9,11 +9,23 @@ use tract_ndarray::{ArrayD, Axis, Dimension}; use tract_core::ops::einsum::EinSum; use tract_num_traits::{One, Zero}; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct BinEinsumProblemParams { pub force_unique_non_trivial_m_n: bool, pub no_trivial_axes: bool, pub force_max_one_iter_axis: bool, + pub max_dims: usize, +} + +impl Default for BinEinsumProblemParams { + fn default() -> BinEinsumProblemParams { + BinEinsumProblemParams { + force_unique_non_trivial_m_n: false, + no_trivial_axes: false, + force_max_one_iter_axis: false, + max_dims: 8, + } + } } #[derive(Clone)] @@ -41,22 +53,71 @@ impl Arbitrary for BinEinsumProblem { type Strategy = BoxedStrategy; fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { - let m_n_axes_range = if params.force_unique_non_trivial_m_n { 1..2usize } else { 1..3usize }; - let trivial_axes_range = if params.no_trivial_axes { 0..1usize } else { 0..2usize }; - let iter_axes_range = if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize }; - (m_n_axes_range.clone(), m_n_axes_range, iter_axes_range, trivial_axes_range.clone(), trivial_axes_range) - .prop_map(|(m_axes, n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { - let m_axes: String = ('a'..).take(m_axes).collect(); + let supp_m_n_axes_range = + if params.force_unique_non_trivial_m_n { 0..1usize } else { 0..2usize }; + assert!(params.max_dims >= 3); + let remaining = params.max_dims - 3; // At least 1 m, and n + + supp_m_n_axes_range + .clone() + .prop_flat_map(move |supp_m_axes| { + let remaining = remaining - supp_m_axes; + let n_axes_range = if remaining < supp_m_n_axes_range.end { + 0..(remaining + 1) + } else { + supp_m_n_axes_range.clone() + }; + let iter_axes_range = + if params.force_max_one_iter_axis { 0..2usize } else { 0..3usize }; + n_axes_range.prop_flat_map(move |supp_n_axes| { + let remaining = remaining - supp_n_axes; + let iter_axes_range = if remaining < iter_axes_range.end { + 0..(remaining + 1) + } else { + iter_axes_range.clone() + }; + iter_axes_range.clone().prop_flat_map(move |iter_axes| { + let remaining = remaining - iter_axes; + let trivial_m_n_axes_range = + if params.no_trivial_axes { 0..1usize } else { 0..2usize }; + let trivial_m_axes_range = if remaining < trivial_m_n_axes_range.end { + 0..(remaining + 1) + } else { + trivial_m_n_axes_range.clone() + }; + trivial_m_axes_range.clone().prop_flat_map(move |trivial_m_axes| { + let remaining = remaining - trivial_m_axes; + let trivial_n_axes_range = if remaining < trivial_m_n_axes_range.end { + 0..(remaining + 1) + } else { + trivial_m_n_axes_range.clone() + }; + trivial_n_axes_range.clone().prop_flat_map(move |trivial_n_axes| { + Just(( + supp_m_axes, + supp_n_axes, + iter_axes, + trivial_m_axes, + trivial_n_axes, + )) + }) + }) + }) + }) + }) + .prop_map(|(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { + dbg!(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes); + let m_axes: String = ('b'..).take(supp_m_axes).collect(); let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); let trivial_n_axes: String = ('p'..).take(trivial_n_axes).collect(); - let n_axes: String = ('g'..).take(n_axes).collect(); + let n_axes: String = ('h'..).take(supp_n_axes).collect(); let iter_axes: String = ('w'..).take(iter_axes).collect(); let a_axes: Vec = - (m_axes.clone() + &trivial_m_axes + &iter_axes + "k").chars().collect(); + (m_axes.clone() + "a" + &trivial_m_axes + &iter_axes + "k").chars().collect(); let b_axes: Vec = - (n_axes.clone() + &trivial_n_axes + &iter_axes + "k").chars().collect(); + (n_axes.clone() + "g" + &trivial_n_axes + &iter_axes + "k").chars().collect(); let c_axes: Vec = - (m_axes + &n_axes + &trivial_m_axes + &trivial_n_axes + &iter_axes) + (m_axes + &n_axes + "ag" + &trivial_m_axes + &trivial_n_axes + &iter_axes) .chars() .collect(); (Just(a_axes), Just(b_axes), Just(c_axes)) diff --git a/test-rt/test-f16/build.rs b/test-rt/test-f16/build.rs index 1d07109787..79117e02a2 100644 --- a/test-rt/test-f16/build.rs +++ b/test-rt/test-f16/build.rs @@ -1,7 +1,11 @@ -#[path="suite.rs"] +#[path = "suite.rs"] mod suite; fn main() { - suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::SuperApproximate"); + suite::suite().test_runtime( + "tests", + "suite::suite()", + "runtime()", + "Approximation::SuperApproximate", + ); } - diff --git a/test-rt/test-f16/suite.rs b/test-rt/test-f16/suite.rs index f98cbe15a3..f0822b50b2 100644 --- a/test-rt/test-f16/suite.rs +++ b/test-rt/test-f16/suite.rs @@ -19,6 +19,7 @@ fn mk_suite() -> infra::TestSuite { QConvProblemParams::default(), compatible_conv_q, ); + infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -28,6 +29,7 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool { return true; } } + let [section, _unit] = t else { return false }; ["q_flavours"].contains(&&**section) } diff --git a/test-rt/test-nnef-cycle/build.rs b/test-rt/test-nnef-cycle/build.rs index 67d54cf809..390e8f9843 100644 --- a/test-rt/test-nnef-cycle/build.rs +++ b/test-rt/test-nnef-cycle/build.rs @@ -1,7 +1,11 @@ -#[path="suite.rs"] +#[path = "suite.rs"] mod suite; fn main() { - suite::suite().test_runtime("nnef_cycle", "suite::suite()", "runtime()", "Approximation::Approximate"); + suite::suite().test_runtime( + "nnef_cycle", + "suite::suite()", + "runtime()", + "Approximation::Approximate", + ); } - diff --git a/test-rt/test-nnef-cycle/suite.rs b/test-rt/test-nnef-cycle/suite.rs index 7c79474b81..8cf8d3366b 100644 --- a/test-rt/test-nnef-cycle/suite.rs +++ b/test-rt/test-nnef-cycle/suite.rs @@ -1,5 +1,4 @@ use infra::Test; -use suite_unit::bin_einsum::{BinEinsumProblem, BinEinsumProblemParams}; use suite_unit::conv_q::{QConvProblem, QConvProblemParams}; pub fn suite() -> &'static infra::TestSuite { @@ -22,9 +21,6 @@ fn mk_suite() -> infra::TestSuite { compatible_conv_q, ); - let einsum_params = BinEinsumProblemParams {no_trivial_axes: true, ..BinEinsumProblemParams::default()}; - unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); - infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -88,7 +84,6 @@ fn ignore_unit(t: &[String], tc: &dyn Test) -> bool { if let Some(qcp) = tc.downcast_ref::() { return !compatible_conv_q(qcp); } - if t[0] == "bin_einsum" && t[1] == "proptest" { return true; } false } diff --git a/test-rt/test-tflite/build.rs b/test-rt/test-tflite/build.rs index 09d4238f2d..61d18726aa 100644 --- a/test-rt/test-tflite/build.rs +++ b/test-rt/test-tflite/build.rs @@ -1,7 +1,11 @@ -#[path="suite.rs"] +#[path = "suite.rs"] mod suite; fn main() { - suite::suite().test_runtime("tests", "suite::suite()", "runtime()", "Approximation::Approximate"); + suite::suite().test_runtime( + "tests", + "suite::suite()", + "runtime()", + "Approximation::Approximate", + ); } - diff --git a/test-rt/test-tflite/suite.rs b/test-rt/test-tflite/suite.rs index 5f54552b4b..edfd7ae7b7 100644 --- a/test-rt/test-tflite/suite.rs +++ b/test-rt/test-tflite/suite.rs @@ -29,8 +29,9 @@ fn mk_suite() -> infra::TestSuite { compatible_conv_q, ); - let einsum_params = BinEinsumProblemParams { force_unique_non_trivial_m_n: true, no_trivial_axes: true, ..BinEinsumProblemParams::default()}; - unit.get_sub_mut("bin_einsum").add_arbitrary::("proptest", einsum_params.clone()); + let einsum_params = BinEinsumProblemParams { max_dims: 4, ..BinEinsumProblemParams::default() }; + unit.get_sub_mut("bin_einsum") + .add_arbitrary::("proptest", einsum_params.clone()); infra::TestSuite::default().with("onnx", onnx).with("unit", unit) } @@ -104,7 +105,8 @@ fn ignore_onnx(t: &[String]) -> bool { test_thresholdrelu ", ); - let excluded = patterns(" + let excluded = patterns( + " test_slice_start_out_of_bounds test_Conv1d_groups test_Conv2d_groups @@ -122,7 +124,8 @@ fn ignore_onnx(t: &[String]) -> bool { pool_2d_same_lower test_cosh.* test_sinh.* - "); + ", + ); !included.iter().any(|pat| pat.is_match(name)) || excluded.iter().any(|pat| pat.is_match(name)) } @@ -151,7 +154,9 @@ fn ignore_unit(t: &[String], case: &dyn Test) -> bool { } } - if t[0] == "bin_einsum" && t[1] == "proptest" { return true; } + if t[0] == "bin_einsum" && t[1] == "proptest" { + return true; + } let [section, _unit] = t else { return false }; ["deconv", "q_flavours", "q_binary", "q_elmwise"].contains(&&**section) From 1e764055e0080af16b4d0d356cbae4c339a6c680 Mon Sep 17 00:00:00 2001 From: LouisChourakiSonos Date: Thu, 30 Jan 2025 16:57:47 +0100 Subject: [PATCH 9/9] renamed axes and add back support for unicas constant --- core/src/ops/einsum/as_matmul.rs | 8 ++++++-- test-rt/suite-unit/src/bin_einsum.rs | 30 ++++++++++++++++------------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/core/src/ops/einsum/as_matmul.rs b/core/src/ops/einsum/as_matmul.rs index 2ca8e7ae0b..f61ffc7a5c 100644 --- a/core/src/ops/einsum/as_matmul.rs +++ b/core/src/ops/einsum/as_matmul.rs @@ -38,8 +38,12 @@ fn einsum_rules( AxesOrPatch::Annotated(op) => op, AxesOrPatch::Patch(p) => return Ok(Some(p)), AxesOrPatch::NotAMatMul(axes) => { - bail!("{} is not a matmul because of axis {}", op.axes, axes.iter().map(|a| a.repr).join(", ") ) - } + bail!( + "{} is not a matmul because of axis {}", + op.axes, + axes.iter().map(|a| a.repr).join(", ") + ) + } }; let prefix: String = op .axes diff --git a/test-rt/suite-unit/src/bin_einsum.rs b/test-rt/suite-unit/src/bin_einsum.rs index 507c2a919b..c67b4755d8 100644 --- a/test-rt/suite-unit/src/bin_einsum.rs +++ b/test-rt/suite-unit/src/bin_einsum.rs @@ -106,18 +106,17 @@ impl Arbitrary for BinEinsumProblem { }) }) .prop_map(|(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes)| { - dbg!(supp_m_axes, supp_n_axes, iter_axes, trivial_m_axes, trivial_n_axes); - let m_axes: String = ('b'..).take(supp_m_axes).collect(); - let trivial_m_axes: String = ('m'..).take(trivial_m_axes).collect(); - let trivial_n_axes: String = ('p'..).take(trivial_n_axes).collect(); + let m_axes: String = ('a'..).take(supp_m_axes).collect(); + let trivial_m_axes: String = ('e'..).take(trivial_m_axes).collect(); let n_axes: String = ('h'..).take(supp_n_axes).collect(); + let trivial_n_axes: String = ('o'..).take(trivial_n_axes).collect(); let iter_axes: String = ('w'..).take(iter_axes).collect(); let a_axes: Vec = - (m_axes.clone() + "a" + &trivial_m_axes + &iter_axes + "k").chars().collect(); + (m_axes.clone() + "m" + &trivial_m_axes + &iter_axes + "k").chars().collect(); let b_axes: Vec = - (n_axes.clone() + "g" + &trivial_n_axes + &iter_axes + "k").chars().collect(); + (n_axes.clone() + "n" + &trivial_n_axes + &iter_axes + "k").chars().collect(); let c_axes: Vec = - (m_axes + &n_axes + "ag" + &trivial_m_axes + &trivial_n_axes + &iter_axes) + (m_axes + &n_axes + "mn" + &trivial_m_axes + &trivial_n_axes + &iter_axes) .chars() .collect(); (Just(a_axes), Just(b_axes), Just(c_axes)) @@ -197,16 +196,16 @@ impl BinEinsumProblem { model.add_source("b", TypedFact::shape_and_dt_of(&self.b))? }; - let output = model.wire_node( + let mut output = model.wire_node( "einsum", EinSum { axes: self.expr.clone(), operating_dt: f32::datum_type(), q_params: None }, &[a, b], )?; - //if let Some(c) = &self.unicast_add_constant { - // let c = model.add_const("c", c.clone())?; - // output = model.wire_node("add", tract_core::ops::math::add(), &[output[0], c])?; - //} + if let Some(c) = &self.unicast_add_constant { + let c = model.add_const("c", c.clone())?; + output = model.wire_node("add", tract_core::ops::math::add(), &[output[0], c])?; + } model.set_output_outlets(&output)?; @@ -285,7 +284,12 @@ impl BinEinsumProblem { } sum }); - output + if let Some(unicast_const) = self.unicast_add_constant.clone() { + output + unicast_const.into_array::().unwrap() + } + else { + output + } } }