Skip to content

Commit

Permalink
mostly fixed symbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Aug 4, 2024
1 parent ecf6d65 commit caa7e55
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 79 deletions.
16 changes: 0 additions & 16 deletions src/hl_ops/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,6 @@ impl GraphTensor {
}
}

impl From<f32> for ConstantValue {
fn from(value: f32) -> Self {
ConstantValue::Float(value)
}
}
impl From<f64> for ConstantValue {
fn from(value: f64) -> Self {
ConstantValue::Float(value as f32)
}
}
impl<T: Into<Expression>> From<T> for ConstantValue {
fn from(value: T) -> Self {
ConstantValue::Expression(value.into())
}
}

impl Graph {
/// A scalar constant
pub fn constant(&mut self, i: impl Into<ConstantValue>) -> GraphTensor {
Expand Down
20 changes: 18 additions & 2 deletions src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,22 @@ pub enum ConstantValue {
Float(f32),
}

impl From<f32> for ConstantValue {
fn from(value: f32) -> Self {
ConstantValue::Float(value)
}
}
impl From<f64> for ConstantValue {
fn from(value: f64) -> Self {
ConstantValue::Float(value as f32)
}
}
impl<T: Into<Expression>> From<T> for ConstantValue {
fn from(value: T) -> Self {
ConstantValue::Expression(value.into())
}
}

/// Produces a single number constant from an expression or a float
#[derive(Clone, PartialEq)]
pub struct Constant(pub ConstantValue, pub *const FxHashMap<char, usize>);
Expand Down Expand Up @@ -280,7 +296,6 @@ pub struct Mul;
impl Operator for Mul {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (lhs, rhs) = (get_vec(&inp[0].0), get_vec(&inp[1].0));
println!("EXPR: {:?}", inp[0].1.dims());
let mut out_data = vec![0.; inp[0].1.n_elements().to_usize().unwrap()];
let lexpr = (inp[0].1.index_expression(), inp[0].1.valid_expression());
let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression());
Expand Down Expand Up @@ -389,7 +404,8 @@ fn get_index(
index: usize,
) -> f32 {
if val.exec_single_var_stack(index, stack) != 0 {
data[ind.exec_single_var_stack(index, stack)]
let i = ind.exec_single_var_stack(index, stack);
data[i]
} else {
0.0
}
Expand Down
170 changes: 109 additions & 61 deletions src/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, Div, DivAssign, Mul, MulAssign,
Rem, RemAssign, Sub, SubAssign,
},
time::Duration,
};
use symbolic_expressions::Sexp;

Expand Down Expand Up @@ -476,6 +475,13 @@ impl From<&bool> for Expression {
}
}

impl Add<Expression> for usize {
type Output = Expression;
fn add(self, rhs: Expression) -> Self::Output {
rhs + self
}
}

impl Sub<Expression> for usize {
type Output = Expression;
fn sub(self, rhs: Expression) -> Self::Output {
Expand All @@ -497,6 +503,76 @@ impl Div<Expression> for usize {
}
}

impl Rem<Expression> for usize {
type Output = Expression;
fn rem(self, rhs: Expression) -> Self::Output {
Expression::from(self) % rhs
}
}

impl BitAnd<Expression> for usize {
type Output = Expression;
fn bitand(self, rhs: Expression) -> Self::Output {
rhs & self
}
}

impl BitOr<Expression> for usize {
type Output = Expression;
fn bitor(self, rhs: Expression) -> Self::Output {
rhs | self
}
}

impl Add<Expression> for i32 {
type Output = Expression;
fn add(self, rhs: Expression) -> Self::Output {
rhs + self
}
}

impl Sub<Expression> for i32 {
type Output = Expression;
fn sub(self, rhs: Expression) -> Self::Output {
Expression::from(self) - rhs
}
}

impl Mul<Expression> for i32 {
type Output = Expression;
fn mul(self, rhs: Expression) -> Self::Output {
rhs * self
}
}

impl Div<Expression> for i32 {
type Output = Expression;
fn div(self, rhs: Expression) -> Self::Output {
Expression::from(self) / rhs
}
}

impl Rem<Expression> for i32 {
type Output = Expression;
fn rem(self, rhs: Expression) -> Self::Output {
Expression::from(self) % rhs
}
}

impl BitAnd<Expression> for i32 {
type Output = Expression;
fn bitand(self, rhs: Expression) -> Self::Output {
rhs & self
}
}

impl BitOr<Expression> for i32 {
type Output = Expression;
fn bitor(self, rhs: Expression) -> Self::Output {
rhs | self
}
}

impl<E: Into<Expression>> Add<E> for Expression {
type Output = Self;
fn add(self, rhs: E) -> Self::Output {
Expand Down Expand Up @@ -883,37 +959,13 @@ impl Analysis<Math> for ConstantFold {
a.checked_div(b)?
}
}
Math::Mod([a, b]) if x(b) != Some(0) => x(a)?.checked_rem(x(b)?)?,
Math::Min([a, b]) if x(b) != Some(0) => x(a)?.min(x(b)?),
Math::Max([a, b]) if x(b) != Some(0) => x(a)?.max(x(b)?),
Math::And([a, b]) if x(b) != Some(0) => {
if x(a)? != 0 && x(b)? != 0 {
1
} else {
0
}
}
Math::Or([a, b]) if x(b) != Some(0) => {
if x(a)? != 0 || x(b)? != 0 {
1
} else {
0
}
}
Math::LessThan([a, b]) if x(b) != Some(0) => {
if x(a)? < x(b)? {
1
} else {
0
}
}
Math::GreaterThanEqual([a, b]) if x(b) != Some(0) => {
if x(a)? >= x(b)? {
1
} else {
0
}
}
Math::Mod([a, b]) => x(a)?.checked_rem(x(b)?)?,
Math::Min([a, b]) => x(a)?.min(x(b)?),
Math::Max([a, b]) => x(a)?.max(x(b)?),
Math::And([a, b]) => (x(a)? != 0 && x(b)? != 0) as i32,
Math::Or([a, b]) => (x(a)? != 0 || x(b)? != 0) as i32,
Math::LessThan([a, b]) => (x(a)? < x(b)?) as i32,
Math::GreaterThanEqual([a, b]) => (x(a)? >= x(b)?) as i32,
_ => return None,
})
}
Expand Down Expand Up @@ -965,23 +1017,23 @@ fn make_rules() -> Vec<Rewrite> {
rewrite!("assoc-mul"; "(* ?a (* ?b ?c))" => "(* (* ?a ?b) ?c)"),
rewrite!("assoc-div"; "(/ (/ ?a ?b) ?c)" => "(/ ?a (* ?b ?c))"),
rewrite!("mul-div-associative"; "(/ (* ?a ?b) ?c)" => "(* ?a (/ ?b ?c))"),
// rewrite!("mul-div-associative-rev"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"),
// rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
// rewrite!("mul-div-associative-rev"; "(* ?a (/ ?b ?c))" => "(/ (* ?a ?b) ?c)"), // BAD? Makes test_pool_1d fail
rewrite!("sub-canon"; "(- ?a ?b)" => "(+ ?a (* -1 ?b))"),
// Distributive
rewrite!("distribute-mul"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
// rewrite!("distribute-div"; "(/ (+ ?a ?b) ?c)" => "(+ (/ ?a ?c) (/ ?b ?c))"),
rewrite!("distribute-div"; "(/ (+ ?a ?b) ?c)" => "(+ (/ ?a ?c) (/ ?b ?c))"),
rewrite!("distribute-max"; "(* ?a (max ?b ?c))" => "(max (* ?a ?b) (* ?a ?c))" if is_const_positive(&["?a"])),
// rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"),
rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"),
// rewrite!("distribute-mod"; "(* (% ?b ?c) ?a)" => "(% (* ?b ?a) (* ?c ?a))"),
// Factoring
rewrite!("factor-mul" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
rewrite!("factor-div" ; "(+ (/ ?a ?b) (/ ?a ?c))" => "(/ ?a (+ ?b ?c))"),
rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)" if is_const_positive(&["?a"])),
// rewrite!("factor-div" ; "(+ (/ ?a ?b) (/ ?a ?c))" => "(/ ?a (+ ?b ?c))"),
rewrite!("group-terms"; "(+ ?a ?a)" => "(* 2 ?a)"),
// Other
// rewrite!("explicit-truncate"; "(* (/ ?a ?b) ?b)" => "(- ?a (% ?a ?b))"),
// rewrite!("mul-mod"; "(% (* ?a ?b) ?b)" => "0"),
rewrite!("div-move-inside"; "(+ (/ ?a ?b) ?c)" => "(/ (+ ?a (* ?c ?b)) ?b)"),
// rewrite!("mul-distribute"; "(* ?a (% (/ ?b ?c) ?d))" => "(% (/ ?b (* ?c ?a)) (* ?d ?a))"),
// rewrite!("mul-distribute"; "(* ?a (% (/ ?b ?c) ?d))" => "(% (/ ?b (* ?c ?a)) (* ?d ?a))"), // BAD
// rewrite!("div-mod-mul"; "(% (/ ?a ?b) ?c)" => "(% ?a (* ?b ?c))"),
// Simple binary reductions
rewrite!("add-0"; "(+ ?a 0)" => "?a"),
Expand All @@ -1008,11 +1060,12 @@ fn egg_simplify(e: Expression) -> Expression {
let expr = luminal_to_egg(&e);
// Simplify
let runner = Runner::default()
.with_iter_limit(1_000)
.with_time_limit(Duration::from_secs(30))
.with_node_limit(100_000)
// .with_iter_limit(1_000)
// .with_time_limit(std::time::Duration::from_secs(30))
// .with_node_limit(100_000_000)
.with_expr(&expr)
.run(&make_rules());
// runner.print_report();
let extractor = Extractor::new(&runner.egraph, AstSize);
let (_, best) = extractor.find_best(runner.roots[0]);
// Convert back to luminal expression
Expand Down Expand Up @@ -1048,15 +1101,15 @@ mod tests {
let main = Expression::from('x') - 255;
let sub = Expression::from('x') / 2;
let new = main.substitute('x', sub).simplify();
assert_eq!(new, (Expression::from('x') / 2) + -255);
assert_eq!(new.len(), 5);
expression_cleanup();
}

#[test]
fn test_group_terms() {
let s = Expression::from('s');
let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1)));
assert_eq!(expr.simplify().terms.read().len(), 7);
assert_eq!(expr.simplify().len(), 7);
expression_cleanup();
}

Expand All @@ -1073,27 +1126,22 @@ mod tests {
let z = Expression::from('z');
let w = Expression::from('w');
let h = Expression::from('h');
let x = ((z
/ ((Expression::from(-5)
+ (((((Expression::from(-5) + ((((((w + 153) / 2) / 2) / 2) / 2) / 2)) * 4)
+ 9)
/ 2)
/ 2))
* (Expression::from(-5)
+ (((Expression::from(9)
+ (4 * (Expression::from(-5)
+ ((((((Expression::from(153) + h) / 2) / 2) / 2) / 2) / 2))))
/ 2)
/ 2))))
% 64)
.simplify();
panic!("{x}")
let o = (z
/ ((-5 + (((((-5 + ((((((w + 153) / 2) / 2) / 2) / 2) / 2)) * 4) + 9) / 2) / 2))
* (-5 + (((9 + (4 * (-5 + ((((((153 + h) / 2) / 2) / 2) / 2) / 2)))) / 2) / 2))))
% 64;
let x = o.simplify();
assert_eq!(x.len(), 23); // Should be 21 if we can re-enable mul-div-associative-rev
expression_cleanup();
}

#[test]
fn test_final() {
let z = Expression::from('z');
let w = Expression::from('w');
let x = (((w + -7) / 32) + (Expression::from(-11) / 4)).simplify();
assert_eq!(x.len(), 5);
let h = Expression::from('h');
let x = (z % (((((153 + h) / 8) + -31) * ((((w + 153) / 8) + -31) / 16)) * 64)).simplify();
assert_eq!(x.len(), 15); // Should be 11 if we can re-enable mul-div-associative-rev
expression_cleanup();
}
}

0 comments on commit caa7e55

Please sign in to comment.