Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BitVec.umulOverflow and BitVec.smulOverflow definitions and additional theorems #6949

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Init/Data/BitVec/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,16 @@ def reverse : {w : Nat} → BitVec w → BitVec w
| 0, x => x
| w + 1, x => concat (reverse (x.truncate w)) (x.msb)

/-- Overflow predicate for unsigned multiplication modulo 2^w.
SMT-Lib name: `bvumulo`.
-/

def umulOverflow {w : Nat} (x y : BitVec w) : Bool := x.toNat * y.toNat ≥ 2 ^ w

/-- Overflow predicate for signed multiplication on w-bit 2's complement.
SMT-Lib name: `bvsmulo`.
-/

def smulOverflow {w : Nat} (x y : BitVec w) : Bool := (x.toInt * y.toInt ≥ 2 ^ (w - 1)) || (x.toInt * y.toInt < - 2 ^ (w - 1))

end BitVec
31 changes: 31 additions & 0 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,37 @@ theorem saddOverflow_eq {w : Nat} (x y : BitVec w) :
simp
omega

theorem umulOverflow_eq {w : Nat} (x y : BitVec w) :
umulOverflow x y =
(0 < w && BitVec.twoPow (w * 2) w ≤ x.zeroExtend (w * 2) * y.zeroExtend (w * 2)) := by
simp only [umulOverflow, ge_iff_le, truncate_eq_setWidth, le_def, toNat_twoPow, toNat_mul,
toNat_setWidth, Nat.mul_mod_mod, Nat.mod_mul_mod]
rcases w with _|w'
· decide +revert
· simp only [show 0 < w' + 1 by omega, decide_true, Bool.true_and, decide_eq_decide]
rw [Nat.mod_eq_of_lt BitVec.toNat_mul_toNat_lt, Nat.mod_eq_of_lt]
have := Nat.pow_lt_pow_of_lt (a := 2) (n := (w' + 1)) (m := (w' + 1) * 2)
omega

theorem smulOverflow_eq {w : Nat} (x y : BitVec w) :
smulOverflow x y
= (w ≠ 0) || (((BitVec.intMin w).signExtend (w * 2)).sle (x.signExtend (w * 2) * y.signExtend (w * 2)) &&
(x.signExtend (w * 2) * y.signExtend (w * 2)).sle ((BitVec.intMax w).signExtend (w * 2))) := by
simp only [smulOverflow]
rcases w with _|w
· decide +revert
· simp only [Bool.false_or, BitVec.intMin, BitVec.intMax, BitVec.sle, BitVec.toInt_mul,
decide_eq_true_eq, BitVec.ofNat_eq_ofNat]
have := Int.pow_lt_pow (a := 2) (b := ((w + 1) * 2 - 2)) (c := ((w + 1) * 2 - 1))
have := @BitVec.le_toInt_mul_toInt (w + 1) x y
have hlb : -((2 ^ ((w + 1) * 2 - 1) : Nat) * 2) ≤ x.toInt * y.toInt * 2 := by push_cast; omega
have := @BitVec.toInt_mul_toInt_lt (w + 1) x y
have hub : x.toInt * y.toInt * 2 < ((2 ^ ((w + 1) * 2 - 1): Nat) * 2) := Int.mul_lt_mul_of_pos_right (by norm_cast at *; omega) (by omega)
rw [BitVec.toInt_signExtend_of_lt (by omega), BitVec.toInt_signExtend_of_lt (by omega),
BitVec.toInt_signExtend_of_lt (by omega), BitVec.toInt_signExtend_of_lt (by omega), BitVec.toInt_twoPow_of_eq (by omega), ←Nat.two_pow_pred_add_two_pow_pred (by omega)]
simp only [← Nat.mul_two, Int.bmod_eq_of_le_of_lt hlb hub, BitVec.toInt_twoPow_sub_one, or_eq_true, decide_eq_true_eq, _root_.eq_iff_iff, and_eq_true]
omega

/- ### umod -/

theorem getElem_umod {n d : BitVec w} (hi : i < w) :
Expand Down
75 changes: 75 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Init.Data.Nat.Mod
import Init.Data.Nat.Div.Lemmas
import Init.Data.Int.Bitwise.Lemmas
import Init.Data.Int.Pow
import Init.Data.Int.LemmasAux

set_option linter.missingDocs true

Expand Down Expand Up @@ -301,6 +302,7 @@ theorem getLsbD_ofNat (n : Nat) (x : Nat) (i : Nat) :

@[simp] theorem getMsbD_zero : (0#w).getMsbD i = false := by simp [getMsbD]


@[simp] theorem getLsbD_one : (1#w).getLsbD i = (decide (0 < w) && decide (i = 0)) := by
simp only [getLsbD, toNat_ofNat, Nat.testBit_mod_two_pow]
by_cases h : i = 0
Expand Down Expand Up @@ -497,6 +499,14 @@ theorem toInt_eq_toNat_bmod (x : BitVec n) : x.toInt = Int.bmod x.toNat (2^n) :=
rw [Int.bmod_neg] <;> simp only [←Int.ofNat_emod, toNat_mod_cancel]
omega

@[simp] theorem toInt_one_of_lt {w : Nat} (h : 1 < w) : (1#w).toInt = 1 := by
rw [toInt_eq_msb_cond]
simp only [msb_one, show w ≠ 1 by omega, decide_false, Bool.false_eq_true, ↓reduceIte,
toNat_ofNat, Int.ofNat_emod, Int.Nat.cast_ofNat_Int]
norm_cast
apply Nat.mod_eq_of_lt
apply Nat.one_lt_two_pow (by omega)

/-- Prove equality of bitvectors in terms of nat operations. -/
theorem eq_of_toInt_eq {x y : BitVec n} : x.toInt = y.toInt → x = y := by
intro eq
Expand Down Expand Up @@ -3682,6 +3692,67 @@ theorem udiv_twoPow_eq_of_lt {w : Nat} {x : BitVec w} {k : Nat} (hk : k < w) : x
have : 2^k < 2^w := Nat.pow_lt_pow_of_lt (by decide) hk
simp [bv_toNat, Nat.shiftRight_eq_div_pow, Nat.mod_eq_of_lt this]

theorem toInt_twoPow {w i : Nat} :
(BitVec.twoPow w i).toInt = if w ≤ i then (0 : Int) else (if i + 1 = w then -(1 <<< i:Int) else (1 <<< i)) := by
simp only [BitVec.twoPow, BitVec.toInt]
rcases w with _|w'
· simp
· by_cases h : w' + 1 ≤ i
· simp [h]; norm_cast; omega
· simp only [toNat_shiftLeft, toNat_ofNat, Nat.zero_lt_succ, Nat.one_mod_two_pow, Int.ofNat_emod, h,
↓reduceIte, Nat.add_right_cancel_iff]
have hy : (2 ^ i % 2 ^ (w' + 1)) = 2 ^ i := by rw [Nat.mod_eq_of_lt (by rw [Nat.pow_lt_pow_iff_right (by omega)]; omega)]
have hj : 2 * 2 ^ i = 2 ^ (i + 1) := by rw [Nat.pow_add, Nat.mul_comm]
norm_cast
simp only [Nat.shiftLeft_eq, Nat.one_mul, hy, hj]
by_cases i + 1 = (w' + 1)
· simp only [show i = w' by omega, Nat.lt_irrefl, ↓reduceIte]; omega
· simp only [show 2 ^ (i + 1) < 2 ^ (w' + 1) by rw [Nat.pow_lt_pow_iff_right (by omega)]; omega, ↓reduceIte]
omega

theorem toInt_twoPow_of_eq {w i : Nat} (h : i + 1 = w) :
(BitVec.twoPow w i).toInt = -(2 ^ i) := by
simp only [toInt_twoPow, show ¬(w ≤ i) by omega, ↓reduceIte, h, Nat.shiftLeft_eq, Nat.one_mul]
norm_cast

theorem toInt_mul_toInt_lt {x y : BitVec w} : x.toInt * y.toInt ≤ 2 ^ (w * 2 - 2) := by
rcases w with _|w
· simp [of_length_zero]
· have xlt := toInt_lt (x := x); have xle := le_toInt (x := x)
have ylt := toInt_lt (x := y); have yle := le_toInt (x := y)
have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by
rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := ((w + 1) - 1)), Nat.mul_sub_one, Nat.mul_comm]
rw_mod_cast [h]
rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle
exact Int.mul_le_mul_self (by omega) (by omega)

theorem le_toInt_mul_toInt {x y : BitVec w} : - (2 ^ (w * 2 - 2)) ≤ x.toInt * y.toInt := by
rcases w with _|w
· simp [of_length_zero]
· have xlt := toInt_lt (x := x); have xle := le_toInt (x := x)
have ylt := toInt_lt (x := y); have yle := le_toInt (x := y)
have h : 2 ^ ((w + 1) * 2 - 2) = 2 ^ ((w + 1) - 1) * 2 ^ ((w + 1) - 1) := by
rw [← Nat.pow_add, ←Nat.mul_two, Nat.mul_comm (m := 2) (n := ((w + 1) - 1)), Nat.mul_sub_one, Nat.mul_comm]
rw_mod_cast [h]
rw [← Nat.two_pow_pred_mul_two (by omega), Int.natCast_mul] at xlt ylt xle yle
exact Int.neg_mul_self_le_mul (by omega) (by omega) (by omega) (by omega)

@[simp]
theorem toInt_twoPow_sub_one : (BitVec.twoPow w (w - 1) - 1#w).toInt = 2 ^ (w - 1) - 1 := by
rcases w with _|_|w
· decide
· decide
· have : 1 < 2 ^ (w + 1 + 1) := Nat.one_lt_two_pow (by omega)
rw_mod_cast [BitVec.twoPow, BitVec.toInt_sub, BitVec.toInt_shiftLeft, BitVec.toNat_ofNat,
Int.bmod_sub_bmod_congr, toInt_one_of_lt (by omega), Nat.shiftLeft_eq,
Nat.mod_eq_of_lt (by omega), Int.bmod_eq_of_le_of_lt]
simp only [Nat.add_one_sub_one, Nat.one_mul]
· have : 0 < (2 ^ (w + 1 + 1 - 1) - 1) * 2 := by simp; omega
norm_cast
omega
· rw [← Nat.two_pow_pred_add_two_pow_pred (w := w + 1 + 1) (by omega)]
omega

/- ### cons -/

@[simp] theorem true_cons_zero : cons true 0#w = twoPow (w + 1) w := by
Expand Down Expand Up @@ -3959,6 +4030,10 @@ theorem toNat_mul_of_lt {w} {x y : BitVec w} (h : x.toNat * y.toNat < 2^w) :
(x * y).toNat = x.toNat * y.toNat := by
rw [BitVec.toNat_mul, Nat.mod_eq_of_lt h]

theorem toNat_mul_toNat_lt {x y : BitVec w} : x.toNat * y.toNat < 2 ^ (w * 2) := by
have := BitVec.isLt x; have := BitVec.isLt y
simp only [Nat.mul_two, Nat.pow_add]
exact Nat.mul_lt_mul_of_le_of_lt (by omega) (by omega) (by omega)

/--
`x ≤ y + z` if and only if `x - z ≤ y`
Expand Down
7 changes: 7 additions & 0 deletions src/Init/Data/Int/DivModLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,9 @@ theorem lt_mul_ediv_self_add {x k : Int} (h : 0 < k) : x < k * (x / k) + k :=
@[simp] theorem add_emod_self_left {a b : Int} : (a + b) % a = b % a := by
rw [Int.add_comm, Int.add_emod_self]

theorem Int.emod_eq_add_self_emod {a b : Int} : a % b = (a + b) % b :=
Int.add_emod_self.symm

theorem neg_emod {a b : Int} : -a % b = (b - a) % b := by
rw [← add_emod_self_left]; rfl

Expand Down Expand Up @@ -820,6 +823,10 @@ theorem ediv_eq_ediv_of_mul_eq_mul {a b c d : Int}
Int.ediv_eq_of_eq_mul_right H3 <| by
rw [← Int.mul_ediv_assoc _ H2]; exact (Int.ediv_eq_of_eq_mul_left H4 H5.symm).symm

theorem mul_le_mul_neg {a b c d : Int}
(hac : a ≤ c) (hbd : d ≤ b) (hb : 0 ≤ b) (hc : c ≤ 0) : a * b ≤ c * d :=
Int.le_trans (Int.mul_le_mul_of_nonneg_right hac hb) (Int.mul_le_mul_of_nonpos_left hc hbd)

/-! ### tdiv -/

@[simp] protected theorem tdiv_one : ∀ a : Int, a.tdiv 1 = a
Expand Down
28 changes: 28 additions & 0 deletions src/Init/Data/Int/LemmasAux.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,32 @@ theorem bmod_neg_iff {m : Nat} {x : Int} (h2 : -m ≤ x) (h1 : x < m) :
· rw [Int.emod_eq_of_lt xpos (by omega)]; omega
· rw [Int.add_emod_self.symm, Int.emod_eq_of_lt (by omega) (by omega)]; omega

theorem bmod_eq_of_le_of_lt {x : Int} {y : Nat} (hge : -y ≤ x * 2) (hlt : x * 2 < y) :
x.bmod y = x := by
simp only [Int.bmod_def]
rcases x
· rw [Int.emod_eq_of_lt (by simp only [ofNat_eq_coe]; omega) (by omega)]; omega
· rw [Int.emod_eq_add_self_emod, Int.emod_eq_of_lt (by omega) (by omega)]; omega

theorem mul_le_mul_self {x y : Int} {s : Nat} (hx : x.natAbs ≤ s) (hy : y.natAbs ≤ s) :
x * y ≤ s * s := by
rcases s with _|s
· simp [show x = 0 by omega]
· have := Nat.mul_pos (n := (s + 1)) (m := (s + 1)) (by omega) (by omega)
by_cases hx : 0 < x <;> by_cases hy : 0 < y
· exact Int.mul_le_mul (by omega) (by omega) (by omega) (by omega)
· have : x * y ≤ 0 := Int.mul_nonpos_of_nonneg_of_nonpos (by omega) (by omega); omega
· have : x * y ≤ 0 := Int.mul_nonpos_of_nonpos_of_nonneg (by omega) (by omega); omega
· have : -x * -y ≤ (s + 1) * (s + 1) := Int.mul_le_mul (by omega) (by omega) (by omega) (by omega)
simp_all

theorem neg_mul_self_le_mul {x y : Int} {s : Nat} (lbx : -s ≤ x) (ubx : x < s) (lby : -s ≤ y) (uby : y < s) :
-(s * s) ≤ x * y := by
have := Nat.mul_pos (n := s) (m := s) (by omega) (by omega)
by_cases 0 ≤ x <;> by_cases 0 ≤ y
· have : 0 ≤ x * y := Int.mul_nonneg (by omega) (by omega); omega
· rw [← Int.neg_mul, Int.mul_comm (a := x)]; exact Int.mul_le_mul_neg (by omega) (by omega) (by omega) (by omega)
· rw [← Int.neg_mul]; exact Int.mul_le_mul_neg (by omega) (by omega) (by omega) (by omega)
· have : 0 < x * y := Int.mul_pos_of_neg_of_neg (by omega) (by omega); omega

end Int
7 changes: 7 additions & 0 deletions src/Init/Data/Int/Pow.lean
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,11 @@ protected theorem two_pow_pred_sub_two_pow' {w : Nat} (h : 0 < w) :
rw [← Nat.two_pow_pred_add_two_pow_pred h]
simp [h]

theorem pow_lt_pow {a : Int} {b c : Nat} (ha : (1 : Nat) < a) (hbc : b < c):
a ^ b < a ^ c := by
rw [← Int.toNat_of_nonneg (a := a) (by omega), ← Int.natCast_pow, ← Int.natCast_pow]
have := Nat.pow_lt_pow_of_lt (a := a.toNat) (m := c) (n := b)
simp only [Int.ofNat_lt]
omega

end Int
3 changes: 3 additions & 0 deletions src/Init/Data/Nat/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,9 @@ theorem mul_mod (a b n : Nat) : a * b % n = (a % n) * (b % n) % n := by
Nat.mul_assoc, Nat.mul_assoc, ← Nat.mul_add n, add_mul_mod_self_left,
Nat.mul_comm _ (n * (b / n)), Nat.mul_assoc, add_mul_mod_self_left]

theorem mul_mod_mod (a b c : Nat) : (a * (b % c)) % c = a * b % c := by
rw [mul_mod, mod_mod, ← mul_mod]

@[simp] theorem mod_add_mod (m n k : Nat) : (m % n + k) % n = (m + k) % n := by
have := (add_mul_mod_self_left (m % n + k) n (m / n)).symm
rwa [Nat.add_right_comm, mod_add_div] at this
Expand Down
3 changes: 3 additions & 0 deletions src/Std/Tactic/BVDecide/Normalize/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ attribute [bv_normalize] BitVec.umod_zero
attribute [bv_normalize] BitVec.umod_one
attribute [bv_normalize] BitVec.umod_eq_and

attribute [bv_normalize] BitVec.umulOverflow_eq
attribute [bv_normalize] BitVec.smulOverflow_eq

attribute [bv_normalize] BitVec.saddOverflow_eq
attribute [bv_normalize] BitVec.uaddOverflow_eq

Expand Down
Loading