Skip to content

Commit

Permalink
Use a custom RealFloatF to avoid the OS.Shape constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Apr 26, 2024
1 parent 7914861 commit 3d556e2
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 131 deletions.
43 changes: 8 additions & 35 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ import Type.Reflection (Typeable, eqTypeRep, typeRep, (:~~:) (HRefl))

import HordeAd.Core.HVector
import HordeAd.Core.Types
import HordeAd.Util.ShapedList (SizedListS (..), IndexS)
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..))
import HordeAd.Util.ShapedList (IndexS, SizedListS (..))
import HordeAd.Util.SizedList
import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF(..))

-- * Basic type family instances

Expand Down Expand Up @@ -142,10 +143,10 @@ varNameToAstVarId (AstVarName varId) = varId
-- The reverse derivative artifact from step 6) of our full pipeline.
-- The same type can also hold the forward derivative artifact.
data AstArtifact = AstArtifact
{ artVarsDt :: [AstDynamicVarName]
{ artVarsDt :: [AstDynamicVarName]
, artVarsPrimal :: [AstDynamicVarName]
, artDerivative :: HVectorOf (AstRaw PrimalSpan)
, artPrimal :: HVectorOf (AstRaw PrimalSpan)
, artPrimal :: HVectorOf (AstRaw PrimalSpan)
}

-- | This is the (arbitrarily) chosen representation of terms denoting
Expand Down Expand Up @@ -688,15 +689,6 @@ instance (Num (OS.Array sh r), AstSpan s, OS.Shape sh)
-- it's crucial that there is no AstConstant in fromInteger code
-- so that we don't need 4 times the simplification rules

instance (Real (OS.Array sh r), AstSpan s, OS.Shape sh)
=> Real (AstShaped s r sh) where
toRational = undefined
-- very low priority, since these are all extremely not continuous

instance Enum r => Enum (AstShaped s r n) where
toEnum = undefined
fromEnum = undefined -- do we need to define our own Enum class for this?

-- Warning: div and mod operations are very costly (simplifying them
-- requires constructing conditionals, etc). If this error is removed,
-- they are going to work, but slowly.
Expand Down Expand Up @@ -731,28 +723,9 @@ instance (Differentiable r, Floating (OS.Array sh r), AstSpan s, OS.Shape sh)
acosh = AstR1S AcoshOp
atanh = AstR1S AtanhOp

instance (Differentiable r, RealFrac (OS.Array sh r), AstSpan s, OS.Shape sh)
=> RealFrac (AstShaped s r sh) where
properFraction = undefined
-- The integral type doesn't have a Storable constraint,
-- so we can't implement this (nor RealFracB from Boolean package).

instance (Differentiable r, RealFloat (OS.Array sh r), AstSpan s, OS.Shape sh)
=> RealFloat (AstShaped s r sh) where
atan2 = AstR2S Atan2Op
-- We can be selective here and omit the other methods,
-- most of which don't even have a differentiable codomain.
floatRadix = undefined
floatDigits = undefined
floatRange = undefined
decodeFloat = undefined
encodeFloat = undefined
isNaN = undefined
isInfinite = undefined
isDenormalized = undefined
isNegativeZero = undefined
isIEEE = undefined

instance (Differentiable r, Floating (OS.Array sh r), AstSpan s, OS.Shape sh)
=> RealFloatF (AstShaped s r sh) where
atan2F = AstR2S Atan2Op

-- * Unlawful instances of AST for bool; they are lawful modulo evaluation

Expand Down
14 changes: 12 additions & 2 deletions src/HordeAd/Core/AstEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module HordeAd.Core.AstEnv
, interpretLambdaHsH
-- * Interpretation of arithmetic, boolean and relation operations
, interpretAstN1, interpretAstN2, interpretAstR1, interpretAstR2
, interpretAstR2F
, interpretAstI2, interpretAstI2F, interpretAstB2, interpretAstRelOp
) where

Expand All @@ -31,7 +32,8 @@ import HordeAd.Core.HVector
import HordeAd.Core.HVectorOps
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..))
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..))
import HordeAd.Util.ShapedList (IndexSh, IntSh)
import qualified HordeAd.Util.ShapedList as ShapedList
import HordeAd.Util.SizedList
Expand Down Expand Up @@ -250,7 +252,7 @@ interpretAstN2 :: Num a
interpretAstN2 MinusOp u v = u - v
interpretAstN2 TimesOp u v = u * v

interpretAstR1 :: RealFloat a
interpretAstR1 :: Floating a
=> OpCode1 -> a -> a
{-# INLINE interpretAstR1 #-}
interpretAstR1 RecipOp u = recip u
Expand Down Expand Up @@ -278,6 +280,14 @@ interpretAstR2 PowerOp u v = u ** v
interpretAstR2 LogBaseOp u v = logBase u v
interpretAstR2 Atan2Op u v = atan2 u v

interpretAstR2F :: RealFloatF a
=> OpCode2 -> a -> a -> a
{-# INLINE interpretAstR2F #-}
interpretAstR2F DivideOp u v = u / v
interpretAstR2F PowerOp u v = u ** v
interpretAstR2F LogBaseOp u v = logBase u v
interpretAstR2F Atan2Op u v = atan2F u v

interpretAstI2 :: Integral a
=> OpCodeIntegral2 -> a -> a -> a
{-# INLINE interpretAstI2 #-}
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ interpretAstS !env = \case
AstR2S opCode u v ->
let u2 = interpretAstS env u
v2 = interpretAstS env v
in interpretAstR2 opCode u2 v2
in interpretAstR2F opCode u2 v2
AstI2S opCode u v ->
let u2 = interpretAstS env u
v2 = interpretAstS env v
Expand Down
11 changes: 10 additions & 1 deletion src/HordeAd/Core/DualNumber.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ import HordeAd.Core.HVector
import HordeAd.Core.IsPrimal
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..))
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..))
import HordeAd.Util.ShapedList (IndexSh)
import qualified HordeAd.Util.ShapedList as ShapedList
import HordeAd.Util.SizedList
Expand Down Expand Up @@ -397,6 +398,14 @@ instance (RealFrac (f r z), IsPrimal f r z)
-- The integral type doesn't have a Storable constraint,
-- so we can't implement this (nor RealFracB from Boolean package).

instance (Fractional (f r z), RealFloatF (f r z), IsPrimal f r z)
=> RealFloatF (ADVal f r z) where
atan2F (D ue u') (D ve v') =
let !u = sharePrimal ue in
let !v = sharePrimal ve in
let !t = sharePrimal (recip (u * u + v * v))
in dD (atan2F u v) (dAdd (dScale ((- u) * t) v') (dScale (v * t) u'))

instance (RealFloat (f r z), IsPrimal f r z)
=> RealFloat (ADVal f r z) where
{- TODO: this causes a cyclic dependency:
Expand Down
30 changes: 8 additions & 22 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ import HordeAd.Core.TensorADVal (unADValHVector)
import HordeAd.Core.TensorClass
import HordeAd.Core.TensorConcrete ()
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..))
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..))
import HordeAd.Util.ShapedList (IntSh)
import qualified HordeAd.Util.ShapedList as ShapedList
import HordeAd.Util.SizedList
Expand Down Expand Up @@ -685,19 +686,14 @@ deriving instance AstSpan s => OrdF (AstRawS s)
deriving instance Eq ((AstRawS s) r sh)
deriving instance Ord ((AstRawS s) r sh)
deriving instance Num (AstShaped s r sh) => Num (AstRawS s r sh)
deriving instance (Real (AstShaped s r sh))
=> Real (AstRawS s r sh)
deriving instance Enum (AstShaped s r sh) => Enum (AstRawS s r sh)
deriving instance (IntegralF (AstShaped s r sh))
=> IntegralF (AstRawS s r sh)
deriving instance Fractional (AstShaped s r sh)
=> Fractional (AstRawS s r sh)
deriving instance Floating (AstShaped s r sh)
=> Floating (AstRawS s r sh)
deriving instance (RealFrac (AstShaped s r sh))
=> RealFrac (AstRawS s r sh)
deriving instance (RealFloat (AstShaped s r sh))
=> RealFloat (AstRawS s r sh)
deriving instance (RealFloatF (AstShaped s r sh))
=> RealFloatF (AstRawS s r sh)

type instance BoolOf (AstNoVectorize s) = AstBool

Expand Down Expand Up @@ -729,19 +725,14 @@ deriving instance AstSpan s => OrdF (AstNoVectorizeS s)
deriving instance Eq ((AstNoVectorizeS s) r sh)
deriving instance Ord ((AstNoVectorizeS s) r sh)
deriving instance Num (AstShaped s r sh) => Num (AstNoVectorizeS s r sh)
deriving instance (Real (AstShaped s r sh))
=> Real (AstNoVectorizeS s r sh)
deriving instance Enum (AstShaped s r sh) => Enum (AstNoVectorizeS s r sh)
deriving instance (IntegralF (AstShaped s r sh))
=> IntegralF (AstNoVectorizeS s r sh)
deriving instance Fractional (AstShaped s r sh)
=> Fractional (AstNoVectorizeS s r sh)
deriving instance Floating (AstShaped s r sh)
=> Floating (AstNoVectorizeS s r sh)
deriving instance (RealFrac (AstShaped s r sh))
=> RealFrac (AstNoVectorizeS s r sh)
deriving instance (RealFloat (AstShaped s r sh))
=> RealFloat (AstNoVectorizeS s r sh)
deriving instance (RealFloatF (AstShaped s r sh))
=> RealFloatF (AstNoVectorizeS s r sh)

type instance BoolOf (AstNoSimplify s) = AstBool

Expand Down Expand Up @@ -773,19 +764,14 @@ deriving instance AstSpan s => OrdF (AstNoSimplifyS s)
deriving instance Eq (AstNoSimplifyS s r sh)
deriving instance Ord (AstNoSimplifyS s r sh)
deriving instance Num (AstShaped s r sh) => Num (AstNoSimplifyS s r sh)
deriving instance (Real (AstShaped s r sh))
=> Real (AstNoSimplifyS s r sh)
deriving instance Enum (AstShaped s r sh) => Enum (AstNoSimplifyS s r sh)
deriving instance (IntegralF (AstShaped s r sh))
=> IntegralF (AstNoSimplifyS s r sh)
deriving instance Fractional (AstShaped s r sh)
=> Fractional (AstNoSimplifyS s r sh)
deriving instance Floating (AstShaped s r sh)
=> Floating (AstNoSimplifyS s r sh)
deriving instance (RealFrac (AstShaped s r sh))
=> RealFrac (AstNoSimplifyS s r sh)
deriving instance (RealFloat (AstShaped s r sh))
=> RealFloat (AstNoSimplifyS s r sh)
deriving instance (RealFloatF (AstShaped s r sh))
=> RealFloatF (AstNoSimplifyS s r sh)

instance AstSpan s => RankedTensor (AstRaw s) where
rlet a f = AstRaw $ astLetFunRaw (unAstRaw a) (unAstRaw . f . AstRaw)
Expand Down
7 changes: 5 additions & 2 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ import Unsafe.Coerce (unsafeCoerce)

import HordeAd.Core.HVector
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..))
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..))
import HordeAd.Util.ShapedList
(IndexSh, IntSh, ShapeIntS, consIndex, pattern (:.$), pattern ZIS)
import qualified HordeAd.Util.ShapedList as ShapedList
Expand Down Expand Up @@ -336,7 +337,9 @@ class ( Integral (IntOf ranked), CRanked ranked Num
-- * Shaped tensor class definition

class ( Integral (IntOf shaped), CShaped shaped Num
, TensorSupports RealFloat shaped, TensorSupports2 Integral IntegralF shaped )
, TensorSupports2 RealFloat Floating shaped
, TensorSupports2 RealFloat RealFloatF shaped
, TensorSupports2 Integral IntegralF shaped )
=> ShapedTensor (shaped :: ShapedTensorType) where

slet :: (KnownShape sh, KnownShape sh2, GoodScalar r, GoodScalar r2)
Expand Down
31 changes: 9 additions & 22 deletions src/HordeAd/Internal/OrthotopeOrphanInstances.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module HordeAd.Internal.OrthotopeOrphanInstances
, Dict(..)
, -- * Numeric instances for tensors
liftVR, liftVR2, liftVS, liftVS2
, IntegralF(..)
, IntegralF(..), RealFloatF(..)
, -- * Assorted orphans and additions
MapSucc, trustMeThisIsAPermutation
) where
Expand Down Expand Up @@ -416,21 +416,13 @@ instance (Real (Vector r), KnownNat n, Numeric r, Show r, Ord r)
=> Real (OR.Array n r) where
toRational = undefined -- TODO

instance (Real (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Ord r)
=> Real (OS.Array sh r) where
toRational = undefined -- TODO

instance ( RealFrac (Vector r), KnownNat n, Numeric r, Show r, Fractional r
, Ord r )
=> RealFrac (OR.Array n r) where
properFraction = error "OR.properFraction: can't be implemented"
-- The integral type doesn't have a Storable constraint,
-- so we can't implement this (nor even RealFracB from Boolean package).

instance (RealFrac (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Fractional r, Ord r)
=> RealFrac (OS.Array sh r) where
properFraction = error "OS.properFraction: can't be implemented"

instance ( RealFloat (Vector r), KnownNat n, Numeric r, Show r, Floating r
, Ord r )
=> RealFloat (OR.Array n r) where
Expand All @@ -446,19 +438,12 @@ instance ( RealFloat (Vector r), KnownNat n, Numeric r, Show r, Floating r
isNegativeZero = undefined
isIEEE = undefined

instance (RealFloat (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Floating r, Ord r)
=> RealFloat (OS.Array sh r) where
atan2 = liftVS2NoAdapt atan2
floatRadix = undefined -- TODO (and below)
floatDigits = undefined
floatRange = undefined
decodeFloat = undefined
encodeFloat = undefined
isNaN = undefined
isInfinite = undefined
isDenormalized = undefined
isNegativeZero = undefined
isIEEE = undefined
class Floating a => RealFloatF a where
atan2F :: a -> a -> a

instance (Floating r, RealFloat (Vector r), KnownShape2 sh, Numeric r)
=> RealFloatF (OS.Array sh r) where
atan2F = liftVS2NoAdapt atan2

deriving instance Num (f a b) => Num (Flip f b a)

Expand All @@ -476,6 +461,8 @@ deriving instance Real (f a b) => Real (Flip f b a)

deriving instance RealFrac (f a b) => RealFrac (Flip f b a)

deriving instance RealFloatF (f a b) => RealFloatF (Flip f b a)

deriving instance RealFloat (f a b) => RealFloat (Flip f b a)

deriving instance NFData (f a b) => NFData (Flip f b a)
Expand Down
Loading

0 comments on commit 3d556e2

Please sign in to comment.