Skip to content

Commit

Permalink
Merge branch 'master' of github.com:leanprover/lean4 into smod-bitblast
Browse files Browse the repository at this point in the history
  • Loading branch information
luisacicolini committed Feb 10, 2025
2 parents 037b339 + d61f506 commit f302b07
Show file tree
Hide file tree
Showing 12 changed files with 524 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/Init/Data/Array/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3653,7 +3653,7 @@ theorem toListRev_toArray (l : List α) : l.toArray.toListRev = l.reverse := by
l.toArray.mapM f = List.toArray <$> l.mapM f := by
simp only [← mapM'_eq_mapM, mapM_eq_foldlM]
suffices ∀ init : Array β,
foldlM (fun bs a => bs.push <$> f a) init l.toArray = (init ++ toArray ·) <$> mapM' f l by
Array.foldlM (fun bs a => bs.push <$> f a) init l.toArray = (init ++ toArray ·) <$> mapM' f l by
simpa using this #[]
intro init
induction l generalizing init with
Expand Down
162 changes: 160 additions & 2 deletions src/Init/Data/Int/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Init.ByCases
import Init.Data.Prod
import Init.Data.Int.Lemmas
import Init.Data.Int.LemmasAux
import Init.Data.Int.DivModLemmas
import Init.Data.RArray

namespace Int.Linear
Expand Down Expand Up @@ -97,10 +98,34 @@ def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop
| .eq p => p.denote ctx = 0
| .le p => p.denote ctx ≤ 0

def Poly.div (k : Int) : Poly → Poly
| .num k' => .num (k'/k)
| .add k' x p => .add (k'/k) x (div k p)

def Poly.divAll (k : Int) : Poly → Bool
| .num k' => (k'/k)*k == k'
| .add k' _ p => (k'/k)*k == k' && divAll k p

def Poly.divCoeffs (k : Int) : Poly → Bool
| .num _ => true
| .add k' _ p => (k'/k)*k == k' && divCoeffs k p

def Poly.getConst : Poly → Int
| .num k => k
| .add _ _ p => getConst p

def PolyCnstr.norm : PolyCnstr → PolyCnstr
| .eq p => .eq p.norm
| .le p => .le p.norm

def PolyCnstr.divAll (k : Int) : PolyCnstr → Bool
| .eq p => p.divAll k
| .le p => p.divAll k

def PolyCnstr.div (k : Int) : PolyCnstr → PolyCnstr
| .eq p => .eq <| p.div k
| .le p => .le <| p.div k

inductive ExprCnstr where
| eq (p₁ p₂ : Expr)
| le (p₁ p₂ : Expr)
Expand All @@ -114,6 +139,10 @@ def ExprCnstr.toPoly : ExprCnstr → PolyCnstr
| .eq e₁ e₂ => .eq (e₁.sub e₂).toPoly.norm
| .le e₁ e₂ => .le (e₁.sub e₂).toPoly.norm

-- Certificate for normalizing the coefficients of a constraint
def divBy (e e' : ExprCnstr) (k : Int) : Bool :=
k > 0 && e.toPoly.divAll k && e'.toPoly == e.toPoly.div k

attribute [local simp] Int.add_comm Int.add_assoc Int.add_left_comm Int.add_mul Int.mul_add
attribute [local simp] Poly.insert Poly.denote Poly.norm Poly.addConst

Expand Down Expand Up @@ -143,7 +172,30 @@ private theorem sub_fold (a b : Int) : a.sub b = a - b := rfl
private theorem neg_fold (a : Int) : a.neg = -a := rfl

attribute [local simp] sub_fold neg_fold
attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly PolyCnstr.denote Expr.denote

attribute [local simp] Poly.div Poly.divAll PolyCnstr.denote

theorem Poly.denote_div_eq_of_divAll (ctx : Context) (p : Poly) (k : Int) : p.divAll k → (p.div k).denote ctx * k = p.denote ctx := by
induction p with
| num _ => simp
| add k' v p ih =>
simp; intro h₁ h₂
have ih := ih h₂
simp [ih]
apply congrArg (denote ctx p + ·)
rw [Int.mul_right_comm, h₁]

attribute [local simp] Poly.divCoeffs Poly.getConst

theorem Poly.denote_div_eq_of_divCoeffs (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → (p.div k).denote ctx * k + p.getConst % k = p.denote ctx := by
induction p with
| num k' => simp; rw [Int.add_comm, Int.mul_comm, Int.ediv_add_emod]
| add k' v p ih =>
simp; intro h₁ h₂
rw [← ih h₂]
rw [Int.mul_right_comm, h₁, Int.add_assoc]

attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly Expr.denote

theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
(toPoly'.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by
Expand Down Expand Up @@ -172,7 +224,7 @@ theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) :
theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by
simp [toPoly, toPoly', Expr.denote_toPoly'_go]

attribute [local simp] Expr.denote_toPoly
attribute [local simp] Expr.denote_toPoly PolyCnstr.denote

theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denote ctx = c.denote ctx := by
cases c <;> simp
Expand Down Expand Up @@ -214,6 +266,75 @@ theorem ExprCnstr.eq_of_toPoly_eq (ctx : Context) (c c' : ExprCnstr) (h : c.toPo
rw [denote_toPoly, denote_toPoly] at h
assumption

theorem ExprCnstr.eq_of_toPoly_eq_var (ctx : Context) (x y : Var) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.add (-1) y (.num 0))))
: c.denote ctx = (x.denote ctx = y.denote ctx) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
rw [h]; simp
rw [← Int.sub_eq_add_neg, Int.sub_eq_zero]

theorem ExprCnstr.eq_of_toPoly_eq_const (ctx : Context) (x : Var) (k : Int) (c : ExprCnstr) (h : c.toPoly == .eq (.add 1 x (.num (-k))))
: c.denote ctx = (x.denote ctx = k) := by
have h := congrArg (PolyCnstr.denote ctx) (eq_of_beq h)
rw [denote_toPoly] at h
rw [h]; simp
rw [Int.add_comm, ← Int.sub_eq_add_neg, Int.sub_eq_zero]

private theorem mul_eq_zero_iff_eq_zero (a b : Int) : b ≠ 0 → (a * b = 0 ↔ a = 0) := by
intro h
constructor
· intro h'
cases Int.mul_eq_zero.mp h'
· assumption
· contradiction
· intro; simp [*]

private theorem eq_mul_le_zero {a b : Int} : 0 < b → (a ≤ 0 ↔ a * b ≤ 0) := by
intro h
have : 0 = 0 * b := by simp
constructor
· intro h'
rw [this]
apply Int.mul_le_mul h' <;> try simp
apply Int.le_of_lt h
· intro h'
rw [this] at h'
exact Int.le_of_mul_le_mul_right h' h

attribute [local simp] PolyCnstr.divAll PolyCnstr.div

theorem ExprCnstr.eq_of_toPoly_eq_of_divBy' (ctx : Context) (e e' : ExprCnstr) (p : PolyCnstr) (k : Int) : k > 0 → p.divAll k → e.toPoly = p → e'.toPoly = p.div k → e.denote ctx = e'.denote ctx := by
intro h₀ h₁ h₂ h₃
have hz : k ≠ 0 := by intro h; simp [h] at h₀
cases p <;> simp at h₁
next p =>
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_1, ← h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_1, PolyCnstr.div] at h₃
rw [mul_eq_zero_iff_eq_zero _ _ hz] at h₂
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
exact this
next p =>
-- TODO: this is correct but we can simplify `p ≤ 0` if `p.divCoeffs k` and `p.getConst % k > 0`. Here, we are simplifying only the case `p.getConst % k = 0`
replace h₁ := Poly.denote_div_eq_of_divAll ctx p k h₁
replace h₂ := congrArg (PolyCnstr.denote ctx) h₂
simp only [PolyCnstr.denote.eq_2, ← h₁] at h₂
replace h₃ := congrArg (PolyCnstr.denote ctx) h₃
simp only [PolyCnstr.denote.eq_2, PolyCnstr.div] at h₃
rw [eq_mul_le_zero h₀] at h₃
have := Eq.trans h₂ h₃.symm
rw [denote_toPoly, denote_toPoly] at this
exact this

theorem ExprCnstr.eq_of_toPoly_eq_of_divBy (ctx : Context) (e e' : ExprCnstr) (k : Int) : divBy e e' k → e.denote ctx = e'.denote ctx := by
intro h
simp only [divBy, Bool.and_eq_true, bne_iff_ne, ne_eq, beq_iff_eq, decide_eq_true_eq] at h
have ⟨⟨h₁, h₂⟩, h₃⟩ := h
exact ExprCnstr.eq_of_toPoly_eq_of_divBy' ctx e e' e.toPoly k h₁ h₂ rfl h₃

def PolyCnstr.isUnsat : PolyCnstr → Bool
| .eq (.num k) => k != 0
| .eq _ => false
Expand All @@ -229,6 +350,43 @@ theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toP
rw [ExprCnstr.denote_toPoly] at this
assumption

def PolyCnstr.isUnsatCoeff (k : Int) : PolyCnstr → Bool
| .eq p => p.divCoeffs k && k > 0 && p.getConst % k > 0
| .le _ => false

private theorem contra {a b k : Int} (h₀ : 0 < k) (h₁ : 0 < b) (h₂ : b < k) (h₃ : a*k + b = 0) : False := by
have : b = -a*k := by
rw [← Int.neg_eq_of_add_eq_zero h₃, Int.neg_mul]
rw [this] at h₁ h₂
conv at h₂ => rhs; rw [← Int.one_mul k]
have high := Int.lt_of_mul_lt_mul_right h₂ (Int.le_of_lt h₀)
rw [← Int.zero_mul k] at h₁
have low := Int.lt_of_mul_lt_mul_right h₁ (Int.le_of_lt h₀)
replace low : 1 ≤ -a := low
have : (1 : Int) < 1 := Int.lt_of_le_of_lt low high
contradiction

private theorem PolyCnstr.eq_false (ctx : Context) (p : Poly) (k : Int) : p.divCoeffs k → k > 0 → p.getConst % k > 0 → (PolyCnstr.eq p).denote ctx = False := by
simp
intro h₁ h₂ h₃ h
have hnz : k ≠ 0 := by intro h; rw [h] at h₂; contradiction
have := Poly.denote_div_eq_of_divCoeffs ctx p k h₁
rw [h] at this
have low := h₃
have high := Int.emod_lt_of_pos p.getConst h₂
exact contra h₂ low high this

theorem ExprCnstr.eq_false_of_isUnsat_coeff (ctx : Context) (c : ExprCnstr) (k : Int) : c.toPoly.isUnsatCoeff k → c.denote ctx = False := by
intro h
cases c <;> simp [toPoly, PolyCnstr.isUnsatCoeff] at h
next e₁ e₂ =>
have ⟨⟨h₁, h₂⟩, h₃⟩ := h
have := PolyCnstr.eq_false ctx _ _ h₁ h₂ h₃
simp at this
simp
intro he
simp [he] at this

def PolyCnstr.isValid : PolyCnstr → Bool
| .eq (.num k) => k == 0
| .eq _ => false
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/List/Control.lean
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ foldlM f x₀ [a, b, c] = do
```
-/
@[specialize]
protected def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (f : s → α → m s) → (init : s) → List α → m s
def foldlM {m : Type u → Type v} [Monad m] {s : Type u} {α : Type w} : (f : s → α → m s) → (init : s) → List α → m s
| _, s, [] => pure s
| f, s, a :: as => do
let s' ← f s a
Expand Down
29 changes: 20 additions & 9 deletions src/Lean/Meta/Tactic/LinearArith/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,24 @@ we abstract over them.
-/
def withAbstractAtoms (atoms : Array Expr) (type : Name) (k : Array Expr → MetaM (Option (Expr × Expr))) :
MetaM (Option (Expr × Expr)) := do
let atoms := atoms
let decls : Array (Name × (Array Expr → MetaM Expr)) ← atoms.mapM fun _ => do
return ((← mkFreshUserName `x), fun _ => pure (mkConst type))
withLocalDeclsD decls fun ctxt => do
let some (r, p) ← k ctxt | return none
let r := (← mkLambdaFVars ctxt r).beta atoms
let p := mkAppN (← mkLambdaFVars ctxt p) atoms
return some (r, p)
let type := mkConst type
let rec go (i : Nat) (atoms' : Array Expr) (xs : Array Expr) (args : Array Expr) : MetaM (Option (Expr × Expr)) := do
if h : i < atoms.size then
let atom := atoms[i]
if atom.isFVar then
go (i+1) (atoms'.push atom) xs args
else
withLocalDeclD (← mkFreshUserName `x) type fun x =>
go (i+1) (atoms'.push x) (xs.push x) (args.push atom)
else
if xs.isEmpty then
k atoms'
else
let some (r, p) ← k atoms' | return none
let r := (← mkLambdaFVars xs r).beta args
let p := mkAppN (← mkLambdaFVars xs p) args
return some (r, p)
go 0 #[] #[] #[]

/-- Quick filter for linear terms. -/
def isLinearTerm (e : Expr) : Bool :=
Expand All @@ -30,7 +40,8 @@ def isLinearTerm (e : Expr) : Bool :=
false
else
let n := f.constName!
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Nat.succ
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Neg.neg || n == ``Nat.succ
|| n == ``Add.add || n == ``Mul.mul || n == ``Sub.sub

/-- Quick filter for linear constraints. -/
partial def isLinearCnstr (e : Expr) : Bool :=
Expand Down
29 changes: 16 additions & 13 deletions src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ def addAsVar (e : Expr) : M LinearExpr := do
set { varMap := (← s.varMap.insert e x), vars := s.vars.push e : State }
return var x

private def toInt? (e : Expr) : MetaM (Option Int) := do
let_expr OfNat.ofNat _ n i ← e | return none
unless (← isInstOfNatInt i) do return none
let some n ← evalNat n |>.run | return none
return some (Int.ofNat n)

partial def toLinearExpr (e : Expr) : M LinearExpr := do
match e with
| .mdata _ e => toLinearExpr e
Expand All @@ -119,14 +113,14 @@ partial def toLinearExpr (e : Expr) : M LinearExpr := do
where
visit (e : Expr) : M LinearExpr := do
let mul (a b : Expr) := do
match (← toInt? a) with
match (← getIntValue? a) with
| some k => return .mulL k (← toLinearExpr b)
| none => match (← toInt? b) with
| none => match (← getIntValue? b) with
| some k => return .mulR (← toLinearExpr a) k
| none => addAsVar e
match_expr e with
| OfNat.ofNat _ n i =>
if (← isInstOfNatInt i) then toLinearExpr n
| OfNat.ofNat _ _ _ =>
if let some n ← getIntValue? e then return .num n
else addAsVar e
| Int.neg a => return .neg (← toLinearExpr a)
| Neg.neg _ i a =>
Expand All @@ -144,7 +138,7 @@ where
if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
else addAsVar e
| HSub.hSub _ _ _ i a b =>
if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
if (← isInstHSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b)
else addAsVar e
| Int.mul a b => mul a b
| Mul.mul _ i a b =>
Expand All @@ -159,13 +153,22 @@ partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do
match_expr e with
| Eq α a b =>
let_expr Int ← α | failure
return .eq (← toLinearExpr a) (← toLinearExpr b)
let a ← toLinearExpr a
let b ← toLinearExpr b
match a, b with
/-
We do not want to convert `x = y` into `x + -1*y = 0`.
Similarly, we don't want to convert `x = 3` into `x + -3 = 0`.
`grind` and other tactics have better support for this kind of equalities.
-/
| .var _, .var _ | .var _, .num _ | .num _, .var _ => failure
| _, _ => return .eq a b
| Int.le a b =>
return .le (← toLinearExpr a) (← toLinearExpr b)
| Int.lt a b =>
return .le (.add (← toLinearExpr a) (.num 1)) (← toLinearExpr b)
| LE.le _ i a b =>
guard (← isInstLENat i)
guard (← isInstLEInt i)
return .le (← toLinearExpr a) (← toLinearExpr b)
| LT.lt _ i a b =>
guard (← isInstLTInt i)
Expand Down
Loading

0 comments on commit f302b07

Please sign in to comment.