From f00aff9560aaec7997389bbfe1ccf073778f002f Mon Sep 17 00:00:00 2001 From: Levent Erkok Date: Tue, 6 Aug 2024 15:55:12 -0700 Subject: [PATCH] make it size polymorphic --- attic/overflow.hs | 111 ++++++++++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 48 deletions(-) diff --git a/attic/overflow.hs b/attic/overflow.hs index 22aa79ab9..600f82478 100644 --- a/attic/overflow.hs +++ b/attic/overflow.hs @@ -1,13 +1,16 @@ -- 64-bit signed multiplication overflow detection using a 66-bit multiplier -{-# LANGUAGE DataKinds #-} -{-# OPTIONS_GHC -Wall -Werror #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -Wall -Werror #-} module Main where import Data.SBV import Data.SBV.Tools.Overflow -import GHC.TypeLits (KnownNat) +import Data.Proxy +import GHC.TypeLits (KnownNat, type (+), natVal) -- Find the position of the first non-sign bit. i.e., the first bit that differs from the msb. -- Position is 0 indexed. Note that if there's no differing bit, then you also get back 0. @@ -30,74 +33,86 @@ position w = case blastBE w of walk check i (b:bs) = ite (check b) i (walk check (i-1) bs) -- Algorithm 1: Builtin overflow detection -spec :: SInt 64 -> SInt 64 -> SBool +spec :: (BVIsNonZero n, KnownNat n) => SInt n -> SInt n -> SBool spec = bvMulO --- Algorithm 2: use a 66 bit multiplier -alg66 :: SInt 64 -> SInt 64 -> SBool -alg66 x y = sNot zeroOut .&& overflow - where m = position x - n = position y +-- Algorithm 2: use a N+2 bit multiplier +algN2 :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+2), KnownNat (n+2)) => SInt n -> SInt n -> SBool +algN2 x y = sNot zeroOut .&& overflow + where zeroOut = x .== 0 .|| y .== 0 - zeroOut = x .== 0 .|| y .== 0 - - prod :: SInt 66 + prod :: SInt (n+2) prod = sFromIntegral x * sFromIntegral y - prod65, prod64, prod63 :: SBool - prod65 = prod `sTestBit` 65 - prod64 = prod `sTestBit` 64 - prod63 = prod `sTestBit` 63 + nv :: Int + nv = fromIntegral $ natVal (Proxy @n) - prod6563 :: [SBool] - prod6563 = [prod65, prod64, prod63] + prodNp1, prodN, prodNm1 :: SBool + prodNp1 = prod `sTestBit` (nv+1) + prodN = prod `sTestBit` nv + prodNm1 = prod `sTestBit` (nv-1) - overflow = ((m + n) .> 62) - .|| ( (prod6563 ./= [sFalse, sFalse, sFalse]) - .&& (prod6563 ./= [sTrue, sTrue, sTrue ]) - ) + prodTop3 :: [SBool] + prodTop3 = [prodNp1, prodN, prodNm1] --- Algorithm 3: use a 65 bit multiplier -alg65 :: SInt 64 -> SInt 64 -> SBool -alg65 x y = sNot zeroOut .&& overflow - where m = position x - n = position y + overflow = ((position x + position y) .> literal (fromIntegral (nv - 2))) + .|| ( (prodTop3 ./= [sFalse, sFalse, sFalse]) + .&& (prodTop3 ./= [sTrue, sTrue, sTrue ]) + ) - zeroOut = x .== 0 .|| y .== 0 +-- Algorithm 3: use an N+1 bit multiplier +algN1 :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool +algN1 x y = sNot zeroOut .&& overflow + where zeroOut = x .== 0 .|| y .== 0 - prod :: SInt 65 + prod :: SInt (n+1) prod = sFromIntegral x * sFromIntegral y - prod64, prod63 :: SBool - prod64 = prod `sTestBit` 64 - prod63 = prod `sTestBit` 63 + nv :: Int + nv = fromIntegral $ natVal (Proxy @n) - prod6463 :: [SBool] - prod6463 = [prod64, prod63] + prodN, prodNm1 :: SBool + prodN = prod `sTestBit` nv + prodNm1 = prod `sTestBit` (nv-1) - overflow = ((m + n) .> 62) - .|| ( (prod6463 ./= [sFalse, sFalse]) - .&& (prod6463 ./= [sTrue, sTrue ]) + prodTop2 :: [SBool] + prodTop2 = [prodN, prodNm1] + + overflow = ((position x + position y) .> literal (fromIntegral (nv - 2))) + .|| ( (prodTop2 ./= [sFalse, sFalse]) + .&& (prodTop2 ./= [sTrue, sTrue ]) ) -- Algorithm 4: Text-book definition -textbook :: SInt 64 -> SInt 64 -> SBool -textbook x y = prod128 ./= sFromIntegral prod64 - where prod128 :: SInt 128 - prod128 = sFromIntegral x * sFromIntegral y +textbook :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+n), KnownNat (n+n)) => SInt n -> SInt n -> SBool +textbook x y = prod2N ./= sFromIntegral prodN + where prod2N :: SInt (n+n) + prod2N = sFromIntegral x * sFromIntegral y - prod64 :: SInt 64 - prod64 = x * y + prodN :: SInt n + prodN = x * y -- Comparators -comp :: String -> (SInt 64 -> SInt 64 -> SBool) -> (SInt 64 -> SInt 64 -> SBool) -> IO ThmResult -comp w f g = do putStrLn $ "Proving: " ++ w ++ ":" +comp :: forall n. (BVIsNonZero n, KnownNat n) => String -> (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> IO ThmResult +comp w f g = do putStrLn $ "Proving: " ++ w ++ ", N = " ++ show (natVal (Proxy @n)) proveWith bitwuzla{timing = PrintTiming} $ do x <- sInt "x" y <- sInt "x" pure $ f x y .== g x y +runAll :: forall n. (BVIsNonZero n, KnownNat n) => (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> (SInt n -> SInt n -> SBool) -> IO () +runAll specF n2F n1F = do print =<< comp ("Spec vs " ++ n2 ++ "-bit mult") specF n2F + print =<< comp (n2 ++ "-bit vs " ++ n1 ++ "-bit mult") n2F n1F + print =<< comp (n1 ++ "-bit vs text-book") n1F specF + where n = natVal (Proxy @n) + n1 = show $ n+1 + n2 = show $ n+2 + +run :: forall proxy n. (KnownNat n, BVIsNonZero n, KnownNat (n+1), BVIsNonZero (n+1), KnownNat (n+2), BVIsNonZero (n+2)) => proxy n -> IO () +run _ = runAll @n spec algN2 algN1 + main :: IO () -main = do print =<< comp "Spec vs 66-bit mult" spec alg66 - print =<< comp "66-bit vs 65-bit mult" alg66 alg65 - print =<< comp "65-bit vs text-book" alg65 textbook +main = do run (Proxy @8) + run (Proxy @16) + run (Proxy @24) + run (Proxy @32)