Skip to content

Commit

Permalink
A first shot at more higher-order list functions
Browse files Browse the repository at this point in the history
  • Loading branch information
LeventErkok committed Nov 13, 2024
1 parent e487bdc commit fc2c8a8
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 41 deletions.
50 changes: 30 additions & 20 deletions Data/SBV/Core/Symbolic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -520,21 +520,24 @@ instance Show RegExOp where
show (RegExNEq r1 r2) = "(distinct " ++ regExpToSMTString r1 ++ " " ++ regExpToSMTString r2 ++ ")"

-- | Sequence operations.
data SeqOp = SeqConcat -- ^ See StrConcat
| SeqLen -- ^ See StrLen
| SeqUnit -- ^ See StrUnit
| SeqNth -- ^ See StrNth
| SeqSubseq -- ^ See StrSubseq
| SeqIndexOf -- ^ See StrIndexOf
| SeqContains -- ^ See StrContains
| SeqPrefixOf -- ^ See StrPrefixOf
| SeqSuffixOf -- ^ See StrSuffixOf
| SeqReplace -- ^ See StrReplace
| SeqMap String -- ^ Mapping over sequences
| SeqMapI String -- ^ Mapping over sequences with offset
| SeqFoldLeft String -- ^ Folding of sequences
| SeqFoldLeftI String -- ^ Folding of sequences with offset
| SBVReverse Kind -- ^ Reversal of sequences. NB. Also works for strings; hence the name.
data SeqOp = SeqConcat -- ^ See StrConcat
| SeqLen -- ^ See StrLen
| SeqUnit -- ^ See StrUnit
| SeqNth -- ^ See StrNth
| SeqSubseq -- ^ See StrSubseq
| SeqIndexOf -- ^ See StrIndexOf
| SeqContains -- ^ See StrContains
| SeqPrefixOf -- ^ See StrPrefixOf
| SeqSuffixOf -- ^ See StrSuffixOf
| SeqReplace -- ^ See StrReplace
| SeqMap String -- ^ Mapping over sequences
| SeqMapI String -- ^ Mapping over sequences with offset
| SeqFoldLeft String -- ^ Folding of sequences
| SeqFoldLeftI String -- ^ Folding of sequences with offset
| SBVReverse Kind -- ^ Reversal of sequences. NB. Also works for strings; hence the name.
| SBVSeqFilter Kind String -- ^ filter the list. Kind is the element type
| SBVSeqAll Kind String -- ^ map the function and reduce via and, with base true. Kind is the element type.
| SBVSeqAny Kind String -- ^ map the function and reduce via or, with base false. Kind is the element type.
deriving (Eq, Ord, G.Data, NFData, Generic)

-- | Show instance for SeqOp. Again, mapping is important.
Expand All @@ -554,11 +557,18 @@ instance Show SeqOp where
show (SeqFoldLeft s) = "seq.foldl " ++ s
show (SeqFoldLeftI s) = "seq.foldli " ++ s

-- Note: This isn't part of SMTLib, we explicitly handle it
show (SBVReverse k) = let sk = show k
ssk | any isSpace sk = '(' : sk ++ ")"
| True = sk
in "sbv.reverse @" ++ ssk
-- Note: The followings aren't part of SMTLib, we explicitly handle it
show (SBVReverse k) = funcWithKind "sbv.reverse" k Nothing
show (SBVSeqFilter k s) = funcWithKind "sbv.seqFilter" k (Just s)
show (SBVSeqAll k s) = funcWithKind "sbv.seqAll" k (Just s)
show (SBVSeqAny k s) = funcWithKind "sbv.seqAny" k (Just s)

-- helper for above
funcWithKind :: String -> Kind -> Maybe String -> String
funcWithKind f k mbExtra = f ++ " @" ++ ssk ++ maybe "" (' ':) mbExtra
where sk = show k
ssk | any isSpace sk = '(' : sk ++ ")"
| True = sk

-- | Set operations.
data SetOp = SetEqual
Expand Down
8 changes: 7 additions & 1 deletion Data/SBV/Lambda.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import Data.SBV.Core.Kind
import Data.SBV.SMT.SMTLib2
import Data.SBV.Utils.PrettyNum

import qualified Data.Map.Strict as M

import Data.SBV.Core.Symbolic hiding (mkNewState)
import qualified Data.SBV.Core.Symbolic as S (mkNewState)

Expand Down Expand Up @@ -336,7 +338,7 @@ toLambda level curProgInfo cfg expectedKind result@Result{resAsgns = SBVPgm asgn
where mkAsgn (sv, e@(SBVApp (Label l) _)) = ((sv, converter e), Just l)
mkAsgn (sv, e) = ((sv, converter e), Nothing)

converter = cvtExp curProgInfo (capabilities (solver cfg)) rm tableMap
converter = cvtExp curProgInfo (capabilities (solver cfg)) rm tableMap funcMap


out :: SV
Expand All @@ -348,6 +350,10 @@ toLambda level curProgInfo cfg expectedKind result@Result{resAsgns = SBVPgm asgn

rm = roundingMode cfg

-- NB. The following isn't really kosher since a lambda might refer to an SBV function
-- like reverse/any/all etc. But let's cross that bridge when we get to it
funcMap = M.empty

-- NB. The following is dead-code, since we ensure tbls is empty
-- We used to support this, but there are issues, so dropping support
-- See, for instance, https://github.com/LeventErkok/sbv/issues/664
Expand Down
38 changes: 32 additions & 6 deletions Data/SBV/List.hs
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,15 @@ concat = foldl (++) []
-- True
-- >>> all isEven [2, 4, 6, 1, 8, 10 :: Integer]
-- False
all :: SymVal a => (SBV a -> SBool) -> SList a -> SBool
all f = foldl (\sofar e -> sofar .&& f e) sTrue
all :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SBool
all f l
| Just l' <- unliteral l
= sAll f (P.map literal l')
| True
= SBV $ SVal KBool $ Right $ cache r
where r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st KBool (SBVApp (SeqOp (SBVSeqAll (kindOf (Proxy @a)) lam)) [sva])

-- | Check some element satisfies the predicate.
-- --
Expand All @@ -571,17 +578,36 @@ all f = foldl (\sofar e -> sofar .&& f e) sTrue
-- False
-- >>> any isEven [2, 4, 6, 1, 8, 10 :: Integer]
-- True
any :: SymVal a => (SBV a -> SBool) -> SList a -> SBool
any f = foldl (\sofar e -> sofar .|| f e) sFalse
any :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SBool
any f l
| Just l' <- unliteral l
= sAny f (P.map literal l')
| True
= SBV $ SVal KBool $ Right $ cache r
where r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st KBool (SBVApp (SeqOp (SBVSeqAny (kindOf (Proxy @a)) lam)) [sva])

-- | @filter f xs@ filters the list with the given predicate.
--
-- >>> filter (\x -> x `sMod` 2 .== 0) [1 .. 10 :: Integer]
-- [2,4,6,8,10] :: [SInteger]
-- >>> filter (\x -> x `sMod` 2 ./= 0) [1 .. 10 :: Integer]
-- [1,3,5,7,9] :: [SInteger]
filter :: SymVal a => (SBV a -> SBool) -> SList a -> SList a
filter f = foldl (\sofar e -> sofar ++ ite (f e) (singleton e) []) []
filter :: forall a. SymVal a => (SBV a -> SBool) -> SList a -> SList a
filter f l
| Just l' <- unliteral l, Just concResult <- concreteFilter l'
= literal concResult
| True
= SBV $ SVal k $ Right $ cache r
where concreteFilter l' = case P.map (unliteral . f . literal) l' of
xs | P.any isNothing xs -> Nothing
| True -> Just [e | (True, e) <- P.zip (catMaybes xs) l']

k = kindOf (Proxy @(SList a))
r st = do sva <- sbvToSV st l
lam <- lambdaStr st KBool f
newExpr st k (SBVApp (SeqOp (SBVSeqFilter (kindOf (Proxy @a)) lam)) [sva])

-- | Lift a unary operator over lists.
lift1 :: forall a b. (SymVal a, SymVal b) => Bool -> SeqOp -> Maybe (a -> b) -> SBV a -> SBV b
Expand Down
80 changes: 66 additions & 14 deletions Data/SBV/SMT/SMTLib2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
++ [ "; --- uninterpreted constants ---" ]
++ concatMap (declUI curProgInfo) uis
++ [ "; --- SBV Function definitions" | not (null funcMap) ]
++ concat [ declSBVFunc op nm | (op, nm) <- M.toAscList funcMap ]
++ concat [declSBVFunc op nm | (op, nm) <- M.toAscList funcMap]
++ [ "; --- user defined functions ---"]
++ userDefs
++ [ "; --- assignments ---" ]
++ concatMap (declDef curProgInfo cfg tableMap) asgns
++ concatMap (declDef curProgInfo cfg tableMap funcMap) asgns
++ [ "; --- delayedEqualities ---" ]
++ map (\s -> "(assert " ++ s ++ ")") delayedEqualities
++ [ "; --- formula ---" ]
Expand Down Expand Up @@ -290,8 +290,22 @@ cvt ctx curProgInfo kindInfo isSat comments allInputs (_, consts) tbls uis defs
| True = Just $ Left s

-- SBV only functions.
funcMap = M.fromList [(op, "|sbv.reverse_" ++ show k ++ "|") | (op, k) <- revs]
where revs = nub [(op, k) | op@(SeqOp (SBVReverse k)) <- G.universeBi asgnsSeq]
funcMap = M.fromList $ [(op, "|sbv.reverse_" ++ show k ++ "|") | (op@(SeqOp (SBVReverse k )), _) <- specials]
++ [(op, "|sbv.seqFilter_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqFilter k _)), i) <- specials]
++ [(op, "|sbv.seqAll_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAll k _)), i) <- specials]
++ [(op, "|sbv.seqAny_" ++ show k ++ idx i ++ "|") | (op@(SeqOp (SBVSeqAny k _)), i) <- specials]
where specials = zip (nub [op | op@(SeqOp so) <- G.universeBi asgnsSeq, isSpecial so]) [0..]

isSpecial SBVReverse{} = True
isSpecial SBVSeqFilter{} = True
isSpecial SBVSeqAll{} = True
isSpecial SBVSeqAny{} = True
isSpecial _ = False

-- if index 0, then ignore it; other wise add it. This distinguishes different functions passed to all/any
idx :: Int -> String
idx 0 = ""
idx i = show i

asgns = F.toList asgnsSeq

Expand All @@ -305,6 +319,9 @@ declSBVFunc :: Op -> String -> [String]
declSBVFunc op nm = case op of
SeqOp (SBVReverse KString) -> mkStringRev
SeqOp (SBVReverse (KList k)) -> mkSeqRev (KList k)
SeqOp (SBVSeqFilter ek f) -> mkFilter ek f
SeqOp (SBVSeqAll ek f) -> mkAnyAll True ek f
SeqOp (SBVSeqAny ek f) -> mkAnyAll False ek f
_ -> error $ "Data.SBV.declSBVFunc: Unexpected internal function: " ++ show (op, nm)
where mkStringRev = [ "(define-fun-rec " ++ nm ++ " ((str String)) String"
, " (ite (= str \"\")"
Expand All @@ -321,6 +338,24 @@ declSBVFunc op nm = case op of
]
where t = smtType k

mkAnyAll isAll ek f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) Bool"
, " (ite (= lst (as seq.empty " ++ t ++ "))"
, " " ++ base
, " (" ++ conn ++ " (select " ++ f ++ " (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))"
]
where t = smtType (KList ek)
(base, conn) | isAll = ("true", "and")
| True = ("false", "or")

mkFilter k f = [ "(define-fun-rec " ++ nm ++ " ((lst " ++ t ++ ")) " ++ t
, " (ite (= lst (as seq.empty " ++ t ++ "))"
, " (as seq.empty " ++ t ++ ")"
, " (ite (select " ++ f ++ " (seq.nth lst 0))"
, " (seq.++ (seq.unit (seq.nth lst 0)) (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))"
, " (" ++ nm ++ " (seq.extract lst 1 (- (seq.len lst) 1))))))"
]
where t = smtType (KList k)

-- | Declare new sorts
declSort :: (String, Maybe [String]) -> [String]
declSort (s, _)
Expand Down Expand Up @@ -472,7 +507,7 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg =
-- table declarations
++ tableDecls
-- expressions
++ concatMap (declDef curProgInfo cfg tableMap) (F.toList asgnsSeq)
++ concatMap (declDef curProgInfo cfg tableMap funcMap) (F.toList asgnsSeq)
-- table setups
++ concat tableAssigns
-- extra constraints
Expand All @@ -488,6 +523,10 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg =

(tableDecls, tableAssigns) = unzip $ map mkTable allTables

-- This isn't super kosher, since we might refer to an internal function in
-- the incremental context. But let's cross that bridge when we come to it.
funcMap = M.empty

-- If we need flattening in models, do emit the required lines if preset
settings
| any needsFlattening newKinds
Expand All @@ -496,11 +535,11 @@ cvtInc curProgInfo inps newKs (_, consts) tbls uis (SBVPgm asgnsSeq) cstrs cfg =
= []
where solverCaps = capabilities (solver cfg)

declDef :: ProgInfo -> SMTConfig -> TableMap -> (SV, SBVExpr) -> [String]
declDef curProgInfo cfg tableMap (s, expr) =
declDef :: ProgInfo -> SMTConfig -> TableMap -> FuncMap -> (SV, SBVExpr) -> [String]
declDef curProgInfo cfg tableMap funcMap (s, expr) =
case expr of
SBVApp (Label m) [e] -> defineFun cfg (s, cvtSV e) (Just m)
e -> defineFun cfg (s, cvtExp curProgInfo caps rm tableMap e) Nothing
SBVApp (Label m) [e] -> defineFun cfg (s, cvtSV e) (Just m)
e -> defineFun cfg (s, cvtExp curProgInfo caps rm tableMap funcMap e) Nothing
where caps = capabilities (solver cfg)
rm = roundingMode cfg

Expand Down Expand Up @@ -684,7 +723,8 @@ cvtType (SBVType []) = error "SBV.SMT.SMTLib2.cvtType: internal: received an emp
cvtType (SBVType xs) = "(" ++ unwords (map smtType body) ++ ") " ++ smtType ret
where (body, ret) = (init xs, last xs)

type TableMap = IM.IntMap String
type TableMap = IM.IntMap String
type FuncMap = M.Map Op String

-- Present an SV, simply show
cvtSV :: SV -> String
Expand All @@ -698,8 +738,8 @@ getTable m i
| Just tn <- i `IM.lookup` m = tn
| True = "table" ++ show i -- constant tables are always named this way

cvtExp :: ProgInfo -> SolverCapabilities -> RoundingMode -> TableMap -> SBVExpr -> String
cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr
cvtExp :: ProgInfo -> SolverCapabilities -> RoundingMode -> TableMap -> FuncMap -> SBVExpr -> String
cvtExp curProgInfo caps rm tableMap funcMap expr@(SBVApp _ arguments) = sh expr
where hasPB = supportsPseudoBooleans caps
hasInt2bv = supportsInt2bv caps
hasDistinct = supportsDistinct caps
Expand Down Expand Up @@ -788,6 +828,15 @@ cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr
= let idx v = "(" ++ s ++ "_constrIndex " ++ v ++ ")" in "(" ++ o ++ " " ++ idx a ++ " " ++ idx b ++ ")"
unintComp o sbvs = error $ "SBV.SMT.SMTLib2.sh.unintComp: Unexpected arguments: " ++ show (o, sbvs, map kindOf arguments)

getFuncName op = case op `M.lookup` funcMap of
Just n -> n
Nothing -> error $ unlines [ ""
, "*** Cannot translate operator: " ++ show op
, "***"
, "*** Note that this operator isn't currently supported in incremental query mode."
, "*** If you are not in query mode, or would like support for this feature, please report!"
]

stringOrChar KString = True
stringOrChar KChar = True
stringOrChar _ = False
Expand Down Expand Up @@ -974,8 +1023,11 @@ cvtExp curProgInfo caps rm tableMap expr@(SBVApp _ arguments) = sh expr
sh (SBVApp (RegExOp o@RegExEq{}) []) = show o
sh (SBVApp (RegExOp o@RegExNEq{}) []) = show o

-- Reverse is special, since we need to generate call to the internally generated function
sh (SBVApp (SeqOp (SBVReverse k)) args) = "(|sbv.reverse_" ++ show k ++ "| " ++ unwords (map cvtSV args) ++ ")"
-- Reverse and higher order functions are special
sh (SBVApp o@(SeqOp SBVReverse{}) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")"
sh (SBVApp o@(SeqOp SBVSeqFilter{}) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")"
sh (SBVApp o@(SeqOp SBVSeqAll{} ) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")"
sh (SBVApp o@(SeqOp SBVSeqAny{} ) args) = "(" ++ getFuncName o ++ " " ++ unwords (map cvtSV args) ++ ")"

sh (SBVApp (SeqOp op) args) = "(" ++ show op ++ " " ++ unwords (map cvtSV args) ++ ")"

Expand Down

0 comments on commit fc2c8a8

Please sign in to comment.