diff --git a/src/lib.rs b/src/lib.rs index a2d3157..9ac8040 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -862,6 +862,291 @@ endmodule", registers = registers, ) } + +pub fn to_rewrite_rule_egraph_serialize( + egraph: &egraph_serialize::EGraph, + choices: &IndexMap, + name: &str, +) -> String { + let (inputs, outputs) = get_inputs_and_outputs_serialized(egraph); + + // Add all of the outputs to the queue + let mut queue: Vec = outputs.into_iter().map(|c| c.1).collect(); + let mut done = HashSet::new(); + + fn maybe_push_expr_on_queue( + queue: &mut Vec, + done: &HashSet, + class_id: &ClassId, + ) { + if !queue.contains(class_id) && !done.contains(class_id) { + queue.push(class_id.clone()); + } + } + + let id_to_wire_name = |id: &ClassId| -> String { + let v = inputs.iter().find(|(_k, v)| v == id); + match v { + Some((var, _id)) => var.to_string(), + None => format!("wire_{}", id), + } + }; + + let mut rules = String::new(); + + // This is to deal with the generation of (Op2 (Extract b b) (Var x)) -> "x_b" + // Keys are the variable names and values are list of which bits are extracted + // This is used at the end to Concat all of the variables together. + let mut input_extract_map: HashMap> = HashMap::new(); + + while let Some(id) = queue.pop() { + done.insert(id.clone()); + let term = &egraph[&choices[&id]]; + + let op = &term.op; + + match op.as_str() { + // Things to ignore. + // + // Ignore the Unit. + "()" | + // Ignore various relations/facts. + "IsPort" | + "Input" | + "Output" | + // Ignore the nodes for the ops themselves. + "ZeroExtend" | + "Concat" | + "Extract" | + "Or" | + "And" | + "Add" | + "Shr" | + "Eq" | + "Xor" | + "Reg" => (), + // Ignore integer literals. + v if v.parse::().is_ok() => (), + + "Op0" | "Op1" | "Op2" => { + let op_node = &egraph[&term.children[0]]; + // BV, ZeroExtend, Add, Reg, working + match op_node.op.as_str() { + "ZeroExtend" => { + assert_eq!(op_node.children.len(), 1); + assert_eq!(term.children.len(), 2); + let bw = egraph[&op_node.children[0]].op.parse::().unwrap(); + let val_id = &egraph[&term.children[1]].eclass; + let o = format!( + "(= {this_wire} (Op1 (ZeroExtend {bw}) {value}))\n", + this_wire = id_to_wire_name(&id), + value = id_to_wire_name(val_id) + ); + rules.push_str(o.as_str()); + + maybe_push_expr_on_queue(&mut queue, &done, val_id); + }, + + "Not" => { + assert_eq!(term.children.len(), 2); + let val_id = &egraph[&term.children[1]].eclass; + let o = format!( + "(= {this_wire} (Op1 (Not) {value}))\n", + this_wire = id_to_wire_name(&id), + value = id_to_wire_name(val_id) + ); + rules.push_str(o.as_str()); + maybe_push_expr_on_queue(&mut queue, &done, val_id); + + }, + "BV" => { + assert_eq!(op_node.children.len(), 2); + let value = egraph[&op_node.children[0]].op.parse::().unwrap(); + let bw = egraph[&op_node.children[1]].op.parse::().unwrap(); + let o = format!( + "(= {this_wire} (Op0 (BV {value} {bw})))\n", + this_wire = id_to_wire_name(&id) + ); + rules.push_str(o.as_str()); + }, + "Reg" => { + let default_val = egraph[&op_node.children[0]].op.parse::().unwrap(); + let d_id = &egraph[&term.children[1]].eclass; + let val_id = &egraph[&term.children[2]].eclass; + rules.push_str( + format!( + "(= {this_wire} (Op2 (Reg {default_val}) {val} {val1}))\n", + this_wire = id_to_wire_name(&id), + val = id_to_wire_name(d_id), + val1 = id_to_wire_name(val_id) + ).as_str() + ); + if !done.contains(d_id) { + queue.push(d_id.clone()); + } + + maybe_push_expr_on_queue(&mut queue, &done, val_id); + }, + "Concat" | "Xor" | "And" | "Or" | "Add" => { + let expr0_id = &egraph[&term.children[1]].eclass; + let expr1_id = &egraph[&term.children[2]].eclass; + + rules.push_str( + format!( + "(= {this_wire} (Op2 ({op}) {expr0} {expr1}))\n", + this_wire = id_to_wire_name(&id), + op = op_node.op.as_str(), + expr0 = id_to_wire_name(expr0_id), + expr1 = id_to_wire_name(expr1_id) + ).as_str() + ); + maybe_push_expr_on_queue(&mut queue, &done, expr0_id); + maybe_push_expr_on_queue(&mut queue, &done, expr1_id); + }, + "Extract" => { + assert_eq!(term.children.len(), 2); + assert_eq!(op_node.children.len(), 2); + // TODO: need to think out the semantics of when Extract + // hi != lo + // i.e. how to construct the module + let hi:i64 = egraph[&op_node.children[0]].op.parse().unwrap(); + let lo:i64 = egraph[&op_node.children[1]].op.parse().unwrap(); + assert_eq!(hi, lo); + let id = &term.eclass; + // TODO: I need to check if an expression is an input here + let expr_id = &egraph[&term.children[1]].eclass; + + let v = inputs.iter().find(|(_k, v)| v == expr_id); + let expr = match v { + Some((var, _id)) => { + let v = input_extract_map.entry(var.clone()).or_default(); + v.push(hi); + + format!("{}_{hi}", var) + }, + None => format!("(Op1 (Extract {hi} {lo}) wire_{})", id), + }; + rules.push_str( + format!( + "(= {this_wire} {expr})\n", + this_wire = id_to_wire_name(id), + ).as_str() + ); + + maybe_push_expr_on_queue(&mut queue, &done, expr_id); + + }, + v => todo!("{:?}", v) + + } + }, + "Var" => { }, + + _ => todo!("{:?}", &term), + } + } + + // TODO: need to figure out how to do definitions - what does this look like for register.? + fn vec_list_to_str_cons(v: &Vec) -> String { + let mut str: String = String::new(); + if v.is_empty() { + return str; + } + + for i in v { + let s = format!("(StringCons \"{i}\" "); + str.push_str(s.as_str()); + } + str.push_str("(StringNil)"); + + for _i in v { + str.push(')'); + } + + str + } + fn vec_list_to_expr_cons(v: &Vec) -> String { + let mut str: String = String::new(); + if v.is_empty() { + return str; + } + + for i in v { + let s = format!("(ExprCons {i} "); + str.push_str(s.as_str()); + } + str.push_str("(ExprNil)"); + + for _i in v { + str.push(')'); + } + + str + } + + fn vec_list_to_concat(v: &mut [String]) -> String { + assert!(v.len() > 0); + if v.len() == 1 { + return v[0].clone(); + } + let mut str: String = String::new(); + // assuming it's [v0, v1, v2...] + // want (Concat v0 (Concat v1 v0)) + if v.len() == 2 { + let s = format!("(Op2 (Concat) {} {})", v[0], v[1]); + str.push_str(s.as_str()); + return s; + } + let sz: usize = v.len() - 2; + for i in &mut v[0..sz] { + let s = format!("(Op2 (Concat) {i} "); + str.push_str(s.as_str()); + } + let s = format!("(Op2 (Concat) {} {})", v[v.len() - 2], v[v.len() - 1]); + str.push_str(s.as_str()); + + for _i in &mut v[0..sz] { + str.push(')'); + } + + str + } + + let input_names = inputs.iter().map(|a| a.0.clone()).collect(); + let inputs_str = vec_list_to_str_cons(&input_names); + let expr_cons = vec_list_to_expr_cons(&input_names); + + let mut maybe_let = String::new(); + let mut vec = input_extract_map.drain().collect::>(); + vec.sort(); + + for (k, v) in &mut vec { + // let s = format!("(let {} (Wire \"{}\"))", &k, &k); + // sort the vector + v.sort(); + let mut v1: Vec<_> = v.iter_mut().map(|bw| format!("{k}_{bw}")).collect(); + let s1 = vec_list_to_concat(&mut v1); + let s = format!("(let {k} {s1})\n"); + maybe_let.push_str(s.as_str()); + } + + let rule = format!( + r#"(rule + ;; set of definitions + ({rules}) + ;; set of declarations + ( +{maybe_let} +(let instance (ModuleInstance "{name}" (StringNil) (ExprNil) + {inputs_str} + {expr_cons} + ) + )) :ruleset module_rewrites)"# + ); + + rule +} + pub fn to_verilog(term_dag: &TermDag, id: usize) -> String { // let mut wires = HashMap::default(); @@ -1936,4 +2221,173 @@ endmodule", get_inputs_and_outputs_serialized(&egraph.serialize(SerializeConfig::default())); } + #[test] + fn compile_rewrite_rule() { + let mut egraph = EGraph::default(); + import_churchroad(&mut egraph); + egraph + .parse_and_run_program( + r#" +(let v0 (Wire "v0" 4)) +; clk +(let v1 (Wire "v1" 1)) +; out +(let v2 (Wire "v2" 4)) + +; cells +; 1'1 +(let v3 (Op0 (BV 1 1))) +; TODO not handling signedness +(let v4 (Op1 (ZeroExtend 4) v3)) +(union v0 (Op2 (Add) v2 v4)) +; TODO: assuming 0 default for Reg +(union v2 (Op2 (Reg 0) v1 v0)) + +; inputs +(let clk (Var "clk" 1)) + + +(IsPort "" "clk" (Input) clk) +(union v1 clk) + +; outputs +(let out v2) +(IsPort "" "out" (Output) out) + +; delete wire expressions +(delete (Wire "v0" 4)) +(delete (Wire "v1" 1)) +(delete (Wire "v2" 4)) + "#, + ) + .unwrap(); + + let serialized = egraph.serialize(SerializeConfig::default()); + let imap = AnythingExtractor.extract(&serialized, &[]); + + let out = to_rewrite_rule_egraph_serialize(&serialized, &imap, "REG"); + println!("{out}"); + + assert_eq!( + r#"(rule + ;; set of definitions + ((= wire_24 (Op2 (Reg 0) clk wire_6)) +(= wire_6 (Op2 (Add) wire_24 wire_15)) +(= wire_15 (Op1 (ZeroExtend 4) wire_12)) +(= wire_12 (Op0 (BV 1 1))) +) + ;; set of declarations + ( + +(let instance (ModuleInstance "REG" (StringNil) (ExprNil) + (StringCons "clk" (StringNil)) + (ExprCons clk (ExprNil)) + ) + )) :ruleset module_rewrites)"#, + out + ); + } + + #[test] + fn compile_rewrite_rule_1() { + let mut egraph = EGraph::default(); + import_churchroad(&mut egraph); + egraph + .parse_and_run_program( + r#" +; wire declarations +; $abc$84$auto$blifparse.cc:396:parse_blif$85.A +(let v0 (Wire "v0" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$85.Y +(let v1 (Wire "v1" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$86.A +(let v2 (Wire "v2" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$86.Y +(let v3 (Wire "v3" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$87.B +(let v4 (Wire "v4" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$87.Y +(let v5 (Wire "v5" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$88.B +(let v6 (Wire "v6" 1)) +; $abc$84$auto$blifparse.cc:396:parse_blif$88.Y +(let v7 (Wire "v7" 1)) +; i_a +(let v8 (Wire "v8" 2)) +; i_b +(let v9 (Wire "v9" 2)) +; o_res +(let v10 (Wire "v10" 2)) + +; cells +(union v10 (Op2 (Concat) v7 v5)) +(union v4 (Op1 (Extract 0 0) v9)) +(union v6 (Op1 (Extract 1 1) v9)) +(union v0 (Op1 (Extract 0 0) v8)) +(union v2 (Op1 (Extract 1 1) v8)) +(union v1 (Op1 (Not) v0)) +(union v3 (Op1 (Not) v2)) +(union v5 (Op2 (Xor) v1 v4)) +(union v7 (Op2 (Xor) v3 v6)) + +; inputs +(let i_a (Var "i_a" 2)) +(IsPort "" "i_a" (Input) i_a) +(union v8 i_a) +(let i_b (Var "i_b" 2)) +(IsPort "" "i_b" (Input) i_b) +(union v9 i_b) + +; outputs +(let o_res v10) +(IsPort "" "o_res" (Output) o_res) + +; delete wire expressions +(delete (Wire "v0" 1)) +(delete (Wire "v1" 1)) +(delete (Wire "v2" 1)) +(delete (Wire "v3" 1)) +(delete (Wire "v4" 1)) +(delete (Wire "v5" 1)) +(delete (Wire "v6" 1)) +(delete (Wire "v7" 1)) +(delete (Wire "v8" 2)) +(delete (Wire "v9" 2)) +(delete (Wire "v10" 2)) + "#, + ) + .unwrap(); + + let serialized = egraph.serialize(SerializeConfig::default()); + let imap = AnythingExtractor.extract(&serialized, &[]); + + let out = to_rewrite_rule_egraph_serialize(&serialized, &imap, "ALU"); + println!("\n{out}\nend"); + assert_eq!( + r#"(rule + ;; set of definitions + ((= wire_46 (Op2 (Concat) wire_20 wire_16)) +(= wire_16 (Op2 (Xor) wire_8 wire_14)) +(= wire_14 i_b_0) +(= wire_8 (Op1 (Not) wire_6)) +(= wire_6 i_a_0) +(= wire_20 (Op2 (Xor) wire_12 wire_18)) +(= wire_18 i_b_1) +(= wire_12 (Op1 (Not) wire_10)) +(= wire_10 i_a_1) +) + ;; set of declarations + ( +(let i_a (Op2 (Concat) i_a_0 i_a_1)) +(let i_b (Op2 (Concat) i_b_0 i_b_1)) + +(let instance (ModuleInstance "ALU" (StringNil) (ExprNil) + (StringCons "i_a" (StringCons "i_b" (StringNil))) + (ExprCons i_a (ExprCons i_b (ExprNil))) + ) + )) :ruleset module_rewrites)"#, + out + ) + // println!("rule:\n {out}"); + } }