diff --git a/crates/luminal_symbolic/src/simplify.rs b/crates/luminal_symbolic/src/simplify.rs index 919ea484..df3f426b 100644 --- a/crates/luminal_symbolic/src/simplify.rs +++ b/crates/luminal_symbolic/src/simplify.rs @@ -27,6 +27,8 @@ pub fn reduce_triples( if let (Term::Num(a), Term::Num(b)) = (a_term, b_term) { if let Some(c) = term.as_op().unwrap()(a, b) { stack.push((None, Term::Num(c))); + } else { + break; } } else if let Term::Var(a) = a_term { stack.push((None, Term::Var(a))); diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index 8a1f9985..0ce12b48 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -100,9 +100,10 @@ impl ShapeTracker { /// Create an expression to translate logical indexes into physical indexes pub fn index_expression(&self) -> BigExpression { - println!("ORIG: {:?}", self); + if !self.is_reshaped() { + return 'z'.into(); + } let shape = combine_dims(*self); - println!("Combined: {:?}", shape); let strides = shape.unordered_strides(); // Dimension strides in original order let mut ind_expr = BigExpression::from(0); // The final index expression let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1) @@ -138,6 +139,9 @@ impl ShapeTracker { /// If this expression evaluates to 0, the logical index is invalid. Otherwise it is valid pub fn valid_expression(&self) -> BigExpression { + if !self.is_reshaped() { + return true.into(); + } let shape = combine_dims(*self); let mut ret = BigExpression::from(1); let mut acc = BigExpression::from(1);