From 0cd4098956bcdf73e65ea12ee5be792bc4e69f61 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Mon, 12 Aug 2024 14:20:26 +0200 Subject: [PATCH] Recycle more xmm registers at the same time --- src/evaluate.rs | 392 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 281 insertions(+), 111 deletions(-) diff --git a/src/evaluate.rs b/src/evaluate.rs index 0f05886..d916deb 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -1268,149 +1268,320 @@ impl ExpressionEvaluator { }; } - let mut in_asm_block = false; - let mut regcount = 0; - - fn reg_unused(reg: usize, instr: &[Instr], out: &[usize]) -> bool { - if out.contains(®) { - return false; - } + let mut reg_last_use = vec![self.instructions.len(); self.instructions.len()]; + let mut stack_to_reg = HashMap::default(); - for ins in instr { - match ins { - Instr::Add(r, a) | Instr::Mul(r, a) => { - if a.iter().any(|x| *x == reg) { - return false; + for (i, ins) in instr.iter().enumerate() { + match ins { + Instr::Add(r, a) | Instr::Mul(r, a) => { + for x in a { + if x >= &self.reserved_indices { + reg_last_use[stack_to_reg[x]] = i; } + } - if r == ® { - return true; - } + stack_to_reg.insert(r, i); + } + Instr::Pow(r, b, _) => { + if b >= &self.reserved_indices { + reg_last_use[stack_to_reg[b]] = i; } - Instr::Pow(r, b, _) => { - if *b == reg { - return false; - } - if r == ® { - return true; - } + stack_to_reg.insert(r, i); + } + Instr::Powf(r, b, e) => { + if b >= &self.reserved_indices { + reg_last_use[stack_to_reg[b]] = i; } - Instr::Powf(r, b, e) => { - if *b == reg || *e == reg { - return false; - } - if r == ® { - return true; - } + if e >= &self.reserved_indices { + reg_last_use[stack_to_reg[e]] = i; } - Instr::BuiltinFun(r, _, b) => { - if *b == reg { - return false; - } - if r == ® { - return true; - } + stack_to_reg.insert(r, i); + } + Instr::BuiltinFun(r, _, b) => { + if b >= &self.reserved_indices { + reg_last_use[stack_to_reg[b]] = i; } + stack_to_reg.insert(r, i); } } + } - true + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum MemOrReg { + Mem(usize), + Reg(usize), } - let mut recycle_register: (Option<(usize, u32)>, Option<(usize, u32)>) = (None, None); // old and current register - for (i, ins) in instr.iter().enumerate() { - // keep results in xmm registers if the last use is in the next instruction - if let Some(ii) = instr.get(i + 1) { - match ins { - Instr::Add(r, _) | Instr::Mul(r, _) => match ii { - Instr::Add(j, _) | Instr::Mul(j, _) => { - if r == j || reg_unused(*r, &instr[i + 2..], &self.result_indices) { - if let Some(old) = recycle_register.0 { - recycle_register.1 = Some((*r, old.1)); - } else { - recycle_register.1 = Some((*r, regcount)); + #[derive(Debug, Clone)] + enum RegInstr { + Add(MemOrReg, u16, Vec), + Mul(MemOrReg, u16, Vec), + Pow(usize, usize, i64), + Powf(usize, usize, usize), + BuiltinFun(usize, Symbol, usize), + } + + let mut new_instr: Vec = instr + .iter() + .map(|i| match i { + Instr::Add(r, a) => RegInstr::Add( + MemOrReg::Mem(*r), + u16::MAX, + a.iter().map(|x| MemOrReg::Mem(*x)).collect(), + ), + Instr::Mul(r, a) => RegInstr::Mul( + MemOrReg::Mem(*r), + u16::MAX, + a.iter().map(|x| MemOrReg::Mem(*x)).collect(), + ), + Instr::Pow(r, b, e) => RegInstr::Pow(*r, *b, *e), + Instr::Powf(r, b, e) => RegInstr::Powf(*r, *b, *e), + Instr::BuiltinFun(r, s, a) => RegInstr::BuiltinFun(*r, *s, *a), + }) + .collect(); + + for i in 0..40 { + for (j, last_use) in reg_last_use.iter().enumerate() { + if *last_use - j == i { + if *last_use == self.instructions.len() { + continue; + } + + let old_reg = + if let RegInstr::Add(r, _, _) | RegInstr::Mul(r, _, _) = &new_instr[j] { + if let MemOrReg::Mem(r) = r { + *r + } else { + continue; + } + } else { + continue; + }; + + // find free registers in the range + // start at j+1 as we can recycle registers that are last used in iteration j + let mut free_regs = u16::MAX; + + for k in &new_instr[j + 1..=*last_use] { + match k { + RegInstr::Add(_, f, _) | RegInstr::Mul(_, f, _) => { + free_regs &= f; + } + _ => { + free_regs = 0; // the current instruction is not allowed to be used outside of ASM blocks + } + } + } + + if free_regs == 0 { + continue; + } + + for k in 0..15 { + if free_regs & (1 << k) != 0 { + if let RegInstr::Add(r, _, _) | RegInstr::Mul(r, _, _) = + &mut new_instr[j] + { + *r = MemOrReg::Reg(k); + } + + for l in &mut new_instr[j + 1..=*last_use] { + match l { + RegInstr::Add(_, f, a) | RegInstr::Mul(_, f, a) => { + *f &= !(1 << k); // FIXME: do not set on last use? + for x in a { + if *x == MemOrReg::Mem(old_reg) { + *x = MemOrReg::Reg(k); + } + } + } + RegInstr::Pow(_, a, _) => { + if *a == old_reg { + panic!("use outside of ASM block"); + } + } + RegInstr::Powf(_, a, b) => { + if *a == old_reg { + panic!("use outside of ASM block"); + } + if *b == old_reg { + panic!("use outside of ASM block"); + } + } + RegInstr::BuiltinFun(_, _, a) => { + if *a == old_reg { + panic!("use outside of ASM block"); + } + } } } + + break; } - _ => {} - }, - _ => {} + } } } + } + //for (i, x) in new_instr.iter().enumerate() { + // println!("{} {:?}", i, x); + //} + + let mut in_asm_block = false; + for ins in &new_instr { match ins { - Instr::Add(o, a) | Instr::Mul(o, a) => { + RegInstr::Add(o, free, a) | RegInstr::Mul(o, free, a) => { if !in_asm_block { *out += "\t__asm__(\n"; in_asm_block = true; } - let oper = if matches!(ins, Instr::Add(_, _)) { + let oper = if matches!(ins, RegInstr::Add(_, _, _)) { "add" } else { "mul" }; - if let Some(old) = recycle_register.0 { - assert!(a.iter().any(|rr| *rr == old.0)); // the last value must be used + match o { + MemOrReg::Reg(out_reg) => { + if let Some(j) = a.iter().find(|x| **x == MemOrReg::Reg(*out_reg)) { + // we can recycle the register completely + for i in a { + if i != j { + match i { + MemOrReg::Reg(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, xmm{}\\n\\t\"\n", + oper, out_reg, k + ); + } + MemOrReg::Mem(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", + oper, + out_reg, + format_addr!(*k) + ); + } + } + } + } + } else if let Some(MemOrReg::Reg(j)) = + a.iter().find(|x| matches!(x, MemOrReg::Reg(_))) + { + *out += &format!("\t\t\"movapd xmm{}, xmm{}\\n\\t\"\n", out_reg, j); + + for i in a { + if *i != MemOrReg::Reg(*j) { + match i { + MemOrReg::Reg(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, xmm{}\\n\\t\"\n", + oper, out_reg, k + ); + } + MemOrReg::Mem(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", + oper, + out_reg, + format_addr!(*k) + ); + } + } + } + } + } else { + if let MemOrReg::Mem(k) = &a[0] { + *out += &format!( + "\t\t\"movsd xmm{}, QWORD {}\\n\\t\"\n", + out_reg, + format_addr!(*k) + ); + } else { + unreachable!(); + } - for i in a { - if *i != old.0 { - *out += &format!( - "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", - oper, - old.1, - format_addr!(*i) - ); + for i in &a[1..] { + if let MemOrReg::Mem(k) = i { + *out += &format!( + "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", + oper, + out_reg, + format_addr!(*k) + ); + } + } } } + MemOrReg::Mem(out_mem) => { + // we need to find a free temporary register + if *free == 0 { + panic!("no free registers"); + // we can move the value of xmm0 into the memory location of the output register + // and then swap later + } - if recycle_register.1.is_none() { - *out += &format!( - "\t\t\"movsd QWORD {}, xmm{}\\n\\t\"\n", - format_addr!(*o), - old.1, - ); - } - } else if let Some(new) = recycle_register.1 { - *out += &format!( - "\t\t\"movsd xmm{}, QWORD {}\\n\\t\"\n", - new.1, - format_addr!(a[0]) - ); - - for i in &a[1..] { - *out += &format!( - "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", - oper, - new.1, - format_addr!(*i) - ); - } - } else { - *out += &format!( - "\t\t\"movsd xmm{}, QWORD {}\\n\\t\"\n", - regcount, - format_addr!(a[0]) - ); + if let Some(out_reg) = (0..15).position(|k| free & (1 << k) != 0) { + if let Some(MemOrReg::Reg(j)) = + a.iter().find(|x| matches!(x, MemOrReg::Reg(_))) + { + *out += + &format!("\t\t\"movapd xmm{}, xmm{}\\n\\t\"\n", out_reg, j); + + for i in a { + if *i != MemOrReg::Reg(*j) { + match i { + MemOrReg::Reg(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, xmm{}\\n\\t\"\n", + oper, out_reg, k + ); + } + MemOrReg::Mem(k) => { + *out += &format!( + "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", + oper, + out_reg, + format_addr!(*k) + ); + } + } + } + } + } else { + if let MemOrReg::Mem(k) = &a[0] { + *out += &format!( + "\t\t\"movsd xmm{}, QWORD {}\\n\\t\"\n", + out_reg, + format_addr!(*k) + ); + } else { + unreachable!(); + } + + for i in &a[1..] { + if let MemOrReg::Mem(k) = i { + *out += &format!( + "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", + oper, + out_reg, + format_addr!(*k) + ); + } + } + } - for i in &a[1..] { - *out += &format!( - "\t\t\"{}sd xmm{}, QWORD {}\\n\\t\"\n", - oper, - regcount, - format_addr!(*i) - ); + *out += &format!( + "\t\t\"movsd QWORD {}, xmm{}\\n\\t\"\n", + format_addr!(*out_mem), + out_reg + ); + } } - - *out += &format!( - "\t\t\"movsd QWORD {}, xmm{}\\n\\t\"\n", - format_addr!(*o), - regcount, - ); } } - Instr::Pow(o, b, e) => { + RegInstr::Pow(o, b, e) => { if *e == -1 { if !in_asm_block { *out += "\t__asm__(\n"; @@ -1421,7 +1592,7 @@ impl ExpressionEvaluator { "\t\t\"movsd xmm{0}, QWORD PTR [%1+{1}]\\n\\t\" \t\t\"divsd xmm{0}, QWORD {2}\\n\\t\" \t\t\"movsd QWORD {3}, xmm{0}\\n\\t\"\n", - regcount, + 0, (self.reserved_indices - self.param_count) * 8, format_addr!(*b), format_addr!(*o) @@ -1433,14 +1604,14 @@ impl ExpressionEvaluator { *out += format!("\tZ[{}] = pow({}, {});\n", o, base, e).as_str(); } } - Instr::Powf(o, b, e) => { + RegInstr::Powf(o, b, e) => { end_asm_block!(in_asm_block); let base = get_input!(*b); let exp = get_input!(*e); *out += format!("\tZ[{}] = pow({}, {});\n", o, base, exp).as_str(); } - Instr::BuiltinFun(o, s, a) => { + RegInstr::BuiltinFun(o, s, a) => { end_asm_block!(in_asm_block); let arg = get_input!(*a); @@ -1465,12 +1636,11 @@ impl ExpressionEvaluator { } } } - - recycle_register.0 = recycle_register.1.take(); } end_asm_block!(in_asm_block); + let mut regcount = 0; *out += "\t__asm__(\n"; for (i, r) in &mut self.result_indices.iter().enumerate() { if *r < self.param_count {