Skip to content

Commit

Permalink
cleanup overflow with good constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
LeventErkok committed Aug 7, 2024
1 parent 05539b3 commit 2d79ead
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions attic/overflow.hs
Original file line number Diff line number Diff line change
@@ -1,16 +1,40 @@
-- N-bit signed multiplication overflow detection using a N+1 bit multiplier

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# OPTIONS_GHC -Wall -Werror #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Main(main) where

import Data.SBV
import Data.SBV.Tools.Overflow (bvMulO)

import Data.Proxy
import GHC.TypeLits (KnownNat, type (+), natVal)
import Data.Kind
import Data.Type.Bool
import Data.Type.Equality

import GHC.TypeLits

type InvalidBVSMULO (n :: Nat) = 'Text "Invalid call to bvsmulO on n: " ':<>: 'ShowType n
':$$: 'Text ""
':$$: 'Text "A valid call must have 0 < n <= 32769"
':$$: 'Text ""
':$$: 'Text "Given type falls outside of this range, or the size is not a known natural."

-- We use an SWord16 as the approxLog; and we add two of these. The function computes
-- N-2 at the largest for N bits. Two of them give us 2N-4, and to fit into SWord16,
-- we need 2N-4 <= 2^16-1, which implies N <= 32769, or N < 32770; which should be plenty enough for
-- any practical purpose. Hence the constraint below.
type family BVIsValidSMulO (arg :: Nat) :: Constraint where
BVIsValidSMulO (n :: Nat) = ( BVIsNonZero n
, KnownNat n
, If (n `CmpNat` 32770 == 'LT)
(() :: Constraint)
(TypeError (InvalidBVSMULO n)))

-- Find the position of the first non-sign bit. i.e., the first bit that differs from the msb.
-- Position is 0 indexed. Note that if there's no differing bit, then you also get back 0.
Expand All @@ -26,16 +50,16 @@ import GHC.TypeLits (KnownNat, type (+), natVal)
-- 101 -> 1
-- 110 -> 0
-- 111 -> 0 (no differing bit from 1; so we get 0)
approxLog :: (BVIsNonZero n, KnownNat n) => SInt n -> SWord8
approxLog :: BVIsValidSMulO n => SInt n -> SWord 16
approxLog w = case blastBE w of
[] -> error $ "Impossible happened: Got no bits after blasing " ++ show w
x : xs -> walk (.== sNot x) (literal (fromIntegral (length xs - 1))) xs
where walk :: (SBool -> SBool) -> SWord8 -> [SBool] -> SWord8
where walk :: (SBool -> SBool) -> SWord 16 -> [SBool] -> SWord 16
walk _check _i [] = 0
walk check i (b:bs) = ite (check b) i (walk check (i-1) bs)

-- Algorithm using an N+1 bit multiplier
bvsmulO :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool
bvsmulO :: forall n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool
bvsmulO x y = sNot zeroOut .&& overflow
where zeroOut = x .== 0 .|| y .== 0

Expand All @@ -61,7 +85,7 @@ textbook x y = prod2N ./= sFromIntegral prodN
prodN :: SInt n
prodN = x * y

test :: forall proxy n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+1), KnownNat (n+1), BVIsNonZero (n+n), KnownNat (n+n)) => proxy n -> IO ()
test :: forall proxy n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1), BVIsNonZero (n+n), KnownNat (n+n)) => proxy n -> IO ()
test _ = do print =<< check "Against builtin" bvMulO
print =<< check "Against textbook" textbook
where check w f = do putStrLn $ "Proving: " ++ w ++ ", N = " ++ show (natVal (Proxy @n))
Expand All @@ -81,4 +105,5 @@ main = do test (Proxy @1)
test (Proxy @8)
test (Proxy @16)
test (Proxy @24)
-- run (Proxy @32) -- Takes about 2 minutes
test (Proxy @32)
test (Proxy @64)

0 comments on commit 2d79ead

Please sign in to comment.