Skip to content

Commit

Permalink
clean-up further
Browse files Browse the repository at this point in the history
  • Loading branch information
LeventErkok committed Aug 9, 2024
1 parent 0fa642c commit 37f0c1c
Showing 1 changed file with 35 additions and 48 deletions.
83 changes: 35 additions & 48 deletions attic/overflow.hs
Original file line number Diff line number Diff line change
@@ -1,52 +1,27 @@
-- N-bit signed multiplication overflow detection using a N+1 bit multiplier

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DataKinds #-}

{-# OPTIONS_GHC -Wall -Werror #-}
{-# OPTIONS_GHC -Wall -Werror #-}

module Main(main) where

import Data.SBV
import Data.SBV.Tools.Overflow (bvMulO)

import Data.Proxy
import Data.Kind
import Data.Type.Bool
import Data.Type.Equality

import Control.Monad
import Data.Proxy

import GHC.TypeLits

type InvalidBVSMULO (n :: Nat) = 'Text "Invalid type/size with n: " ':<>: 'ShowType n
':$$: 'Text ""
':$$: 'Text "A valid call must pass `SInt n` argument with 0 < n <= 32769"
':$$: 'Text ""
':$$: 'Text "Given type falls outside of this range, or the size is not a known natural."

-- We use an SWord16 as the nonSignBitPos; and we add two of these. The function computes
-- N-2 at the largest for N bits. Two of them give us 2N-4, and to fit into SWord16,
-- we need 2N-4 <= 2^16-1, which implies N <= 32769, or N < 32770; which should be plenty enough for
-- any practical purpose. Hence the constraint below.
--
-- Alternatively, we can use Integers and not worry about this, alas Bitwuzla (which does really well
-- on bit-vector programs) does not support unbounded integers.
--
-- TODO: See if we can avoid the addition completely and somehow do a position comparision to see if it'll be N-2.
type family BVIsValidSMulO (arg :: Nat) :: Constraint where
BVIsValidSMulO (n :: Nat) = ( BVIsNonZero n
, KnownNat n
, If (n `CmpNat` 32770 == 'LT)
(() :: Constraint)
(TypeError (InvalidBVSMULO n)))
import Data.SBV
import Data.SBV.Tools.Overflow (bvMulO)

-- Find the position of the first non-sign bit. i.e., the first bit that differs from the msb.
-- Position is 0 indexed. Note that if there's no differing bit, then you also get back 0.
-- This is essentially an approximation of the logarithm of the magnitude of the number.
--
-- The result is at most N-2 for an N-bit word. Later we add two of these, so the maximum
-- value we need to represent is 2N-4. This will require 1 + lg(2N-4) = 2 + log(N-1) bits.
-- To suppor the case N=0, we return a (2 + log N) bit word.
--
-- Example for 3 bits:
--
-- 000 -> 0 (no differing bit from 0; so we get 0)
Expand All @@ -57,16 +32,22 @@ type family BVIsValidSMulO (arg :: Nat) :: Constraint where
-- 101 -> 1
-- 110 -> 0
-- 111 -> 0 (no differing bit from 1; so we get 0)
nonSignBitPos :: BVIsValidSMulO n => SInt n -> SWord 16
nonSignBitPos w = case blastBE w of
[] -> error $ "Impossible happened: Got no bits after blasing " ++ show w
x : xs -> walk (.== sNot x) (literal (fromIntegral (length xs - 1))) xs
where walk :: (SBool -> SBool) -> SWord 16 -> [SBool] -> SWord 16
walk _check _i [] = 0
walk check i (b:bs) = ite (check b) i (walk check (i-1) bs)
nonSignBitPos :: ( KnownNat n, BVIsNonZero n
, KnownNat (2 + Log2 n), BVIsNonZero (2+Log2 n))
=> SInt n -> SWord (2 + Log2 n)
nonSignBitPos w = walk 0 rest
where (sign, rest) = case blastBE w of
[] -> error $ "Impossible happened, blastBE returned no bits for " ++ show w
(x:xs) -> (x, zip [0..] (reverse xs))

walk sofar [] = sofar
walk sofar ((i, b):bs) = walk (ite (b ./= sign) i sofar) bs

-- Algorithm using an N+1 bit multiplier
bvsmulO :: forall n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1)) => SInt n -> SInt n -> SBool
bvsmulO :: forall n. ( KnownNat n, BVIsNonZero n
, KnownNat (n+1), BVIsNonZero (n+1)
, KnownNat (2+Log2 n), BVIsNonZero (2+Log2 n))
=> SInt n -> SInt n -> SBool
bvsmulO x y = sNot zeroOut .&& overflow
where zeroOut = x .== 0 .|| y .== 0

Expand All @@ -80,25 +61,31 @@ bvsmulO x y = sNot zeroOut .&& overflow
prodN = prod `sTestBit` nv
prodNm1 = prod `sTestBit` (nv-1)

overflow = ((nonSignBitPos x + nonSignBitPos y) .> literal (fromIntegral (nv - 2)))
.|| (prodN .<+> prodNm1)
overflow = nonSignBitPos x + nonSignBitPos y .> literal (fromIntegral (nv - 2))
.|| prodN .<+> prodNm1

-- Text-book definition
textbook :: forall n. (BVIsNonZero n, KnownNat n, BVIsNonZero (n+n), KnownNat (n+n)) => SInt n -> SInt n -> SBool
textbook :: forall n. ( KnownNat n, BVIsNonZero n
, KnownNat (n+n), BVIsNonZero (n+n)
) => SInt n -> SInt n -> SBool
textbook x y = prod2N ./= sFromIntegral prodN
where prod2N :: SInt (n+n)
prod2N = sFromIntegral x * sFromIntegral y

prodN :: SInt n
prodN = x * y

test :: forall proxy n. (BVIsValidSMulO n, BVIsNonZero (n+1), KnownNat (n+1), BVIsNonZero (n+n), KnownNat (n+n)) => proxy n -> Bool -> IO ()
test :: forall proxy n. ( KnownNat n, BVIsNonZero n
, KnownNat (n+1), BVIsNonZero (n+1)
, KnownNat (n+n), BVIsNonZero (n+n)
, KnownNat (2+Log2 n), BVIsNonZero (2+Log2 n)
) => proxy n -> Bool -> IO ()
test _ checkTextBook = do print =<< check "Against builtin" bvMulO
when checkTextBook (print =<< check "Against textbook" textbook)
where check w f = do putStrLn $ "Proving: " ++ w ++ ", N = " ++ show (natVal (Proxy @n))
proveWith bitwuzla $ do
proveWith bitwuzla{timing = PrintTiming} $ do
x <- sInt "x"
y <- sInt "x"
y <- sInt "y"
pure $ f x y .== (bvsmulO :: SInt n -> SInt n -> SBool) x y

main :: IO ()
Expand Down

0 comments on commit 37f0c1c

Please sign in to comment.