Skip to content

Commit

Permalink
Simplify away SimpleBoolOf
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Mar 31, 2024
1 parent 5010757 commit 616fb1c
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 31 deletions.
11 changes: 4 additions & 7 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module HordeAd.Core.Ast
, AstBool(..), OpCodeNum1(..), OpCodeNum2(..), OpCode1(..), OpCode2(..)
, OpCodeIntegral2(..), OpCodeBool(..), OpCodeRel(..)
-- * Boolean definitions and instances
, BoolOf, IfF(..), EqF(..), OrdF(..), minF, maxF
, IfF(..), EqF(..), OrdF(..), minF, maxF
-- * The AstRaw, AstNoVectorize and AstNoSimplify definitions
, AstRaw(..), AstRawS(..), AstRawWrap(..)
, AstNoVectorize(..), AstNoVectorizeS(..), AstNoVectorizeWrap(..)
Expand Down Expand Up @@ -757,22 +757,19 @@ instance Boolean AstBool where

-- * Boolean definitions and instances

type BoolOf :: TensorType ty -> Type
type BoolOf f = SimpleBoolOf f

class Boolean (SimpleBoolOf f) => IfF (f :: TensorType ty) where
class Boolean (BoolOf f) => IfF (f :: TensorType ty) where
ifF :: (GoodScalar r, HasSingletonDict y)
=> BoolOf f -> f r y -> f r y -> f r y

infix 4 ==., /=.
class Boolean (SimpleBoolOf f) => EqF (f :: TensorType ty) where
class Boolean (BoolOf f) => EqF (f :: TensorType ty) where
-- The existential variables here are handled in instances, e.g., via AstRel.
(==.), (/=.) :: (GoodScalar r, HasSingletonDict y)
=> f r y -> f r y -> BoolOf f
u /=. v = notB (u ==. v)

infix 4 <., <=., >=., >.
class Boolean (SimpleBoolOf f) => OrdF (f :: TensorType ty) where
class Boolean (BoolOf f) => OrdF (f :: TensorType ty) where
-- The existential variables here are handled in instances, e.g., via AstRel.
(<.), (<=.), (>.), (>=.) :: (GoodScalar r, HasSingletonDict y)
=> f r y -> f r y -> BoolOf f
Expand Down
10 changes: 5 additions & 5 deletions src/HordeAd/Core/DualNumber.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ makeADInputs =

-- * Assorted instances

type instance SimpleBoolOf (ADVal f) = SimpleBoolOf f
type instance BoolOf (ADVal f) = BoolOf f

instance EqF f => EqF (ADVal f) where
D u _ ==. D v _ = u ==. v
Expand All @@ -193,8 +193,8 @@ fromList lu =
(FromListR $ map (\(D _ u') -> u') lu)

instance ( RankedTensor ranked, IfF (RankedOf (PrimalOf ranked))
, Boolean (SimpleBoolOf ranked)
, SimpleBoolOf (RankedOf (PrimalOf ranked)) ~ SimpleBoolOf ranked )
, Boolean (BoolOf ranked)
, BoolOf (RankedOf (PrimalOf ranked)) ~ BoolOf ranked )
=> IfF (ADVal ranked) where
ifF b v w =
let D u u' = indexPrimal (fromList [v, w])
Expand All @@ -216,8 +216,8 @@ fromListS lu = assert (length lu == valueOf @n) $
(FromListS $ map (\(D _ u') -> u') lu)

instance ( ShapedTensor shaped, IfF (RankedOf (PrimalOf shaped))
, Boolean (SimpleBoolOf shaped)
, SimpleBoolOf (RankedOf (PrimalOf shaped)) ~ SimpleBoolOf shaped )
, Boolean (BoolOf shaped)
, BoolOf (RankedOf (PrimalOf shaped)) ~ BoolOf shaped )
=> IfF (ADVal shaped) where
ifF b v w =
let D u u' = indexPrimalS (fromListS @2 [v, w])
Expand Down
16 changes: 8 additions & 8 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ fwdProduceArtifact g envInit =

-- * Unlawful boolean instances of ranked AST; they are lawful modulo evaluation

type instance SimpleBoolOf (AstRanked s) = AstBool
type instance BoolOf (AstRanked s) = AstBool

instance AstSpan s => EqF (AstRanked s) where
v ==. u = AstRel EqOp (astSpanPrimal v) (astSpanPrimal u)
Expand All @@ -192,7 +192,7 @@ instance IfF (AstRanked s) where

-- * Unlawful boolean instances of shaped AST; they are lawful modulo evaluation

type instance SimpleBoolOf (AstShaped s) = AstBool
type instance BoolOf (AstShaped s) = AstBool

instance AstSpan s => EqF (AstShaped s) where
v ==. u = AstRelS EqOp (astSpanPrimalS v) (astSpanPrimalS u)
Expand Down Expand Up @@ -680,7 +680,7 @@ astBuildHVector1Vectorize k f = build1VectorizeHVector k $ funToAstI f

-- * The AstRaw, AstNoVectorize and AstNoSimplify instances

type instance SimpleBoolOf (AstRaw s) = AstBool
type instance BoolOf (AstRaw s) = AstBool

deriving instance IfF (AstRaw s)
deriving instance AstSpan s => EqF (AstRaw s)
Expand All @@ -702,7 +702,7 @@ deriving instance (RealFrac (AstRanked s r n))
deriving instance (RealFloat (AstRanked s r n))
=> RealFloat (AstRaw s r n)

type instance SimpleBoolOf (AstRawS s) = AstBool
type instance BoolOf (AstRawS s) = AstBool

deriving instance IfF (AstRawS s)
deriving instance AstSpan s => EqF (AstRawS s)
Expand All @@ -724,7 +724,7 @@ deriving instance (RealFrac (AstShaped s r sh))
deriving instance (RealFloat (AstShaped s r sh))
=> RealFloat (AstRawS s r sh)

type instance SimpleBoolOf (AstNoVectorize s) = AstBool
type instance BoolOf (AstNoVectorize s) = AstBool

deriving instance IfF (AstNoVectorize s)
deriving instance AstSpan s => EqF (AstNoVectorize s)
Expand All @@ -746,7 +746,7 @@ deriving instance (RealFrac (AstRanked s r n))
deriving instance (RealFloat (AstRanked s r n))
=> RealFloat (AstNoVectorize s r n)

type instance SimpleBoolOf (AstNoVectorizeS s) = AstBool
type instance BoolOf (AstNoVectorizeS s) = AstBool

deriving instance IfF (AstNoVectorizeS s)
deriving instance AstSpan s => EqF (AstNoVectorizeS s)
Expand All @@ -768,7 +768,7 @@ deriving instance (RealFrac (AstShaped s r sh))
deriving instance (RealFloat (AstShaped s r sh))
=> RealFloat (AstNoVectorizeS s r sh)

type instance SimpleBoolOf (AstNoSimplify s) = AstBool
type instance BoolOf (AstNoSimplify s) = AstBool

deriving instance IfF (AstNoSimplify s)
deriving instance AstSpan s => EqF (AstNoSimplify s)
Expand All @@ -790,7 +790,7 @@ deriving instance (RealFrac (AstRanked s r n))
deriving instance (RealFloat (AstRanked s r n))
=> RealFloat (AstNoSimplify s r n)

type instance SimpleBoolOf (AstNoSimplifyS s) = AstBool
type instance BoolOf (AstNoSimplifyS s) = AstBool

deriving instance IfF (AstNoSimplifyS s)
deriving instance AstSpan s => EqF (AstNoSimplifyS s)
Expand Down
14 changes: 7 additions & 7 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1148,13 +1148,13 @@ type ADReadySmall ranked shaped =
, PrimalOf ranked ~ RankedOf (PrimalOf shaped)
, ShapedOf (PrimalOf ranked) ~ PrimalOf shaped
, PrimalOf shaped ~ ShapedOf (PrimalOf ranked)
, SimpleBoolOf ranked ~ SimpleBoolOf shaped
, SimpleBoolOf shaped ~ SimpleBoolOf ranked
, SimpleBoolOf ranked ~ SimpleBoolOf (PrimalOf ranked)
, SimpleBoolOf (PrimalOf ranked) ~ SimpleBoolOf ranked
, SimpleBoolOf shaped ~ SimpleBoolOf (PrimalOf shaped)
, SimpleBoolOf (PrimalOf shaped) ~ SimpleBoolOf shaped
, Boolean (SimpleBoolOf ranked)
, BoolOf ranked ~ BoolOf shaped
, BoolOf shaped ~ BoolOf ranked
, BoolOf ranked ~ BoolOf (PrimalOf ranked)
, BoolOf (PrimalOf ranked) ~ BoolOf ranked
, BoolOf shaped ~ BoolOf (PrimalOf shaped)
, BoolOf (PrimalOf shaped) ~ BoolOf shaped
, Boolean (BoolOf ranked)
, IfF ranked, IfF shaped, IfF (PrimalOf ranked), IfF (PrimalOf shaped)
, EqF ranked, EqF shaped, EqF (PrimalOf ranked), EqF (PrimalOf shaped)
, OrdF ranked, OrdF shaped, OrdF (PrimalOf ranked), OrdF (PrimalOf shaped)
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/TensorConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import HordeAd.Core.Types
import HordeAd.Internal.BackendConcrete
import HordeAd.Util.ShapedList (shapedNat, unShapedNat)

type instance SimpleBoolOf (Flip OR.Array) = Bool
type instance BoolOf (Flip OR.Array) = Bool

instance EqF (Flip OR.Array) where
u ==. v = u == v
Expand Down Expand Up @@ -117,7 +117,7 @@ instance RankedTensor (Flip OR.Array) where
rD u _ = u
rScale _ _ = DummyDual

type instance SimpleBoolOf (Flip OS.Array) = Bool
type instance BoolOf (Flip OS.Array) = Bool

instance EqF (Flip OS.Array) where
u ==. v = u == v
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module HordeAd.Core.Types
-- * Generic types of indexes used in tensor operations
, IntOf, IndexOf, IntSh, IndexSh
-- * Generic types of booleans used in tensor operations
, SimpleBoolOf, Boolean(..)
, BoolOf, Boolean(..)
-- * Definitions to help express and manipulate type-level natural numbers
, SNat, pattern SNat, withSNat, sNatValue, proxyFromSNat
) where
Expand Down Expand Up @@ -139,7 +139,7 @@ type IndexSh (f :: TensorType ty) (sh :: [Nat]) = IndexS sh (IntOf f)

-- * Generic types of booleans used in tensor operations

type family SimpleBoolOf (t :: ty) :: Type
type family BoolOf (t :: ty) :: Type


-- * Definitions to help express and manipulate type-level natural numbers
Expand Down

0 comments on commit 616fb1c

Please sign in to comment.