diff --git a/attic/overflow.hs b/attic/overflow.hs index 284472ecd..81e83b078 100644 --- a/attic/overflow.hs +++ b/attic/overflow.hs @@ -1,8 +1,11 @@ -- N-bit signed multiplication overflow detection using a N+1 bit multiplier -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitNamespaces #-} -{-# OPTIONS_GHC -Wall -Werror #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} + +{-# OPTIONS_GHC -Wall -Werror #-} module Main(main) where @@ -10,7 +13,28 @@ import Data.SBV import Data.SBV.Tools.Overflow (bvMulO) import Data.Proxy -import GHC.TypeLits (KnownNat, type (+), natVal) +import Data.Kind +import Data.Type.Bool +import Data.Type.Equality + +import GHC.TypeLits + +type InvalidBVSMULO (n :: Nat) = 'Text "Invalid call to bvsmulO on n: " ':<>: 'ShowType n + ':$$: 'Text "" + ':$$: 'Text "A valid call must have 0 < n <= 32769" + ':$$: 'Text "" + ':$$: 'Text "Given type falls outside of this range, or the size is not a known natural." + +-- We use an SWord16 as the approxLog; and we add two of these. The function computes +-- N-2 at the largest for N bits. Two of them give us 2N-4, and to fit into SWord16, +-- we need 2N-4 <= 2^16-1, which implies N <= 32769, or N < 32770; which should be plenty enough for +-- any practical purpose. Hence the constraint below. +type family BVIsValidSMulO (arg :: Nat) :: Constraint where + BVIsValidSMulO (n :: Nat) = ( BVIsNonZero n + , KnownNat n + , If (n `CmpNat` 32770 == 'LT) + (() :: Constraint) + (TypeError (InvalidBVSMULO n))) -- Find the position of the first non-sign bit. i.e., the first bit that differs from the msb. -- Position is 0 indexed. Note that if there's no differing bit, then you also get back 0. @@ -26,16 +50,16 @@ import GHC.TypeLits (KnownNat, type (+), natVal) -- 101 -> 1 -- 110 -> 0 -- 111 -> 0 (no differing bit from 1; so we get 0) -approxLog :: (BVIsNonZero n, KnownNat n) => SInt n -> SWord8 +approxLog :: BVIsValidSMulO n => SInt n -> SWord 16 approxLog w = case blastBE w of [] -> error $ "Impossible happened: Got no bits after blasing " ++ show w x : xs -> walk (.== sNot x) (literal (fromIntegral (length xs - 1))) xs - where walk :: (SBool -> SBool) -> SWord8 -> [SBool] -> SWord8 + where walk :: (SBool -> SBool) -> SWord 16 -> [SBool] -> SWord 16 walk _check _i [] = 0 walk check i (b:bs) = ite (check b) i (walk check (i-1) bs) -- Algorithm using an N+1 bit multiplier -bvsmulO :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool +bvsmulO :: forall n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool bvsmulO x y = sNot zeroOut .&& overflow where zeroOut = x .== 0 .|| y .== 0 @@ -61,7 +85,7 @@ textbook x y = prod2N ./= sFromIntegral prodN prodN :: SInt n prodN = x * y -test :: forall proxy n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1), BVIsNonZero (n+n), KnownNat (n+n)) => proxy n -> IO () +test :: forall proxy n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1), BVIsNonZero (n+n), KnownNat (n+n)) => proxy n -> IO () test _ = do print =<< check "Against builtin" bvMulO print =<< check "Against textbook" textbook where check w f = do putStrLn $ "Proving: " ++ w ++ ", N = " ++ show (natVal (Proxy @n)) @@ -81,4 +105,5 @@ main = do test (Proxy @1) test (Proxy @8) test (Proxy @16) test (Proxy @24) - -- run (Proxy @32) -- Takes about 2 minutes + test (Proxy @32) + test (Proxy @64)