Skip to content

Commit

Permalink
make it size polymorphic
Browse files Browse the repository at this point in the history
  • Loading branch information
LeventErkok committed Aug 6, 2024
1 parent da8a564 commit f00aff9
Showing 1 changed file with 63 additions and 48 deletions.
111 changes: 63 additions & 48 deletions attic/overflow.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
-- 64-bit signed multiplication overflow detection using a 66-bit multiplier

{-# LANGUAGE DataKinds #-}
{-# OPTIONS_GHC -Wall -Werror #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -Wall -Werror #-}

module Main where

import Data.SBV
import Data.SBV.Tools.Overflow
import GHC.TypeLits (KnownNat)
import Data.Proxy
import GHC.TypeLits (KnownNat, type (+), natVal)

-- Find the position of the first non-sign bit. i.e., the first bit that differs from the msb.
-- Position is 0 indexed. Note that if there's no differing bit, then you also get back 0.
Expand All @@ -30,74 +33,86 @@ position w = case blastBE w of
walk check i (b:bs) = ite (check b) i (walk check (i-1) bs)

-- Algorithm 1: Builtin overflow detection
spec :: SInt 64 -> SInt 64 -> SBool
spec :: (BVIsNonZero n, KnownNat n) => SInt n -> SInt n -> SBool
spec = bvMulO

-- Algorithm 2: use a 66 bit multiplier
alg66 :: SInt 64 -> SInt 64 -> SBool
alg66 x y = sNot zeroOut .&& overflow
where m = position x
n = position y
-- Algorithm 2: use a N+2 bit multiplier
algN2 :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+2), KnownNat (n+2)) => SInt n -> SInt n -> SBool
algN2 x y = sNot zeroOut .&& overflow
where zeroOut = x .== 0 .|| y .== 0

zeroOut = x .== 0 .|| y .== 0

prod :: SInt 66
prod :: SInt (n+2)
prod = sFromIntegral x * sFromIntegral y

prod65, prod64, prod63 :: SBool
prod65 = prod `sTestBit` 65
prod64 = prod `sTestBit` 64
prod63 = prod `sTestBit` 63
nv :: Int
nv = fromIntegral $ natVal (Proxy @n)

prod6563 :: [SBool]
prod6563 = [prod65, prod64, prod63]
prodNp1, prodN, prodNm1 :: SBool
prodNp1 = prod `sTestBit` (nv+1)
prodN = prod `sTestBit` nv
prodNm1 = prod `sTestBit` (nv-1)

overflow = ((m + n) .> 62)
.|| ( (prod6563 ./= [sFalse, sFalse, sFalse])
.&& (prod6563 ./= [sTrue, sTrue, sTrue ])
)
prodTop3 :: [SBool]
prodTop3 = [prodNp1, prodN, prodNm1]

-- Algorithm 3: use a 65 bit multiplier
alg65 :: SInt 64 -> SInt 64 -> SBool
alg65 x y = sNot zeroOut .&& overflow
where m = position x
n = position y
overflow = ((position x + position y) .> literal (fromIntegral (nv - 2)))
.|| ( (prodTop3 ./= [sFalse, sFalse, sFalse])
.&& (prodTop3 ./= [sTrue, sTrue, sTrue ])
)

zeroOut = x .== 0 .|| y .== 0
-- Algorithm 3: use an N+1 bit multiplier
algN1 :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool
algN1 x y = sNot zeroOut .&& overflow
where zeroOut = x .== 0 .|| y .== 0

prod :: SInt 65
prod :: SInt (n+1)
prod = sFromIntegral x * sFromIntegral y

prod64, prod63 :: SBool
prod64 = prod `sTestBit` 64
prod63 = prod `sTestBit` 63
nv :: Int
nv = fromIntegral $ natVal (Proxy @n)

prod6463 :: [SBool]
prod6463 = [prod64, prod63]
prodN, prodNm1 :: SBool
prodN = prod `sTestBit` nv
prodNm1 = prod `sTestBit` (nv-1)

overflow = ((m + n) .> 62)
.|| ( (prod6463 ./= [sFalse, sFalse])
.&& (prod6463 ./= [sTrue, sTrue ])
prodTop2 :: [SBool]
prodTop2 = [prodN, prodNm1]

overflow = ((position x + position y) .> literal (fromIntegral (nv - 2)))
.|| ( (prodTop2 ./= [sFalse, sFalse])
.&& (prodTop2 ./= [sTrue, sTrue ])
)

-- Algorithm 4: Text-book definition
textbook :: SInt 64 -> SInt 64 -> SBool
textbook x y = prod128 ./= sFromIntegral prod64
where prod128 :: SInt 128
prod128 = sFromIntegral x * sFromIntegral y
textbook :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+n), KnownNat (n+n)) => SInt n -> SInt n -> SBool
textbook x y = prod2N ./= sFromIntegral prodN
where prod2N :: SInt (n+n)
prod2N = sFromIntegral x * sFromIntegral y

prod64 :: SInt 64
prod64 = x * y
prodN :: SInt n
prodN = x * y

-- Comparators
comp :: String -> (SInt 64 -> SInt 64 -> SBool) -> (SInt 64 -> SInt 64 -> SBool) -> IO ThmResult
comp w f g = do putStrLn $ "Proving: " ++ w ++ ":"
comp :: forall n. (BVIsNonZero n, KnownNat n) => String -> (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> IO ThmResult
comp w f g = do putStrLn $ "Proving: " ++ w ++ ", N = " ++ show (natVal (Proxy @n))
proveWith bitwuzla{timing = PrintTiming} $ do
x <- sInt "x"
y <- sInt "x"
pure $ f x y .== g x y

runAll :: forall n. (BVIsNonZero n, KnownNat n) => (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> IO ()
runAll specF n2F n1F = do print =<< comp ("Spec vs " ++ n2 ++ "-bit mult") specF n2F
print =<< comp (n2 ++ "-bit vs " ++ n1 ++ "-bit mult") n2F n1F
print =<< comp (n1 ++ "-bit vs text-book") n1F specF
where n = natVal (Proxy @n)
n1 = show $ n+1
n2 = show $ n+2

run :: forall proxy n. (KnownNat n, BVIsNonZero n, KnownNat (n+1), BVIsNonZero (n+1), KnownNat (n+2), BVIsNonZero (n+2)) => proxy n -> IO ()
run _ = runAll @n spec algN2 algN1

main :: IO ()
main = do print =<< comp "Spec vs 66-bit mult" spec alg66
print =<< comp "66-bit vs 65-bit mult" alg66 alg65
print =<< comp "65-bit vs text-book" alg65 textbook
main = do run (Proxy @8)
run (Proxy @16)
run (Proxy @24)
run (Proxy @32)

0 comments on commit f00aff9

Please sign in to comment.