diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index db1b88365..0177a8ebd 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -40,9 +40,10 @@ import Type.Reflection (Typeable, eqTypeRep, typeRep, (:~~:) (HRefl)) import HordeAd.Core.HVector import HordeAd.Core.Types -import HordeAd.Util.ShapedList (SizedListS (..), IndexS) +import HordeAd.Internal.OrthotopeOrphanInstances + (IntegralF (..), RealFloatF (..)) +import HordeAd.Util.ShapedList (IndexS, SizedListS (..)) import HordeAd.Util.SizedList -import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF(..)) -- * Basic type family instances @@ -142,10 +143,10 @@ varNameToAstVarId (AstVarName varId) = varId -- The reverse derivative artifact from step 6) of our full pipeline. -- The same type can also hold the forward derivative artifact. data AstArtifact = AstArtifact - { artVarsDt :: [AstDynamicVarName] + { artVarsDt :: [AstDynamicVarName] , artVarsPrimal :: [AstDynamicVarName] , artDerivative :: HVectorOf (AstRaw PrimalSpan) - , artPrimal :: HVectorOf (AstRaw PrimalSpan) + , artPrimal :: HVectorOf (AstRaw PrimalSpan) } -- | This is the (arbitrarily) chosen representation of terms denoting @@ -688,15 +689,6 @@ instance (Num (OS.Array sh r), AstSpan s, OS.Shape sh) -- it's crucial that there is no AstConstant in fromInteger code -- so that we don't need 4 times the simplification rules -instance (Real (OS.Array sh r), AstSpan s, OS.Shape sh) - => Real (AstShaped s r sh) where - toRational = undefined - -- very low priority, since these are all extremely not continuous - -instance Enum r => Enum (AstShaped s r n) where - toEnum = undefined - fromEnum = undefined -- do we need to define our own Enum class for this? - -- Warning: div and mod operations are very costly (simplifying them -- requires constructing conditionals, etc). If this error is removed, -- they are going to work, but slowly. @@ -731,28 +723,9 @@ instance (Differentiable r, Floating (OS.Array sh r), AstSpan s, OS.Shape sh) acosh = AstR1S AcoshOp atanh = AstR1S AtanhOp -instance (Differentiable r, RealFrac (OS.Array sh r), AstSpan s, OS.Shape sh) - => RealFrac (AstShaped s r sh) where - properFraction = undefined - -- The integral type doesn't have a Storable constraint, - -- so we can't implement this (nor RealFracB from Boolean package). - -instance (Differentiable r, RealFloat (OS.Array sh r), AstSpan s, OS.Shape sh) - => RealFloat (AstShaped s r sh) where - atan2 = AstR2S Atan2Op - -- We can be selective here and omit the other methods, - -- most of which don't even have a differentiable codomain. - floatRadix = undefined - floatDigits = undefined - floatRange = undefined - decodeFloat = undefined - encodeFloat = undefined - isNaN = undefined - isInfinite = undefined - isDenormalized = undefined - isNegativeZero = undefined - isIEEE = undefined - +instance (Differentiable r, Floating (OS.Array sh r), AstSpan s, OS.Shape sh) + => RealFloatF (AstShaped s r sh) where + atan2F = AstR2S Atan2Op -- * Unlawful instances of AST for bool; they are lawful modulo evaluation diff --git a/src/HordeAd/Core/AstEnv.hs b/src/HordeAd/Core/AstEnv.hs index b19fdc7ff..e027659ef 100644 --- a/src/HordeAd/Core/AstEnv.hs +++ b/src/HordeAd/Core/AstEnv.hs @@ -13,6 +13,7 @@ module HordeAd.Core.AstEnv , interpretLambdaHsH -- * Interpretation of arithmetic, boolean and relation operations , interpretAstN1, interpretAstN2, interpretAstR1, interpretAstR2 + , interpretAstR2F , interpretAstI2, interpretAstI2F, interpretAstB2, interpretAstRelOp ) where @@ -31,7 +32,8 @@ import HordeAd.Core.HVector import HordeAd.Core.HVectorOps import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..)) +import HordeAd.Internal.OrthotopeOrphanInstances + (IntegralF (..), RealFloatF (..)) import HordeAd.Util.ShapedList (IndexSh, IntSh) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedList @@ -250,7 +252,7 @@ interpretAstN2 :: Num a interpretAstN2 MinusOp u v = u - v interpretAstN2 TimesOp u v = u * v -interpretAstR1 :: RealFloat a +interpretAstR1 :: Floating a => OpCode1 -> a -> a {-# INLINE interpretAstR1 #-} interpretAstR1 RecipOp u = recip u @@ -278,6 +280,14 @@ interpretAstR2 PowerOp u v = u ** v interpretAstR2 LogBaseOp u v = logBase u v interpretAstR2 Atan2Op u v = atan2 u v +interpretAstR2F :: RealFloatF a + => OpCode2 -> a -> a -> a +{-# INLINE interpretAstR2F #-} +interpretAstR2F DivideOp u v = u / v +interpretAstR2F PowerOp u v = u ** v +interpretAstR2F LogBaseOp u v = logBase u v +interpretAstR2F Atan2Op u v = atan2F u v + interpretAstI2 :: Integral a => OpCodeIntegral2 -> a -> a -> a {-# INLINE interpretAstI2 #-} diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index f07d12700..6098638cf 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -583,7 +583,7 @@ interpretAstS !env = \case AstR2S opCode u v -> let u2 = interpretAstS env u v2 = interpretAstS env v - in interpretAstR2 opCode u2 v2 + in interpretAstR2F opCode u2 v2 AstI2S opCode u v -> let u2 = interpretAstS env u v2 = interpretAstS env v diff --git a/src/HordeAd/Core/DualNumber.hs b/src/HordeAd/Core/DualNumber.hs index cb6303fd8..cd1b6b8e7 100644 --- a/src/HordeAd/Core/DualNumber.hs +++ b/src/HordeAd/Core/DualNumber.hs @@ -31,7 +31,8 @@ import HordeAd.Core.HVector import HordeAd.Core.IsPrimal import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..)) +import HordeAd.Internal.OrthotopeOrphanInstances + (IntegralF (..), RealFloatF (..)) import HordeAd.Util.ShapedList (IndexSh) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedList @@ -397,6 +398,14 @@ instance (RealFrac (f r z), IsPrimal f r z) -- The integral type doesn't have a Storable constraint, -- so we can't implement this (nor RealFracB from Boolean package). +instance (Fractional (f r z), RealFloatF (f r z), IsPrimal f r z) + => RealFloatF (ADVal f r z) where + atan2F (D ue u') (D ve v') = + let !u = sharePrimal ue in + let !v = sharePrimal ve in + let !t = sharePrimal (recip (u * u + v * v)) + in dD (atan2F u v) (dAdd (dScale ((- u) * t) v') (dScale (v * t) u')) + instance (RealFloat (f r z), IsPrimal f r z) => RealFloat (ADVal f r z) where {- TODO: this causes a cyclic dependency: diff --git a/src/HordeAd/Core/TensorAst.hs b/src/HordeAd/Core/TensorAst.hs index 6b2c6b2d0..41e3bf647 100644 --- a/src/HordeAd/Core/TensorAst.hs +++ b/src/HordeAd/Core/TensorAst.hs @@ -44,7 +44,8 @@ import HordeAd.Core.TensorADVal (unADValHVector) import HordeAd.Core.TensorClass import HordeAd.Core.TensorConcrete () import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..)) +import HordeAd.Internal.OrthotopeOrphanInstances + (IntegralF (..), RealFloatF (..)) import HordeAd.Util.ShapedList (IntSh) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedList @@ -685,19 +686,14 @@ deriving instance AstSpan s => OrdF (AstRawS s) deriving instance Eq ((AstRawS s) r sh) deriving instance Ord ((AstRawS s) r sh) deriving instance Num (AstShaped s r sh) => Num (AstRawS s r sh) -deriving instance (Real (AstShaped s r sh)) - => Real (AstRawS s r sh) -deriving instance Enum (AstShaped s r sh) => Enum (AstRawS s r sh) deriving instance (IntegralF (AstShaped s r sh)) => IntegralF (AstRawS s r sh) deriving instance Fractional (AstShaped s r sh) => Fractional (AstRawS s r sh) deriving instance Floating (AstShaped s r sh) => Floating (AstRawS s r sh) -deriving instance (RealFrac (AstShaped s r sh)) - => RealFrac (AstRawS s r sh) -deriving instance (RealFloat (AstShaped s r sh)) - => RealFloat (AstRawS s r sh) +deriving instance (RealFloatF (AstShaped s r sh)) + => RealFloatF (AstRawS s r sh) type instance BoolOf (AstNoVectorize s) = AstBool @@ -729,19 +725,14 @@ deriving instance AstSpan s => OrdF (AstNoVectorizeS s) deriving instance Eq ((AstNoVectorizeS s) r sh) deriving instance Ord ((AstNoVectorizeS s) r sh) deriving instance Num (AstShaped s r sh) => Num (AstNoVectorizeS s r sh) -deriving instance (Real (AstShaped s r sh)) - => Real (AstNoVectorizeS s r sh) -deriving instance Enum (AstShaped s r sh) => Enum (AstNoVectorizeS s r sh) deriving instance (IntegralF (AstShaped s r sh)) => IntegralF (AstNoVectorizeS s r sh) deriving instance Fractional (AstShaped s r sh) => Fractional (AstNoVectorizeS s r sh) deriving instance Floating (AstShaped s r sh) => Floating (AstNoVectorizeS s r sh) -deriving instance (RealFrac (AstShaped s r sh)) - => RealFrac (AstNoVectorizeS s r sh) -deriving instance (RealFloat (AstShaped s r sh)) - => RealFloat (AstNoVectorizeS s r sh) +deriving instance (RealFloatF (AstShaped s r sh)) + => RealFloatF (AstNoVectorizeS s r sh) type instance BoolOf (AstNoSimplify s) = AstBool @@ -773,19 +764,14 @@ deriving instance AstSpan s => OrdF (AstNoSimplifyS s) deriving instance Eq (AstNoSimplifyS s r sh) deriving instance Ord (AstNoSimplifyS s r sh) deriving instance Num (AstShaped s r sh) => Num (AstNoSimplifyS s r sh) -deriving instance (Real (AstShaped s r sh)) - => Real (AstNoSimplifyS s r sh) -deriving instance Enum (AstShaped s r sh) => Enum (AstNoSimplifyS s r sh) deriving instance (IntegralF (AstShaped s r sh)) => IntegralF (AstNoSimplifyS s r sh) deriving instance Fractional (AstShaped s r sh) => Fractional (AstNoSimplifyS s r sh) deriving instance Floating (AstShaped s r sh) => Floating (AstNoSimplifyS s r sh) -deriving instance (RealFrac (AstShaped s r sh)) - => RealFrac (AstNoSimplifyS s r sh) -deriving instance (RealFloat (AstShaped s r sh)) - => RealFloat (AstNoSimplifyS s r sh) +deriving instance (RealFloatF (AstShaped s r sh)) + => RealFloatF (AstNoSimplifyS s r sh) instance AstSpan s => RankedTensor (AstRaw s) where rlet a f = AstRaw $ astLetFunRaw (unAstRaw a) (unAstRaw . f . AstRaw) diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index 205c7f13d..87233fa8b 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -37,7 +37,8 @@ import Unsafe.Coerce (unsafeCoerce) import HordeAd.Core.HVector import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..)) +import HordeAd.Internal.OrthotopeOrphanInstances + (IntegralF (..), RealFloatF (..)) import HordeAd.Util.ShapedList (IndexSh, IntSh, ShapeIntS, consIndex, pattern (:.$), pattern ZIS) import qualified HordeAd.Util.ShapedList as ShapedList @@ -336,7 +337,9 @@ class ( Integral (IntOf ranked), CRanked ranked Num -- * Shaped tensor class definition class ( Integral (IntOf shaped), CShaped shaped Num - , TensorSupports RealFloat shaped, TensorSupports2 Integral IntegralF shaped ) + , TensorSupports2 RealFloat Floating shaped + , TensorSupports2 RealFloat RealFloatF shaped + , TensorSupports2 Integral IntegralF shaped ) => ShapedTensor (shaped :: ShapedTensorType) where slet :: (KnownShape sh, KnownShape sh2, GoodScalar r, GoodScalar r2) diff --git a/src/HordeAd/Internal/OrthotopeOrphanInstances.hs b/src/HordeAd/Internal/OrthotopeOrphanInstances.hs index 3d3e372dd..44dffed31 100644 --- a/src/HordeAd/Internal/OrthotopeOrphanInstances.hs +++ b/src/HordeAd/Internal/OrthotopeOrphanInstances.hs @@ -10,7 +10,7 @@ module HordeAd.Internal.OrthotopeOrphanInstances , Dict(..) , -- * Numeric instances for tensors liftVR, liftVR2, liftVS, liftVS2 - , IntegralF(..) + , IntegralF(..), RealFloatF(..) , -- * Assorted orphans and additions MapSucc, trustMeThisIsAPermutation ) where @@ -416,10 +416,6 @@ instance (Real (Vector r), KnownNat n, Numeric r, Show r, Ord r) => Real (OR.Array n r) where toRational = undefined -- TODO -instance (Real (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Ord r) - => Real (OS.Array sh r) where - toRational = undefined -- TODO - instance ( RealFrac (Vector r), KnownNat n, Numeric r, Show r, Fractional r , Ord r ) => RealFrac (OR.Array n r) where @@ -427,10 +423,6 @@ instance ( RealFrac (Vector r), KnownNat n, Numeric r, Show r, Fractional r -- The integral type doesn't have a Storable constraint, -- so we can't implement this (nor even RealFracB from Boolean package). -instance (RealFrac (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Fractional r, Ord r) - => RealFrac (OS.Array sh r) where - properFraction = error "OS.properFraction: can't be implemented" - instance ( RealFloat (Vector r), KnownNat n, Numeric r, Show r, Floating r , Ord r ) => RealFloat (OR.Array n r) where @@ -446,19 +438,12 @@ instance ( RealFloat (Vector r), KnownNat n, Numeric r, Show r, Floating r isNegativeZero = undefined isIEEE = undefined -instance (RealFloat (Vector r), KnownShape2 sh, OS.Shape sh, Numeric r, Floating r, Ord r) - => RealFloat (OS.Array sh r) where - atan2 = liftVS2NoAdapt atan2 - floatRadix = undefined -- TODO (and below) - floatDigits = undefined - floatRange = undefined - decodeFloat = undefined - encodeFloat = undefined - isNaN = undefined - isInfinite = undefined - isDenormalized = undefined - isNegativeZero = undefined - isIEEE = undefined +class Floating a => RealFloatF a where + atan2F :: a -> a -> a + +instance (Floating r, RealFloat (Vector r), KnownShape2 sh, Numeric r) + => RealFloatF (OS.Array sh r) where + atan2F = liftVS2NoAdapt atan2 deriving instance Num (f a b) => Num (Flip f b a) @@ -476,6 +461,8 @@ deriving instance Real (f a b) => Real (Flip f b a) deriving instance RealFrac (f a b) => RealFrac (Flip f b a) +deriving instance RealFloatF (f a b) => RealFloatF (Flip f b a) + deriving instance RealFloat (f a b) => RealFloat (Flip f b a) deriving instance NFData (f a b) => NFData (Flip f b a) diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index f5aa7a0e8..c0428749f 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -25,6 +25,7 @@ import HordeAd import HordeAd.Core.AstEnv import HordeAd.Core.AstFreshId (funToAstR, funToAstS, resetVarCounter) import HordeAd.Core.IsPrimal (resetIdCounter) +import HordeAd.Internal.OrthotopeOrphanInstances (RealFloatF (..)) import CrossTesting import EqEpsilon @@ -235,19 +236,19 @@ testZero3S :: Assertion testZero3S = assertEqualUpToEpsilon 1e-9 (Flip $ OS.fromList @'[33, 2] (replicate 66 3.6174114266850617)) - (crev (\x -> bar @(ADVal (Flip OS.Array) Double '[33, 2]) (x, x)) 1) + (crev (\x -> barF @(ADVal (Flip OS.Array) Double '[33, 2]) (x, x)) 1) testCFwdZero3S :: Assertion testCFwdZero3S = assertEqualUpToEpsilon 1e-9 (Flip $ OS.fromList @'[33, 2] (replicate 66 3.9791525693535674)) - (cfwd (\x -> bar @(ADVal (Flip OS.Array) Double '[33, 2]) (x, x)) 1 1.1) + (cfwd (\x -> barF @(ADVal (Flip OS.Array) Double '[33, 2]) (x, x)) 1 1.1) testFwdZero3S :: Assertion testFwdZero3S = assertEqualUpToEpsilon 1e-9 (Flip $ OS.fromList @'[33, 2] (replicate 66 3.9791525693535674)) - (fwd (\x -> bar @(AstShaped FullSpan Double '[33, 2]) (x, x)) 1 1.1) + (fwd (\x -> barF @(AstShaped FullSpan Double '[33, 2]) (x, x)) 1 1.1) testZero4S :: Assertion testZero4S = @@ -271,7 +272,7 @@ testZero6S = assertEqualUpToEpsilon 1e-9 (Flip $ OS.fromList @'[2, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,11,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,111,1,1,1,1, 2, 2, 2, 2] (replicate (product ([2, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,11,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,111,1,1,1,1, 2, 2, 2, 2] :: [Int])) 3.6174114266850617)) (rev @Double @'[2, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 2, 2, 2, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,11,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,111,1,1,1,1, 2, 2, 2, 2] - @(AstShaped FullSpan) (\x -> bar (x, x)) 1) + @(AstShaped FullSpan) (\x -> barF (x, x)) 1) testZero7S :: Assertion testZero7S = @@ -500,6 +501,11 @@ foo (x, y, z) = let w = x * sin y in atan2 z w + z * w +fooF :: RealFloatF a => (a, a, a) -> a +fooF (x, y, z) = + let w = x * sin y + in atan2F z w + z * w + testFoo :: Assertion testFoo = do assertEqualUpToEpsilon 1e-10 @@ -510,14 +516,14 @@ testFooS :: Assertion testFooS = do assertEqualUpToEpsilon 1e-10 (2.4396285219055063, -1.953374825727421, 0.9654825811012627) - (rev @Double @'[3, 534, 3] @(AstShaped FullSpan) foo (1.1, 2.2, 3.3)) + (rev @Double @'[3, 534, 3] @(AstShaped FullSpan) fooF (1.1, 2.2, 3.3)) testFooSToFloat :: Assertion testFooSToFloat = do assertEqualUpToEpsilon 1e-10 (2.4396285219055063, -1.953374825727421, 0.9654825811012627) (rev @Float @'[3, 534, 3] @(AstShaped FullSpan) - (scast . foo) + (scast . fooF) (1.1 :: Flip OS.Array Double '[3, 534, 3], 2.2, 3.3)) testFooSBoth :: Assertion @@ -525,7 +531,7 @@ testFooSBoth = do assertEqualUpToEpsilon 1e-10 (2.439628436155373, -1.9533749, 0.9654825479484146) (rev @Float @'[3, 534, 3] @(AstShaped FullSpan) - (scast . foo . (\(d, f, d2) -> (d, scast f, d2))) + (scast . fooF . (\(d, f, d2) -> (d, scast f, d2))) ( 1.1 :: Flip OS.Array Double '[3, 534, 3] , 2.2 :: Flip OS.Array Float '[3, 534, 3] , 3.3 )) @@ -1201,6 +1207,11 @@ bar (x, y) = let w = foo (x, y, x) * sin y in atan2 x w + y * w +barF :: forall a. RealFloatF a => (a, a) -> a +barF (x, y) = + let w = fooF (x, y, x) * sin y + in atan2F x w + y * w + testBar :: Assertion testBar = assertEqualUpToEpsilon 1e-9 @@ -1211,13 +1222,13 @@ testBarS :: Assertion testBarS = assertEqualUpToEpsilon 1e-9 (3.1435239435581166,-1.1053869545195814) - (crev (bar @(ADVal (Flip OS.Array) Double '[])) (1.1, 2.2)) + (crev (barF @(ADVal (Flip OS.Array) Double '[])) (1.1, 2.2)) testBar2S :: Assertion testBar2S = assertEqualUpToEpsilon 1e-9 (3.1435239435581166,-1.1053869545195814) - (rev (bar @(AstShaped FullSpan Double '[52, 2, 2, 1, 1, 3])) (1.1, 2.2)) + (rev (barF @(AstShaped FullSpan Double '[52, 2, 2, 1, 1, 3])) (1.1, 2.2)) testBarCFwd :: Assertion testBarCFwd = @@ -1526,9 +1537,9 @@ reluMaxS = smap0N (maxF 0) barReluMaxS :: ( ADReadyS shaped, GoodScalar r, KnownShape sh, KnownNat (Sh.Rank sh) - , RealFloat (shaped r sh) ) + , RealFloatF (shaped r sh) ) => shaped r sh -> shaped r sh -barReluMaxS x = reluMaxS $ bar (x, reluMaxS x) +barReluMaxS x = reluMaxS $ barF (x, reluMaxS x) -- Previously the shape of FromListR[ZeroR] couldn't be determined -- in buildDerivative, so this was needed. See below that it now works fine. diff --git a/test/simplified/TestHighRankSimplified.hs b/test/simplified/TestHighRankSimplified.hs index 72b0a2450..cf2b0d588 100644 --- a/test/simplified/TestHighRankSimplified.hs +++ b/test/simplified/TestHighRankSimplified.hs @@ -18,6 +18,7 @@ import Test.Tasty.HUnit hiding (assert) import HordeAd import HordeAd.Core.AstFreshId (funToAstR, resetVarCounter) +import HordeAd.Internal.OrthotopeOrphanInstances (RealFloatF (..)) import CrossTesting import EqEpsilon @@ -77,6 +78,11 @@ foo (x,y,z) = let w = x * sin y in atan2 z w + z * w +fooF :: RealFloatF a => (a,a,a) -> a +fooF (x,y,z) = + let w = x * sin y + in atan2F z w + z * w + testFoo :: Assertion testFoo = assertEqualUpToEpsilon 1e-3 @@ -88,6 +94,11 @@ bar (x, y) = let w = foo (x, y, x) * sin y in atan2 x w + y * w +barF :: forall a. RealFloatF a => (a, a) -> a +barF (x, y) = + let w = fooF (x, y, x) * sin y + in atan2F x w + y * w + testBar :: Assertion testBar = assertEqualUpToEpsilon 1e-5 @@ -98,7 +109,7 @@ testBarS :: Assertion testBarS = assertEqualUpToEpsilon 1e-5 (Flip $ OS.fromList @'[3, 1, 2, 2, 1, 2, 2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1917,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], Flip $ OS.fromList @'[3, 1, 2, 2, 1, 2, 2] [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34987.0,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) - (crev (bar @(ADVal (Flip OS.Array) Float '[3, 1, 2, 2, 1, 2, 2])) (sfromR t48, sfromR t48)) + (crev (barF @(ADVal (Flip OS.Array) Float '[3, 1, 2, 2, 1, 2, 2])) (sfromR t48, sfromR t48)) -- A dual-number and list-based version of a function that goes -- from `R^3` to `R`. diff --git a/test/simplified/TestRevFwdFold.hs b/test/simplified/TestRevFwdFold.hs index 081e801ba..bae4b7b50 100644 --- a/test/simplified/TestRevFwdFold.hs +++ b/test/simplified/TestRevFwdFold.hs @@ -20,6 +20,7 @@ import Test.Tasty.HUnit hiding (assert) import HordeAd import HordeAd.Core.AstFreshId (resetVarCounter) +import HordeAd.Internal.OrthotopeOrphanInstances (RealFloatF (..)) import HordeAd.Util.ShapedList (pattern (:.$), pattern ZIS) import CrossTesting @@ -635,7 +636,7 @@ testSin0Fold4S = do assertEqualUpToEpsilon' 1e-10 (-0.7053476446727861 :: OR.Array 0 Double) (rev' (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[] - f a0 = sfold (\x a -> atan2 (sin x) (sin a)) + f a0 = sfold (\x a -> atan2F (sin x) (sin a)) (2 * a0) (sreplicate @f @3 a0) in rfromS . f . sfromR) 1.1) @@ -648,9 +649,9 @@ testSin0Fold5S = do => f2 Double '[] -> f2 Double '[2, 5] -> f2 Double '[] g x a = ssum - $ atan2 (sin $ sreplicate @f2 @5 x) - (ssum $ sin $ ssum - $ str $ sreplicate @f2 @7 a) + $ atan2F (sin $ sreplicate @f2 @5 x) + (ssum $ sin $ ssum + $ str $ sreplicate @f2 @7 a) in g) (2 * a0) (sreplicate @f @3 @@ -689,9 +690,9 @@ testSin0Fold8S = do (rev' (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @3 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @3 a0) in rfromS . f . sfromR) 1.1) @@ -728,9 +729,9 @@ testSin0Fold8Srev = do (rrev1 (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @3 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @3 a0) in rfromS . f . sfromR) 1.1) @@ -742,9 +743,9 @@ testSin0Fold8Srev2 = do => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @3 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @3 a0) in f) @@ -758,9 +759,9 @@ testSin0Fold182SrevPP = do let a1 = rrev1 @(AstRanked FullSpan) (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[5] f a0 = sfold @_ @f @Double @Double @'[5] @'[] @1 - (\_x a -> atan2 (sreplicate @_ @5 a) - (sreplicate @_ @5 - $ sin (ssum $ sreplicate @_ @7 a))) + (\_x a -> atan2F (sreplicate @_ @5 a) + (sreplicate @_ @5 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @5 a0) (sreplicate @_ @1 a0) in rfromS . f . sfromR) 1.1 @@ -774,9 +775,9 @@ testSin0Fold18Srev = do (rrev1 (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @2 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @2 a0) in rfromS . f . sfromR) 1.1) @@ -813,9 +814,9 @@ testSin0Fold8Sfwd = do (rfwd1 (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @3 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @3 a0) in rfromS . f . sfromR) 1.1) @@ -827,9 +828,9 @@ testSin0Fold8Sfwd2 = do => f Double '[] -> f Double '[2, 5] f a0 = sfold @_ @f @Double @Double @'[2, 5] @'[] @3 (\x a -> str $ sreplicate @_ @5 - $ atan2 (ssum (str $ sin x)) - (sreplicate @_ @2 - $ sin (ssum $ sreplicate @_ @7 a))) + $ atan2F (ssum (str $ sin x)) + (sreplicate @_ @2 + $ sin (ssum $ sreplicate @_ @7 a))) (sreplicate @_ @2 (sreplicate @_ @5 (2 * a0))) (sreplicate @_ @3 a0) in rfromS . f . sfromR) @@ -846,9 +847,9 @@ testSin0Fold5Sfwd = do => f2 Double '[] -> f2 Double '[2, 5] -> f2 Double '[] g x a = ssum - $ atan2 (sin $ sreplicate @f2 @5 x) - (ssum $ sin $ ssum - $ str $ sreplicate @f2 @7 a) + $ atan2F (sin $ sreplicate @f2 @5 x) + (ssum $ sin $ ssum + $ str $ sreplicate @f2 @7 a) in g) (2 * a0) (sreplicate @f @3 @@ -865,9 +866,9 @@ testSin0Fold5Sfwds = do => f2 Double '[] -> f2 Double '[2, 5] -> f2 Double '[] g x a = ssum - $ atan2 (sin $ sreplicate @f2 @5 x) - (ssum $ sin $ ssum - $ str $ sreplicate @f2 @7 a) + $ atan2F (sin $ sreplicate @f2 @5 x) + (ssum $ sin $ ssum + $ str $ sreplicate @f2 @7 a) in g) (2 * a0) (sreplicate @f @3 @@ -2716,7 +2717,7 @@ testSin0rmapAccumRD01SN6 = do $ V.fromList [ DynamicShaped $ sin x - `atan2` smaxIndex + `atan2F` smaxIndex @_ @Double @Double @'[] @2 (sfromD (a V.! 1)) , DynamicShaped @@ -2889,7 +2890,7 @@ testSin0ScanD51S = do -> f2 Double '[1,1,1,1] g x a = ssum - $ atan2 (sin $ sreplicate @f2 @5 x) + $ atan2F (sin $ sreplicate @f2 @5 x) (ssum $ sin $ ssum $ str $ sreplicate @f2 @7 $ sreplicate @f2 @2 $ sreplicate @f2 @5 @@ -4012,7 +4013,7 @@ testSin0FoldNestedSi = do assertEqualUpToEpsilon' 1e-10 (-0.20775612781643243 :: OR.Array 0 Double) (rev' (let f :: forall f. ADReadyS f => f Double '[] -> f Double '[3] - f a0 = sfold (\x a -> atan2 + f a0 = sfold (\x a -> atan2F (sscan (+) (ssum x) (sscan (*) 2 (sreplicate @_ @1 a)))