Skip to content

Commit

Permalink
Explicitly detect support for lambda expression support
Browse files Browse the repository at this point in the history
  • Loading branch information
LeventErkok committed Nov 14, 2024
1 parent fc2c8a8 commit 6275389
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 34 deletions.
14 changes: 12 additions & 2 deletions Data/SBV/Core/Symbolic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ module Data.SBV.Core.Symbolic
, getTableIndex, sObserve
, SBVPgm(..), MonadSymbolic(..), SymbolicT, Symbolic, runSymbolic, mkNewState, runSymbolicInState, State(..), SMTDef(..), smtDefGivenName, withNewIncState, IncState(..), incrementInternalCounter
, inSMTMode, SBVRunMode(..), IStage(..), Result(..), ResultInp(..), UICodeKind(..)
, registerKind, registerLabel, recordObservable
, registerKind, registerLabel, registerSpecialFunction, recordObservable
, addAssertion, addNewSMTOption, imposeConstraint, internalConstraint, internalVariable, lambdaVar, quantVar
, SMTLibPgm(..), SMTLibVersion(..), smtLibVersionExtension
, SolverCapabilities(..)
Expand Down Expand Up @@ -872,11 +872,12 @@ instance NFData ResultInp where
data ProgInfo = ProgInfo { hasQuants :: Bool
, progSpecialRels :: [SpecialRelOp]
, progTransClosures :: [(String, String)]
, progSpecialFuncs :: [Op] -- functions that need to be generated, like list reverse/all/any/filter
}
deriving G.Data

instance NFData ProgInfo where
rnf (ProgInfo a b c) = rnf a `seq` rnf b `seq` rnf c
rnf (ProgInfo a b c d) = rnf a `seq` rnf b `seq` rnf c `seq` rnf d

deriving instance G.Data CallStack
deriving instance G.Data SrcLoc
Expand Down Expand Up @@ -1528,6 +1529,13 @@ registerLabel whence st nm

where err w = error $ "SBV (" ++ whence ++ "): " ++ show nm ++ " " ++ w

-- We need to auto-generate certain functions, so keep track of them here
registerSpecialFunction :: State -> Op -> IO ()
registerSpecialFunction st o =
do progInfo <- readIORef (rProgInfo st)
let upd p@ProgInfo{progSpecialFuncs} = p{progSpecialFuncs = o : progSpecialFuncs}
when (o `notElem` progSpecialFuncs progInfo) $ modifyState st rProgInfo upd (pure ())

-- | Create a new constant; hash-cons as necessary
newConst :: State -> CV -> IO SV
newConst st c = do
Expand Down Expand Up @@ -1827,6 +1835,7 @@ mkNewState cfg currentRunMode = liftIO $ do
progInfo <- newIORef ProgInfo { hasQuants = False
, progSpecialRels = []
, progTransClosures = []
, progSpecialFuncs = []
}
rm <- newIORef currentRunMode
ctr <- newIORef (-2) -- start from -2; False and True will always occupy the first two elements
Expand Down Expand Up @@ -2213,6 +2222,7 @@ data SolverCapabilities = SolverCapabilities {
, supportsGlobalDecls :: Bool -- ^ Supports global declarations? (Needed for push-pop.)
, supportsDataTypes :: Bool -- ^ Supports datatypes?
, supportsFoldAndMap :: Bool -- ^ Does it support fold and map?
, supportsLambdas :: Bool -- ^ Does it support lambdas?
, supportsSpecialRels :: Bool -- ^ Does it support special relations (orders, transitive closure etc.)
, supportsDirectAccessors :: Bool -- ^ Supports data-type accessors without full ascription?
, supportsFlattenedModels :: Maybe [String] -- ^ Supports flattened model output? (With given config lines.)
Expand Down
17 changes: 13 additions & 4 deletions Data/SBV/List.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import qualified Prelude as P
import Data.SBV.Core.Kind
import Data.SBV.Core.Data hiding (StrOp(..))
import Data.SBV.Core.Model
import Data.SBV.Core.Symbolic (registerSpecialFunction)

import Data.SBV.Lambda
import Data.SBV.Tuple
Expand Down Expand Up @@ -378,7 +379,9 @@ reverse l
= SBV $ SVal k $ Right $ cache r
where k = kindOf l
r st = do sva <- sbvToSV st l
newExpr st k (SBVApp (SeqOp (SBVReverse k)) [sva])
let op = SeqOp (SBVReverse k)
registerSpecialFunction st op
newExpr st k (SBVApp op [sva])

-- | @`map` op s@ maps the operation on to sequence.
--
Expand Down Expand Up @@ -569,7 +572,9 @@ all f l
= SBV $ SVal KBool $ Right $ cache r
where r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st KBool (SBVApp (SeqOp (SBVSeqAll (kindOf (Proxy @a)) lam)) [sva])
let op = SeqOp (SBVSeqAll (kindOf (Proxy @a)) lam)
registerSpecialFunction st op
newExpr st KBool (SBVApp op [sva])

-- | Check some element satisfies the predicate.
-- --
Expand All @@ -586,7 +591,9 @@ any f l
= SBV $ SVal KBool $ Right $ cache r
where r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st KBool (SBVApp (SeqOp (SBVSeqAny (kindOf (Proxy @a)) lam)) [sva])
let op = SeqOp (SBVSeqAny (kindOf (Proxy @a)) lam)
registerSpecialFunction st op
newExpr st KBool (SBVApp op [sva])

-- | @filter f xs@ filters the list with the given predicate.
--
Expand All @@ -607,7 +614,9 @@ filter f l
k = kindOf (Proxy @(SList a))
r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st k (SBVApp (SeqOp (SBVSeqFilter (kindOf (Proxy @a)) lam)) [sva])
let op = SeqOp (SBVSeqFilter (kindOf (Proxy @a)) lam)
registerSpecialFunction st op
newExpr st k (SBVApp op [sva])

-- | Lift a unary operator over lists.
lift1 :: forall a b. (SymVal a, SymVal b) => Bool -> SeqOp -> Maybe (a -> b) -> SBV a -> SBV b
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/ABC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ abc = SMTSolver {
, supportsGlobalDecls = False
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/Bitwuzla.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ bitwuzla = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/Boolector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ boolector = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/CVC4.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cvc4 = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = True
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = True
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/CVC5.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cvc5 = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = True
, supportsFoldAndMap = False
, supportsLambdas = True
, supportsSpecialRels = False
, supportsDirectAccessors = True
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/DReal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dReal = SMTSolver {
, supportsGlobalDecls = False
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/MathSAT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mathSAT = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = True
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = True
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/OpenSMT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ openSMT = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/Yices.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ yices = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = False
, supportsFoldAndMap = False
, supportsLambdas = False
, supportsSpecialRels = False
, supportsDirectAccessors = False
, supportsFlattenedModels = Nothing
Expand Down
1 change: 1 addition & 0 deletions Data/SBV/Provers/Z3.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ z3 = SMTSolver {
, supportsGlobalDecls = True
, supportsDataTypes = True
, supportsFoldAndMap = True
, supportsLambdas = True
, supportsSpecialRels = True
, supportsDirectAccessors = False -- Needs ascriptions. (See the CVC4 version of this)
, supportsFlattenedModels = Just [ "(set-option :pp.max_depth 4294967295)"
Expand Down
71 changes: 43 additions & 28 deletions Data/SBV/SMT/SMTLib2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}

Expand Down Expand Up @@ -79,15 +80,21 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
isFoldMap SeqFoldLeftI{} = True
isFoldMap _ = False
in (not . null) [ () | o :: SeqOp <- G.universeBi asgnsSeq, isFoldMap o]
hasLambdas = let needsLambda SBVSeqFilter{} = True
needsLambda SBVSeqAll{} = True
needsLambda SBVSeqAny{} = True
needsLambda _ = False
in (not . null) [ () | o :: SeqOp <- G.universeBi asgnsSeq, needsLambda o]

(needsQuantifiers, needsSpecialRels) = case curProgInfo of
ProgInfo hasQ srs tcs -> (hasQ, not (null srs && null tcs))
(needsQuantifiers, needsSpecialRels, specialFuncs) = case curProgInfo of
ProgInfo hasQ srs tcs sf -> (hasQ, not (null srs && null tcs), sf)

-- Is there a reason why we can't handle this problem?
-- NB. There's probably a lot more checking we can do here, but this is a start:
doesntHandle = listToMaybe [nope w | (w, have, need) <- checks, need && not (have solverCaps)]
where checks = [ ("data types", supportsDataTypes, hasTuples || hasEither || hasMaybe)
, ("folds and maps", supportsFoldAndMap, hasFoldMap)
, ("needs lambds", supportsLambdas, hasLambdas)
, ("set operations", supportsSets, hasSets)
, ("bit vectors", supportsBitVectors, hasBVs)
, ("special relations", supportsSpecialRels, needsSpecialRels)
Expand All @@ -102,8 +109,14 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
, "*** But the chosen solver (" ++ show (name (solver cfg)) ++ ") doesn't support this feature."
]

-- Some cases require all, some require none. Sigh..
setAll reason = ["(set-logic ALL) ; " ++ reason ++ ", using catch-all."]
-- Some cases require all, some require none. Sigh.. Also, if there're lambdas CVC5 needs HO_ALL
setAll reason = ["(set-logic " ++ allName ++ ") ; " ++ reason ++ ", using catch-all."]
where allName | hasLambdas && isCVC5 = "HO_ALL"
| True = "ALL"

isCVC5 = case name (solver cfg) of
CVC5 -> True
_ -> False

-- Determining the logic is surprisingly tricky!
logic
Expand Down Expand Up @@ -138,6 +151,7 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
| needsSpecialRels = ["; has special relations, no logic set."]

-- Things that require ALL
| hasLambdas = setAll "has lambda expressions"
| hasInteger = setAll "has unbounded values"
| hasRational = setAll "has rational values"
| hasReal = setAll "has algebraic reals"
Expand Down Expand Up @@ -232,7 +246,7 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
++ [ "; --- uninterpreted constants ---" ]
++ concatMap (declUI curProgInfo) uis
++ [ "; --- SBV Function definitions" | not (null funcMap) ]
++ concat [declSBVFunc op nm | (op, nm) <- M.toAscList funcMap]
++ concat [declSBVFunc cfg op nm | (op, nm) <- M.toAscList funcMap]
++ [ "; --- user defined functions ---"]
++ userDefs
++ [ "; --- assignments ---" ]
Expand Down Expand Up @@ -290,19 +304,11 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
| True = Just $ Left s

-- SBV only functions.
funcMap = M.fromList $ [(op, "|sbv.reverse_" ++ show k ++ "|") | (op@(SeqOp (SBVReverse k )), _) <- specials]
++ [(op, "|sbv.seqFilter_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqFilter k _)), i) <- specials]
++ [(op, "|sbv.seqAll_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAll k _)), i) <- specials]
++ [(op, "|sbv.seqAny_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAny k _)), i) <- specials]
where specials = zip (nub [op | op@(SeqOp so) <- G.universeBi asgnsSeq, isSpecial so]) [0..]

isSpecial SBVReverse{} = True
isSpecial SBVSeqFilter{} = True
isSpecial SBVSeqAll{} = True
isSpecial SBVSeqAny{} = True
isSpecial _ = False

-- if index 0, then ignore it; other wise add it. This distinguishes different functions passed to all/any
funcMap = M.fromList $ [(op, "|sbv.reverse_" ++ show k ++ "|") | op@(SeqOp (SBVReverse k )) <- specialFuncs]
++ [(op, "|sbv.seqFilter_" ++ show k ++ idx i ++ "|") | op@(SeqOp (SBVSeqFilter k _)) <- specialFuncs | i <- [0..]]
++ [(op, "|sbv.seqAll_" ++ show k ++ idx i ++ "|") | op@(SeqOp (SBVSeqAll k _)) <- specialFuncs | i <- [0..]]
++ [(op, "|sbv.seqAny_" ++ show k ++ idx i ++ "|") | op@(SeqOp (SBVSeqAny k _)) <- specialFuncs | i <- [0..]]
where -- if index 0, then ignore it; other wise add it. This distinguishes different functions passed to all/any
idx :: Int -> String
idx 0 = ""
idx i = show i
Expand All @@ -315,14 +321,14 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
_ -> Nothing

-- Declare "known" SBV functions here
declSBVFunc :: Op -> String -> [String]
declSBVFunc op nm = case op of
SeqOp (SBVReverse KString) -> mkStringRev
SeqOp (SBVReverse (KList k)) -> mkSeqRev (KList k)
SeqOp (SBVSeqFilter ek f) -> mkFilter ek f
SeqOp (SBVSeqAll ek f) -> mkAnyAll True ek f
SeqOp (SBVSeqAny ek f) -> mkAnyAll False ek f
_ -> error $ "Data.SBV.declSBVFunc: Unexpected internal function: " ++ show (op, nm)
declSBVFunc :: SMTConfig -> Op -> String -> [String]
declSBVFunc cfg op nm = case op of
SeqOp (SBVReverse KString) -> mkStringRev
SeqOp (SBVReverse (KList k)) -> mkSeqRev (KList k)
SeqOp (SBVSeqFilter ek f) -> mkFilter ek f
SeqOp (SBVSeqAll ek f) -> mkAnyAll True ek f
SeqOp (SBVSeqAny ek f) -> mkAnyAll False ek f
_ -> error $ "Data.SBV.declSBVFunc: Unexpected internal function: " ++ show (op, nm)
where mkStringRev = [ "(define-fun-rec " ++ nm ++ " ((str String)) String"
, " (ite (= str \"\")"
, " \"\""
Expand All @@ -338,10 +344,19 @@ declSBVFunc op nm = case op of
]
where t = smtType k

-- in Z3, lambdas are applied with select. In CVC5, it's @. This might change with higher-order features being added to SMTLib in v3
hoApply
| isCVC5 = "@"
| True = "select"

isCVC5 = case name (solver cfg) of
CVC5 -> True
_ -> False

mkAnyAll isAll ek f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) Bool"
, " (ite (= lst (as seq.empty " ++ t ++ "))"
, " " ++ base
, " (" ++ conn ++ " (select " ++ f ++ " (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))"
, " (" ++ conn ++ " (" ++ hoApply ++ " " ++ f ++ " (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))"
]
where t = smtType (KList ek)
(base, conn) | isAll = ("true", "and")
Expand All @@ -350,7 +365,7 @@ declSBVFunc op nm = case op of
mkFilter k f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) " ++ t
, " (ite (= lst (as seq.empty " ++ t ++ "))"
, " (as seq.empty " ++ t ++ ")"
, " (ite (select " ++ f ++ " (seq.nth lst 0))"
, " (ite (" ++ hoApply ++ " " ++ f ++ " (seq.nth lst 0))"
, " (seq.++ (seq.unit (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))"
, " (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))"
]
Expand Down

0 comments on commit 6275389

Please sign in to comment.