diff --git a/egglog_src/churchroad.egg b/egglog_src/churchroad.egg index 86e0ad3..e25e6e9 100644 --- a/egglog_src/churchroad.egg +++ b/egglog_src/churchroad.egg @@ -12,9 +12,12 @@ (datatype Op (And) (Add) + (Mul) (Or) (Xor) + ;;; TODO(@gussmith23): logical or arithmetic? (Shr) + (Shl) ; Returns a bitvector of width 1. (Eq) ; Bitwise not. @@ -141,6 +144,7 @@ (AllBitwidthsMatch (Or)) (AllBitwidthsMatch (Xor)) (AllBitwidthsMatch (Shr)) +(AllBitwidthsMatch (Mul)) ; Have to write this one as a rule, unfortunately. (ruleset core) (rule ((Reg n)) ((AllBitwidthsMatch (Reg n))) :ruleset core) @@ -230,7 +234,7 @@ (rule ((Op1 (Extract high low) expr) (HasType expr (Bitvector n)) - (>= 0 low) + (>= low 0) (< high n)) ((HasType (Op1 (Extract high low) expr) (Bitvector (+ 1 (- high low))))) :ruleset typing) diff --git a/src/lib.rs b/src/lib.rs index bfd0d4c..fca9c5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1539,4 +1539,146 @@ endmodule", to_verilog_egraph_serialize(&serialized, &out, "") ); } + + #[test] + fn smoketest_2024_05_06_wide_fma() { + let churchroad_program = r#" + ; wire declarations + ; $mul$/Users/gus/churchroad/yosys-plugin/tests/wide-fma.sv:10$1_Y + (let v0 (Wire "v0" 32)) + ; a + (let v1 (Wire "v1" 32)) + ; b + (let v2 (Wire "v2" 32)) + ; c + (let v3 (Wire "v3" 32)) + ; out + (let v4 (Wire "v4" 32)) + + ; cells + (union v4 (Op2 (Add) v0 v3)) + (union v0 (Op2 (Mul) v1 v2)) + + ; inputs + (let a (Var "a" 32)) + (IsPort "" "a" (Input) a) + (union v1 a) + (let b (Var "b" 32)) + (IsPort "" "b" (Input) b) + (union v2 b) + (let c (Var "c" 32)) + (IsPort "" "c" (Input) c) + (union v3 c) + + ; outputs + (let out v4) + (IsPort "" "out" (Output) out) + + ; delete wire expressions + (delete (Wire "v0" 32)) + (delete (Wire "v1" 32)) + (delete (Wire "v2" 32)) + (delete (Wire "v3" 32)) + (delete (Wire "v4" 32)) + "#; + + let mut egraph = EGraph::default(); + import_churchroad(&mut egraph); + egraph.parse_and_run_program(&churchroad_program).unwrap(); + + egraph + .parse_and_run_program( + r#" + (run-schedule (saturate typing)) + (check (HasType out (Bitvector 32))) + "#, + ) + .unwrap(); + + // Verifying the following rewrite with Rosette (note that we really + // should do this with higher bitwidth -- I don't consider this verified + // at the moment) + // > (require rosette) + // > (define-symbolic a b (bitvector 2)) + // > (define ahi (bit 1 a)) + // > (define alo (bit 0 a)) + // > (define bhi (bit 1 b)) + // > (define blo (bit 0 b)) + // > (define part0 (zero-extend (bvmul alo blo) (bitvector 2))) + // > (define part1 (bvshl (zero-extend (bvmul alo bhi) (bitvector 2)) (bv 1 2))) + // > (define part2 (bvshl (zero-extend (bvmul ahi blo) (bitvector 2)) (bv 1 2))) + // > (verify (assert (bvmul a b) (bvadd part0 part1 part2))) + // (unsat) + // + // Ok, here's for 8 bits: + // > (require rosette) + // > (define-symbolic a b (bitvector 8)) + // > (define ahi (extract 7 5 a)) + // > (define ahi (extract 7 4 a)) + // > (define alo (extract 3 0 a)) + // > (define bhi (extract 7 4 b)) + // > (define blo (extract 3 0 b)) + // > (define part0 (zero-extend (bvmul alo blo) (bitvector 8))) + // > (define part1 (bvshl (zero-extend (bvmul alo bhi) (bitvector 8)) (bv 4 8))) + // > (define part2 (bvshl (zero-extend (bvmul ahi blo) (bitvector 8)) (bv 4 8))) + // > (verify (assert (bvmul a b) (bvadd part0 part1 part2))) + // (unsat) + egraph + .parse_and_run_program( + r#" + ; Rewrite larger FMA into smaller FMA that could fit on one DSP. + (ruleset mul) + (rule + ( + (= bigmul (Op2 (Mul) a b)) + (HasType a (Bitvector n)) + (HasType b (Bitvector n)) + (= (% n 2) 0) + ) + ( + (let mid (/ n 2)) + (let ahi (Op1 (Extract (- n 1) mid) a)) + (let alo (Op1 (Extract (- mid 1) 0) a)) + (let bhi (Op1 (Extract (- n 1) mid) b)) + (let blo (Op1 (Extract (- mid 1) 0) b)) + (let mul0 (Op1 (ZeroExtend n) (Op2 (Mul) alo blo))) + (let mul1 (Op1 (Shl) (Op1 (ZeroExtend n) (Op2 (Mul) ahi blo)))) + (let mul2 (Op1 (Shl) (Op1 (ZeroExtend n) (Op2 (Mul) alo bhi)))) + (let smallmuls (Op2 (Add) (Op2 (Add) mul0 mul1) mul2)) + (union bigmul smallmuls) + ) + :ruleset mul + ) + "#, + ) + .unwrap(); + + // 2: 10 * 1: 01 + // = + // 10 * 1 + (10 * 0) * 10 + // 1*1*10 + 0*1 + (1*0*10 + 0*0)*10 + + // 0010 (2) * 0011 (3) + + // a1a0 * b1b0 + // = a1*b1*4 + a1*b0*2 + a0*b1*2 + a0*b0 + // + // + + dbg!(egraph + .parse_and_run_program( + r#" + (run-schedule (repeat 5 (seq (saturate typing) (repeat 5 mul)))) + (run-schedule (saturate typing)) + (query-extract (Op2 (Mul) ?a ?b)) + (query-extract (Bitvector n)) + + (run typing 5) + (check (HasType (Var "a" 32) (Bitvector 32))) + (query-extract (HasType (Op1 (Extract ?v0 ?v1) ?expr) (Bitvector 16))) + (check (HasType (Op1 (Extract 31 16) (Var "a" 32)) (Bitvector 16))) + "#, + ) + .unwrap()); + } } diff --git a/yosys-plugin/tests/wide-fma.sv b/yosys-plugin/tests/wide-fma.sv new file mode 100644 index 0000000..aeb677a --- /dev/null +++ b/yosys-plugin/tests/wide-fma.sv @@ -0,0 +1,35 @@ +// RUN: $YOSYS -q -m $CHURCHROAD_DIR/yosys-plugin/churchroad.so \ +// RUN: -p 'read_verilog -sv %s; prep -top top; write_lakeroad' \ +// RUN: | cat + +module top + ( + input [31:0] a, b, c, + output [31:0] out + ); + assign out = (a*b)+c; +endmodule + +// CHECK: (let v0 (Wire "v0" 32)) +// CHECK: (let v1 (Wire "v1" 32)) +// CHECK: (let v2 (Wire "v2" 32)) +// CHECK: (let v3 (Wire "v3" 32)) +// CHECK: (let v4 (Wire "v4" 32)) +// CHECK: (union v4 (Op2 (Add) v0 v3)) +// CHECK: (union v0 (Op2 (Mul) v1 v2)) +// CHECK: (let a (Var "a" 32)) +// CHECK: (IsPort "" "a" (Input) a) +// CHECK: (union v1 a) +// CHECK: (let b (Var "b" 32)) +// CHECK: (IsPort "" "b" (Input) b) +// CHECK: (union v2 b) +// CHECK: (let c (Var "c" 32)) +// CHECK: (IsPort "" "c" (Input) c) +// CHECK: (union v3 c) +// CHECK: (let out v4) +// CHECK: (IsPort "" "out" (Output) out) +// CHECK: (delete (Wire "v0" 32)) +// CHECK: (delete (Wire "v1" 32)) +// CHECK: (delete (Wire "v2" 32)) +// CHECK: (delete (Wire "v3" 32)) +// CHECK: (delete (Wire "v4" 32))