Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wide FMA smoketest #68

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion egglog_src/churchroad.egg
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
(datatype Op
(And)
(Add)
(Mul)
(Or)
(Xor)
;;; TODO(@gussmith23): logical or arithmetic?
(Shr)
(Shl)
; Returns a bitvector of width 1.
(Eq)
; Bitwise not.
Expand Down Expand Up @@ -141,6 +144,7 @@
(AllBitwidthsMatch (Or))
(AllBitwidthsMatch (Xor))
(AllBitwidthsMatch (Shr))
(AllBitwidthsMatch (Mul))
; Have to write this one as a rule, unfortunately.
(ruleset core)
(rule ((Reg n)) ((AllBitwidthsMatch (Reg n))) :ruleset core)
Expand Down Expand Up @@ -230,7 +234,7 @@
(rule
((Op1 (Extract high low) expr)
(HasType expr (Bitvector n))
(>= 0 low)
(>= low 0)
(< high n))
((HasType (Op1 (Extract high low) expr) (Bitvector (+ 1 (- high low)))))
:ruleset typing)
Expand Down
142 changes: 142 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1539,4 +1539,146 @@ endmodule",
to_verilog_egraph_serialize(&serialized, &out, "")
);
}

#[test]
fn smoketest_2024_05_06_wide_fma() {
let churchroad_program = r#"
; wire declarations
; $mul$/Users/gus/churchroad/yosys-plugin/tests/wide-fma.sv:10$1_Y
(let v0 (Wire "v0" 32))
; a
(let v1 (Wire "v1" 32))
; b
(let v2 (Wire "v2" 32))
; c
(let v3 (Wire "v3" 32))
; out
(let v4 (Wire "v4" 32))

; cells
(union v4 (Op2 (Add) v0 v3))
(union v0 (Op2 (Mul) v1 v2))

; inputs
(let a (Var "a" 32))
(IsPort "" "a" (Input) a)
(union v1 a)
(let b (Var "b" 32))
(IsPort "" "b" (Input) b)
(union v2 b)
(let c (Var "c" 32))
(IsPort "" "c" (Input) c)
(union v3 c)

; outputs
(let out v4)
(IsPort "" "out" (Output) out)

; delete wire expressions
(delete (Wire "v0" 32))
(delete (Wire "v1" 32))
(delete (Wire "v2" 32))
(delete (Wire "v3" 32))
(delete (Wire "v4" 32))
"#;

let mut egraph = EGraph::default();
import_churchroad(&mut egraph);
egraph.parse_and_run_program(&churchroad_program).unwrap();

egraph
.parse_and_run_program(
r#"
(run-schedule (saturate typing))
(check (HasType out (Bitvector 32)))
"#,
)
.unwrap();

// Verifying the following rewrite with Rosette (note that we really
// should do this with higher bitwidth -- I don't consider this verified
// at the moment)
// > (require rosette)
// > (define-symbolic a b (bitvector 2))
// > (define ahi (bit 1 a))
// > (define alo (bit 0 a))
// > (define bhi (bit 1 b))
// > (define blo (bit 0 b))
// > (define part0 (zero-extend (bvmul alo blo) (bitvector 2)))
// > (define part1 (bvshl (zero-extend (bvmul alo bhi) (bitvector 2)) (bv 1 2)))
// > (define part2 (bvshl (zero-extend (bvmul ahi blo) (bitvector 2)) (bv 1 2)))
// > (verify (assert (bvmul a b) (bvadd part0 part1 part2)))
// (unsat)
//
// Ok, here's for 8 bits:
// > (require rosette)
// > (define-symbolic a b (bitvector 8))
// > (define ahi (extract 7 5 a))
// > (define ahi (extract 7 4 a))
// > (define alo (extract 3 0 a))
// > (define bhi (extract 7 4 b))
// > (define blo (extract 3 0 b))
// > (define part0 (zero-extend (bvmul alo blo) (bitvector 8)))
// > (define part1 (bvshl (zero-extend (bvmul alo bhi) (bitvector 8)) (bv 4 8)))
// > (define part2 (bvshl (zero-extend (bvmul ahi blo) (bitvector 8)) (bv 4 8)))
// > (verify (assert (bvmul a b) (bvadd part0 part1 part2)))
// (unsat)
egraph
.parse_and_run_program(
r#"
; Rewrite larger FMA into smaller FMA that could fit on one DSP.
(ruleset mul)
(rule
(
(= bigmul (Op2 (Mul) a b))
(HasType a (Bitvector n))
(HasType b (Bitvector n))
(= (% n 2) 0)
)
(
(let mid (/ n 2))
(let ahi (Op1 (Extract (- n 1) mid) a))
(let alo (Op1 (Extract (- mid 1) 0) a))
(let bhi (Op1 (Extract (- n 1) mid) b))
(let blo (Op1 (Extract (- mid 1) 0) b))
(let mul0 (Op1 (ZeroExtend n) (Op2 (Mul) alo blo)))
(let mul1 (Op1 (Shl) (Op1 (ZeroExtend n) (Op2 (Mul) ahi blo))))
(let mul2 (Op1 (Shl) (Op1 (ZeroExtend n) (Op2 (Mul) alo bhi))))
(let smallmuls (Op2 (Add) (Op2 (Add) mul0 mul1) mul2))
(union bigmul smallmuls)
)
:ruleset mul
)
"#,
)
.unwrap();

// 2: 10 * 1: 01
// =
// 10 * 1 + (10 * 0) * 10
// 1*1*10 + 0*1 + (1*0*10 + 0*0)*10

// 0010 (2) * 0011 (3)

// a1a0 * b1b0
// = a1*b1*4 + a1*b0*2 + a0*b1*2 + a0*b0
//
//

dbg!(egraph
.parse_and_run_program(
r#"
(run-schedule (repeat 5 (seq (saturate typing) (repeat 5 mul))))
(run-schedule (saturate typing))
(query-extract (Op2 (Mul) ?a ?b))
(query-extract (Bitvector n))

(run typing 5)
(check (HasType (Var "a" 32) (Bitvector 32)))
(query-extract (HasType (Op1 (Extract ?v0 ?v1) ?expr) (Bitvector 16)))
(check (HasType (Op1 (Extract 31 16) (Var "a" 32)) (Bitvector 16)))
"#,
)
.unwrap());
}
}
35 changes: 35 additions & 0 deletions yosys-plugin/tests/wide-fma.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: $YOSYS -q -m $CHURCHROAD_DIR/yosys-plugin/churchroad.so \
// RUN: -p 'read_verilog -sv %s; prep -top top; write_lakeroad' \
// RUN: | cat

module top
(
input [31:0] a, b, c,
output [31:0] out
);
assign out = (a*b)+c;
endmodule

// CHECK: (let v0 (Wire "v0" 32))
// CHECK: (let v1 (Wire "v1" 32))
// CHECK: (let v2 (Wire "v2" 32))
// CHECK: (let v3 (Wire "v3" 32))
// CHECK: (let v4 (Wire "v4" 32))
// CHECK: (union v4 (Op2 (Add) v0 v3))
// CHECK: (union v0 (Op2 (Mul) v1 v2))
// CHECK: (let a (Var "a" 32))
// CHECK: (IsPort "" "a" (Input) a)
// CHECK: (union v1 a)
// CHECK: (let b (Var "b" 32))
// CHECK: (IsPort "" "b" (Input) b)
// CHECK: (union v2 b)
// CHECK: (let c (Var "c" 32))
// CHECK: (IsPort "" "c" (Input) c)
// CHECK: (union v3 c)
// CHECK: (let out v4)
// CHECK: (IsPort "" "out" (Output) out)
// CHECK: (delete (Wire "v0" 32))
// CHECK: (delete (Wire "v1" 32))
// CHECK: (delete (Wire "v2" 32))
// CHECK: (delete (Wire "v3" 32))
// CHECK: (delete (Wire "v4" 32))