diff --git a/bench/common/BenchMnistTools.hs b/bench/common/BenchMnistTools.hs index dd45076d5..f0ffcda79 100644 --- a/bench/common/BenchMnistTools.hs +++ b/bench/common/BenchMnistTools.hs @@ -210,7 +210,7 @@ mnistTrainBench2VTA extraPrefix chunkLength xs widthHidden widthHidden2 Just (SomeNat @widthHidden _) -> case someNatVal $ toInteger widthHidden2 of Just (SomeNat @widthHidden2 _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped (Flip OS.Array) widthHidden widthHidden2 r) @@ -240,7 +240,7 @@ mnistTestBench2VTA extraPrefix chunkLength xs widthHidden widthHidden2 = do Just (SomeNat @widthHidden _) -> case someNatVal $ toInteger widthHidden2 of Just (SomeNat @widthHidden2 _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped (Flip OS.Array) widthHidden widthHidden2 r) @@ -296,7 +296,7 @@ mnistTrainBench2VTO extraPrefix chunkLength xs widthHidden widthHidden2 Just (SomeNat @widthHidden _) -> case someNatVal $ toInteger widthHidden2 of Just (SomeNat @widthHidden2 _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped (Flip OS.Array) widthHidden widthHidden2 r) diff --git a/example/MnistData.hs b/example/MnistData.hs index 491ba4feb..b1c9081d9 100644 --- a/example/MnistData.hs +++ b/example/MnistData.hs @@ -178,7 +178,7 @@ chunksOf n = go where :: KnownNat y => Double -> (MnistData Double - -> Domains (DynamicOf (ADVal (Flip OR.Array))) + -> Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y) -> [MnistData Double] -> DomainsOD @@ -186,7 +186,7 @@ chunksOf n = go where {-# SPECIALIZE sgdAdam :: KnownNat y - => (MnistDataBatchR Double -> Domains (DynamicOf (ADVal (Flip OR.Array))) + => (MnistDataBatchR Double -> Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y) -> [MnistDataBatchR Double] -> DomainsOD @@ -196,7 +196,7 @@ chunksOf n = go where {-# SPECIALIZE sgdAdamArgs :: KnownNat y => ArgsAdam - -> (MnistDataBatchR Double -> Domains (DynamicOf (ADVal (Flip OR.Array))) + -> (MnistDataBatchR Double -> Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y) -> [MnistDataBatchR Double] -> DomainsOD diff --git a/src/HordeAd/Core/Adaptor.hs b/src/HordeAd/Core/Adaptor.hs index e29803420..d7f9b95d6 100644 --- a/src/HordeAd/Core/Adaptor.hs +++ b/src/HordeAd/Core/Adaptor.hs @@ -7,7 +7,6 @@ module HordeAd.Core.Adaptor import Prelude import Control.Exception (assert) -import Data.Kind (Type) import qualified Data.Strict.Vector as Data.Vector import qualified Data.Vector.Generic as V import System.Random @@ -23,14 +22,14 @@ import HordeAd.Core.Types -- * Adaptor classes -- Inspired by adaptors from @tomjaguarpaw's branch. -class AdaptableDomains (dynamic :: Type -> Type) vals where +class AdaptableDomains (ranked :: RankedTensorKind) vals where type Value vals -- ^ a helper type, with the same general shape, -- but possibly more concrete, e.g., arrays instead of terms - toDomains :: vals -> Domains dynamic + toDomains :: vals -> Domains ranked -- ^ represent a value of the domain of objective function -- in a canonical, much less typed way common to all possible types - fromDomains :: Value vals -> Domains dynamic - -> Maybe (vals, Domains dynamic) + fromDomains :: Value vals -> Domains ranked + -> Maybe (vals, Domains ranked) -- ^ recovers a value of the domain of objective function -- from its canonical representation, using the general shape -- recorded in a value of a more concrete type; the remainder @@ -40,8 +39,8 @@ class AdaptableDomains (dynamic :: Type -> Type) vals where -- there is no remainder. This is the main call of the recursive -- procedure where @fromDomains@ calls itself for sub-values. parseDomains - :: AdaptableDomains dynamic vals - => Value vals -> Domains dynamic -> vals + :: AdaptableDomains ranked vals + => Value vals -> Domains ranked -> vals parseDomains aInit domains = case fromDomains aInit domains of Just (vals, rest) -> assert (V.null rest) vals @@ -62,8 +61,8 @@ class RandomDomains vals where -- * Basic Adaptor class instances {- This is temporarily moved to TensorADVal in order to specialize manually -instance AdaptableDomains dynamic a - => AdaptableDomains dynamic [a] where +instance AdaptableDomains ranked a + => AdaptableDomains ranked [a] where {-# SPECIALIZE instance (KnownNat n, AdaptableDomains OD.Array (OR.Array n Double)) => AdaptableDomains OD.Array @@ -96,8 +95,8 @@ instance ForgetShape a type NoShape [a] = [NoShape a] forgetShape = map forgetShape -instance AdaptableDomains dynamic a - => AdaptableDomains dynamic (Data.Vector.Vector a) where +instance AdaptableDomains ranked a + => AdaptableDomains ranked (Data.Vector.Vector a) where type Value (Data.Vector.Vector a) = Data.Vector.Vector (Value a) toDomains = V.concatMap toDomains fromDomains lInit source = @@ -115,8 +114,8 @@ instance ForgetShape a type NoShape (Data.Vector.Vector a) = Data.Vector.Vector (NoShape a) forgetShape = V.map forgetShape -instance ( AdaptableDomains dynamic a - , AdaptableDomains dynamic b ) => AdaptableDomains dynamic (a, b) where +instance ( AdaptableDomains ranked a + , AdaptableDomains ranked b ) => AdaptableDomains ranked (a, b) where type Value (a, b) = (Value a, Value b) toDomains (a, b) = let a1 = toDomains a @@ -139,10 +138,10 @@ instance ( RandomDomains a (v2, g2) = randomVals range g1 in ((v1, v2), g2) -instance ( AdaptableDomains dynamic a - , AdaptableDomains dynamic b - , AdaptableDomains dynamic c ) - => AdaptableDomains dynamic (a, b, c) where +instance ( AdaptableDomains ranked a + , AdaptableDomains ranked b + , AdaptableDomains ranked c ) + => AdaptableDomains ranked (a, b, c) where type Value (a, b, c) = (Value a, Value b, Value c) toDomains (a, b, c) = let a1 = toDomains a @@ -170,11 +169,11 @@ instance ( RandomDomains a (v3, g3) = randomVals range g2 in ((v1, v2, v3), g3) -instance ( AdaptableDomains dynamic a - , AdaptableDomains dynamic b - , AdaptableDomains dynamic c - , AdaptableDomains dynamic d ) - => AdaptableDomains dynamic (a, b, c, d) where +instance ( AdaptableDomains ranked a + , AdaptableDomains ranked b + , AdaptableDomains ranked c + , AdaptableDomains ranked d ) + => AdaptableDomains ranked (a, b, c, d) where type Value (a, b, c, d) = (Value a, Value b, Value c, Value d) toDomains (a, b, c, d) = let a1 = toDomains a @@ -209,12 +208,12 @@ instance ( RandomDomains a (v4, g4) = randomVals range g3 in ((v1, v2, v3, v4), g4) -instance ( AdaptableDomains dynamic a - , AdaptableDomains dynamic b - , AdaptableDomains dynamic c - , AdaptableDomains dynamic d - , AdaptableDomains dynamic e ) - => AdaptableDomains dynamic (a, b, c, d, e) where +instance ( AdaptableDomains ranked a + , AdaptableDomains ranked b + , AdaptableDomains ranked c + , AdaptableDomains ranked d + , AdaptableDomains ranked e ) + => AdaptableDomains ranked (a, b, c, d, e) where type Value (a, b, c, d, e) = (Value a, Value b, Value c, Value d, Value e) toDomains (a, b, c, d, e) = let a1 = toDomains a @@ -254,8 +253,8 @@ instance ( RandomDomains a (v5, g5) = randomVals range g4 in ((v1, v2, v3, v4, v5), g5) -instance ( AdaptableDomains dynamic a, AdaptableDomains dynamic b ) - => AdaptableDomains dynamic (Either a b) where +instance ( AdaptableDomains ranked a, AdaptableDomains ranked b ) + => AdaptableDomains ranked (Either a b) where type Value (Either a b) = Either (Value a) (Value b) toDomains e = case e of Left a -> toDomains a @@ -275,8 +274,8 @@ instance ( ForgetShape a Left a -> Left $ forgetShape a Right b -> Right $ forgetShape b -instance AdaptableDomains dynamic a - => AdaptableDomains dynamic (Maybe a) where +instance AdaptableDomains ranked a + => AdaptableDomains ranked (Maybe a) where type Value (Maybe a) = Maybe (Value a) toDomains e = case e of Nothing -> V.concat [] diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index ef3e273b2..f44b205c4 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -16,7 +16,7 @@ module HordeAd.Core.Ast , AstArtifactRev, AstArtifactFwd , AstIndex, AstVarList, AstIndexS, AstVarListS -- * ASTs - , AstRanked(..), AstShaped(..), AstDynamic(..), AstDomains(..) + , AstRanked(..), AstShaped(..), AstDynamic, AstDomains(..) , AstBool(..), OpCodeNum1(..), OpCodeNum2(..), OpCode1(..), OpCode2(..) , OpCodeIntegral2(..), OpCodeBool(..), OpCodeRel(..) -- * Boolean definitions and instances @@ -31,7 +31,6 @@ import Prelude hiding (foldl') import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS -import Data.Bifunctor.Clown import Data.Bifunctor.Flip import Data.Int (Int64) import Data.Kind (Type) @@ -88,21 +87,14 @@ sameAstSpan = case eqTypeRep (typeRep @s1) (typeRep @s2) of -- * Basic type family instances -type instance RankedOf (Clown (AstDynamic s)) = AstRanked s -type instance ShapedOf (Clown (AstDynamic s)) = AstShaped s -type instance DynamicOf (Clown (AstDynamic s)) = AstDynamic s -type instance DomainsOf (Clown (AstDynamic s)) = AstDomains s - type instance RankedOf (AstRanked s) = AstRanked s type instance ShapedOf (AstRanked s) = AstShaped s -type instance DynamicOf (AstRanked s) = AstDynamic s type instance DomainsOf (AstRanked s) = AstDomains s type instance PrimalOf (AstRanked s) = AstRanked PrimalSpan type instance DualOf (AstRanked s) = AstRanked DualSpan type instance RankedOf (AstShaped s) = AstRanked s type instance ShapedOf (AstShaped s) = AstShaped s -type instance DynamicOf (AstShaped s) = AstDynamic s type instance DomainsOf (AstShaped s) = AstDomains s type instance PrimalOf (AstShaped s) = AstShaped PrimalSpan type instance DualOf (AstShaped s) = AstShaped DualSpan @@ -138,8 +130,8 @@ type family ConcreteOf f = result | result -> f where ConcreteOf (AstRanked FullSpan) = Flip OR.Array ConcreteOf (AstShaped FullSpan) = Flip OS.Array -type AstBindings = AstBindingsD (AstDynamic PrimalSpan) -type ADShare = ADShareD (AstDynamic PrimalSpan) +type AstBindings = AstBindingsD (AstRanked PrimalSpan) +type ADShare = ADShareD (AstRanked PrimalSpan) -- * More and less typed variables and type synonyms containing them @@ -297,8 +289,8 @@ data AstRanked :: AstSpanType -> RankedTensorKind where -> AstRanked s2 r n AstFwd :: (GoodScalar r, KnownNat n) => ([AstDynamicVarName], AstRanked s r n) - -> Domains (AstDynamic s) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) + -> Domains (AstRanked s) -> AstRanked s r n AstFold :: forall rn rm n m s. (GoodScalar rm, KnownNat m) => ( AstVarName (AstRanked PrimalSpan) rn n @@ -438,8 +430,8 @@ data AstShaped :: AstSpanType -> ShapedTensorKind where -> AstShaped s2 r sh AstFwdS :: (GoodScalar r, Sh.Shape sh) => ([AstDynamicVarName], AstShaped s r sh) - -> Domains (AstDynamic s) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) + -> Domains (AstRanked s) -> AstShaped s r sh AstFoldS :: forall rn rm sh shm k s. (GoodScalar rm, Sh.Shape shm, KnownNat k) => ( AstVarName (AstShaped PrimalSpan) rn sh @@ -468,18 +460,12 @@ data AstShaped :: AstSpanType -> ShapedTensorKind where deriving instance (GoodScalar r, Sh.Shape sh) => Show (AstShaped s r sh) -type role AstDynamic nominal nominal -data AstDynamic :: AstSpanType -> Type -> Type where - AstRToD :: forall n r s. KnownNat n - => AstRanked s r n -> AstDynamic s r - AstSToD :: forall sh r s. Sh.Shape sh - => AstShaped s r sh -> AstDynamic s r -deriving instance GoodScalar r => Show (AstDynamic s r) +type AstDynamic (s :: AstSpanType) = DynamicTensor (AstRanked s) type role AstDomains nominal data AstDomains s where -- There are existential variables inside DynamicExists here. - AstDomains :: Domains (AstDynamic s) -> AstDomains s + AstDomains :: Domains (AstRanked s) -> AstDomains s -- This operation is why we need AstDomains and so DomainsOf. -- If we kept a vector of terms instead, we'd need to let-bind in each -- of the terms separately, duplicating the let-bound term. @@ -496,7 +482,7 @@ data AstDomains s where -> AstDomains s2 AstRev :: (GoodScalar r, KnownNat n) => ([AstDynamicVarName], AstRanked s r n) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstDomains s -- ^ the function body can't have any free variables outside those -- listed in the first component of the pair; this reflects @@ -504,16 +490,16 @@ data AstDomains s where -- the same holds for the similar operations below AstRevDt :: (GoodScalar r, KnownNat n) => ([AstDynamicVarName], AstRanked s r n) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstRanked s r n -> AstDomains s AstRevS :: (GoodScalar r, Sh.Shape sh) => ([AstDynamicVarName], AstShaped s r sh) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstDomains s AstRevDtS :: (GoodScalar r, Sh.Shape sh) => ([AstDynamicVarName], AstShaped s r sh) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstShaped s r sh -> AstDomains s @@ -877,25 +863,21 @@ maxF u v = ifF (u >=. v) u v type instance RankedOf (AstNoVectorize s) = AstNoVectorize s type instance ShapedOf (AstNoVectorize s) = AstNoVectorizeS s -type instance DynamicOf (AstNoVectorize s) = AstDynamic s type instance DomainsOf (AstNoVectorize s) = AstDomains s type instance PrimalOf (AstNoVectorize s) = AstRanked PrimalSpan type instance DualOf (AstNoVectorize s) = AstRanked DualSpan type instance RankedOf (AstNoVectorizeS s) = AstNoVectorize s type instance ShapedOf (AstNoVectorizeS s) = AstNoVectorizeS s -type instance DynamicOf (AstNoVectorizeS s) = AstDynamic s type instance DomainsOf (AstNoVectorizeS s) = AstDomains s type instance PrimalOf (AstNoVectorizeS s) = AstShaped PrimalSpan type instance DualOf (AstNoVectorizeS s) = AstShaped DualSpan type instance RankedOf (AstNoSimplify s) = AstNoSimplify s type instance ShapedOf (AstNoSimplify s) = AstNoSimplifyS s -type instance DynamicOf (AstNoSimplify s) = AstDynamic s type instance DomainsOf (AstNoSimplify s) = AstDomains s type instance PrimalOf (AstNoSimplify s) = AstRanked PrimalSpan type instance DualOf (AstNoSimplify s) = AstRanked DualSpan type instance RankedOf (AstNoSimplifyS s) = AstNoSimplify s type instance ShapedOf (AstNoSimplifyS s) = AstNoSimplifyS s -type instance DynamicOf (AstNoSimplifyS s) = AstDynamic s type instance DomainsOf (AstNoSimplifyS s) = AstDomains s type instance PrimalOf (AstNoSimplifyS s) = AstShaped PrimalSpan type instance DualOf (AstNoSimplifyS s) = AstShaped DualSpan diff --git a/src/HordeAd/Core/AstEnv.hs b/src/HordeAd/Core/AstEnv.hs index f92856b51..83811c324 100644 --- a/src/HordeAd/Core/AstEnv.hs +++ b/src/HordeAd/Core/AstEnv.hs @@ -23,14 +23,17 @@ import Prelude import qualified Data.Array.Shape as Sh import qualified Data.EnumMap.Strict as EM import Data.Kind (Type) +import Data.Proxy (Proxy (Proxy)) import Data.Type.Equality (testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V -import GHC.TypeLits (KnownNat, Nat) +import GHC.TypeLits (KnownNat, Nat, sameNat) import Type.Reflection (typeRep) import HordeAd.Core.Ast import HordeAd.Core.TensorClass import HordeAd.Core.Types +import HordeAd.Internal.OrthotopeOrphanInstances + (matchingRank, sameShape) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedIndex import HordeAd.Util.SizedList @@ -68,29 +71,43 @@ extendEnvS (AstVarName var) !t !env = EM.insertWithKey (\_ _ _ -> error $ "extendEnvS: duplicate " ++ show var) var (AstEnvElemS t) env -extendEnvD :: forall ranked shaped. ConvertTensor ranked shaped - => (AstDynamicVarName, DynamicExists (DynamicOf ranked)) +extendEnvD :: forall ranked shaped. + ( RankedTensor ranked, ShapedTensor shaped + , shaped ~ ShapedOf ranked ) + => (AstDynamicVarName, DynamicTensor ranked) -> AstEnv ranked shaped -> AstEnv ranked shaped -extendEnvD ( AstDynamicVarName @k @r @sh @n (AstVarName var) - , DynamicExists @r2 d ) - !env - | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) = - -- We don't need to manually pick a specialization for the existential - -- variable r2, because rfromD does not depend on r2. - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> extendEnvR @_ @_ @_ @n (AstVarName var) (rfromD d) env - _ -> error "extendEnvD: type mismatch" -extendEnvD ( AstDynamicVarName @k @r @sh @sh2 (AstVarName var) - , DynamicExists @r2 d ) - env - | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) = - -- We don't need to manually pick a specialization for the existential - -- variable r2, because sfromD does not depend on r2. - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> extendEnvS @_ @_ @_ @sh (AstVarName var) (sfromD d) env - _ -> error "extendEnvD: type mismatch" -extendEnvD _ _ = error "extendEnvD: kind mismatch" +extendEnvD (AstDynamicVarName @k @r @sh @n (AstVarName var), d) !env + | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) = case d of + DynamicRanked @r2 @n2 t -> case sameNat (Proxy @n2) (Proxy @n) of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> extendEnvR @_ @_ @_ @n (AstVarName var) t env + _ -> error "extendEnvD: type mismatch" + _ -> error "extendEnvD: rank mismatch" + DynamicShaped{} -> error "extendEnvD: ranked from shaped" + DynamicRankedDummy @r2 @sh2 _ _ -> case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> + let sh2 = listShapeToShape (Sh.shapeT @sh2) + in extendEnvR @_ @_ @r @n (AstVarName var) (rzero sh2) env + _ -> error "extendEnvD: type mismatch" + _ -> error "extendEnvD: rank mismatch" + DynamicShapedDummy{} -> error "extendEnvD: ranked from shaped" +extendEnvD (AstDynamicVarName @k @r @sh (AstVarName var), d) env + | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) = case d of + DynamicRanked{} -> error "extendEnvD: shaped from ranked" + DynamicShaped @r2 @sh2 t -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> extendEnvS @_ @_ @_ @sh (AstVarName var) t env + _ -> error "extendEnvD: type mismatch" + _ -> error "extendEnvD: shape mismatch" + DynamicRankedDummy{} -> error "extendEnvD: shaped from ranked" + DynamicShapedDummy @r2 @sh2 _ _ -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> extendEnvS @_ @_ @r @sh (AstVarName var) 0 env + _ -> error "extendEnvD: type mismatch" + _ -> error "extendEnvD: shape mismatch" +extendEnvD _ _ = error "extendEnvD: unexpected kind" extendEnvI :: ( RankedTensor ranked , RankedOf (PrimalOf ranked) ~ PrimalOf ranked ) @@ -119,8 +136,10 @@ extendEnvVarsS vars !ix !env = (ShapedList.sizedListToList ix) in foldr (uncurry extendEnvI) env assocs -extendEnvPars :: forall ranked shaped. ConvertTensor ranked shaped - => [AstDynamicVarName] -> Domains (DynamicOf ranked) +extendEnvPars :: forall ranked shaped. + ( RankedTensor ranked, ShapedTensor shaped + , shaped ~ ShapedOf ranked ) + => [AstDynamicVarName] -> Domains ranked -> AstEnv ranked shaped -> AstEnv ranked shaped extendEnvPars vars !pars !env = @@ -203,22 +222,26 @@ interpretLambdaIndexToIndexS f !env (!vars, !asts) = \ix -> f (extendEnvVarsS vars ix env) <$> asts interpretLambdaDomains - :: forall s ranked shaped r n. ConvertTensor ranked shaped + :: forall s ranked shaped r n. + ( RankedTensor ranked, ShapedTensor shaped + , shaped ~ ShapedOf ranked ) => (AstEnv ranked shaped -> AstRanked s r n -> ranked r n) -> AstEnv ranked shaped -> ([AstDynamicVarName], AstRanked s r n) - -> Domains (DynamicOf ranked) + -> Domains ranked -> ranked r n {-# INLINE interpretLambdaDomains #-} interpretLambdaDomains f !env (!vars, !ast) = \pars -> f (extendEnvPars vars pars env) ast interpretLambdaDomainsS - :: forall s ranked shaped r sh. ConvertTensor ranked shaped + :: forall s ranked shaped r sh. + ( RankedTensor ranked, ShapedTensor shaped + , shaped ~ ShapedOf ranked ) => (AstEnv ranked shaped -> AstShaped s r sh -> shaped r sh) -> AstEnv ranked shaped -> ([AstDynamicVarName], AstShaped s r sh) - -> Domains (DynamicOf ranked) + -> Domains ranked -> shaped r sh {-# INLINE interpretLambdaDomainsS #-} interpretLambdaDomainsS f !env (!vars, !ast) = diff --git a/src/HordeAd/Core/AstFreshId.hs b/src/HordeAd/Core/AstFreshId.hs index f832953db..66a584e7f 100644 --- a/src/HordeAd/Core/AstFreshId.hs +++ b/src/HordeAd/Core/AstFreshId.hs @@ -6,7 +6,7 @@ module HordeAd.Core.AstFreshId ( astRegisterFun, astRegisterADShare, astRegisterADShareS , funToAstIOR, funToAstR, fun2ToAstR, fun2ToAstS, fun3ToAstR, fun3ToAstS - , fun4ToAstR, fun4ToAstS, funToAstDomains, funToAstDomainsS + , fun4ToAstR, fun4ToAstS, funToAstDomains , funToAstRevIO, funToAstRev, funToAstFwdIO, funToAstFwd , funToAstIOI, funToAstI, funToAstIndexIO, funToAstIndex , funToAstIOS, funToAstS, astRegisterFunS, funToAstIndexIOS, funToAstIndexS @@ -16,9 +16,10 @@ module HordeAd.Core.AstFreshId import Prelude import Control.Monad (replicateM) -import qualified Data.Array.DynamicS as OD import Data.Array.Internal (valueOf) +import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh +import Data.Bifunctor.Flip import Data.IORef.Unboxed (Counter, atomicAddCounter_, newCounter, writeIORefU) import Data.List (unzip4, unzip6) @@ -29,6 +30,7 @@ import System.IO.Unsafe (unsafePerformIO) import HordeAd.Core.Ast import HordeAd.Core.AstTools +import HordeAd.Core.TensorClass (DomainsOD) import HordeAd.Core.Types import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedIndex @@ -56,26 +58,26 @@ unsafeGetFreshAstVarName = astRegisterFun :: (GoodScalar r, KnownNat n) - => AstRanked s r n -> AstBindingsD (AstDynamic s) - -> (AstBindingsD (AstDynamic s), AstRanked s r n) + => AstRanked s r n -> AstBindingsD (AstRanked s) + -> (AstBindingsD (AstRanked s), AstRanked s r n) {-# NOINLINE astRegisterFun #-} astRegisterFun !r !l | astIsSmall True r = (l, r) astRegisterFun r l = unsafePerformIO $ do !freshId <- unsafeGetFreshAstVarId let !r2 = AstVar (shapeAst r) $ AstVarName freshId - !d = DynamicExists $ AstRToD r + !d = DynamicRanked r return ((freshId, d) : l, r2) astRegisterFunS :: (Sh.Shape sh, GoodScalar r) - => AstShaped s r sh -> AstBindingsD (AstDynamic s) - -> (AstBindingsD (AstDynamic s), AstShaped s r sh) + => AstShaped s r sh -> AstBindingsD (AstRanked s) + -> (AstBindingsD (AstRanked s), AstShaped s r sh) {-# NOINLINE astRegisterFunS #-} astRegisterFunS !r !l | astIsSmallS True r = (l, r) astRegisterFunS r l = unsafePerformIO $ do !freshId <- unsafeGetFreshAstVarId let !r2 = AstVarS $ AstVarName freshId - !d = DynamicExists $ AstSToD r + !d = DynamicShaped r return ((freshId, d) : l, r2) astRegisterADShare :: (GoodScalar r, KnownNat n) @@ -85,7 +87,7 @@ astRegisterADShare :: (GoodScalar r, KnownNat n) astRegisterADShare !r !l | astIsSmall True r = (l, r) astRegisterADShare r l = unsafePerformIO $ do freshId <- unsafeGetFreshAstVarId - let !l2 = insertADShare freshId (AstRToD r) l + let !l2 = insertADShare freshId (DynamicRanked r) l !r2 = AstVar (shapeAst r) $ AstVarName freshId return (l2, r2) @@ -96,7 +98,7 @@ astRegisterADShareS :: (GoodScalar r, Sh.Shape sh) astRegisterADShareS !r !l | astIsSmallS True r = (l, r) astRegisterADShareS r l = unsafePerformIO $ do freshId <- unsafeGetFreshAstVarId - let !l2 = insertADShare freshId (AstSToD r) l + let !l2 = insertADShare freshId (DynamicShaped r) l !r2 = AstVarS $ AstVarName freshId return (l2, r2) @@ -292,80 +294,98 @@ fun4ToAstS :: (AstShaped s rn shn -> AstShaped s rm shm fun4ToAstS f = unsafePerformIO $ fun4ToAstIOS f funToAstDomainsIO - :: (Domains (AstDynamic s) -> AstRanked s r n) + :: (Domains (AstRanked s) -> AstRanked s r n) -> DomainsOD -> IO ([AstDynamicVarName], AstRanked s r n) {-# INLINE funToAstDomainsIO #-} funToAstDomainsIO g parameters0 = do - let f (DynamicExists @r2 e) = do - let sh = OD.shapeL e + let f (DynamicRanked @r2 @n2 e) = do + let sh3 = OR.shapeL $ runFlip e freshId <- unsafeGetFreshAstVarId - return $! Sh.withShapeP sh $ \(Proxy :: Proxy p_sh) -> - withListShape sh $ \ (_ :: Shape n Int) -> - let !varE = AstDynamicVarName @Nat @r2 @p_sh @n (AstVarName freshId) - dynE :: DynamicExists (AstDynamic s) - !dynE = DynamicExists @r2 - $ AstRToD @n (AstVar (listShapeToShape sh) - (AstVarName freshId)) - in (varE, dynE) + return $! Sh.withShapeP sh3 $ \(Proxy :: Proxy sh2) -> + let !varE = AstDynamicVarName @Nat @r2 @sh2 @n2 (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicRanked @r2 @n2 + (AstVar (listShapeToShape sh3) + (AstVarName freshId)) + in (varE, dynE) + f (DynamicShaped @r2 @sh2 _) = do + freshId <- unsafeGetFreshAstVarId + return $! + let !varE = AstDynamicVarName @[Nat] @r2 @sh2 @sh2 + (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicShaped @r2 @sh2 + (AstVarS (AstVarName freshId)) + in (varE, dynE) + f (DynamicRankedDummy @r2 @sh2 _ _) = do + let sh3 = Sh.shapeT @sh2 + freshId <- unsafeGetFreshAstVarId + return $! withListShape sh3 $ \ (sh4 :: Shape n2 Int) -> + let !varE = AstDynamicVarName @Nat @r2 @sh2 @n2 (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicRanked @r2 @n2 (AstVar sh4 (AstVarName freshId)) + in (varE, dynE) + f (DynamicShapedDummy @r2 @sh2 _ _) = do + freshId <- unsafeGetFreshAstVarId + return $! + let !varE = AstDynamicVarName @[Nat] @r2 @sh2 @sh2 + (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicShaped @r2 @sh2 + (AstVarS (AstVarName freshId)) + in (varE, dynE) (!vars, !asts) <- V.unzip <$> V.mapM f parameters0 let !x = g asts return (V.toList vars, x) funToAstDomains - :: (Domains (AstDynamic s) -> AstRanked s r n) + :: (Domains (AstRanked s) -> AstRanked s r n) -> DomainsOD -> ([AstDynamicVarName], AstRanked s r n) {-# NOINLINE funToAstDomains #-} funToAstDomains g parameters0 = unsafePerformIO $ funToAstDomainsIO g parameters0 -funToAstDomainsIOS - :: (Domains (AstDynamic s) -> AstShaped s r sh) - -> DomainsOD - -> IO ([AstDynamicVarName], AstShaped s r sh) -{-# INLINE funToAstDomainsIOS #-} -funToAstDomainsIOS g parameters0 = do - let f (DynamicExists @r2 e) = do - let sh = OD.shapeL e - freshId <- unsafeGetFreshAstVarId - return $! Sh.withShapeP sh $ \(Proxy :: Proxy p_sh) -> - let !varE = AstDynamicVarName @[Nat] @r2 @p_sh @p_sh - (AstVarName freshId) - dynE :: DynamicExists (AstDynamic s) - !dynE = DynamicExists @r2 - $ AstSToD (AstVarS @p_sh (AstVarName freshId)) - in (varE, dynE) - (!vars, !asts) <- V.unzip <$> V.mapM f parameters0 - let !x = g asts - return (V.toList vars, x) - -funToAstDomainsS - :: (Domains (AstDynamic s) -> AstShaped s r sh) - -> DomainsOD - -> ([AstDynamicVarName], AstShaped s r sh) -{-# NOINLINE funToAstDomainsS #-} -funToAstDomainsS g parameters0 = - unsafePerformIO $ funToAstDomainsIOS g parameters0 - funToAstRevIO :: DomainsOD -> IO ( [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic FullSpan) ) + , Domains (AstRanked FullSpan) ) {-# INLINE funToAstRevIO #-} funToAstRevIO parameters0 = do - let f (DynamicExists @r2 e) = do - let sh = OD.shapeL e + let f (DynamicRanked @r @n e) = do + let sh2 = OR.shapeL $ runFlip e + freshId <- unsafeGetFreshAstVarId + return $! Sh.withShapeP sh2 $ \(Proxy :: Proxy sh) -> + let !varE = AstDynamicVarName @Nat @r @sh @n (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicRanked @r @n + (AstVar (listShapeToShape sh2) + (AstVarName freshId)) + in (varE, dynE, varE, dynE) + f (DynamicShaped @r @sh _) = do + freshId <- unsafeGetFreshAstVarId + return $! + let !varE = AstDynamicVarName @[Nat] @r @sh @sh (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicShaped @r @sh (AstVarS (AstVarName freshId)) + in (varE, dynE, varE, dynE) + f (DynamicRankedDummy @r @sh _ _) = do + let sh2 = Sh.shapeT @sh freshId <- unsafeGetFreshAstVarId - return $! Sh.withShapeP sh $ \(Proxy :: Proxy p_sh) -> - withListShape sh $ \ (_ :: Shape n Int) -> - let !varE = AstDynamicVarName @Nat @r2 @p_sh @n (AstVarName freshId) - dynE :: DynamicExists (AstDynamic s) - !dynE = DynamicExists @r2 - $ AstRToD @n (AstVar (listShapeToShape sh) - (AstVarName freshId)) - in (varE, dynE, varE, dynE) + return $! withListShape sh2 $ \ (sh :: Shape n Int) -> + let !varE = AstDynamicVarName @Nat @r @sh @n (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicRanked @r @n (AstVar sh (AstVarName freshId)) + in (varE, dynE, varE, dynE) + f (DynamicShapedDummy @r @sh _ _) = do + freshId <- unsafeGetFreshAstVarId + return $! + let !varE = AstDynamicVarName @[Nat] @r @sh @sh (AstVarName freshId) + dynE :: AstDynamic s + !dynE = DynamicShaped @r @sh (AstVarS (AstVarName freshId)) + in (varE, dynE, varE, dynE) (!varsPrimal, !astsPrimal, !vars, !asts) <- unzip4 <$> mapM f (V.toList parameters0) let !vp = V.fromList astsPrimal @@ -377,9 +397,9 @@ funToAstRevIO parameters0 = do funToAstRev :: DomainsOD -> ( AstVarId , [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic FullSpan) ) + , Domains (AstRanked FullSpan) ) {-# NOINLINE funToAstRev #-} funToAstRev parameters0 = unsafePerformIO $ do freshId <- unsafeGetFreshAstVarId @@ -388,31 +408,77 @@ funToAstRev parameters0 = unsafePerformIO $ do funToAstFwdIO :: DomainsOD -> IO ( [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic FullSpan) ) + , Domains (AstRanked FullSpan) ) {-# INLINE funToAstFwdIO #-} funToAstFwdIO parameters0 = do - let f (DynamicExists @r2 e) = do - let sh = OD.shapeL e + let f :: DynamicTensor (Flip OR.Array) + -> IO ( AstDynamicVarName, AstDynamic PrimalSpan + , AstDynamicVarName, AstDynamic PrimalSpan + , AstDynamicVarName, AstDynamic FullSpan ) + f (DynamicRanked @r @n e) = do + let sh2 = OR.shapeL $ runFlip e + freshIdDs <- unsafeGetFreshAstVarId + freshId <- unsafeGetFreshAstVarId + return $! Sh.withShapeP sh2 $ \(Proxy :: Proxy sh) -> + let varE :: AstVarId -> AstDynamicVarName + varE v = AstDynamicVarName @Nat @r @sh @n (AstVarName v) + dynE :: AstVarId -> AstDynamic s + dynE v = DynamicRanked @r @n + (AstVar (listShapeToShape sh2) + (AstVarName v)) + !vd = varE freshIdDs + !dd = dynE freshIdDs + !vi = varE freshId + di :: AstDynamic s + !di = dynE freshId + in (vd, dd, vi, di, vi, di) + f (DynamicShaped @r @sh _) = do + freshIdDs <- unsafeGetFreshAstVarId + freshId <- unsafeGetFreshAstVarId + return $! + let varE :: AstVarId -> AstDynamicVarName + varE v = AstDynamicVarName @[Nat] @r @sh @sh (AstVarName v) + dynE :: AstVarId -> AstDynamic s + dynE v = DynamicShaped @r @sh (AstVarS (AstVarName v)) + !vd = varE freshIdDs + !dd = dynE freshIdDs + !vi = varE freshId + di :: AstDynamic s + !di = dynE freshId + in (vd, dd, vi, di, vi, di) + f (DynamicRankedDummy @r @sh _ _) = do + let sh2 = Sh.shapeT @sh + freshIdDs <- unsafeGetFreshAstVarId + freshId <- unsafeGetFreshAstVarId + return $! withListShape sh2 $ \ (sh :: Shape n Int) -> + let varE :: AstVarId -> AstDynamicVarName + varE v = AstDynamicVarName @Nat @r @sh @n (AstVarName v) + dynE :: AstVarId -> AstDynamic s + dynE v = DynamicRanked @r (AstVar sh (AstVarName v)) + !vd = varE freshIdDs + !dd = dynE freshIdDs + !vi = varE freshId + di :: AstDynamic s + !di = dynE freshId + in (vd, dd, vi, di, vi, di) + f (DynamicShapedDummy @r @sh _ _) = do freshIdDs <- unsafeGetFreshAstVarId freshId <- unsafeGetFreshAstVarId - return $! Sh.withShapeP sh $ \(Proxy :: Proxy p_sh) -> - withListShape sh $ \ (_ :: Shape n Int) -> - let varE :: AstVarId -> AstDynamicVarName - varE v = AstDynamicVarName @Nat @r2 @p_sh @n (AstVarName v) - dynE :: AstVarId -> DynamicExists (AstDynamic s) - dynE v = DynamicExists @r2 - $ AstRToD @n (AstVar (listShapeToShape sh) - (AstVarName v)) - !vd = varE freshIdDs - !dd = dynE freshIdDs - !vi = varE freshId - di :: DynamicExists (AstDynamic s) - !di = dynE freshId - in (vd, dd, vi, di, vi, di) + return $! + let varE :: AstVarId -> AstDynamicVarName + varE v = AstDynamicVarName @[Nat] @r @sh @sh (AstVarName v) + dynE :: AstVarId -> AstDynamic s + dynE v = DynamicShaped @r @sh (AstVarS (AstVarName v)) + !vd = varE freshIdDs + !dd = dynE freshIdDs + !vi = varE freshId + di :: AstDynamic s + !di = dynE freshId + in (vd, dd, vi, di, vi, di) (!varsPrimalDs, !astsPrimalDs, !varsPrimal, !astsPrimal, !vars, !asts) <- unzip6 <$> mapM f (V.toList parameters0) let !vd = V.fromList astsPrimalDs @@ -424,11 +490,11 @@ funToAstFwdIO parameters0 = do -- compared with a bare AstVarId, so let's keep it. funToAstFwd :: DomainsOD -> ( [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic PrimalSpan) + , Domains (AstRanked PrimalSpan) , [AstDynamicVarName] - , Domains (AstDynamic FullSpan) ) + , Domains (AstRanked FullSpan) ) {-# NOINLINE funToAstFwd #-} funToAstFwd parameters0 = unsafePerformIO $ funToAstFwdIO parameters0 diff --git a/src/HordeAd/Core/AstInline.hs b/src/HordeAd/Core/AstInline.hs index 81eeb8e7b..5240229c6 100644 --- a/src/HordeAd/Core/AstInline.hs +++ b/src/HordeAd/Core/AstInline.hs @@ -227,13 +227,13 @@ inlineAst memo v0 = case v0 of inlineAstDynamic :: AstSpan s - => AstMemo -> DynamicExists (AstDynamic s) - -> (AstMemo, DynamicExists (AstDynamic s)) + => AstMemo -> AstDynamic s + -> (AstMemo, AstDynamic s) inlineAstDynamic memo = \case - DynamicExists (AstRToD w) -> - second (DynamicExists . AstRToD) $ inlineAst memo w - DynamicExists (AstSToD w) -> - second (DynamicExists . AstSToD) $ inlineAstS memo w + DynamicRanked w -> second DynamicRanked $ inlineAst memo w + DynamicShaped w -> second DynamicShaped $ inlineAstS memo w + u@DynamicRankedDummy{} -> (memo, u) + u@DynamicShapedDummy{} -> (memo, u) inlineAstDomains :: AstSpan s @@ -573,10 +573,12 @@ unletAst env t = case t of unletAstDynamic :: AstSpan s - => UnletEnv -> DynamicExists (AstDynamic s) -> DynamicExists (AstDynamic s) + => UnletEnv -> AstDynamic s -> AstDynamic s unletAstDynamic env = \case - DynamicExists (AstRToD u) -> DynamicExists $ AstRToD $ unletAst env u - DynamicExists (AstSToD u) -> DynamicExists $ AstSToD $ unletAstS env u + DynamicRanked u -> DynamicRanked $ unletAst env u + DynamicShaped u -> DynamicShaped $ unletAstS env u + u@DynamicRankedDummy{} -> u + u@DynamicShapedDummy{} -> u unletAstDomains :: AstSpan s => UnletEnv -> AstDomains s -> AstDomains s diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index 0e0ba8a31..6d11277d5 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -17,7 +17,6 @@ module HordeAd.Core.AstInterpret import Prelude import Control.Exception.Assert.Sugar -import qualified Data.Array.DynamicS as OD import Data.Array.Internal (valueOf) import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh @@ -469,38 +468,43 @@ interpretAst !env = \case t2 = interpretAstDual env u' in rD t1 t2 AstLetDomainsIn vars l v -> - let odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 - lt0 = V.fromList $ map odFromVar vars + let lt0 = V.fromList $ map odFromVar vars lt = interpretAstDomains env l -- We don't need to manually pick a specialization for the existential -- variable r2, because the operations do not depend on r2. - f :: (AstDynamicVarName, DynamicExists (DynamicOf ranked)) + f :: (AstDynamicVarName, DynamicTensor ranked) -> AstEnv ranked shaped -> AstEnv ranked shaped - f ( AstDynamicVarName @k @r2 @sh2 @y (AstVarName varId) - , DynamicExists @r3 d ) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r3) = - case testEquality (typeRep @k) (typeRep @Nat) of - Just Refl -> - extendEnvR @ranked @shaped @r2 @y - (AstVarName varId) (rfromD d) - _ -> case testEquality (typeRep @k) (typeRep @[Nat]) of - Just Refl -> - extendEnvS @ranked @shaped @r2 @y - (AstVarName varId) (sfromD d) - _ -> error "interpretAst: impossible kind" - f _ = error "interpretAst: type mismatch" + f (AstDynamicVarName @k @r2 @sh2 @y (AstVarName varId), d) = case d of + DynamicRanked @r3 @n3 u + | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) + , Just Refl <- sameNat (Proxy @n3) (Proxy @y) + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvR @ranked @shaped @r2 @y (AstVarName varId) u + DynamicShaped @r3 @sh3 u + | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) u + DynamicRankedDummy @r3 @sh3 _ _ + | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + let sh2 = listShapeToShape (Sh.shapeT @sh2) + in extendEnvR @ranked @shaped @r2 @y (AstVarName varId) + $ rzero sh2 + DynamicShapedDummy @r3 @sh3 _ _ + | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) 0 + _ -> error "interpretAst: impossible kind" env2 lw = foldr f env (zip vars (V.toList lw)) in rletDomainsIn lt0 lt (\lw -> interpretAst (env2 lw) v) AstFwd (vars, ast) parameters ds -> - let g :: forall f. ADReady f => Domains (DynamicOf f) -> f r n + let g :: forall f. ADReady f => Domains f -> f r n g = interpretLambdaDomains interpretAst EM.empty (vars, ast) -- interpretation in empty environment makes sense only -- if there are no free variables outside of those listed - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters d = interpretAstDynamic @ranked env <$> ds @@ -531,17 +535,15 @@ interpretAst !env = \case interpretAstDynamic :: forall ranked shaped s. (ADReadyBoth ranked shaped, AstSpan s) => AstEnv ranked shaped - -> DynamicExists (AstDynamic s) -> DynamicExists (DynamicOf ranked) + -> AstDynamic s -> DynamicTensor ranked {-# INLINE interpretAstDynamic #-} interpretAstDynamic !env = \case - DynamicExists @r (AstRToD AstIota) -> - DynamicExists $ ddummy @ranked @shaped @r - DynamicExists (AstRToD w) -> - DynamicExists $ dfromR $ interpretAstRuntimeSpecialized env w - DynamicExists @r (AstSToD AstIotaS) -> - DynamicExists $ ddummy @ranked @shaped @r - DynamicExists (AstSToD w) -> - DynamicExists $ dfromS $ interpretAstSRuntimeSpecialized env w + DynamicRanked w -> + DynamicRanked $ interpretAstRuntimeSpecialized env w + DynamicShaped w -> + DynamicShaped $ interpretAstSRuntimeSpecialized env w + DynamicRankedDummy p1 p2 -> DynamicRankedDummy p1 p2 + DynamicShapedDummy p1 p2 -> DynamicShapedDummy p1 p2 interpretAstDomains :: forall ranked shaped s. (ADReadyBoth ranked shaped, AstSpan s) @@ -560,42 +562,30 @@ interpretAstDomains !env = \case env2 w = extendEnvS var w env in sletInDomains t (\w -> interpretAstDomains (env2 w) v) AstRev @r @n (vars, ast) parameters -> - let g :: forall f. ADReady f => Domains (DynamicOf f) -> f r n + let g :: forall f. ADReady f => Domains f -> f r n g = interpretLambdaDomains interpretAst EM.empty (vars, ast) -- interpretation in empty environment; makes sense only -- if there are no free variables outside of those listed; -- the same below - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters in rrev @ranked g parameters0 pars AstRevDt @r @n (vars, ast) parameters dt -> - let g :: forall f. ADReady f => Domains (DynamicOf f) -> f r n + let g :: forall f. ADReady f => Domains f -> f r n g = interpretLambdaDomains interpretAst EM.empty (vars, ast) - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters d = interpretAst env dt in rrevDt @ranked g parameters0 pars d AstRevS @r @sh (vars, ast) parameters -> - let g :: forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh + let g :: forall f. ADReadyS f => Domains (RankedOf f) -> f r sh g = interpretLambdaDomainsS interpretAstS EM.empty (vars, ast) - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters in srev @ranked g parameters0 pars AstRevDtS @r @sh (vars, ast) parameters dt -> - let g :: forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh + let g :: forall f. ADReadyS f => Domains (RankedOf f) -> f r sh g = interpretLambdaDomainsS interpretAstS EM.empty (vars, ast) - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters d = interpretAstS env dt @@ -989,38 +979,43 @@ interpretAstS !env = \case t2 = interpretAstDualS env u' in sD t1 t2 AstLetDomainsInS vars l v -> - let odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 - lt0 = V.fromList $ map odFromVar vars + let lt0 = V.fromList $ map odFromVar vars lt = interpretAstDomains env l -- We don't need to manually pick a specialization for the existential -- variable r2, because the operations do not depend on r2. - f :: (AstDynamicVarName, DynamicExists (DynamicOf ranked)) + f :: (AstDynamicVarName, DynamicTensor ranked) -> AstEnv ranked shaped -> AstEnv ranked shaped - f ( AstDynamicVarName @k @r2 @sh2 @y (AstVarName varId) - , DynamicExists @r3 d ) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r3) = - case testEquality (typeRep @k) (typeRep @Nat) of - Just Refl -> - extendEnvR @ranked @shaped @r2 @y - (AstVarName varId) (rfromD d) - _ -> case testEquality (typeRep @k) (typeRep @[Nat]) of - Just Refl -> - extendEnvS @ranked @shaped @r2 @y - (AstVarName varId) (sfromD d) - _ -> error "interpretAstS: impossible kind" - f _ = error "interpretAstS: type mismatch" + f (AstDynamicVarName @k @r2 @sh2 @y (AstVarName varId), d) = case d of + DynamicRanked @r3 @n3 u + | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) + , Just Refl <- sameNat (Proxy @n3) (Proxy @y) + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvR @ranked @shaped @r2 @y (AstVarName varId) u + DynamicShaped @r3 @sh3 u + | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) u + DynamicRankedDummy @r3 @sh3 _ _ + | Just Refl <- testEquality (typeRep @k) (typeRep @Nat) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + let sh2 = listShapeToShape (Sh.shapeT @sh2) + in extendEnvR @ranked @shaped @r2 @y (AstVarName varId) + $ rzero sh2 + DynamicShapedDummy @r3 @sh3 _ _ + | Just Refl <- testEquality (typeRep @k) (typeRep @[Nat]) + , Just Refl <- sameShape @sh3 @sh2 + , Just Refl <- testEquality (typeRep @r2) (typeRep @r3) -> + extendEnvS @ranked @shaped @r2 @sh2 (AstVarName varId) 0 + _ -> error "interpretAstS: impossible kind" env2 lw = foldr f env (zip vars (V.toList lw)) in sletDomainsIn lt0 lt (\lw -> interpretAstS (env2 lw) v) AstFwdS (vars, ast) parameters ds -> - let g :: forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh + let g :: forall f. ADReadyS f => Domains (RankedOf f) -> f r sh g = interpretLambdaDomainsS interpretAstS EM.empty (vars, ast) -- interpretation in empty environment makes sense only -- if there are no free variables outside of those listed - odFromVar :: AstDynamicVarName -> DynamicExists OD.Array - odFromVar (AstDynamicVarName @_ @rD @shD _) = - DynamicExists $ OD.constant @rD (Sh.shapeT @shD) 0 parameters0 = V.fromList $ map odFromVar vars pars = interpretAstDynamic @ranked env <$> parameters d = interpretAstDynamic @ranked env <$> ds diff --git a/src/HordeAd/Core/AstPrettyPrint.hs b/src/HordeAd/Core/AstPrettyPrint.hs index db1764d35..d9cc53728 100644 --- a/src/HordeAd/Core/AstPrettyPrint.hs +++ b/src/HordeAd/Core/AstPrettyPrint.hs @@ -408,30 +408,32 @@ showCollectionWith start end showx (x:xs) s = start ++ showx x (showl xs) showl [] = end ++ s showl (y:ys) = ", " ++ showx y (showl ys) -printAstDynamic :: (GoodScalar r, AstSpan s) - => PrintConfig -> Int -> AstDynamic s r -> ShowS +printAstDynamic :: AstSpan s + => PrintConfig -> Int -> AstDynamic s -> ShowS printAstDynamic cfg d = \case - AstRToD v -> printPrefixOp printAst cfg d "dfromR" [v] - AstSToD v -> printPrefixOp printAstS cfg d "dfromS" [v] + DynamicRanked v -> printPrefixOp printAst cfg d "dfromR" [v] + DynamicShaped v -> printPrefixOp printAstS cfg d "dfromS" [v] + DynamicRankedDummy{} -> showString "dfromR 0" + DynamicShapedDummy{} -> showString "dfromS 0" -printAstUnDynamic :: (GoodScalar r, AstSpan s) - => PrintConfig -> Int -> AstDynamic s r -> ShowS +printAstUnDynamic :: AstSpan s + => PrintConfig -> Int -> AstDynamic s -> ShowS printAstUnDynamic cfg d = \case - AstRToD v -> printAst cfg d v - AstSToD v -> printAstS cfg d v + DynamicRanked v -> printAst cfg d v + DynamicShaped v -> printAstS cfg d v + DynamicRankedDummy{} -> showString "0" + DynamicShapedDummy{} -> showString "0" printDomainsAst :: forall s. AstSpan s - => PrintConfig -> Domains (AstDynamic s) -> ShowS + => PrintConfig -> Domains (AstRanked s) -> ShowS printDomainsAst cfg l = if prettifyLosingSharing cfg then - showCollectionWith "(" ")" (\(DynamicExists e) -> - printAstUnDynamic cfg 0 e) (V.toList l) + showCollectionWith "(" ")" (\e -> printAstUnDynamic cfg 0 e) (V.toList l) else showParen True $ showString "fromList " - . showListWith (\(DynamicExists e) -> - printAstDynamic cfg 0 e) (V.toList l) + . showListWith (\e -> printAstDynamic cfg 0 e) (V.toList l) printAstDomains :: forall s. AstSpan s => PrintConfig -> Int -> AstDomains s -> ShowS diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index 403c30890..e72cea755 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -27,8 +27,7 @@ module HordeAd.Core.AstSimplify , astReplicate, astReplicateS, astAppend, astAppendS, astSlice, astSliceS , astReverse, astReverseS , astTranspose, astTransposeS, astReshape, astReshapeS - , astCast, astCastS, astFromIntegral, astFromIntegralS - , astSToR, astRToS, astFromDynamic, astFromDynamicS + , astCast, astCastS, astFromIntegral, astFromIntegralS, astSToR, astRToS , astPrimalPart, astPrimalPartS, astDualPart, astDualPartS , astLetDomainsIn, astLetDomainsInS, astLetInDomains, astLetInDomainsS -- * The simplifying bottom-up pass @@ -1463,34 +1462,6 @@ astRToS (Ast.AstSToR @sh1 v) = _ -> error "astRToS: different ranks in RToS(SToR)" astRToS v = Ast.AstRToS v -astFromDynamic :: forall n s r. KnownNat n - => AstDynamic s r -> AstRanked s r n -astFromDynamic (AstRToD Ast.AstIota) = error "astFromDynamic: dummy" -astFromDynamic (AstRToD (Ast.AstLetADShare l v)) = - Ast.AstLetADShare l $ astFromDynamic (AstRToD v) -astFromDynamic (AstRToD @n2 v) = - case sameNat (Proxy @n) (Proxy @n2) of - Just Refl -> v - _ -> error "astFromDynamic: different rank expected and uncovered" -astFromDynamic (AstSToD @sh2 v) = - case matchingRank @sh2 @n of - Just Refl -> astSToR v - _ -> error "astFromDynamic: different rank expected and uncovered" - -astFromDynamicS :: forall sh s r. Sh.Shape sh - => AstDynamic s r -> AstShaped s r sh -astFromDynamicS (AstSToD Ast.AstIotaS) = error "astFromDynamicS: dummy" -astFromDynamicS (AstSToD (Ast.AstLetADShareS l v)) = - Ast.AstLetADShareS l $ astFromDynamicS (AstSToD v) -astFromDynamicS (AstSToD @sh2 v) = - case sameShape @sh @sh2 of - Just Refl -> v - _ -> error "astFromDynamicS: different shape expected and uncovered" -astFromDynamicS (AstRToD @n2 v) = - case matchingRank @sh @n2 of - Just Refl -> astRToS v - _ -> error "astFromDynamicS: different rank expected and uncovered" - astPrimalPart :: (GoodScalar r, KnownNat n) => AstRanked FullSpan r n -> AstRanked PrimalSpan r n astPrimalPart t = case t of @@ -1660,11 +1631,11 @@ astLetDomainsIn vars l v = in Sh.withShapeP (shapeToList sh) $ \proxy -> case proxy of Proxy @sh | Just Refl <- matchingRank @sh @n -> case l of Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomains - let f :: (AstDynamicVarName, DynamicExists (AstDynamic s)) + let f :: (AstDynamicVarName, AstDynamic s) -> AstRanked s2 r n -> AstRanked s2 r n f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) - , DynamicExists @r4 (Ast.AstRToD @n4 v3) ) + , DynamicRanked @r4 @n4 v3 ) acc | Just Refl <- matchingRank @sh3 @n4 -- To impose such checks, we'd need to switch from OD tensors @@ -1674,12 +1645,29 @@ astLetDomainsIn vars l v = , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = Ast.AstLet (AstVarName varId) v3 acc f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) - , DynamicExists @r4 (Ast.AstSToD @sh4 v3) ) + , DynamicShaped @r4 @sh4 v3 ) acc | Just Refl <- sameShape @sh3 @sh4 , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = Ast.AstSToR @sh $ Ast.AstLetS (AstVarName varId) v3 $ Ast.AstRToS acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicRankedDummy @r4 @sh4 _ _ ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + withListShape (Sh.shapeT @sh3) $ \(_ :: Shape m Int) -> + gcastWith (unsafeCoerce Refl :: m :~: Sh.Rank sh3) $ + Ast.AstLet @m + (AstVarName varId) (Ast.AstSToR @sh3 @s @r3 0) acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicShapedDummy @r4 @sh4 _ _ ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstSToR + $ Ast.AstLetS @sh4 @sh @r4 @s @s2 (AstVarName varId) 0 + $ Ast.AstRToS acc f _ _ = error "astLetDomainsIn: corrupted arguments" in foldr f v (zip vars (V.toList l3)) _ -> Ast.AstLetDomainsIn vars l v @@ -1695,22 +1683,38 @@ astLetDomainsInS vars l v = Just (SomeNat @n _) -> gcastWith (unsafeCoerce Refl :: n :~: Sh.Rank sh) $ case l of Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomainsS - let f :: (AstDynamicVarName, DynamicExists (AstDynamic s)) + let f :: (AstDynamicVarName, AstDynamic s) -> AstShaped s2 r sh -> AstShaped s2 r sh f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) - , DynamicExists @r4 (Ast.AstRToD @n4 v3) ) + , DynamicRanked @r4 @n4 v3 ) acc | Just Refl <- matchingRank @sh3 @n4 , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = Ast.AstRToS @sh $ Ast.AstLet (AstVarName varId) v3 $ Ast.AstSToR acc f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) - , DynamicExists @r4 (Ast.AstSToD @sh4 v3) ) + , DynamicShaped @r4 @sh4 v3 ) acc | Just Refl <- sameShape @sh3 @sh4 , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = Ast.AstLetS (AstVarName varId) v3 acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicRankedDummy @r4 @sh4 _ _ ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + withListShape (Sh.shapeT @sh3) $ \(_ :: Shape m Int) -> + gcastWith (unsafeCoerce Refl :: m :~: Sh.Rank sh3) $ + Ast.AstRToS @sh + $ Ast.AstLet @m (AstVarName varId) (Ast.AstSToR @sh3 @s @r3 0) + $ Ast.AstSToR acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicShapedDummy @r4 @sh4 _ _ ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstLetS @sh4 @sh @r4 @s @s2 (AstVarName varId) 0 acc f _ _ = error "astLetDomainsInS: corrupted arguments" in foldr f v (zip vars (V.toList l3)) _ -> Ast.AstLetDomainsInS vars l v @@ -1831,11 +1835,13 @@ simplifyAst t = case t of simplifyAstDynamic :: AstSpan s - => DynamicExists (AstDynamic s) -> DynamicExists (AstDynamic s) -simplifyAstDynamic (DynamicExists (AstRToD u)) = - DynamicExists $ AstRToD $ simplifyAst u -simplifyAstDynamic (DynamicExists (AstSToD u)) = - DynamicExists $ AstSToD $ simplifyAstS u + => AstDynamic s -> AstDynamic s +simplifyAstDynamic (DynamicRanked u) = + DynamicRanked $ simplifyAst u +simplifyAstDynamic (DynamicShaped u) = + DynamicShaped $ simplifyAstS u +simplifyAstDynamic u@DynamicRankedDummy{} = u +simplifyAstDynamic u@DynamicShapedDummy{} = u simplifyAstDomains :: AstSpan s => AstDomains s -> AstDomains s @@ -2383,13 +2389,11 @@ substitute1Ast i var v1 = case v1 of Just $ astLetDomainsIn vars (fromMaybe l ml) (fromMaybe v mv) Ast.AstFwd f args ds -> -- No other free variables in v and var is not among vars. - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args marg = if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing - mds = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) ds + mds = V.map (substitute1AstDynamic i var) ds md = if V.any isJust mds then Just $ V.zipWith fromMaybe ds mds else Nothing @@ -2417,12 +2421,14 @@ substitute1AstIndex i var ix = else Nothing substitute1AstDynamic - :: (GoodScalar r, GoodScalar r2, AstSpan s, AstSpan s2) - => SubstitutionPayload s2 r2 -> AstVarId -> AstDynamic s r - -> Maybe (AstDynamic s r) + :: (GoodScalar r2, AstSpan s, AstSpan s2) + => SubstitutionPayload s2 r2 -> AstVarId -> AstDynamic s + -> Maybe (AstDynamic s) substitute1AstDynamic i var = \case - Ast.AstRToD t -> Ast.AstRToD <$> substitute1Ast i var t - Ast.AstSToD t -> Ast.AstSToD <$> substitute1AstS i var t + DynamicRanked t -> DynamicRanked <$> substitute1Ast i var t + DynamicShaped t -> DynamicShaped <$> substitute1AstS i var t + DynamicRankedDummy{} -> Nothing + DynamicShapedDummy{} -> Nothing substitute1AstDomains :: (GoodScalar r2, AstSpan s, AstSpan s2) @@ -2430,8 +2436,7 @@ substitute1AstDomains -> Maybe (AstDomains s) substitute1AstDomains i var = \case Ast.AstDomains args -> - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args in if V.any isJust margs then Just $ Ast.AstDomains $ V.zipWith fromMaybe args margs else Nothing @@ -2446,15 +2451,13 @@ substitute1AstDomains i var = \case Ast.AstRev (vars, v) args -> -- No other free variables in v and var is not among vars. Ast.AstRev (vars, v) <$> - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args in if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing Ast.AstRevDt (vars, v) args dt -> -- No other free variables in v and var is not among vars. - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args marg = if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing @@ -2465,15 +2468,13 @@ substitute1AstDomains i var = \case Ast.AstRevS (vars, v) args -> -- No other free variables in v and var is not among vars. Ast.AstRevS (vars, v) <$> - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args in if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing Ast.AstRevDtS (vars, v) args dt -> -- No other free variables in v and var is not among vars. - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args marg = if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing @@ -2636,13 +2637,11 @@ substitute1AstS i var = \case Just $ astLetDomainsInS vars (fromMaybe l ml) (fromMaybe v mv) Ast.AstFwdS (vars, v) args ds -> -- No other free variables in v and var is not among vars. - let margs = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) args + let margs = V.map (substitute1AstDynamic i var) args marg = if V.any isJust margs then Just $ V.zipWith fromMaybe args margs else Nothing - mds = V.map (\(DynamicExists d) -> - DynamicExists <$> substitute1AstDynamic i var d) ds + mds = V.map (substitute1AstDynamic i var) ds md = if V.any isJust mds then Just $ V.zipWith fromMaybe ds mds else Nothing diff --git a/src/HordeAd/Core/AstTools.hs b/src/HordeAd/Core/AstTools.hs index 5dea915b3..54bfc8f43 100644 --- a/src/HordeAd/Core/AstTools.hs +++ b/src/HordeAd/Core/AstTools.hs @@ -27,7 +27,9 @@ import GHC.TypeLits (KnownNat, sameNat, type (+)) import Unsafe.Coerce (unsafeCoerce) import HordeAd.Core.Ast +import HordeAd.Core.TensorClass import HordeAd.Core.Types +import HordeAd.Internal.OrthotopeOrphanInstances (matchingRank) import HordeAd.Util.SizedIndex -- * Shape calculation @@ -148,7 +150,7 @@ varInAst var = \case AstD u u' -> varInAst var u || varInAst var u' AstLetDomainsIn _vars l v -> varInAstDomains var l || varInAst var v AstFwd _f l ds -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d + let f = varInAstDynamic var in any f l || any f ds AstFold _f x0 as -> varInAst var x0 || varInAst var as AstFoldDer _f _df _rf x0 as -> varInAst var x0 || varInAst var as @@ -156,28 +158,27 @@ varInAst var = \case varInAstDomains :: AstSpan s => AstVarId -> AstDomains s -> Bool varInAstDomains var = \case - AstDomains l -> let f (DynamicExists d) = varInAstDynamic var d - in any f l + AstDomains l -> any (varInAstDynamic var) l AstLetInDomains _var2 u v -> varInAst var u || varInAstDomains var v AstLetInDomainsS _var2 u v -> varInAstS var u || varInAstDomains var v AstRev _f l -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d - in any f l + any (varInAstDynamic var) l AstRevDt _f l dt -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d + let f = varInAstDynamic var in any f l || varInAst var dt AstRevS _f l -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d - in any f l + any (varInAstDynamic var) l AstRevDtS _f l dt -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d + let f = varInAstDynamic var in any f l || varInAstS var dt varInAstDynamic :: AstSpan s - => AstVarId -> AstDynamic s r -> Bool + => AstVarId -> AstDynamic s -> Bool varInAstDynamic var = \case - AstRToD t -> varInAst var t - AstSToD t -> varInAstS var t + DynamicRanked t -> varInAst var t + DynamicShaped t -> varInAstS var t + DynamicRankedDummy{} -> False + DynamicShapedDummy{} -> False varInAstBool :: AstVarId -> AstBool -> Bool varInAstBool var = \case @@ -231,7 +232,7 @@ varInAstS var = \case AstDS u u' -> varInAstS var u || varInAstS var u' AstLetDomainsInS _vars l v -> varInAstDomains var l || varInAstS var v AstFwdS _f l ds -> -- _f has no non-bound variables - let f (DynamicExists d) = varInAstDynamic var d + let f = varInAstDynamic var in any f l || any f ds AstFoldS _f x0 as -> varInAstS var x0 || varInAstS var as AstFoldDerS _f _df _rf x0 as -> varInAstS var x0 || varInAstS var as @@ -289,44 +290,68 @@ astIsSmallS relaxed = \case -- * Odds and ends -bindsToLet :: forall n s r. (KnownNat n, GoodScalar r) +bindsToLet :: forall n s r. (AstSpan s, KnownNat n, GoodScalar r) => AstRanked s r n -> AstBindings -> AstRanked s r n {-# INLINE bindsToLet #-} -- help list fusion bindsToLet = foldl' bindToLet where bindToLet :: AstRanked s r n - -> (AstVarId, DynamicExists (AstDynamic PrimalSpan)) + -> (AstVarId, AstDynamic PrimalSpan) -> AstRanked s r n - bindToLet !u (var, DynamicExists d) = case d of - AstRToD w -> AstLet (AstVarName var) w u - AstSToD w -> - let shList = shapeToList $ shapeAst u - in if valueOf @n == length shList - then Sh.withShapeP shList $ \(_proxy :: Proxy sh) -> - gcastWith (unsafeCoerce Refl :: Sh.Rank sh :~: n) $ - AstSToR @sh $ AstLetS (AstVarName var) w (AstRToS u) - else error "bindsToLet: rank mismatch" + bindToLet !u (var, d) = + let convertShaped :: (GoodScalar r2, Sh.Shape sh2) + => AstShaped PrimalSpan r2 sh2 -> AstRanked s r n + convertShaped t = + Sh.withShapeP (shapeToList $ shapeAst u) $ \proxy -> case proxy of + Proxy @sh | Just Refl <- matchingRank @sh @n -> + AstSToR @sh $ AstLetS (AstVarName var) t (AstRToS u) + _ -> error "bindToLet: wrong rank" + in case d of + DynamicRanked w -> AstLet (AstVarName var) w u + DynamicShaped w -> convertShaped w + DynamicRankedDummy @r2 @sh2 _ _ -> + withListShape (Sh.shapeT @sh2) $ \(_ :: Shape n3 Int) -> + gcastWith (unsafeCoerce Refl :: n3 :~: Sh.Rank sh2) $ + AstLet @n3 @n @r2 @s (AstVarName var) (AstSToR @sh2 @s @r2 0) u + DynamicShapedDummy @r2 @sh2 _ _ -> convertShaped @r2 @sh2 0 -bindsToLetS :: forall sh s r. Sh.Shape sh +bindsToLetS :: forall sh s r. (AstSpan s, Sh.Shape sh) => AstShaped s r sh -> AstBindings -> AstShaped s r sh {-# INLINE bindsToLetS #-} -- help list fusion bindsToLetS = foldl' bindToLetS where bindToLetS :: AstShaped s r sh - -> (AstVarId, DynamicExists (AstDynamic PrimalSpan)) + -> (AstVarId, AstDynamic PrimalSpan) -> AstShaped s r sh - bindToLetS !u (var, DynamicExists d) = case d of - AstRToD w -> - withListShape (Sh.shapeT @sh) $ \ (_ :: Shape n Int) -> - gcastWith (unsafeCoerce Refl :: Sh.Rank sh :~: n) - $ AstRToS $ AstLet (AstVarName var) w (AstSToR u) - AstSToD w -> AstLetS (AstVarName var) w u + bindToLetS !u (var, d) = case d of + DynamicRanked w -> + withListShape (Sh.shapeT @sh) $ \sh -> case sh of + (_ :: Shape n Int) | Just Refl <- matchingRank @sh @n -> + AstRToS $ AstLet (AstVarName var) w (AstSToR u) + _ -> error "bindToLetS: wrong rank" + DynamicShaped w -> AstLetS (AstVarName var) w u + DynamicRankedDummy @r2 @sh2 _ _ -> + withListShape (Sh.shapeT @sh2) $ \(_ :: Shape n3 Int) -> + gcastWith (unsafeCoerce Refl :: n3 :~: Sh.Rank sh2) $ + withListShape (Sh.shapeT @sh) $ \(_ :: Shape m Int) -> + gcastWith (unsafeCoerce Refl :: m :~: Sh.Rank sh) $ + AstRToS $ AstLet @n3 @m @r2 @s + (AstVarName var) (AstSToR @sh2 @s @r2 0) (AstSToR u) + DynamicShapedDummy @r2 @sh2 _ _ -> + AstLetS @sh2 @sh @r2 @s (AstVarName var) 0 u bindsToDomainsLet - :: AstDomains s -> AstBindings -> AstDomains s + :: forall s. AstSpan s + => AstDomains s -> AstBindings -> AstDomains s {-# INLINE bindsToDomainsLet #-} -- help list fusion bindsToDomainsLet = foldl' bindToDomainsLet where - bindToDomainsLet !u (var, DynamicExists d) = case d of - AstRToD w -> AstLetInDomains (AstVarName var) w u - AstSToD w -> AstLetInDomainsS (AstVarName var) w u + bindToDomainsLet !u (var, d) = case d of + DynamicRanked w -> AstLetInDomains (AstVarName var) w u + DynamicShaped w -> AstLetInDomainsS (AstVarName var) w u + DynamicRankedDummy @r2 @sh2 _ _ -> + withListShape (Sh.shapeT @sh2) $ \(_ :: Shape n Int) -> + gcastWith (unsafeCoerce Refl :: n :~: Sh.Rank sh2) $ + AstLetInDomains @n @r2 @s (AstVarName var) (AstSToR @sh2 @s @r2 0) u + DynamicShapedDummy @r2 @sh2 _ _ -> + AstLetInDomainsS @sh2 @r2 @s (AstVarName var) 0 u diff --git a/src/HordeAd/Core/AstVectorize.hs b/src/HordeAd/Core/AstVectorize.hs index 1bcbf2ac8..11876ba87 100644 --- a/src/HordeAd/Core/AstVectorize.hs +++ b/src/HordeAd/Core/AstVectorize.hs @@ -321,17 +321,20 @@ build1V k (var, v00) = (astTr $ build1VOccurenceUnknown k (var, as)) build1VOccurenceUnknownDynamic - :: AstSpan s - => Int -> (IntVarName, DynamicExists (AstDynamic s)) - -> DynamicExists (AstDynamic s) + :: forall s. AstSpan s + => Int -> (IntVarName, AstDynamic s) -> AstDynamic s build1VOccurenceUnknownDynamic k (var, d) = case d of - DynamicExists (AstRToD u) -> - DynamicExists $ AstRToD $ build1VOccurenceUnknown k (var, u) - DynamicExists (AstSToD u) -> case someNatVal $ toInteger k of + DynamicRanked u -> DynamicRanked $ build1VOccurenceUnknown k (var, u) + DynamicShaped u -> case someNatVal $ toInteger k of Just (SomeNat @k _proxy) -> - DynamicExists $ AstSToD $ build1VOccurenceUnknownS @k (var, u) + DynamicShaped $ build1VOccurenceUnknownS @k (var, u) Nothing -> error "build1VOccurenceUnknownDynamic: impossible someNatVal error" + DynamicRankedDummy @r @sh _ _ -> + withListShape (Sh.shapeT @sh) $ \(_ :: Shape n3 Int) -> + gcastWith (unsafeCoerce Refl :: n3 :~: Sh.Rank sh) $ + DynamicRanked @r (Ast.AstSToR @sh @s @r 0) + DynamicShapedDummy @r @sh _ _ -> DynamicShaped @r @sh 0 build1VOccurenceUnknownDomains :: forall s. AstSpan s diff --git a/src/HordeAd/Core/Delta.hs b/src/HordeAd/Core/Delta.hs index 0d73e6c95..935897bd8 100644 --- a/src/HordeAd/Core/Delta.hs +++ b/src/HordeAd/Core/Delta.hs @@ -38,7 +38,7 @@ -- to understand. module HordeAd.Core.Delta ( -- * Abstract syntax trees of the delta expressions - DeltaR (..), DeltaS (..), DeltaD (..) + DeltaR (..), DeltaS (..) , -- * Delta expression identifiers NodeId (..), InputId, toInputId , -- * Evaluation of the delta expressions @@ -50,12 +50,10 @@ import Prelude import Control.Exception.Assert.Sugar import Control.Monad (liftM2) import Control.Monad.ST.Strict (ST, runST) -import qualified Data.Array.DynamicS as OD import Data.Array.Internal (valueOf) import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS -import Data.Bifunctor.Clown import Data.Bifunctor.Flip import qualified Data.EnumMap.Strict as EM import Data.Int (Int64) @@ -77,7 +75,7 @@ import Unsafe.Coerce (unsafeCoerce) import HordeAd.Core.TensorClass import HordeAd.Core.Types import HordeAd.Internal.OrthotopeOrphanInstances - (matchingRank, sameShape, trustMeThisIsAPermutation) + (sameShape, trustMeThisIsAPermutation) import HordeAd.Util.ShapedList (ShapedList (..)) import HordeAd.Util.SizedIndex @@ -246,10 +244,6 @@ data DeltaR :: RankedTensorKind -> ShapedTensorKind -> RankedTensorKind where -- in the first argument, should be strict in the accumulator. CastR :: (GoodScalar r1, RealFrac r1, RealFrac r2) => DeltaR ranked shaped r1 n -> DeltaR ranked shaped r2 n - - DToR :: forall n r ranked shaped. - DeltaD (Clown (DynamicOf ranked)) ranked shaped r '() - -> DeltaR ranked shaped r n SToR :: forall sh r ranked shaped. Sh.Shape sh => DeltaS ranked shaped r sh -> DeltaR ranked shaped r (Sh.Rank sh) @@ -373,10 +367,6 @@ data DeltaS :: RankedTensorKind -> ShapedTensorKind -> ShapedTensorKind where -> DeltaS ranked shaped rn sh CastS :: (GoodScalar r1, RealFrac r1, RealFrac r2) => DeltaS ranked shaped r1 sh -> DeltaS ranked shaped r2 sh - - DToS :: forall sh r ranked shaped. - DeltaD (Clown (DynamicOf ranked)) ranked shaped r '() - -> DeltaS ranked shaped r sh RToS :: forall sh r ranked shaped. KnownNat (Sh.Rank sh) => DeltaR ranked shaped r (Sh.Rank sh) -> DeltaS ranked shaped r sh @@ -390,25 +380,6 @@ deriving instance ( Sh.Shape sh0, GoodScalar r0 , Show (IntOf shaped) ) => Show (DeltaS ranked shaped r0 sh0) -type role DeltaD nominal nominal nominal nominal nominal -data DeltaD :: TensorKind () -> RankedTensorKind -> ShapedTensorKind - -> TensorKind () where - RToD :: forall n r ranked shaped. KnownNat n - => DeltaR ranked shaped r n - -> DeltaD (Clown (DynamicOf ranked)) ranked shaped r '() - SToD :: forall sh r ranked shaped. Sh.Shape sh - => DeltaS ranked shaped r sh - -> DeltaD (Clown (DynamicOf ranked)) ranked shaped r '() - -deriving instance ( GoodScalar r0 - , (forall nn4 rr. (KnownNat nn4, GoodScalar rr) - => Show (ranked rr nn4)) - , (forall sh r. (Sh.Shape sh, GoodScalar r) - => Show (shaped r sh)) - , Show (IntOf ranked) - , Show (IntOf shaped) ) - => Show (DeltaD clownDynamic ranked shaped r0 '()) - shapeDelta :: forall ranked shaped r n. (GoodScalar r, KnownNat n, RankedTensor ranked) => DeltaR ranked shaped r n -> ShapeInt n @@ -446,11 +417,6 @@ shapeDelta = \case GatherR sh _ _ -> sh FoldR _f x0 _as _df _rf _x0' _as' -> rshape x0 CastR d -> shapeDelta d - DToR (RToD @n2 d) -> - case sameNat (Proxy @n) (Proxy @n2) of - Just Refl -> shapeDelta d - _ -> error "shapeDelta: different ranks in DToR(RToD)" - DToR (SToD @sh _) -> listShapeToShape $ Sh.shapeT @sh SToR @sh _ -> listShapeToShape $ Sh.shapeT @sh lengthDelta :: forall ranked shaped r n. @@ -460,6 +426,10 @@ lengthDelta d = case shapeDelta d of ZS -> error "lengthDelta: impossible pattern needlessly required" k :$ _ -> k +type instance RankedOf (DeltaS ranked shaped) = DeltaR ranked shaped + +type instance ShapedOf (DeltaR ranked shaped) = DeltaS ranked shaped + -- * Delta expression identifiers @@ -488,60 +458,16 @@ class DualPart (f :: TensorKind k) where reverseDervative :: (HasSingletonDict y, GoodScalar r) => Bool -> DomainsOD -> f r y -> Maybe (f r y) -> Dual f r y - -> (AstBindingsD (DynamicOf f), Domains (DynamicOf f)) + -> (AstBindingsD (RankedOf f), Domains (RankedOf f)) forwardDerivative :: (HasSingletonDict y, GoodScalar r) - => Int -> Dual f r y -> Domains (DynamicOf f) - -> (AstBindingsD (DynamicOf f), f r y) - --- clownDynamic is, e.g., Clown OD.Array or Clown (AstDynamic s) -instance ( clownDynamic ~ Clown (DynamicOf (RankedOf clownDynamic)) - , clownDynamic ~ Clown (DynamicOf clownDynamic) - , RankedTensor (RankedOf clownDynamic) - , ShapedTensor (ShapedOf clownDynamic) - , ConvertTensor (RankedOf clownDynamic) (ShapedOf clownDynamic) ) - => DualPart @() clownDynamic where - type Dual clownDynamic = - DeltaD clownDynamic (RankedOf clownDynamic) (ShapedOf clownDynamic) - reverseDervative = gradientDtD - forwardDerivative = derivativeFromDeltaD - -gradientDtD - :: forall clownDynamic ranked shaped r (y :: ()). - ( clownDynamic ~ Clown (DynamicOf ranked) - , GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) - => Bool -> DomainsOD - -> clownDynamic r y - -> Maybe (clownDynamic r y) - -> DeltaD clownDynamic ranked shaped r y - -> ( AstBindingsD (DynamicOf ranked) - , Domains (DynamicOf ranked) ) -gradientDtD useDummies !parameters0 !value !mdt !deltaTopLevel = - withListShape (dshape @ranked (runClown value)) $ \sh -> - let dt = maybe (dfromR @ranked $ rreplicate0N sh 1) runClown mdt - deltaDt = DeltaDtD dt deltaTopLevel - in gradientFromDelta useDummies parameters0 deltaDt - -derivativeFromDeltaD - :: forall clownDynamic ranked shaped r (y :: ()). - ( clownDynamic ~ Clown (DynamicOf ranked) - , GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) - => Int - -> DeltaD clownDynamic ranked shaped r y - -> Domains (DynamicOf ranked) - -> ( AstBindingsD (DynamicOf ranked) - , clownDynamic r y ) -derivativeFromDeltaD !dim !deltaTopLevel !ds = - case runST $ buildDerivative dim (DeltaDtD (dfromR @ranked @shaped @r @0 0) - deltaTopLevel) ds of - (l, DeltaDtD res _) -> (l, Clown res) - (_, DeltaDtR{}) -> error "derivativeFromDeltaD" - (_, DeltaDtS{}) -> error "derivativeFromDeltaD" + => Int -> Dual f r y -> Domains (RankedOf f) + -> (AstBindingsD (RankedOf f), f r y) instance ( RankedTensor ranked, ShapedTensor (ShapedOf ranked) - , ConvertTensor ranked (ShapedOf ranked) ) + , ConvertTensor ranked (ShapedOf ranked) + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked + , RankedOf ranked ~ ranked ) => DualPart @Nat ranked where type Dual ranked = DeltaR ranked (ShapedOf ranked) reverseDervative = gradientDtR @@ -549,11 +475,11 @@ instance ( RankedTensor ranked, ShapedTensor (ShapedOf ranked) gradientDtR :: ( KnownNat y, GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) + , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) => Bool -> DomainsOD -> ranked r y -> Maybe (ranked r y) -> DeltaR ranked shaped r y - -> ( AstBindingsD (DynamicOf ranked) - , Domains (DynamicOf ranked) ) + -> (AstBindingsD ranked, Domains ranked) gradientDtR useDummies !parameters0 value !mdt !deltaTopLevel = let dt = fromMaybe (rreplicate0N (rshape value) 1) mdt deltaDt = DeltaDtR dt deltaTopLevel @@ -562,25 +488,24 @@ gradientDtR useDummies !parameters0 value !mdt !deltaTopLevel = :: KnownNat y => Bool -> DomainsOD -> Flip OR.Array Double y -> Maybe (Flip OR.Array Double y) -> DeltaR (Flip OR.Array) (Flip OS.Array) Double y - -> ( AstBindingsD (DynamicOf (Flip OR.Array)) - , Domains (DynamicOf (Flip OR.Array)) ) #-} + -> (AstBindingsD (Flip OR.Array), Domains (Flip OR.Array) ) #-} {- TODO: this causes a cyclic dependency: {-# SPECIALIZE gradientDtR :: KnownNat y => Bool -> DomainsOD -> AstRanked PrimalSpan Double y -> Maybe (AstRanked PrimalSpan Double y) -> DeltaR (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double y - -> ( AstBindingsD (DynamicOf (AstRanked PrimalSpan)) - , Domains (DynamicOf (AstRanked PrimalSpan)) ) #-} + -> ( AstBindingsD (DynamicTensor (AstRanked PrimalSpan)) + , Domains (DynamicTensor (AstRanked PrimalSpan)) ) #-} -} derivativeFromDeltaR :: forall ranked shaped r n. - ( KnownNat n, GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) - => Int -> DeltaR ranked shaped r n -> Domains (DynamicOf ranked) - -> ( AstBindingsD (DynamicOf ranked) - , ranked r n ) + ( KnownNat n, GoodScalar r + , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) + => Int -> DeltaR ranked shaped r n -> Domains ranked + -> (AstBindingsD ranked, ranked r n) derivativeFromDeltaR dim deltaTopLevel ds = let dummyZero = rzero $ listShapeToShape $ replicate (valueOf @n) 1 in case runST $ buildDerivative dim (DeltaDtR dummyZero deltaTopLevel) ds of @@ -588,10 +513,10 @@ derivativeFromDeltaR dim deltaTopLevel ds = Just Refl -> (l, res) _ -> error "derivativeFromDeltaR" (_, DeltaDtS{}) -> error "derivativeFromDeltaR" - (_, DeltaDtD{}) -> error "derivativeFromDeltaR" instance ( RankedTensor (RankedOf shaped), ShapedTensor shaped - , ConvertTensor (RankedOf shaped) shaped ) + , ConvertTensor (RankedOf shaped) shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) => DualPart @[Nat] shaped where type Dual shaped = DeltaS (RankedOf shaped) shaped reverseDervative useDummies parameters0 _ = gradientDtS useDummies parameters0 @@ -600,11 +525,11 @@ instance ( RankedTensor (RankedOf shaped), ShapedTensor shaped gradientDtS :: forall ranked shaped r y. ( Sh.Shape y, GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) + , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) => Bool -> DomainsOD -> Maybe (shaped r y) -> DeltaS ranked shaped r y - -> ( AstBindingsD (DynamicOf shaped) - , Domains (DynamicOf shaped) ) + -> (AstBindingsD ranked, Domains ranked) gradientDtS useDummies !parameters0 !mdt !deltaTopLevel = let dt = fromMaybe 1 mdt deltaDt = DeltaDtS dt deltaTopLevel @@ -613,37 +538,35 @@ gradientDtS useDummies !parameters0 !mdt !deltaTopLevel = :: Sh.Shape y => Bool -> DomainsOD -> Maybe (Flip OS.Array Double y) -> DeltaS (Flip OR.Array) (Flip OS.Array) Double y - -> ( AstBindingsD (DynamicOf (Flip OS.Array)) - , Domains (DynamicOf (Flip OS.Array)) ) #-} + -> (AstBindingsD (Flip OR.Array), Domains (Flip OR.Array)) #-} {- TODO: this causes a cyclic dependency: {-# SPECIALIZE gradientDtS :: Sh.Shape y => Bool -> DomainsOD -> Maybe (AstShaped PrimalSpan Double y) -> DeltaS (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double y - -> ( AstBindingsD (DynamicOf (AstShaped PrimalSpan)) - , Domains (DynamicOf (AstShaped PrimalSpan)) ) #-} + -> ( AstBindingsD (DynamicTensor (AstShaped PrimalSpan)) + , Domains (DynamicTensor (AstShaped PrimalSpan)) ) #-} -} derivativeFromDeltaS :: forall ranked shaped r sh. ( Sh.Shape sh, GoodScalar r - , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped ) - => Int -> DeltaS ranked shaped r sh -> Domains (DynamicOf shaped) - -> ( AstBindingsD (DynamicOf shaped) - , shaped r sh ) + , RankedTensor ranked, ShapedTensor shaped, ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) + => Int -> DeltaS ranked shaped r sh -> Domains ranked + -> (AstBindingsD ranked, shaped r sh) derivativeFromDeltaS !dim !deltaTopLevel !ds = case runST $ buildDerivative dim (DeltaDtS 0 deltaTopLevel) ds of (l, DeltaDtS @_ @sh2 res _) -> case sameShape @sh @sh2 of Just Refl -> (l, res) _ -> error "derivativeFromDeltaS" (_, DeltaDtR{}) -> error "derivativeFromDeltaS" - (_, DeltaDtD{}) -> error "derivativeFromDeltaS" -- | The main input of the differentiation functions: -- the delta expression to be differentiated and the dt perturbation -- (small change) of the objective function codomain, for which we compute -- the gradient. -type role DeltaDt nominal nominal nominal -- nominal due to DynamicOf family +type role DeltaDt nominal nominal nominal data DeltaDt :: RankedTensorKind -> ShapedTensorKind -> Type -> Type where DeltaDtR :: forall r n ranked shaped. KnownNat n => ranked r n -> DeltaR ranked shaped r n @@ -651,10 +574,6 @@ data DeltaDt :: RankedTensorKind -> ShapedTensorKind -> Type -> Type where DeltaDtS :: forall r sh ranked shaped. Sh.Shape sh => shaped r sh -> DeltaS ranked shaped r sh -> DeltaDt ranked shaped r - DeltaDtD :: forall r (y :: ()) ranked shaped. - DynamicOf ranked r - -> DeltaD (Clown (DynamicOf ranked)) ranked shaped r y - -> DeltaDt ranked shaped r -- * Reverse pass, transpose/evaluation of the delta expressions @@ -670,18 +589,17 @@ data DeltaDt :: RankedTensorKind -> ShapedTensorKind -> Type -> Type where -- 2. key `member` dMap == nMap!key is DeltaBindingR type role EvalState nominal nominal data EvalState ranked shaped = EvalState - { iMap :: EM.EnumMap (InputId ranked) - (DynamicExists (DynamicOf ranked)) + { iMap :: EM.EnumMap (InputId ranked) (DynamicTensor ranked) -- ^ eventually, cotangents of objective function inputs -- (eventually copied to the vector representing the gradient -- of the objective function); -- the identifiers need to be contiguous and start at 0 - , dMap :: EM.EnumMap (NodeId ranked) (DynamicExists (DynamicOf ranked)) + , dMap :: EM.EnumMap (NodeId ranked) (DynamicTensor ranked) -- ^ eventually, cotangents of non-input subterms indexed -- by their node identifiers , nMap :: EM.EnumMap (NodeId ranked) (DeltaBinding ranked shaped) -- ^ nodes left to be evaluated - , astBindings :: AstBindingsD (DynamicOf ranked) + , astBindings :: AstBindingsD ranked } -- | Nodes left to be evaluated. @@ -751,11 +669,11 @@ data DeltaBinding :: RankedTensorKind -> ShapedTensorKind -> Type where -- value (usually set to @1@) is given in the @DeltaDt ranked r@ parameter. gradientFromDelta :: forall ranked shaped r. - ( GoodScalar r, RankedTensor ranked, ShapedTensor shaped - , ConvertTensor ranked shaped ) + ( GoodScalar r, RankedTensor ranked, ShapedTensor shaped + , ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) => Bool -> DomainsOD -> DeltaDt ranked shaped r - -> ( AstBindingsD (DynamicOf ranked) - , Domains (DynamicOf ranked) ) + -> (AstBindingsD ranked, Domains ranked) gradientFromDelta useDummies !parameters0 !deltaDt = -- Create finite maps that hold values associated with inputs -- and with (possibly shared) term tree nodes. @@ -768,18 +686,18 @@ gradientFromDelta useDummies !parameters0 !deltaDt = -- but a shape is not preserved in a dummy, so it's not shape-correct. let s0 = let iMap = - if useDummies - then let f (DynamicExists @re _) = - DynamicExists $ ddummy @ranked @shaped @re - in EM.fromDistinctAscList - $ zip [toInputId 0 ..] $ map f $ V.toList parameters0 - else let f :: DynamicExists OD.Array - -> DynamicExists (DynamicOf ranked) - f (DynamicExists @re d) = - withListShape (dshape @(Flip OR.Array) d) $ \sh -> - DynamicExists $ dfromR $ rzero @ranked @re sh - in EM.fromDistinctAscList - $ zip [toInputId 0 ..] $ map f $ V.toList parameters0 + -- The first two cases are permitted for when the normal main + -- parameters are used as parameters0. + let f (DynamicRanked @r2 @n2 t) = + let sh = rshape @(Flip OR.Array) t + in Sh.withShapeP (shapeToList sh) $ \(Proxy @sh2) -> + DynamicRankedDummy @r2 @sh2 Proxy Proxy + f (DynamicShaped @r2 @sh2 _) = + DynamicShapedDummy @r2 @sh2 Proxy Proxy + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in EM.fromDistinctAscList + $ zip [toInputId 0 ..] $ map f $ V.toList parameters0 dMap = EM.empty nMap = EM.empty astBindings = [] @@ -792,17 +710,18 @@ gradientFromDelta useDummies !parameters0 !deltaDt = -- The warnings in the following seems spurious. A GHC issue to be opened. {-# SPECIALIZE gradientFromDelta :: Bool -> DomainsOD -> DeltaDt (Flip OR.Array) (Flip OS.Array) Double - -> (AstBindingsD OD.Array, DomainsOD) #-} + -> (AstBindingsD (Flip OR.Array), DomainsOD) #-} {- TODO: this causes a cyclic dependency: {-# SPECIALIZE gradientFromDelta :: Bool -> DomainsOD -> DeltaDt (AstRanked PrimalSpan) (AstShaped PrimalSpan) Double - -> (AstBindingsD (DynamicOf (AstRanked PrimalSpan)), Domains (AstDynamic PrimalSpan)) #-} + -> (AstBindingsD (DynamicTensor (AstRanked PrimalSpan)), Domains (AstDynamic PrimalSpan)) #-} -} buildFinMaps :: forall ranked shaped r0. ( GoodScalar r0, RankedTensor ranked, ShapedTensor shaped - , ConvertTensor ranked shaped ) + , ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) => EvalState ranked shaped -> DeltaDt ranked shaped r0 -> EvalState ranked shaped buildFinMaps s0 deltaDt = @@ -878,14 +797,13 @@ buildFinMaps s0 deltaDt = -- result of the evaluation. assert (case d of ZeroR{} -> False - DToR{} -> False LetR{} -> False -- wasteful and nonsensical _ -> True) $ case EM.lookup n $ nMap s of Just (DeltaBindingR _) -> s {dMap = EM.adjust (raddDynamic c) n $ dMap s} Nothing -> - let cs = DynamicExists $ dfromR c + let cs = DynamicRanked c in s { nMap = EM.insert n (DeltaBindingR d) $ nMap s , dMap = EM.insert n cs $ dMap s } _ -> error "buildFinMaps: corrupted nMap" @@ -933,14 +851,6 @@ buildFinMaps s0 deltaDt = in evalR s2 (rfromList cas) as' CastR d -> evalRRuntimeSpecialized s (rcast c) d - DToR (RToD @n2 d) -> - case sameNat (Proxy @n) (Proxy @n2) of - Just Refl -> evalR s c d - _ -> error "buildFinMaps: different ranks in DToR(RToD)" - DToR (SToD @sh2 d) -> - case matchingRank @sh2 @n of - Just Refl -> evalR s c (SToR d) - _ -> error "buildFinMaps: different ranks in DToR(SToD)" SToR (RToS d) -> evalR s c d -- no information lost, so no checks SToR d -> evalS s (sfromR c) d @@ -974,14 +884,13 @@ buildFinMaps s0 deltaDt = LetS n d -> assert (case d of ZeroS -> False - DToS{} -> False LetS{} -> False -- wasteful and nonsensical _ -> True) $ case EM.lookup n $ nMap s of Just (DeltaBindingS _) -> s {dMap = EM.adjust (saddDynamic c) n $ dMap s} Nothing -> - let cs = DynamicExists $ dfromS c + let cs = DynamicShaped c in s { nMap = EM.insert n (DeltaBindingS d) $ nMap s , dMap = EM.insert n cs $ dMap s } _ -> error "buildFinMaps: corrupted nMap" @@ -1019,7 +928,7 @@ buildFinMaps s0 deltaDt = -- in the other direction? What if backend don't have it? let perm = Sh.shapeT @perm permRev = map snd $ sort $ zip perm [0 .. length perm - 1] - in Sh.withShapeP permRev $ \(_proxy :: Proxy permR) -> + in Sh.withShapeP permRev $ \(Proxy @permR) -> gcastWith (unsafeCoerce Refl :: Sh.Permute permR sh :~: sh2) $ gcastWith (unsafeCoerce Refl @@ -1040,15 +949,6 @@ buildFinMaps s0 deltaDt = s2 = evalS sShared cx0 x0' in evalS s2 (sfromList cas) as' CastS d -> evalSRuntimeSpecialized s (scast c) d - - DToS (SToD @sh2 d) -> - case sameShape @sh @sh2 of - Just Refl -> evalS s c d - _ -> error "buildFinMaps: different shapes in DToS(SToD)" - DToS (RToD @n2 d) -> - case matchingRank @sh @n2 of - Just Refl -> evalS s c (RToS d) - _ -> error "buildFinMaps: different ranks in DToS(RToD)" RToS (SToR @sh2 d) -> case sameShape @sh @sh2 of Just Refl -> evalS s c d @@ -1085,39 +985,43 @@ buildFinMaps s0 deltaDt = _ -> error "buildFinMaps: corrupted nMap" -} - evalD - :: GoodScalar r - => EvalState ranked shaped - -> DynamicOf ranked r - -> DeltaD (Clown (DynamicOf ranked)) ranked shaped r y - -> EvalState ranked shaped - evalD s !c = \case - RToD d -> evalR s (rfromD c) d - SToD d -> evalS s (sfromD c) d - evalFromnMap :: EvalState ranked shaped -> EvalState ranked shaped evalFromnMap s@EvalState{nMap, dMap} = case EM.maxViewWithKey nMap of Just ((n, b), nMap2) -> let s2 = s {nMap = nMap2} s3 = case b of - DeltaBindingR @_ @r1 d -> case dMap EM.! n of - DynamicExists @r2 e -> - case testEquality (typeRep @r1) (typeRep @r2) of - Just Refl -> let c = rfromD e - in evalRRuntimeSpecialized s2 c d + DeltaBindingR @n1 @r1 d -> case dMap EM.! n of + DynamicRanked @r2 @n2 c -> case sameNat (Proxy @n2) + (Proxy @n1) of + Just Refl -> case testEquality (typeRep @r1) + (typeRep @r2) of + Just Refl -> evalRRuntimeSpecialized s2 c d _ -> error "buildFinMaps: type mismatch" - DeltaBindingS @_ @r1 d -> case dMap EM.! n of - DynamicExists @r2 e -> - case testEquality (typeRep @r1) (typeRep @r2) of - Just Refl -> let c = sfromD e - in evalSRuntimeSpecialized s2 c d + _ -> error "buildFinMaps: rank mismatch" + DynamicShaped{} -> + error "evalFromnMap: DynamicShaped" + DynamicRankedDummy{} -> + error "evalFromnMap: DynamicRankedDummy" + DynamicShapedDummy{} -> + error "evalFromnMap: DynamicShapedDummy" + DeltaBindingS @sh1 @r1 d -> case dMap EM.! n of + DynamicRanked{} -> + error "evalFromnMap: DynamicRanked" + DynamicShaped @r2 @sh2 c -> case sameShape @sh2 @sh1 of + Just Refl -> case testEquality (typeRep @r1) + (typeRep @r2) of + Just Refl -> evalSRuntimeSpecialized s2 c d _ -> error "buildFinMaps: type mismatch" + _ -> error "buildFinMaps: shape mismatch" + DynamicRankedDummy{} -> + error "evalFromnMap: DynamicRankedDummy" + DynamicShapedDummy{} -> + error "evalFromnMap: DynamicShapedDummy" in evalFromnMap s3 Nothing -> s -- loop ends s1 = case deltaDt of - DeltaDtD dt deltaTopLevel -> evalD s0 dt deltaTopLevel DeltaDtR dt deltaTopLevel -> evalR s0 dt deltaTopLevel DeltaDtS dt deltaTopLevel -> evalS s0 dt deltaTopLevel in evalFromnMap s1 @@ -1151,10 +1055,10 @@ buildFinMaps s0 deltaDt = buildDerivative :: forall ranked shaped r0 s. ( GoodScalar r0, RankedTensor ranked, ShapedTensor shaped - , ConvertTensor ranked shaped ) - => Int -> DeltaDt ranked shaped r0 -> Domains (DynamicOf ranked) - -> ST s ( AstBindingsD (DynamicOf ranked) - , DeltaDt ranked shaped r0 ) + , ConvertTensor ranked shaped + , ShapedOf ranked ~ shaped, RankedOf shaped ~ ranked ) + => Int -> DeltaDt ranked shaped r0 -> Domains ranked + -> ST s (AstBindingsD ranked, DeltaDt ranked shaped r0) buildDerivative dimR deltaDt params = do dMap <- newSTRef EM.empty nMap <- newSTRef EM.empty @@ -1167,10 +1071,14 @@ buildDerivative dimR deltaDt params = do InputR _ (InputId i) -> if i < dimR then case params V.! i of - DynamicExists @r2 e -> - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! rfromD @ranked @shaped @r e + DynamicRanked @r2 @n2 e -> case sameNat (Proxy @n2) (Proxy @n) of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> return e _ -> error "buildDerivative: type mismatch" + _ -> error "buildDerivative: rank mismatch" + DynamicShaped{} -> error "buildDerivative: DynamicShaped" + DynamicRankedDummy{} -> error "buildDerivative: DynamicRankedDummy" + DynamicShapedDummy{} -> error "buildDerivative: DynamicShapedDummy" else error "buildDerivative': wrong index for an input" ScaleR k d -> (* k) <$> evalR d AddR d e -> liftM2 (+) (evalR d) (evalR e) @@ -1180,10 +1088,17 @@ buildDerivative dimR deltaDt params = do Just (DeltaBindingR _) -> do dm <- readSTRef dMap case dm EM.! n of - DynamicExists @r2 t -> - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! rfromD @ranked @shaped @r t + DynamicRanked @r2 @n2 e -> case sameNat (Proxy @n2) + (Proxy @n) of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> return e _ -> error "buildDerivative: type mismatch" + _ -> error "buildDerivative: rank mismatch" + DynamicShaped{} -> error "buildDerivative: DynamicShaped" + DynamicRankedDummy{} -> + error "buildDerivative: DynamicRankedDummy" + DynamicShapedDummy{} -> + error "buildDerivative: DynamicShapedDummy" Nothing -> do cRaw <- evalR d ab <- readSTRef astBindings @@ -1192,7 +1107,7 @@ buildDerivative dimR deltaDt params = do nmNew <- readSTRef nMap writeSTRef nMap $! EM.insert n (DeltaBindingR d) nmNew dm <- readSTRef dMap - writeSTRef dMap $! EM.insert n (DynamicExists $ dfromR cShared) dm + writeSTRef dMap $! EM.insert n (DynamicRanked cShared) dm return cShared _ -> error "buildDerivative: corrupted nMap" @@ -1234,14 +1149,6 @@ buildDerivative dimR deltaDt params = do t <- evalR d return $! rcast t - DToR (RToD @n2 d) -> - case sameNat (Proxy @n) (Proxy @n2) of - Just Refl -> evalR d - _ -> error "buildDerivative: different ranks in DToR(RToD)" - DToR (SToD @sh2 d) -> - case matchingRank @sh2 @n of - Just Refl -> evalR (SToR d) - _ -> error "buildDerivative: different ranks in DToR(SToD)" SToR (RToS d) -> evalR d -- no information lost, so no checks SToR d -> rfromS <$> evalS d @@ -1253,10 +1160,14 @@ buildDerivative dimR deltaDt params = do InputS (InputId i) -> if i < dimR then case params V.! i of - DynamicExists @r2 e -> - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! sfromD @ranked @shaped @r e + DynamicRanked{} -> error "buildDerivative: DynamicRanked" + DynamicShaped @r2 @sh2 e -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> return e _ -> error "buildDerivative: type mismatch" + _ -> error "buildDerivative: shape mismatch" + DynamicRankedDummy{} -> error "buildDerivative: DynamicRankedDummy" + DynamicShapedDummy{} -> error "buildDerivative: DynamicShapedDummy" else error "buildDerivative: wrong index for an input" ScaleS k d -> (* k) <$> evalS d AddS d e -> liftM2 (+) (evalS d) (evalS e) @@ -1266,10 +1177,16 @@ buildDerivative dimR deltaDt params = do Just (DeltaBindingS _) -> do dm <- readSTRef dMap case dm EM.! n of - DynamicExists @r2 t -> - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> return $! sfromD @ranked @shaped @r t + DynamicRanked{} -> error "buildDerivative: DynamicRanked" + DynamicShaped @r2 @sh2 e -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> return e _ -> error "buildDerivative: type mismatch" + _ -> error "buildDerivative: shape mismatch" + DynamicRankedDummy{} -> + error "buildDerivative: DynamicRankedDummy" + DynamicShapedDummy{} -> + error "buildDerivative: DynamicShapedDummy" Nothing -> do cRaw <- evalS d ab <- readSTRef astBindings @@ -1278,7 +1195,7 @@ buildDerivative dimR deltaDt params = do nmNew <- readSTRef nMap writeSTRef nMap $! EM.insert n (DeltaBindingS d) nmNew dm <- readSTRef dMap - writeSTRef dMap $! EM.insert n (DynamicExists $ dfromS cShared) dm + writeSTRef dMap $! EM.insert n (DynamicShaped cShared) dm return cShared _ -> error "buildDerivative: corrupted nMap" @@ -1320,28 +1237,12 @@ buildDerivative dimR deltaDt params = do t <- evalS d return $! scast t - DToS (SToD @sh2 d) -> - case sameShape @sh @sh2 of - Just Refl -> evalS d - _ -> error "buildDerivative: different ranks in DToR(RToD)" - DToS (RToD @n2 d) -> - case matchingRank @sh @n2 of - Just Refl -> evalS (RToS d) - _ -> error "buildDerivative: different ranks in DToR(SToD)" RToS (SToR @sh2 d) -> case sameShape @sh @sh2 of Just Refl -> evalS d _ -> error "buildDerivative: different shapes in RToS(SToR)" RToS d -> sfromR <$> evalR d - evalD - :: GoodScalar r - => DeltaD (Clown (DynamicOf ranked)) ranked shaped r y - -> ST s (DynamicOf ranked r) - evalD = \case - RToD d -> dfromR <$> evalR d - SToD d -> dfromS <$> evalS d - -- A hack to fit both argument delta and, afterwards, the result in a type -- that does not reflect either. case deltaDt of @@ -1356,8 +1257,3 @@ buildDerivative dimR deltaDt params = do let !cDelta = DeltaDtS c ZeroS ab <- readSTRef astBindings return (ab, cDelta) - DeltaDtD _dt deltaTopLevel -> do - c <- evalD deltaTopLevel - let !cDelta = DeltaDtD c (SToD @'[] ZeroS) - ab <- readSTRef astBindings - return (ab, cDelta) diff --git a/src/HordeAd/Core/DualClass.hs b/src/HordeAd/Core/DualClass.hs index 99b46b629..e8a01722b 100644 --- a/src/HordeAd/Core/DualClass.hs +++ b/src/HordeAd/Core/DualClass.hs @@ -29,11 +29,9 @@ module HordeAd.Core.DualClass import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS -import Data.Bifunctor.Clown import Data.Bifunctor.Flip import Data.IORef.Unboxed (Counter, atomicAddCounter_, newCounter, writeIORefU) @@ -45,7 +43,6 @@ import HordeAd.Core.Ast import HordeAd.Core.Delta import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Util.SizedIndex -- * The class and its instances @@ -108,7 +105,6 @@ instance (GoodScalar r, KnownNat n) => IsPrimal (Flip OR.Array) r n where recordSharing d = case d of ZeroR{} -> d InputR{} -> d - DToR{} -> d SToR{} -> d LetR{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaR d @@ -126,21 +122,10 @@ instance (GoodScalar r, Sh.Shape sh) => IsPrimal (Flip OS.Array) r sh where recordSharing d = case d of ZeroS -> d InputS{} -> d - DToS{} -> d RToS{} -> d LetS{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaS d -instance GoodScalar r => IsPrimal (Clown OD.Array) r '() where - dZeroOfShape (Clown tsh) = - withListShape (dshape @(Flip OR.Array) tsh) $ \ (sh :: Shape n Int) -> - RToD @n (ZeroR sh) - dScale = undefined - dAdd = undefined - intOfShape = undefined - recordSharingPrimal = undefined - recordSharing = undefined - -- * Counter handling diff --git a/src/HordeAd/Core/DualNumber.hs b/src/HordeAd/Core/DualNumber.hs index 98be7b394..9badfcc46 100644 --- a/src/HordeAd/Core/DualNumber.hs +++ b/src/HordeAd/Core/DualNumber.hs @@ -7,7 +7,7 @@ -- the safely impure "HordeAd.Core.DualClass". module HordeAd.Core.DualNumber ( -- * The main dual number type - ADVal, dD, pattern D, dDnotShared, constantADVal, ADValClown + ADVal, dD, pattern D, dDnotShared, constantADVal -- * Auxiliary definitions , CRankedIP, indexPrimal, fromList, CRankedIPSh, indexPrimalS, fromListS , ensureToplevelSharing, scaleNotShared, addNotShared, multNotShared @@ -15,13 +15,12 @@ module HordeAd.Core.DualNumber -- * Reverse and forward derivative stages class and instances , DerivativeStages (..), UnletGradient (..) , crevOnADInputs, crevOnDomains, cfwdOnADInputs, cfwdOnDomains - , generateDeltaInputsOD, generateDeltaInputsAst, makeADInputs + , generateDeltaInputs, makeADInputs ) where import Prelude import Control.Exception.Assert.Sugar -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS @@ -30,18 +29,19 @@ import Data.Bifunctor.Flip import Data.Bifunctor.Product import Data.Functor.Const import Data.Kind (Constraint, Type) +import Data.Proxy (Proxy (Proxy)) import Data.Type.Equality (testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V -import GHC.TypeLits (KnownNat, Nat, type (+)) +import GHC.TypeLits (KnownNat, Nat, sameNat, type (+)) import Type.Reflection (typeRep) import HordeAd.Core.Ast import HordeAd.Core.AstEnv -import HordeAd.Core.AstTools import HordeAd.Core.Delta import HordeAd.Core.DualClass import HordeAd.Core.TensorClass import HordeAd.Core.Types +import HordeAd.Internal.OrthotopeOrphanInstances (sameShape) import HordeAd.Util.ShapedList (singletonShaped) import HordeAd.Util.SizedIndex @@ -87,8 +87,6 @@ dDnotShared = D constantADVal :: IsPrimal f r z => f r z -> ADVal f r z constantADVal a = dDnotShared emptyADShare a (dZeroOfShape a) -type ADValClown dynamic = Flip (ADVal (Clown dynamic)) '() - -- * Assorted instances @@ -181,25 +179,11 @@ instance IfF (ADVal (Flip OS.Array)) where ifF (_, b) v w = if b then v else w -} -type instance RankedOf (Clown (ADValClown dynamic)) = - ADVal (RankedOf @() (Clown dynamic)) - -type instance ShapedOf (Clown (ADValClown dynamic)) = - ADVal (ShapedOf @() (Clown dynamic)) - -type instance DynamicOf (Clown (ADValClown dynamic)) = - ADValClown dynamic - -type instance DomainsOf (Clown (ADValClown dynamic)) = - Domains (ADValClown dynamic) - type instance RankedOf (ADVal f) = ADVal (RankedOf f) type instance ShapedOf (ADVal f) = ADVal (ShapedOf f) -type instance DynamicOf (ADVal f) = ADValClown (DynamicOf f) - -type instance DomainsOf (ADVal f) = Domains (DynamicOf (ADVal f)) +type instance DomainsOf (ADVal f) = Domains (ADVal (RankedOf f)) type instance PrimalOf (ADVal f) = f @@ -219,7 +203,6 @@ instance (GoodScalar r, KnownNat n, RankedTensor (ADVal ranked)) recordSharing d = case d of ZeroR{} -> d InputR{} -> d - DToR{} -> d SToR{} -> d LetR{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaR d @@ -238,26 +221,10 @@ instance (GoodScalar r, Sh.Shape sh, ShapedTensor (ADVal shaped)) recordSharing d = case d of ZeroS -> d InputS{} -> d - DToS{} -> d RToS{} -> d LetS{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaS d -instance ( GoodScalar r - , dynamic ~ DynamicOf (ShapedOf @() (Clown dynamic)) - , ConvertTensor (RankedOf @() (Clown dynamic)) - (ShapedOf @() (Clown dynamic)) ) - => IsPrimal (Clown (Flip (ADVal (Clown dynamic)) '())) r '() where - dZeroOfShape (Clown (Flip (D _ (Clown tsh) _))) = - withListShape (dshape @(RankedOf @() (Clown dynamic)) tsh) - $ \ (sh :: Shape n Int) -> - RToD @n (ZeroR sh) - dScale = undefined - dAdd = undefined - intOfShape = undefined - recordSharingPrimal = undefined - recordSharing = undefined - -- * Auxiliary definitions @@ -299,15 +266,26 @@ dotParameters (Domains a0 a1) (Domains b0 b1) = else OD.toVector v1 LA.<.> OD.toVector u1) a1 b1) -} +zeroParameters :: forall ranked. RankedTensor ranked + => Domains ranked -> DomainsOD +zeroParameters = + let f (DynamicRanked @r @n t) = + let sh = rshape @ranked t + in Sh.withShapeP (shapeToList sh) $ \(Proxy @sh) -> + DynamicRankedDummy @r @sh Proxy Proxy + f (DynamicShaped @r @sh _) = DynamicShapedDummy @r @sh Proxy Proxy + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in V.map f + crevOnADInputs :: forall k (f :: TensorKind k) r y. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y ) + ( RankedTensor (ADVal (RankedOf f)) + , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y) => Bool -> Maybe (f r y) - -> (Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> Domains (DynamicOf (ADVal f)) - -> (DomainsOf f, f r y) + -> (Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> Domains (ADVal (RankedOf f)) + -> (DomainsOf (RankedOf f), f r y) -- The functions in which @revOnADInputs@ inlines are not inlined themselves -- in client code, so the bloat is limited. {-# INLINE crevOnADInputs #-} @@ -315,35 +293,30 @@ crevOnADInputs useDummies mdt f inputs = let -- Evaluate completely after terms constructed, to free memory -- before evaluation allocates new memory and new FFI is started. !(D l v deltaTopLevel) = f inputs in - let dToZero :: DynamicExists (DynamicOf (ADVal f)) -> DynamicExists OD.Array - dToZero (DynamicExists @re (Flip (D _ (Clown d) _))) = - DynamicExists @re $ OD.constant (dshape @(RankedOf f) d) 0 - parameters0 = V.map dToZero inputs + let parameters0 = zeroParameters inputs (!astBindings, !gradient) = reverseDervative useDummies parameters0 v mdt deltaTopLevel in (unletGradient @k @f l astBindings gradient, unletValue l [] v) crevOnDomains :: forall r y f. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , Dual (Clown (DynamicOf f)) - ~ DeltaD (Clown (DynamicOf f)) (RankedOf f) (ShapedOf f) - , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y) + ( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f)) + , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y + , RankedOf (ShapedOf f) ~ RankedOf f, ShapedOf (RankedOf f) ~ ShapedOf f ) => Bool -> Maybe (f r y) - -> (Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> Domains (DynamicOf f) - -> (DomainsOf f, f r y) + -> (Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> Domains (RankedOf f) + -> (DomainsOf (RankedOf f), f r y) crevOnDomains useDummies mdt f parameters = - let deltaInputs = generateDeltaInputsOD parameters + let deltaInputs = generateDeltaInputs parameters inputs = makeADInputs parameters deltaInputs in crevOnADInputs useDummies mdt f inputs cfwdOnADInputs :: (DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y) - => Domains (DynamicOf (ADVal f)) - -> (Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> Domains (DynamicOf f) + => Domains (ADVal (RankedOf f)) + -> (Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> Domains (RankedOf f) -> (f r y, f r y) {-# INLINE cfwdOnADInputs #-} cfwdOnADInputs inputs f ds = @@ -354,81 +327,63 @@ cfwdOnADInputs inputs f ds = cfwdOnDomains :: forall r y f. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , Dual (Clown (DynamicOf f)) - ~ DeltaD (Clown (DynamicOf f)) (RankedOf f) (ShapedOf f) - , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y ) - => Domains (DynamicOf f) - -> (Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> Domains (DynamicOf f) + ( RankedTensor (RankedOf f) + , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y + , RankedOf (ShapedOf f) ~ RankedOf f, ShapedOf (RankedOf f) ~ ShapedOf f ) + => Domains (RankedOf f) + -> (Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> Domains (RankedOf f) -> (f r y, f r y) cfwdOnDomains parameters f ds = - let deltaInputs = generateDeltaInputsOD parameters + let deltaInputs = generateDeltaInputs parameters inputs = makeADInputs parameters deltaInputs in cfwdOnADInputs inputs f ds -type DualClown dynamic = Flip (Dual (Clown dynamic)) '() - --- Actually, this is fully general, not only working for DomainsOD. -generateDeltaInputsOD - :: forall ranked shaped dynamic. - ( dynamic ~ DynamicOf ranked, ConvertTensor ranked shaped - , Dual (Clown dynamic) ~ DeltaD (Clown dynamic) ranked shaped ) - => Domains dynamic - -> Domains (DualClown dynamic) -{-# INLINE generateDeltaInputsOD #-} -generateDeltaInputsOD params = - let arrayToInput :: Int - -> DynamicExists dynamic - -> DynamicExists (DualClown dynamic) - arrayToInput i (DynamicExists @r t) = - withListShape (dshape @ranked t) $ \ (sh :: Shape n Int) -> - DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n - sh (toInputId i) - in V.imap arrayToInput params -{- TODO: this can't be specified without a proxy, so we inline instead -{-# SPECIALIZE generateDeltaInputs - :: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-} --} - --- This is preferred for AstDynamic, because it results in shorter terms. -generateDeltaInputsAst - :: forall ranked shaped dynamic s. - ( dynamic ~ AstDynamic s - , Dual (Clown dynamic) ~ DeltaD (Clown dynamic) ranked shaped ) - => Domains dynamic - -> Domains (DualClown dynamic) -{-# INLINE generateDeltaInputsAst #-} -generateDeltaInputsAst params = - let arrayToInput :: Int - -> DynamicExists dynamic - -> DynamicExists (DualClown dynamic) - arrayToInput i (DynamicExists @r d) = case d of - AstRToD @n w -> - DynamicExists $ Flip $ RToD $ InputR @ranked @shaped @r @n - (shapeAst w) (toInputId i) - AstSToD @sh _w -> - DynamicExists $ Flip $ SToD $ InputS @ranked @shaped @r @sh - (toInputId i) - in V.imap arrayToInput params +generateDeltaInputs + :: forall ranked shaped. + (RankedTensor ranked, shaped ~ ShapedOf ranked) + => Domains ranked + -> Domains (Dual ranked) +{-# INLINE generateDeltaInputs #-} +generateDeltaInputs = + let f :: Int -> DynamicTensor ranked -> DynamicTensor (Dual ranked) + f i (DynamicRanked @r @n t) = + case rshape t of + (sh :: Shape n2 Int) | Just Refl <- sameNat (Proxy @n) (Proxy @n2) -> + DynamicRanked $ InputR @ranked @shaped @r @n sh (toInputId i) + _ -> error "generateDeltaInputs: wrong rank" + f i (DynamicShaped @r @sh _) = + DynamicShaped $ InputS @ranked @shaped @r @sh (toInputId i) + f i (DynamicRankedDummy @r @sh _ _) = + withListShape (Sh.shapeT @sh) $ \(sh :: Shape n Int) -> + DynamicRanked $ InputR @ranked @shaped @r @n sh (toInputId i) + f i (DynamicShapedDummy @r @sh _ _) = + DynamicShaped $ InputS @ranked @shaped @r @sh (toInputId i) + in V.imap f {- TODO: this can't be specified without a proxy, so we inline instead {-# SPECIALIZE generateDeltaInputs :: DomainsOD -> Data.Vector.Vector (Dual OD.Array Double) #-} -} makeADInputs - :: Domains dynamic - -> Domains (DualClown dynamic) - -> Domains (ADValClown dynamic) + :: forall ranked. ShapedOf (Dual ranked) ~ Dual (ShapedOf ranked) + => Domains ranked + -> Domains (Dual ranked) + -> Domains (ADVal ranked) {-# INLINE makeADInputs #-} makeADInputs = - V.zipWith (\(DynamicExists @r p) - (DynamicExists @r2 d) -> - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> DynamicExists - $ Flip $ dDnotShared emptyADShare (Clown p) $ runFlip d - _ -> error "makeADInputs: type mismatch") + let f :: DynamicTensor ranked -> DynamicTensor (Dual ranked) + -> DynamicTensor (ADVal ranked) + f (DynamicRanked @r @n t) (DynamicRanked @r2 @n2 d) + | Just Refl <- sameNat (Proxy @n) (Proxy @n2) + , Just Refl <- testEquality (typeRep @r) (typeRep @r2) = + DynamicRanked $ dDnotShared emptyADShare t d + f (DynamicShaped @r @sh t) (DynamicShaped @r2 @sh2 d) + | Just Refl <- sameShape @sh @sh2 + , Just Refl <- testEquality (typeRep @r) (typeRep @r2) = + DynamicShaped $ dDnotShared emptyADShare t d + f _ _ = error "makeADInputs: non-matching arguments" + in V.zipWith f -- * Reverse and forward derivative stages class and instances @@ -437,21 +392,21 @@ type DerivativeStages :: forall k. TensorKind k -> Constraint class DerivativeStages g where forwardPassByInterpretation :: (GoodScalar r, HasSingletonDict y) - => (Domains (DynamicOf g) -> g r y) + => (Domains (RankedOf g) -> g r y) -> AstEnv (ADVal (RankedOf (PrimalOf g))) (ADVal (ShapedOf (PrimalOf g))) - -> Domains (DynamicOf (PrimalOf g)) + -> Domains (RankedOf (PrimalOf g)) -> [AstDynamicVarName] - -> Domains (DynamicOf g) + -> Domains (RankedOf g) -> ADVal (PrimalOf g) r y revArtifactFromForwardPass :: (GoodScalar r, HasSingletonDict y) => TensorFunctor g -> Bool -> Bool - -> (Domains (DynamicOf (PrimalOf g)) + -> (Domains (RankedOf (PrimalOf g)) -> [AstDynamicVarName] - -> Domains (DynamicOf g) + -> Domains (RankedOf g) -> ADVal (PrimalOf g) r y) -> DomainsOD -> (AstArtifactRev (PrimalOf g) r y, Dual (PrimalOf g) r y) @@ -459,7 +414,7 @@ class DerivativeStages g where revProduceArtifact :: (GoodScalar r, HasSingletonDict y) => TensorFunctor g -> Bool -> Bool - -> (Domains (DynamicOf g) -> g r y) + -> (Domains (RankedOf g) -> g r y) -> AstEnv (ADVal (RankedOf (PrimalOf g))) (ADVal (ShapedOf (PrimalOf g))) -> DomainsOD @@ -477,9 +432,9 @@ class DerivativeStages g where fwdArtifactFromForwardPass :: forall r y. (GoodScalar r, HasSingletonDict y) => TensorFunctor g - -> (Domains (DynamicOf (PrimalOf g)) + -> (Domains (RankedOf (PrimalOf g)) -> [AstDynamicVarName] - -> Domains (DynamicOf g) + -> Domains (RankedOf g) -> ADVal (PrimalOf g) r y) -> DomainsOD -> (AstArtifactFwd (PrimalOf g) r y, Dual (PrimalOf g) r y) @@ -491,7 +446,7 @@ class DerivativeStages g where fwdProduceArtifact :: (DerivativeStages g, GoodScalar r, HasSingletonDict y) - => TensorFunctor g -> (Domains (DynamicOf g) -> g r y) + => TensorFunctor g -> (Domains (RankedOf g) -> g r y) -> AstEnv (ADVal (RankedOf (PrimalOf g))) (ADVal (ShapedOf (PrimalOf g))) -> DomainsOD @@ -504,14 +459,15 @@ class DerivativeStages g where type UnletGradient :: forall k. TensorKind k -> Constraint class UnletGradient g where unletGradient - :: ADShare -> AstBindingsD (DynamicOf g) -> Domains (DynamicOf g) - -> DomainsOf g + :: ADShare -> AstBindingsD (RankedOf g) -> Domains (RankedOf g) + -> DomainsOf (RankedOf g) unletValue :: (GoodScalar r, HasSingletonDict y) - => ADShare -> AstBindingsD (DynamicOf g) -> g r y + => ADShare -> AstBindingsD (RankedOf g) -> g r y -> g r y -instance UnletGradient (ADVal f) where +instance RankedOf (RankedOf f) ~ RankedOf f + => UnletGradient (ADVal f) where unletGradient l astBindings gradient = assert (nullADShare l && null astBindings) gradient unletValue l astBindings primalBody = diff --git a/src/HordeAd/Core/Engine.hs b/src/HordeAd/Core/Engine.hs index e67587164..e4e2204d3 100644 --- a/src/HordeAd/Core/Engine.hs +++ b/src/HordeAd/Core/Engine.hs @@ -15,15 +15,12 @@ module HordeAd.Core.Engine , crev, crevDt -- * Old derivative adaptors , cfwd - -- * Additional common mechanisms - , shapedToRanked -- * Re-exported for tests , interpretAst ) where import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS @@ -61,16 +58,16 @@ import HordeAd.Core.Types rev :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals, Value vals ~ vals ) => (astvals -> g r y) -> vals -> vals rev f vals = revDtMaybe f vals Nothing {- TODO: RULE left-hand side too complicated to desugar {-# SPECIALIZE rev :: ( HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals, Value vals ~ vals ) => (astvals -> AstRanked FullSpan Double y) -> vals -> vals #-} @@ -80,16 +77,16 @@ rev f vals = revDtMaybe f vals Nothing revDt :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals, Value vals ~ vals ) => (astvals -> g r y) -> vals -> ConcreteOf g r y -> vals revDt f vals dt = revDtMaybe f vals (Just dt) {- TODO: RULE left-hand side too complicated to desugar {-# SPECIALIZE revDt :: ( HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals, Value vals ~ vals ) => (astvals -> AstRanked FullSpan Double y) -> vals -> Flip OR.Array Double y -> vals #-} @@ -98,8 +95,8 @@ revDt f vals dt = revDtMaybe f vals (Just dt) revDtMaybe :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals, Value vals ~ vals ) => (astvals -> g r y) -> vals -> Maybe (ConcreteOf g) r y) -> vals {-# INLINE revDtMaybe #-} @@ -115,14 +112,14 @@ revDtMaybe f vals mdt = rev :: forall r y g astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array (Value astvals) ) + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) (Value astvals) ) => (astvals -> g r y) -> Value astvals -> Value astvals rev f vals = revDtMaybe f vals Nothing {-# SPECIALIZE rev :: ( HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array (Value astvals) ) + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) (Value astvals) ) => (astvals -> AstRanked FullSpan Double y) -> Value astvals -> Value astvals #-} @@ -130,15 +127,15 @@ rev f vals = revDtMaybe f vals Nothing revDt :: forall r y g astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array (Value astvals) ) + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) (Value astvals) ) => (astvals -> g r y) -> Value astvals -> ConcreteOf g r y -> Value astvals revDt f vals dt = revDtMaybe f vals (Just dt) {-# SPECIALIZE revDt :: ( HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array (Value astvals) ) + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) (Value astvals) ) => (astvals -> AstRanked FullSpan Double y) -> Value astvals -> Flip OR.Array Double y -> Value astvals #-} @@ -146,8 +143,8 @@ revDt f vals dt = revDtMaybe f vals (Just dt) revDtMaybe :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals ) => (astvals -> g r y) -> vals -> Maybe (ConcreteOf g r y) -> vals {-# INLINE revDtMaybe #-} @@ -163,8 +160,8 @@ revDtMaybe f vals mdt = revArtifactAdapt :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals ) => Bool -> (astvals -> g r y) -> vals -> (AstArtifactRev (PrimalOf g) r y, Dual (PrimalOf g) r y) @@ -174,8 +171,8 @@ revArtifactAdapt hasDt f vals = in revProduceArtifact TensorFunctor True hasDt g EM.empty domainsOD {-# SPECIALIZE revArtifactAdapt :: ( HasSingletonDict y - , AdaptableDomains (AstDynamic FullSpan) astvals - , AdaptableDomains OD.Array (Value astvals) ) + , AdaptableDomains (AstRanked FullSpan) astvals + , AdaptableDomains (Flip OR.Array) (Value astvals) ) => Bool -> (astvals -> AstRanked FullSpan Double y) -> Value astvals -> ( AstArtifactRev (AstRanked PrimalSpan) Double y , Dual (AstRanked PrimalSpan) Double y ) #-} @@ -185,7 +182,7 @@ revProduceArtifactWithoutInterpretation ( g ~ AstRanked FullSpan -- needed, because PrimalOf not injective , DerivativeStages g, GoodScalar r, HasSingletonDict y ) => TensorFunctor g -> Bool - -> (Domains (DynamicOf (ADVal (PrimalOf g))) + -> (Domains (ADVal (RankedOf (PrimalOf g))) -> ADVal (PrimalOf g) r y) -> DomainsOD -> (AstArtifactRev (PrimalOf g) r y, Dual (PrimalOf g) r y) @@ -194,25 +191,21 @@ revProduceArtifactWithoutInterpretation tf hasDt g = revArtifactFromForwardPass @Nat @g TensorFunctor True hasDt (forwardPassByApplication tf g) --- The commented out version is more general, but less performant. forwardPassByApplication - :: forall g r y dynamic. - ( -- dynamic ~ DynamicOf (PrimalOf g) - -- , ConvertTensor (PrimalOf g) (ShapedOf (PrimalOf g)) - dynamic ~ AstDynamic PrimalSpan -- needed for generateDeltaInputsAst - , Dual (Clown dynamic) - ~ DeltaD (Clown dynamic) (PrimalOf g) (ShapedOf (PrimalOf g)) ) + :: forall g r y. + ( RankedTensor (RankedOf (PrimalOf g)) + , ShapedOf (RankedOf (PrimalOf g)) ~ ShapedOf (PrimalOf g) + , RankedOf (ShapedOf (PrimalOf g)) ~ RankedOf (PrimalOf g) ) => TensorFunctor g - -> (Domains (DynamicOf (ADVal (PrimalOf g))) + -> (Domains (ADVal (RankedOf (PrimalOf g))) -> ADVal (PrimalOf g) r y) - -> Domains (DynamicOf (PrimalOf g)) + -> Domains (RankedOf (PrimalOf g)) -> [AstDynamicVarName] - -> Domains (DynamicOf g) + -> Domains (RankedOf g) -> ADVal (PrimalOf g) r y {-# INLINE forwardPassByApplication #-} forwardPassByApplication _ g domainsPrimal _ _ = --- let deltaInputs = generateDeltaInputsOD @(PrimalOf g) domainsPrimal - let deltaInputs = generateDeltaInputsAst domainsPrimal + let deltaInputs = generateDeltaInputs domainsPrimal varInputs = makeADInputs domainsPrimal deltaInputs in g varInputs @@ -228,8 +221,8 @@ forwardPassByApplication _ g domainsPrimal _ _ = fwd :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals ) => (astvals -> g r y) -> vals -> vals -> ConcreteOf g r y fwd f x ds = @@ -241,8 +234,8 @@ fwd f x ds = fwdArtifactAdapt :: forall r y g vals astvals. ( DerivativeStages g, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf g) astvals - , AdaptableDomains OD.Array vals + , AdaptableDomains (RankedOf g) astvals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value astvals ) => (astvals -> g r y) -> vals -> (AstArtifactFwd (PrimalOf g) r y, Dual (PrimalOf g) r y) @@ -263,9 +256,8 @@ fwdArtifactAdapt f vals = crev :: forall r y f vals advals. ( DualPart f, GoodScalar r, HasSingletonDict y - , DynamicOf f ~ OD.Array , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains OD.Array vals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value advals, Value vals ~ vals ) => (advals -> ADVal f r y) -> vals -> vals crev f vals = crevDtMaybe f vals Nothing @@ -274,9 +266,9 @@ crev f vals = crevDtMaybe f vals Nothing crevDt :: forall r y f vals advals. ( DualPart f, GoodScalar r, HasSingletonDict y - , DynamicOf f ~ OD.Array + , DynamicOf f ~ (Flip OR.Array) , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains OD.Array vals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value advals, Value vals ~ vals ) => (advals -> ADVal f r y) -> vals -> f r y -> vals crevDt f vals dt = crevDtMaybe f vals (Just dt) @@ -284,9 +276,9 @@ crevDt f vals dt = crevDtMaybe f vals (Just dt) crevDtMaybe :: forall r y f vals advals. ( DualPart f, GoodScalar r, HasSingletonDict y - , DynamicOf f ~ OD.Array + , DynamicOf f ~ (Flip OR.Array) , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains OD.Array vals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value advals, Value vals ~ vals ) => (advals -> ADVal f r y) -> vals -> Maybe (f r y) -> vals {-# INLINE crevDtMaybe #-} @@ -302,17 +294,15 @@ crevDtMaybe f vals mdt = crev :: forall r y f advals. ( DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y - , DynamicOf f ~ OD.Array - , DomainsOf f ~ Domains (DynamicOf f) , RankedOf f ~ Flip OR.Array, ShapedOf f ~ Flip OS.Array - , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains OD.Array (Value advals) ) + , AdaptableDomains (ADVal (RankedOf f)) advals + , AdaptableDomains (Flip OR.Array) (Value advals) ) => (advals -> ADVal f r y) -> Value advals -> Value advals crev f vals = crevDtMaybe f vals Nothing {-# SPECIALIZE crev :: ( HasSingletonDict y - , AdaptableDomains (DynamicOf (ADVal (Flip OR.Array))) advals - , AdaptableDomains OD.Array (Value advals) ) + , AdaptableDomains (ADVal (Flip OR.Array)) advals + , AdaptableDomains (Flip OR.Array) (Value advals) ) => (advals -> ADVal (Flip OR.Array) Double y) -> Value advals -> Value advals #-} @@ -320,20 +310,19 @@ crev f vals = crevDtMaybe f vals Nothing -- | This version additionally takes the sensitivity parameter. crevDt :: forall r y f advals. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , Dual (Clown (DynamicOf f)) - ~ DeltaD (Clown (DynamicOf f)) (RankedOf f) (ShapedOf f) - , DomainsOf f ~ Domains (DynamicOf f) + ( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f)) + , RankedOf (ShapedOf f) ~ RankedOf f + , ShapedOf (RankedOf f) ~ ShapedOf f , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains (DynamicOf f) (Value advals) ) + , DomainsOf (RankedOf f) ~ Domains (RankedOf f) + , AdaptableDomains (ADVal (RankedOf f)) advals + , AdaptableDomains (RankedOf f) (Value advals) ) => (advals -> ADVal f r y) -> Value advals -> f r y -> Value advals crevDt f vals dt = crevDtMaybe f vals (Just dt) {-# SPECIALIZE crevDt :: ( HasSingletonDict y - , AdaptableDomains (DynamicOf (ADVal (Flip OR.Array))) advals - , AdaptableDomains OD.Array (Value advals) ) + , AdaptableDomains (ADVal (Flip OR.Array)) advals + , AdaptableDomains (Flip OR.Array) (Value advals) ) => (advals -> ADVal (Flip OR.Array) Double y) -> Value advals -> Flip OR.Array Double y @@ -341,36 +330,34 @@ crevDt f vals dt = crevDtMaybe f vals (Just dt) crevDtMaybe :: forall r y f vals advals. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , Dual (Clown (DynamicOf f)) - ~ DeltaD (Clown (DynamicOf f)) (RankedOf f) (ShapedOf f) - , DomainsOf f ~ Domains (DynamicOf f) + ( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f)) + , RankedOf (ShapedOf f) ~ RankedOf f + , ShapedOf (RankedOf f) ~ ShapedOf f , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y - , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains (DynamicOf f) vals + , DomainsOf (RankedOf f) ~ Domains (RankedOf f) + , AdaptableDomains (ADVal (RankedOf f)) advals + , AdaptableDomains (RankedOf f) vals , vals ~ Value advals ) => (advals -> ADVal f r y) -> vals -> Maybe (f r y) -> vals {-# INLINE crevDtMaybe #-} crevDtMaybe f vals mdt = gcastWith (unsafeCoerce Refl :: Value vals :~: vals) $ -- !!! let g inputs = f $ parseDomains vals inputs - in parseDomains vals $ fst $ crevOnDomains True mdt g (toDomains vals) + in parseDomains vals + $ fst $ crevOnDomains True mdt g (toDomains vals) {-# SPECIALIZE crevOnDomains :: HasSingletonDict y => Bool -> Maybe (Flip OR.Array Double y) - -> (Domains (DynamicOf (ADVal (Flip OR.Array))) - -> ADVal (Flip OR.Array) Double y) + -> (Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y) -> DomainsOD -> (DomainsOD, Flip OR.Array Double y) #-} {-# SPECIALIZE crevOnADInputs :: HasSingletonDict y => Bool -> Maybe (Flip OR.Array Double y) - -> (Domains (DynamicOf (ADVal (Flip OR.Array))) - -> ADVal (Flip OR.Array) Double y) - -> Domains (DynamicOf (ADVal (Flip OR.Array))) + -> (Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) Double y) + -> Domains (ADVal (Flip OR.Array)) -> (DomainsOD, Flip OR.Array Double y) #-} @@ -380,10 +367,9 @@ crevDtMaybe f vals mdt = cfwd :: forall r y f vals advals. ( DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y - , DynamicOf f ~ OD.Array , RankedOf f ~ Flip OR.Array, ShapedOf f ~ Flip OS.Array - , AdaptableDomains (DynamicOf (ADVal f)) advals - , AdaptableDomains OD.Array vals + , AdaptableDomains (ADVal (RankedOf f)) advals + , AdaptableDomains (Flip OR.Array) vals , vals ~ Value advals ) => (advals -> ADVal f r y) -> vals -> vals -> f r y @@ -392,17 +378,6 @@ cfwd f x ds = in fst $ cfwdOnDomains (toDomains x) g (toDomains ds) --- * Additional common mechanisms - -shapedToRanked - :: forall vals svals dynamic. - ( dynamic ~ OD.Array, NoShape svals ~ vals, Value vals ~ vals - , AdaptableDomains dynamic vals - , AdaptableDomains dynamic svals, ForgetShape svals ) - => svals -> vals -shapedToRanked svals = - parseDomains @dynamic (forgetShape svals) $ toDomains @dynamic svals - @@ -853,12 +828,12 @@ shapedToRanked svals = :: AstSpan s => AstEnv (ADVal (Flip OR.Array)) (ADVal (Flip OS.Array)) -> AstDomains s - -> Domains (DynamicOf (ADVal (Flip OR.Array))) #-} + -> Domains (ADVal (Flip OR.Array)) #-} {-# SPECIALIZE interpretAstDomains :: AstSpan s => AstEnv (ADVal (AstRanked PrimalSpan)) (ADVal (AstShaped PrimalSpan)) -> AstDomains s - -> Domains (DynamicOf (ADVal (AstRanked PrimalSpan))) #-} + -> Domains (ADVal (AstRanked PrimalSpan)) #-} {-# SPECIALIZE interpretAstDomains :: AstSpan s => AstEnv (Flip OR.Array) (Flip OS.Array) diff --git a/src/HordeAd/Core/TensorADVal.hs b/src/HordeAd/Core/TensorADVal.hs index 671dd8e80..4df28f43a 100644 --- a/src/HordeAd/Core/TensorADVal.hs +++ b/src/HordeAd/Core/TensorADVal.hs @@ -9,12 +9,11 @@ -- a middle layer such as "DualClass", separate instances are given -- for ranked tensors and shaped tensors. module HordeAd.Core.TensorADVal - ( CRankedIP, CRankedIPSh, CRankedIPU + ( CRankedIP, CRankedIPSh ) where import Prelude hiding (foldl') -import qualified Data.Array.DynamicS as OD import Data.Array.Internal (valueOf) import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh @@ -24,14 +23,12 @@ import Data.Bifunctor.Flip import Data.Bifunctor.Product import Data.Function ((&)) import Data.Functor.Const -import Data.Kind (Constraint, Type) import Data.List (foldl') import Data.List.Index (imap) import Data.Proxy (Proxy (Proxy)) -import Data.Type.Equality (testEquality, (:~:) (Refl)) +import Data.Type.Equality ((:~:) (Refl)) import qualified Data.Vector.Generic as V import GHC.TypeLits (KnownNat, sameNat, type (+)) -import Type.Reflection (typeRep) import HordeAd.Core.Adaptor import HordeAd.Core.Ast @@ -40,8 +37,7 @@ import HordeAd.Core.DualClass import HordeAd.Core.DualNumber import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances - (matchingRank, sameShape) +import HordeAd.Internal.OrthotopeOrphanInstances (sameShape) import HordeAd.Internal.TensorOps import HordeAd.Util.ShapedList (singletonShaped) import qualified HordeAd.Util.ShapedList as ShapedList @@ -50,12 +46,10 @@ import HordeAd.Util.SizedIndex -- * Ranked tensor instances instance ( KnownNat n, GoodScalar r - , dynamic ~ DynamicOf ranked, RankedOf shaped ~ ranked + , RankedOf shaped ~ ranked , Dual ranked ~ DeltaR ranked shaped - , Dual (Clown dynamic) ~ DeltaD (Clown dynamic) ranked shaped - , RankedTensor (ADVal ranked), ConvertTensor ranked shaped - , CRankedIPU (Clown dynamic) IsPrimal ) - => AdaptableDomains (ADValClown dynamic) + , RankedTensor (ADVal ranked), ConvertTensor ranked shaped ) + => AdaptableDomains (ADVal ranked) (ADVal ranked r n) where {- TODO: RULE left-hand side too complicated to desugar {-# SPECIALIZE instance @@ -71,29 +65,20 @@ instance ( KnownNat n, GoodScalar r -} type Value (ADVal ranked r n) = Flip OR.Array r n -- ! not Value(ranked) toDomains = undefined - fromDomains aInit inputs = case V.uncons inputs of - Just (DynamicExists @r2 a, rest) -> - if dIsDummy @(ADVal ranked) a - then Just (rzero (rshape aInit), rest) - else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !aR = dToR @r (runFlip a) - in Just (aR, rest) - _ -> error "fromDomains: type mismatch" - Nothing -> Nothing + fromDomains _aInit params = fromDomainsR @r @n params -- This is temporarily moved from Adaptor in order to specialize manually -instance AdaptableDomains dynamic a - => AdaptableDomains dynamic [a] where +instance AdaptableDomains ranked a + => AdaptableDomains ranked [a] where {-# SPECIALIZE instance - (KnownNat n, AdaptableDomains OD.Array (OR.Array n Double)) - => AdaptableDomains OD.Array + (KnownNat n, AdaptableDomains (Flip OR.Array) (OR.Array n Double)) + => AdaptableDomains (Flip OR.Array) [OR.Array n Double] #-} {-# SPECIALIZE instance ( KnownNat n, AstSpan s - , AdaptableDomains (AstDynamic s) + , AdaptableDomains (AstRanked s) (AstRanked s Double n) ) - => AdaptableDomains (AstDynamic s) + => AdaptableDomains (AstRanked s) [AstRanked s Double n] #-} {- TODO: RULE left-hand side too complicated to desugar {-# SPECIALIZE instance @@ -122,24 +107,6 @@ instance AdaptableDomains dynamic a -- > let f = swap . flip fromDomains -- > in swap $ mapAccumL f source lInit -dToR :: forall r ranked shaped n. - ( ConvertTensor ranked shaped - , Dual ranked ~ DeltaR ranked shaped - , Dual (Clown (DynamicOf ranked)) - ~ DeltaD (Clown (DynamicOf ranked)) ranked shaped - , KnownNat n, GoodScalar r ) - => ADVal (Clown (DynamicOf ranked)) r '() -> ADVal ranked r n -dToR (D l u u') = dDnotShared l (rfromD $ runClown u) (dDToR u') - where - dDToR (RToD @n1 d) = - case sameNat (Proxy @n1) (Proxy @n) of - Just Refl -> d - _ -> error "dToR: different ranks in DToR(RToD)" - dDToR (SToD @sh1 d) = - case matchingRank @sh1 @n of - Just Refl -> SToR d - _ -> error "dToR: different ranks in DToR(SToD)" - -- Note that these instances don't do vectorization. To enable it, -- use the Ast instance and only then interpret in ADVal. -- In any case, only the Ast instantiation of this instance @@ -155,12 +122,7 @@ instance ( Dual ranked ~ DeltaR ranked shaped , PrimalOf ranked ~ RankedOf (PrimalOf ranked) , RankedOf shaped ~ ranked , ranked ~ RankedOf shaped - , RankedOf @() (Clown (DynamicOf ranked)) ~ ranked - , ranked ~ RankedOf @() (Clown (DynamicOf ranked)) - , ShapedOf @() (Clown (DynamicOf ranked)) ~ shaped - , shaped ~ ShapedOf @() (Clown (DynamicOf ranked)) , CRankedIP ranked IsPrimal - , CRankedIPU (Clown (DynamicOf ranked)) IsPrimal , RankedTensor ranked, ConvertTensor ranked shaped ) => RankedTensor (ADVal ranked) where rlet (D l u u') f = @@ -237,17 +199,6 @@ instance ( Dual ranked ~ DeltaR ranked shaped rconst t = constantADVal (rconst t) rletDomainsIn _ = (&) - raddDynamic :: forall r n. (GoodScalar r, KnownNat n) - => ADVal ranked r n - -> DynamicExists (ADValClown (DynamicOf ranked)) - -> DynamicExists (ADValClown (DynamicOf ranked)) - raddDynamic r (DynamicExists - @r2 d@(Flip (D _ (Clown dd) _))) = DynamicExists $ - if dIsDummy @ranked dd then dfromR r - else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> dfromR $ r + rfromD d - _ -> error "raddDynamic: type mismatch" - rconstant t = let (l, r) = rletUnwrap t in dDnotShared l r (dZeroOfShape r) rprimalPart (D l u _) = rletWrap l u rdualPart (D l _ u') = Pair (Clown (Const l)) u' @@ -262,41 +213,14 @@ instance ( Dual ranked ~ DeltaR ranked shaped -- * Shaped tensor instances instance ( Sh.Shape sh, GoodScalar r - , dynamic ~ DynamicOf shaped, ShapedOf ranked ~ shaped + , ShapedOf ranked ~ shaped , Dual shaped ~ DeltaS ranked shaped - , Dual (Clown dynamic) ~ DeltaD (Clown dynamic) ranked shaped - , ShapedTensor (ADVal shaped), ConvertTensor ranked shaped - , CRankedIPU (Clown dynamic) IsPrimal ) - => AdaptableDomains (ADValClown dynamic) + , ShapedTensor (ADVal shaped), ConvertTensor ranked shaped ) + => AdaptableDomains (ADVal ranked) (ADVal shaped r sh) where type Value (ADVal shaped r sh) = Flip OS.Array r sh -- ! not Value(shaped) toDomains = undefined - fromDomains _aInit inputs = case V.uncons inputs of - Just (DynamicExists @r2 a, rest) -> - if dIsDummy @(ADVal ranked) a then Just (0, rest) else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !aS = dToS @r (runFlip a) - in Just (aS, rest) - _ -> error "fromDomains: type mismatch" - Nothing -> Nothing - -dToS :: forall r ranked shaped sh. - ( ConvertTensor ranked shaped - , Dual shaped ~ DeltaS ranked shaped - , Dual (Clown (DynamicOf ranked)) - ~ DeltaD (Clown (DynamicOf ranked)) ranked shaped - , Sh.Shape sh, GoodScalar r ) - => ADVal (Clown (DynamicOf ranked)) r '() -> ADVal shaped r sh -dToS (D l u u') = dDnotShared l (sfromD $ runClown u) (dDToS u') - where - dDToS (SToD @sh1 d) = - case sameShape @sh1 @sh of - Just Refl -> d - _ -> error "dToS: different ranks in DToS(SToD)" - dDToS (RToD @n1 d) = - case matchingRank @sh @n1 of - Just Refl -> RToS d - _ -> error "dToS: different ranks in DToS(RToD)" + fromDomains _aInit params = fromDomainsS @r @sh params -- Note that these instances don't do vectorization. To enable it, -- use the Ast instance and only then interpret in ADVal. @@ -309,16 +233,9 @@ instance ( Dual shaped ~ DeltaS ranked shaped , DeltaS ranked shaped ~ Dual shaped , RankedOf (PrimalOf shaped) ~ PrimalOf ranked , PrimalOf ranked ~ RankedOf (PrimalOf shaped) - , DynamicOf ranked ~ DynamicOf shaped - , DynamicOf shaped ~ DynamicOf ranked , ShapedOf ranked ~ shaped , shaped ~ ShapedOf ranked - , RankedOf @() (Clown (DynamicOf ranked)) ~ ranked - , ranked ~ RankedOf @() (Clown (DynamicOf ranked)) - , ShapedOf @() (Clown (DynamicOf ranked)) ~ shaped - , shaped ~ ShapedOf @() (Clown (DynamicOf ranked)) , CRankedIPSh shaped IsPrimal - , CRankedIPU (Clown (DynamicOf ranked)) IsPrimal , RankedTensor ranked, ShapedTensor shaped , ConvertTensor ranked shaped ) => ShapedTensor (ADVal shaped) where @@ -398,17 +315,6 @@ instance ( Dual shaped ~ DeltaS ranked shaped sconst t = constantADVal (sconst t) sletDomainsIn _ = (&) - saddDynamic :: forall r sh. (GoodScalar r, Sh.Shape sh) - => ADVal shaped r sh - -> DynamicExists (ADValClown (DynamicOf shaped)) - -> DynamicExists (ADValClown (DynamicOf shaped)) - saddDynamic r (DynamicExists - @r2 d@(Flip (D _ (Clown dd) _))) = DynamicExists $ - if dIsDummy @ranked dd then dfromS r - else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> dfromS $ r + sfromD d - _ -> error "saddDynamic: type mismatch" - sconstant t = let (l, r) = sletUnwrap t in dDnotShared l r (dZeroOfShape t) sprimalPart (D l u _) = sletWrap l u sdualPart (D l _ u') = Pair (Clown (Const l)) u' @@ -422,22 +328,11 @@ instance ( Dual shaped ~ DeltaS ranked shaped -- * ConvertTensor and DomainsTensor instances -type CRankedIPU :: TensorKind () - -> (TensorKind () -> Type -> () -> Constraint) - -> Constraint -class (forall r17. GoodScalar r17 => c ranked r17 '()) - => CRankedIPU ranked c where -instance (forall r17. GoodScalar r17 => c ranked r17 '()) - => CRankedIPU ranked c where - instance ( Dual ranked ~ DeltaR ranked shaped , Dual shaped ~ DeltaS ranked shaped - , Dual (Clown (DynamicOf ranked)) - ~ DeltaD (Clown (DynamicOf ranked)) ranked shaped - , ConvertTensor ranked shaped - , CRankedIPU (Clown (DynamicOf ranked)) IsPrimal ) + , RankedTensor (ADVal ranked), ShapedTensor (ADVal shaped) + , ConvertTensor ranked shaped ) => ConvertTensor (ADVal ranked) (ADVal shaped) where - rfromD = dToR . runFlip rfromS = sToR where sToR :: forall r sh. (GoodScalar r, Sh.Shape sh) @@ -446,23 +341,8 @@ instance ( Dual ranked ~ DeltaR ranked shaped where dSToR (RToS d) = d -- no information lost, so no checks dSToR d = SToR d - dfromR = Flip . rToD - where - rToD :: forall r n. (GoodScalar r, KnownNat n) - => ADVal ranked r n -> ADVal (Clown (DynamicOf ranked)) r '() - rToD (D l u u') = dDnotShared l (Clown $ dfromR u) (dRToD u') - where - dRToD (DToR d) = d -- no information lost, so no checks - dRToD d = RToD d - dfromS = Flip . sToD - where - sToD :: forall r sh. (GoodScalar r, Sh.Shape sh) - => ADVal shaped r sh -> ADVal (Clown (DynamicOf ranked)) r '() - sToD (D l u u') = dDnotShared l (Clown $ dfromS u) (dSToD u') - where - dSToD (DToS d) = d -- no information lost, so no checks - dSToD d = SToD d - sfromD = dToS . runFlip + dfromR = DynamicRanked + dfromS = DynamicShaped sfromR = rToS where rToS :: forall r sh. (GoodScalar r, Sh.Shape sh, KnownNat (Sh.Rank sh)) @@ -474,15 +354,13 @@ instance ( Dual ranked ~ DeltaR ranked shaped Just Refl -> d _ -> error "rToS: different shapes in RToS(SToR)" dRToS d = RToS d - ddummy = Flip (constantADVal (Clown (ddummy @ranked))) - dIsDummy (Flip (D _ (Clown d) _)) = dIsDummy @ranked d - dshape (Flip (D _ u _)) = dshape @ranked (runClown u) + dIsDummy DynamicRankedDummy{} = True + dIsDummy DynamicShapedDummy{} = True + dIsDummy _ = False + dshape = shapeDynamic instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) , UnletGradient ranked, UnletGradient shaped - , Dual (Clown (DynamicOf (ADVal ranked))) - ~ DeltaD (Clown (DynamicOf (ADVal ranked))) - (ADVal ranked) (ADVal shaped) , ShapedOf shaped ~ shaped ) => DomainsTensor (ADVal ranked) (ADVal shaped) where dmkDomains = id @@ -490,7 +368,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) rletInDomains = (&) sletInDomains = (&) rrev :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOf (ADVal ranked) -> DomainsOf (ADVal ranked) @@ -498,7 +376,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) -- This computes the derivative of f again for each new @parmeters@. fst $ crevOnDomains False Nothing (f @(ADVal (ADVal ranked))) parameters rrevDt :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOf (ADVal ranked) -> ADVal ranked r n @@ -506,7 +384,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) rrevDt f _parameters0 parameters dt = fst $ crevOnDomains False (Just dt) (f @(ADVal (ADVal ranked))) parameters rfwd :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOf (ADVal ranked) -> DomainsOf (ADVal ranked) @@ -528,25 +406,11 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) rfold f (D l1 x0 x0') (D l2 as as') = let shn = rshape x0 shm = tailShape $ rshape as - odFromSh :: forall rk k. GoodScalar rk - => ShapeInt k -> DynamicExists OD.Array - odFromSh sh = DynamicExists @rk $ OD.constant (shapeToList sh) 0 domsOD = V.fromList [odFromSh @rn shn, odFromSh @rm shm] domsToPair :: forall f. ADReady f - => Domains (DynamicOf f) -> (f rn n, f rm m) - domsToPair doms = - let d0 = case doms V.! 0 of - DynamicExists @rn2 ex - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) -> - rfromD ex - _ -> error "rfold: type mismatch" - d1 = case doms V.! 1 of - DynamicExists @rm2 ex - | Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - rfromD ex - _ -> error "rfold: type mismatch" - in (d0, d1) - g :: Domains (DynamicOf (ADVal ranked)) -> ADVal ranked rn n + => Domains f -> (f rn n, f rm m) + domsToPair doms = (rfromD $ doms V.! 0, rfromD $ doms V.! 1) + g :: Domains (ADVal ranked) -> ADVal ranked rn n g doms = uncurry (f @(ADVal ranked)) (domsToPair doms) -- This computes the derivative of f again for each new @x@ and @a@ -- (not even once for @as@, but for each @a@ separately). @@ -560,19 +424,13 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) df :: ranked rn n -> (ranked rm m, ranked rn n, ranked rm m) -> ranked rn n df cx (ca, x, a) = - fst $ cfwdOnDomains - (V.fromList [ DynamicExists @rn (dfromR x) - , DynamicExists @rm (dfromR a) ]) - g - (V.fromList [ DynamicExists @rn (dfromR cx) - , DynamicExists @rm (dfromR ca) ]) + fst $ cfwdOnDomains (V.fromList [dfromR x, dfromR a]) + g (V.fromList [dfromR cx, dfromR ca]) rf :: ranked rn n -> (ranked rn n, ranked rm m) -> (ranked rn n, ranked rm m) rf dt (x, a) = domsToPair $ dunDomains @ranked domsOD $ fst - $ crevOnDomains False (Just dt) g - (V.fromList [ DynamicExists @rn (dfromR x) - , DynamicExists @rm (dfromR a) ]) + $ crevOnDomains False (Just dt) g (V.fromList [dfromR x, dfromR a]) in D (l1 `mergeADShare` l2) (rfold @ranked f x0 as) (FoldR f x0 as df rf x0' as') @@ -592,9 +450,6 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) -- on @DomainsOf f@. let shn = rshape x0 shm = tailShape $ rshape as - odFromSh :: forall rk k. GoodScalar rk - => ShapeInt k -> DynamicExists OD.Array - odFromSh sh = DynamicExists @rk $ OD.constant (shapeToList sh) 0 domsOD = V.fromList [odFromSh @rn shn, odFromSh @rm shm] -- Note that this function, and similarly @f@ and @rf@ instantiated -- and passed to FoldR, is not a function on dual numbers. @@ -607,18 +462,10 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) let res = rf0 cx x a -- non-explicit sharing, so helps little in ( rletDomainsIn domsOD res - (\doms -> case doms V.! 0 of - DynamicExists @rn2 ex - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) -> - rfromD ex - _ -> error "rfoldDer: type mismatch") + (\doms -> rfromD $ doms V.! 0) , rletDomainsIn domsOD res - (\doms -> case doms V.! 1 of - DynamicExists @rm2 ea - | Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - rfromD ea - _ -> error "rfoldDer: type mismatch") + (\doms -> rfromD $ doms V.! 1) ) in D (l1 `mergeADShare` l2) (rfold @ranked f x0 as) @@ -630,42 +477,22 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) -> ADVal shaped rm (k ': shm) -> ADVal shaped rn sh sfold f (D l1 x0 x0') (D l2 as as') = - let odFromSh :: forall rk sh1. (GoodScalar rk, Sh.Shape sh1) - => DynamicExists OD.Array - odFromSh = DynamicExists @rk $ OD.constant (Sh.shapeT @sh1) 0 - domsOD = V.fromList [odFromSh @rn @sh, odFromSh @rm @shm] + let domsOD = V.fromList [odFromShS @rn @sh, odFromShS @rm @shm] domsToPair :: forall f. ADReadyS f - => Domains (DynamicOf f) -> (f rn sh, f rm shm) - domsToPair doms = - let d0 = case doms V.! 0 of - DynamicExists @rn2 ex - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) -> - sfromD ex - _ -> error "rfold: type mismatch" - d1 = case doms V.! 1 of - DynamicExists @rm2 ex - | Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - sfromD ex - _ -> error "rfold: type mismatch" - in (d0, d1) - g :: Domains (DynamicOf (ADVal shaped)) -> ADVal shaped rn sh + => Domains (RankedOf f) -> (f rn sh, f rm shm) + domsToPair doms = (sfromD $ doms V.! 0, sfromD $ doms V.! 1) + g :: Domains (ADVal (RankedOf shaped)) -> ADVal shaped rn sh g doms = uncurry (f @(ADVal shaped)) (domsToPair doms) df :: shaped rn sh -> (shaped rm shm, shaped rn sh, shaped rm shm) -> shaped rn sh df cx (ca, x, a) = - fst $ cfwdOnDomains - (V.fromList [ DynamicExists @rn (dfromS x) - , DynamicExists @rm (dfromS a) ]) - g - (V.fromList [ DynamicExists @rn (dfromS cx) - , DynamicExists @rm (dfromS ca) ]) + fst $ cfwdOnDomains (V.fromList [dfromS x, dfromS a]) + g (V.fromList [dfromS cx, dfromS ca]) rf :: shaped rn sh -> (shaped rn sh, shaped rm shm) -> (shaped rn sh, shaped rm shm) rf dt (x, a) = domsToPair $ dunDomains @ranked domsOD $ fst - $ crevOnDomains False (Just dt) g - (V.fromList [ DynamicExists @rn (dfromS x) - , DynamicExists @rm (dfromS a) ]) + $ crevOnDomains False (Just dt) g (V.fromList [dfromS x, dfromS a]) in D (l1 `mergeADShare` l2) (sfold @ranked f x0 as) (FoldS f x0 as df rf x0' as') @@ -682,10 +509,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) -> ADVal shaped rm (k ': shm) -> ADVal shaped rn sh sfoldDer f df0 rf0 (D l1 x0 x0') (D l2 as as') = - let odFromSh :: forall rk sh1. (GoodScalar rk, Sh.Shape sh1) - => DynamicExists OD.Array - odFromSh = DynamicExists @rk $ OD.constant (Sh.shapeT @sh1) 0 - domsOD = V.fromList [odFromSh @rn @sh, odFromSh @rm @shm] + let domsOD = V.fromList [odFromShS @rn @sh, odFromShS @rm @shm] -- Note that this function, and similarly @f@ and @rf@ instantiated -- and passed to FoldR, is not a function on dual numbers. df :: shaped rn sh -> (shaped rm shm, shaped rn sh, shaped rm shm) @@ -697,18 +521,10 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped) let res = rf0 cx x a -- non-explicit sharing, so helps little in ( sletDomainsIn domsOD res - (\doms -> case doms V.! 0 of - DynamicExists @rn2 ex - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) -> - sfromD ex - _ -> error "rfoldDer: type mismatch") + (\doms -> sfromD $ doms V.! 0) , sletDomainsIn domsOD res - (\doms -> case doms V.! 1 of - DynamicExists @rm2 ea - | Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - sfromD ea - _ -> error "rfoldDer: type mismatch") + (\doms -> sfromD $ doms V.! 1) ) in D (l1 `mergeADShare` l2) (sfold @ranked f x0 as) @@ -723,14 +539,14 @@ instance DomainsTensor (Flip OR.Array) (Flip OS.Array) where rletInDomains = (&) sletInDomains = (&) rrev :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOD -> DomainsOD rrev f _parameters0 parameters = fst $ crevOnDomains False Nothing (f @(ADVal (Flip OR.Array))) parameters rrevDt :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOD -> Flip OR.Array r n @@ -738,7 +554,7 @@ instance DomainsTensor (Flip OR.Array) (Flip OS.Array) where rrevDt f _parameters0 parameters dt = fst $ crevOnDomains False (Just dt) (f @(ADVal (Flip OR.Array))) parameters rfwd :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD -> DomainsOD -> DomainsOD diff --git a/src/HordeAd/Core/TensorAst.hs b/src/HordeAd/Core/TensorAst.hs index 91a96910c..e0195c3c5 100644 --- a/src/HordeAd/Core/TensorAst.hs +++ b/src/HordeAd/Core/TensorAst.hs @@ -12,20 +12,17 @@ module HordeAd.Core.TensorAst import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS -import Data.Bifunctor.Clown import Data.Bifunctor.Flip import qualified Data.EnumMap.Strict as EM import Data.Maybe (fromMaybe) import Data.Proxy (Proxy (Proxy)) -import Data.Type.Equality (testEquality, (:~:) (Refl)) +import Data.Type.Equality ((:~:) (Refl)) import qualified Data.Vector.Generic as V -import GHC.TypeLits (KnownNat, Nat, sameNat, type (+)) +import GHC.TypeLits (KnownNat, Nat, type (+)) import System.IO.Unsafe (unsafePerformIO) -import Type.Reflection (typeRep) import HordeAd.Core.Adaptor import HordeAd.Core.Ast @@ -42,8 +39,6 @@ import HordeAd.Core.DualNumber import HordeAd.Core.TensorADVal () import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances - (matchingRank, sameShape) import HordeAd.Util.ShapedList (singletonShaped) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedIndex @@ -63,7 +58,6 @@ instance (GoodScalar r, KnownNat n) => IsPrimal (AstRanked PrimalSpan) r n where recordSharing d = case d of ZeroR{} -> d InputR{} -> d - DToR{} -> d SToR{} -> d LetR{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaR d @@ -82,22 +76,10 @@ instance (GoodScalar r, Sh.Shape sh) recordSharing d = case d of ZeroS -> d InputS{} -> d - DToS{} -> d RToS{} -> d LetS{} -> d -- should not happen, but older/lower id is safer anyway _ -> wrapDeltaS d -instance GoodScalar r => IsPrimal (Clown (AstDynamic PrimalSpan)) r '() where - dZeroOfShape (Clown tsh) = - withListShape (dshape @(AstRanked PrimalSpan) tsh) - $ \ (sh :: Shape n Int) -> - RToD @n (ZeroR sh) - dScale = undefined - dAdd = undefined - intOfShape = undefined - recordSharingPrimal = undefined - recordSharing = undefined - -- * Reverse and forward derivative stages instances @@ -112,15 +94,15 @@ instance GoodScalar r => IsPrimal (Clown (AstDynamic PrimalSpan)) r '() where instance DerivativeStages (AstRanked FullSpan) where forwardPassByInterpretation :: (GoodScalar r, KnownNat n) - => (Domains (AstDynamic FullSpan) -> AstRanked FullSpan r n) + => (Domains (AstRanked FullSpan) -> AstRanked FullSpan r n) -> AstEnv (ADVal (AstRanked PrimalSpan)) (ADVal (AstShaped PrimalSpan)) - -> Domains (AstDynamic PrimalSpan) + -> Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstRanked PrimalSpan) r n {-# INLINE forwardPassByInterpretation #-} forwardPassByInterpretation g envInit domainsPrimal vars domains = - let deltaInputs = generateDeltaInputsAst domainsPrimal + let deltaInputs = generateDeltaInputs domainsPrimal varInputs = makeADInputs domainsPrimal deltaInputs ast = g domains env = foldr extendEnvD envInit $ zip vars $ V.toList varInputs @@ -129,9 +111,9 @@ instance DerivativeStages (AstRanked FullSpan) where revArtifactFromForwardPass :: forall r n. (GoodScalar r, KnownNat n) => TensorFunctor (AstRanked FullSpan) -> Bool -> Bool - -> (Domains (AstDynamic PrimalSpan) + -> (Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstRanked PrimalSpan) r n) -> DomainsOD -> ( AstArtifactRev (AstRanked PrimalSpan) r n @@ -173,9 +155,9 @@ instance DerivativeStages (AstRanked FullSpan) where fwdArtifactFromForwardPass :: forall r n. (GoodScalar r, KnownNat n) - => TensorFunctor (AstRanked FullSpan) -> (Domains (AstDynamic PrimalSpan) + => TensorFunctor (AstRanked FullSpan) -> (Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstRanked PrimalSpan) r n) -> DomainsOD -> ( AstArtifactFwd (AstRanked PrimalSpan) r n @@ -204,7 +186,7 @@ instance DerivativeStages (AstRanked FullSpan) where instance UnletGradient (AstRanked PrimalSpan) where unletGradient - :: ADShare -> AstBindings -> Domains (AstDynamic PrimalSpan) + :: ADShare -> AstBindings -> Domains (AstRanked PrimalSpan) -> AstDomains PrimalSpan unletGradient l astBindings gradient = unletAstDomains6 astBindings l @@ -219,15 +201,15 @@ instance UnletGradient (AstRanked PrimalSpan) where instance DerivativeStages (AstShaped FullSpan) where forwardPassByInterpretation :: (GoodScalar r, Sh.Shape sh) - => (Domains (AstDynamic FullSpan) -> AstShaped FullSpan r sh) + => (Domains (AstRanked FullSpan) -> AstShaped FullSpan r sh) -> AstEnv (ADVal (AstRanked PrimalSpan)) (ADVal (AstShaped PrimalSpan)) - -> Domains (AstDynamic PrimalSpan) + -> Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstShaped PrimalSpan) r sh {-# INLINE forwardPassByInterpretation #-} forwardPassByInterpretation g envInit domainsPrimal vars domains = - let deltaInputs = generateDeltaInputsAst domainsPrimal + let deltaInputs = generateDeltaInputs domainsPrimal varInputs = makeADInputs domainsPrimal deltaInputs ast = g domains env = foldr extendEnvD envInit $ zip vars $ V.toList varInputs @@ -236,9 +218,9 @@ instance DerivativeStages (AstShaped FullSpan) where revArtifactFromForwardPass :: forall r sh. (GoodScalar r, Sh.Shape sh) => TensorFunctor (AstShaped FullSpan) -> Bool -> Bool - -> (Domains (AstDynamic PrimalSpan) + -> (Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstShaped PrimalSpan) r sh) -> DomainsOD -> ( AstArtifactRev (AstShaped PrimalSpan) r sh @@ -270,9 +252,9 @@ instance DerivativeStages (AstShaped FullSpan) where fwdArtifactFromForwardPass :: forall r sh. (GoodScalar r, Sh.Shape sh) - => TensorFunctor (AstShaped FullSpan) -> (Domains (AstDynamic PrimalSpan) + => TensorFunctor (AstShaped FullSpan) -> (Domains (AstRanked PrimalSpan) -> [AstDynamicVarName] - -> Domains (AstDynamic FullSpan) + -> Domains (AstRanked FullSpan) -> ADVal (AstShaped PrimalSpan) r sh) -> DomainsOD -> ( AstArtifactFwd (AstShaped PrimalSpan) r sh @@ -299,7 +281,7 @@ instance DerivativeStages (AstShaped FullSpan) where instance UnletGradient (AstShaped PrimalSpan) where unletGradient - :: ADShare -> AstBindings -> Domains (AstDynamic PrimalSpan) + :: ADShare -> AstBindings -> Domains (AstRanked PrimalSpan) -> AstDomains PrimalSpan unletGradient l astBindings gradient = unletAstDomains6 astBindings l @@ -413,22 +395,6 @@ instance AstSpan s AstLetADShare l t -> (l, t) AstConstant (AstLetADShare l t) -> (l, AstConstant t) _ -> (emptyADShare, u) - raddDynamic :: forall n r. (GoodScalar r, KnownNat n) - => AstRanked s r n -> DynamicExists (AstDynamic s) - -> DynamicExists (AstDynamic s) - raddDynamic r (DynamicExists @r2 d) = DynamicExists @r $ - case d of - _ | isTensorDummyAst d -> AstRToD r - AstRToD @n2 v -> - case ( sameNat (Proxy @n) (Proxy @n2) - , testEquality (typeRep @r) (typeRep @r2) ) of - (Just Refl, Just Refl) -> AstRToD @n2 @r (r + v) - _ -> error "raddDynamic: type mismatch" - AstSToD @sh2 v -> - case ( matchingRank @sh2 @n - , testEquality (typeRep @r) (typeRep @r2) ) of - (Just Refl, Just Refl) -> AstSToD @sh2 @r (astRToS r + v) - _ -> error "raddDynamic: type mismatch" rregister = astRegisterFun rconstant = fromPrimal @@ -440,48 +406,34 @@ instance AstSpan s instance ( GoodScalar r, KnownNat n , RankedTensor (AstRanked s) , ConvertTensor (AstRanked s) (AstShaped s) ) - => AdaptableDomains (AstDynamic s) (AstRanked s r n) where + => AdaptableDomains (AstRanked s) (AstRanked s r n) where {-# SPECIALIZE instance (KnownNat n, AstSpan s) - => AdaptableDomains (AstDynamic s) (AstRanked s Double n) #-} + => AdaptableDomains (AstRanked s) (AstRanked s Double n) #-} type Value (AstRanked s r n) = Flip OR.Array r n toDomains = undefined - fromDomains aInit params = case V.uncons params of - Just (DynamicExists @r2 a, rest) -> - if isTensorDummyAst a then Just (rzero (rshape aInit), rest) else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !t = rfromD @(AstRanked s) @(AstShaped s) @r a - in Just (t, rest) - _ -> error $ "fromDomains: type mismatch: " - ++ show (typeRep @r) ++ " " ++ show (typeRep @r2) - Nothing -> Nothing - -isTensorDummyAst :: AstDynamic s r -> Bool -isTensorDummyAst t = case t of - AstRToD AstIota -> True - AstRToD (AstConstant AstIota) -> True - AstRToD (AstDualPart (AstConstant AstIota)) -> True - AstSToD AstIotaS -> True - AstSToD (AstConstantS AstIotaS) -> True - AstSToD (AstDualPartS (AstConstantS AstIotaS)) -> True - _ -> False + fromDomains _aInit params = fromDomainsR @r @n params -- TODO: move the impure part to AstFreshId astLetDomainsInFun :: forall n s r. (AstSpan s, GoodScalar r, KnownNat n) - => DomainsOD -> AstDomains s -> (Domains (AstDynamic s) -> AstRanked s r n) + => DomainsOD -> AstDomains s -> (Domains (AstRanked s) -> AstRanked s r n) -> AstRanked s r n {-# NOINLINE astLetDomainsInFun #-} astLetDomainsInFun a0 a f = unsafePerformIO $ do - let genVar :: DynamicExists OD.Array - -> IO (AstDynamicVarName, DynamicExists (AstDynamic s)) - genVar (DynamicExists @r2 t) = do - let sh2 = OD.shapeL t - Sh.withShapeP sh2 $ \(Proxy @p_sh2) -> - withListShape sh2 $ \ (sh3 :: Shape n3 Int) -> do - (var, _, ast) <- funToAstIOR @n3 sh3 id - return ( AstDynamicVarName @Nat @r2 @p_sh2 var - , DynamicExists $ AstRToD ast ) + let genVar :: DynamicTensor (Flip OR.Array) + -> IO (AstDynamicVarName, DynamicTensor (AstRanked s)) + genVar (DynamicRankedDummy @r2 @sh2 _ _) = do + let sh2 = Sh.shapeT @sh2 + withListShape sh2 $ \ (sh3 :: Shape n2 Int) -> do + (var, _, ast) <- funToAstIOR @n2 sh3 id + return ( AstDynamicVarName @Nat @r2 @sh2 var + , DynamicRanked ast ) + genVar (DynamicShapedDummy @r2 @sh2 _ _) = do + (var, _, ast) <- funToAstIOS @sh2 id + return ( AstDynamicVarName @[Nat] @r2 @sh2 var + , DynamicShaped ast ) + genVar _ = error "genVar: unexpected OD value" (vars, asts) <- unzip <$> mapM genVar (V.toList a0) return $! astLetDomainsIn vars a (f $ V.fromList asts) @@ -580,22 +532,6 @@ instance AstSpan s AstLetADShareS l t -> (l, t) AstConstantS (AstLetADShareS l t) -> (l, AstConstantS t) _ -> (emptyADShare, u) - saddDynamic :: forall sh r. (GoodScalar r, Sh.Shape sh) - => AstShaped s r sh -> DynamicExists (AstDynamic s) - -> DynamicExists (AstDynamic s) - saddDynamic r (DynamicExists @r2 d) = DynamicExists @r $ - case d of - _ | isTensorDummyAst d -> AstSToD r - AstSToD @sh2 v -> - case ( sameShape @sh @sh2 - , testEquality (typeRep @r) (typeRep @r2) ) of - (Just Refl, Just Refl) -> AstSToD @sh2 @r (r + v) - _ -> error "saddDynamic: type mismatch" - AstRToD @n2 v -> - case ( matchingRank @sh @n2 - , testEquality (typeRep @r) (typeRep @r2) ) of - (Just Refl, Just Refl) -> AstRToD @n2 @r (astSToR r + v) - _ -> error "saddDynamic: type mismatch" sregister = astRegisterFunS sconstant = fromPrimalS @@ -607,32 +543,31 @@ instance AstSpan s instance ( GoodScalar r, Sh.Shape sh , ShapedTensor (AstShaped s) , ConvertTensor (AstRanked s) (AstShaped s) ) - => AdaptableDomains (AstDynamic s) (AstShaped s r sh) where + => AdaptableDomains (AstRanked s) (AstShaped s r sh) where type Value (AstShaped s r sh) = Flip OS.Array r sh toDomains = undefined - fromDomains _aInit params = case V.uncons params of - Just (DynamicExists @r2 a, rest) -> - if isTensorDummyAst a then Just (0, rest) else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !t = sfromD @(AstRanked s) @(AstShaped s) @r a - in Just (t, rest) - _ -> error "fromDomains: type mismatch" - Nothing -> Nothing + fromDomains _aInit params = fromDomainsS @r @sh params +-- TODO: dedup with astLetDomainsInFun astLetDomainsInFunS :: forall sh s r. (AstSpan s, Sh.Shape sh) - => DomainsOD -> AstDomains s -> (Domains (AstDynamic s) -> AstShaped s r sh) + => DomainsOD -> AstDomains s -> (Domains (AstRanked s) -> AstShaped s r sh) -> AstShaped s r sh {-# NOINLINE astLetDomainsInFunS #-} astLetDomainsInFunS a0 a f = unsafePerformIO $ do - let genVar :: DynamicExists OD.Array - -> IO (AstDynamicVarName, DynamicExists (AstDynamic s)) - genVar (DynamicExists @r2 t) = do - let sh2 = OD.shapeL t - Sh.withShapeP sh2 $ \(Proxy @p_sh2) -> do - (var, _, ast) <- funToAstIOS @p_sh2 id - return ( AstDynamicVarName @[Nat] @r2 @p_sh2 var - , DynamicExists $ AstSToD ast ) + let genVar :: DynamicTensor (Flip OR.Array) + -> IO (AstDynamicVarName, DynamicTensor (AstRanked s)) + genVar (DynamicRankedDummy @r2 @sh2 _ _) = do + let sh2 = Sh.shapeT @sh2 + withListShape sh2 $ \ (sh3 :: Shape n2 Int) -> do + (var, _, ast) <- funToAstIOR @n2 sh3 id + return ( AstDynamicVarName @Nat @r2 @sh2 var + , DynamicRanked ast ) + genVar (DynamicShapedDummy @r2 @sh2 _ _) = do + (var, _, ast) <- funToAstIOS @sh2 id + return ( AstDynamicVarName @[Nat] @r2 @sh2 var + , DynamicShaped ast ) + genVar _ = error "genVar: unexpected OD value" (vars, asts) <- unzip <$> mapM genVar (V.toList a0) return $! astLetDomainsInS vars a (f $ V.fromList asts) @@ -681,40 +616,35 @@ astBuild1VectorizeS f = -- * ConvertTensor and DomainsTensor instances instance AstSpan s => ConvertTensor (AstRanked s) (AstShaped s) where - rfromD = astFromDynamic rfromS = astSToR - dfromR = AstRToD - dfromS = AstSToD + dfromR = DynamicRanked + dfromS = DynamicShaped sfromR = astRToS - sfromD = astFromDynamicS - ddummy = AstRToD $ fromPrimal AstIota - dIsDummy = isTensorDummyAst - dshape (AstRToD v) = shapeToList $ shapeAst v - dshape (AstSToD @sh _) = Sh.shapeT @sh + dIsDummy DynamicRankedDummy{} = True + dIsDummy DynamicShapedDummy{} = True + dIsDummy _ = False + dshape = shapeDynamic instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where dmkDomains = AstDomains dunDomains od domainsOf = - let f :: forall r n. (GoodScalar r, KnownNat n) - => Int -> Domains (AstDynamic s) -> AstRanked s r n - f i d = case d V.! i of - DynamicExists (AstRToD @n2 @r2 w) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r) - , Just Refl <- sameNat (Proxy @n2) (Proxy @n) -> w - DynamicExists (AstSToD @sh2 @r2 w) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r) - , Just Refl <- matchingRank @sh2 @n -> rfromS w - _ -> error "dunDomains: type mismatch with od" - in V.imap (\i (DynamicExists @r a) -> - withListShape (dshape @(Flip OR.Array) a) $ \ (_ :: Shape n Int) -> - DynamicExists $ dfromR @(AstRanked s) @(AstShaped s) @r @n - $ rletDomainsIn @(AstRanked s) od domainsOf (f i)) od + let f :: Int -> DynamicTensor (Flip OR.Array) -> DynamicTensor (AstRanked s) + f i = \case + DynamicRankedDummy @r @sh _ _ -> + withListShape (Sh.shapeT @sh) $ \(_ :: Shape n Int) -> + DynamicRanked @r @n + $ rletDomainsIn @(AstRanked s) od domainsOf (rfromD . (V.! i)) + DynamicShapedDummy @r @sh _ _ -> + DynamicShaped @r @sh + $ sletDomainsIn @(AstShaped s) od domainsOf (sfromD . (V.! i)) + _ -> error "dunDomains: unexpected OD value" + in V.imap f od rletInDomains = astLetInDomainsFun sletInDomains = astLetInDomainsFunS rrev :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstDomains s rrev f parameters0 = -- This computes the (AST of) derivative of f once and interprets it again @@ -730,9 +660,9 @@ instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where -- we could shortcut when @s@ is @PrimalSpan@ and @parameters@ -- are the same variables, but it's a very special case rrevDt :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (AstDynamic s) + -> Domains (AstRanked s) -> AstRanked s r n -> AstDomains s rrevDt f parameters0 = @@ -744,10 +674,10 @@ instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where envDt = extendEnvR varDt dt env in interpretAstDomains envDt gradient rfwd :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (AstDynamic s) - -> Domains (AstDynamic s) + -> Domains (AstRanked s) + -> Domains (AstRanked s) -> AstRanked s r n rfwd f parameters0 = let (((varsDt, vars), derivative, _primal), _delta) = @@ -787,21 +717,12 @@ instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where -> AstRanked s rm (1 + m) -> AstRanked s rn n rfold f x0 as = - let domsToPair :: forall f. ADReady f - => Domains (DynamicOf f) -> (f rn n, f rm m) - domsToPair doms = case (doms V.! 0, doms V.! 1) of - (DynamicExists @rn2 ex, DynamicExists @rm2 ea) - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) - , Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - (rfromD ex, rfromD ea) - _ -> error "rfold: type mismatch" - g :: Domains (DynamicOf (AstRanked FullSpan)) -> AstRanked FullSpan rn n + let domsToPair :: forall f. ADReady f => Domains f -> (f rn n, f rm m) + domsToPair doms = (rfromD $ doms V.! 0, rfromD $ doms V.! 1) + g :: Domains (AstRanked FullSpan) -> AstRanked FullSpan rn n g doms = uncurry f (domsToPair doms) shn = rshape x0 shm = tailShape $ rshape as - odFromSh :: forall rk k. GoodScalar rk - => ShapeInt k -> DynamicExists OD.Array - odFromSh sh = DynamicExists @rk $ OD.constant (shapeToList sh) 0 parameters0 = V.fromList [odFromSh @rn shn, odFromSh @rm shm] in -- This computes the (AST of) derivative of f once for each @x0@ -- and @as@, which is better than once for each @a@. We could compute @@ -854,20 +775,11 @@ instance AstSpan s => DomainsTensor (AstRanked s) (AstShaped s) where -> AstShaped s rn sh sfold f x0 as = let domsToPair :: forall f. ADReadyS f - => Domains (DynamicOf f) -> (f rn sh, f rm shm) - domsToPair doms = case (doms V.! 0, doms V.! 1) of - (DynamicExists @rn2 ex, DynamicExists @rm2 ea) - | Just Refl <- testEquality (typeRep @rn) (typeRep @rn2) - , Just Refl <- testEquality (typeRep @rm) (typeRep @rm2) -> - (sfromD ex, sfromD ea) - _ -> error "sfold: type mismatch" - g :: Domains (DynamicOf (AstShaped FullSpan)) - -> AstShaped FullSpan rn sh + => Domains (RankedOf f) -> (f rn sh, f rm shm) + domsToPair doms = (sfromD $ doms V.! 0, sfromD $ doms V.! 1) + g :: Domains (AstRanked FullSpan) -> AstShaped FullSpan rn sh g doms = uncurry f (domsToPair doms) - odFromSh :: forall rk sh1. (GoodScalar rk, Sh.Shape sh1) - => DynamicExists OD.Array - odFromSh = DynamicExists @rk $ OD.constant (Sh.shapeT @sh1) 0 - domsOD = V.fromList [odFromSh @rn @sh, odFromSh @rm @shm] + domsOD = V.fromList [odFromShS @rn @sh, odFromShS @rm @shm] in case revProduceArtifact TensorFunctor False True g EM.empty domsOD of ( ( ( varDt , [ AstDynamicVarName (AstVarName nid) @@ -1059,8 +971,8 @@ instance AstSpan s rconst = AstNoVectorize . fromPrimal . AstConst rletDomainsIn a0 a f = - AstNoVectorize $ astLetDomainsInFun a0 a (unAstNoVectorize . f) - raddDynamic = undefined + AstNoVectorize $ astLetDomainsInFun + a0 a (unAstNoVectorize . f . noVectorizeDomains) rconstant = AstNoVectorize . fromPrimal rprimalPart = astSpanPrimal . unAstNoVectorize @@ -1115,8 +1027,8 @@ instance AstSpan s => ShapedTensor (AstNoVectorizeS s) where sconst = AstNoVectorizeS . fromPrimalS . AstConstS sletDomainsIn a0 a f = - AstNoVectorizeS $ astLetDomainsInFunS a0 a (unAstNoVectorizeS . f) - saddDynamic = undefined + AstNoVectorizeS $ astLetDomainsInFunS + a0 a (unAstNoVectorizeS . f . noVectorizeDomains) sconstant = AstNoVectorizeS . fromPrimalS -- exceptionally we do simplify AstConstant to avoid long boring chains @@ -1126,61 +1038,71 @@ instance AstSpan s => ShapedTensor (AstNoVectorizeS s) where sScale s t = astDualPartS $ AstConstantS s * AstDS 0 t instance AstSpan s => ConvertTensor (AstNoVectorize s) (AstNoVectorizeS s) where - rfromD = AstNoVectorize . rfromD @(AstRanked s) rfromS = AstNoVectorize . rfromS @(AstRanked s) . unAstNoVectorizeS - dfromR = dfromR @(AstRanked s) . unAstNoVectorize - dfromS = dfromS @(AstRanked s) . unAstNoVectorizeS + dfromR = DynamicRanked + dfromS = DynamicShaped sfromR = AstNoVectorizeS . sfromR @(AstRanked s) . unAstNoVectorize - sfromD = AstNoVectorizeS . sfromD @(AstRanked s) - ddummy = ddummy @(AstRanked s) - dIsDummy = dIsDummy @(AstRanked s) - dshape = dshape @(AstRanked s) + dIsDummy DynamicRankedDummy{} = True + dIsDummy DynamicShapedDummy{} = True + dIsDummy _ = False + dshape = shapeDynamic instance AstSpan s => DomainsTensor (AstNoVectorize s) (AstNoVectorizeS s) where - dmkDomains = dmkDomains @(AstRanked s) - dunDomains = dunDomains @(AstRanked s) + dmkDomains domains = dmkDomains @(AstRanked s) (unNoVectorizeDomains domains) + dunDomains parameters0 doms = + noVectorizeDomains $ dunDomains @(AstRanked s) parameters0 doms rletInDomains u f = rletInDomains @(AstRanked s) (unAstNoVectorize u) (f . AstNoVectorize) sletInDomains u f = sletInDomains @(AstRanked s) (unAstNoVectorizeS u) (f . AstNoVectorizeS) - rrev f parameters0 domains = AstRev (funToAstDomains f parameters0) domains + rrev f parameters0 domains = + rrev @(AstRanked s) f parameters0 (unNoVectorizeDomains domains) rrevDt f parameters0 domains dt = - AstRevDt (funToAstDomains f parameters0) domains (unAstNoVectorize dt) + rrevDt @(AstRanked s) f parameters0 + (unNoVectorizeDomains domains) (unAstNoVectorize dt) rfwd f parameters0 domains ds = - AstNoVectorize $ AstFwd (funToAstDomains f parameters0) domains ds - srev f parameters0 domains = AstRevS (funToAstDomainsS f parameters0) domains + AstNoVectorize + $ rfwd @(AstRanked s) f parameters0 + (unNoVectorizeDomains domains) (unNoVectorizeDomains ds) + srev f parameters0 domains = + srev @(AstRanked s) f parameters0 (unNoVectorizeDomains domains) srevDt f parameters0 domains dt = - AstRevDtS (funToAstDomainsS f parameters0) domains (unAstNoVectorizeS dt) + srevDt @(AstRanked s) f parameters0 + (unNoVectorizeDomains domains) (unAstNoVectorizeS dt) sfwd f parameters0 domains ds = - AstNoVectorizeS $ AstFwdS (funToAstDomainsS f parameters0) domains ds + AstNoVectorizeS + $ sfwd @(AstRanked s) f parameters0 + (unNoVectorizeDomains domains) (unNoVectorizeDomains ds) rfold f x0 as = - let shn = rshape (unAstNoVectorize x0) - shm = tailShape $ rshape (unAstNoVectorize as) - in AstNoVectorize - $ AstFold (fun2ToAstR shn shm f) - (unAstNoVectorize x0) - (unAstNoVectorize as) + AstNoVectorize + $ rfold @(AstRanked s) f (unAstNoVectorize x0) (unAstNoVectorize as) rfoldDer f df rf x0 as = - let shn = rshape (unAstNoVectorize x0) - shm = tailShape $ rshape (unAstNoVectorize as) - in AstNoVectorize - $ AstFoldDer (fun2ToAstR shn shm f) - (fun4ToAstR shn shm df) - (fun3ToAstR shn shm rf) - (unAstNoVectorize x0) - (unAstNoVectorize as) + AstNoVectorize + $ rfoldDer @(AstRanked s) + f df rf (unAstNoVectorize x0) (unAstNoVectorize as) sfold f x0 as = AstNoVectorizeS - $ AstFoldS (fun2ToAstS f) - (unAstNoVectorizeS x0) - (unAstNoVectorizeS as) + $ sfold @(AstRanked s) f (unAstNoVectorizeS x0) (unAstNoVectorizeS as) sfoldDer f df rf x0 as = AstNoVectorizeS - $ AstFoldDerS (fun2ToAstS f) - (fun4ToAstS df) - (fun3ToAstS rf) - (unAstNoVectorizeS x0) - (unAstNoVectorizeS as) + $ sfoldDer @(AstRanked s) + f df rf (unAstNoVectorizeS x0) (unAstNoVectorizeS as) + +unNoVectorizeDomains :: Domains (AstNoVectorize s) -> Domains (AstRanked s) +unNoVectorizeDomains = + let f (DynamicRanked (AstNoVectorize t)) = DynamicRanked t + f (DynamicShaped (AstNoVectorizeS t)) = DynamicShaped t + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in V.map f + +noVectorizeDomains :: Domains (AstRanked s) -> Domains (AstNoVectorize s) +noVectorizeDomains = + let f (DynamicRanked t) = DynamicRanked $ AstNoVectorize t + f (DynamicShaped t) = DynamicShaped $ AstNoVectorizeS t + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in V.map f instance AstSpan s => RankedTensor (AstNoSimplify s) where rlet a f = @@ -1225,8 +1147,8 @@ instance AstSpan s => RankedTensor (AstNoSimplify s) where rconst = AstNoSimplify . fromPrimal . AstConst rletDomainsIn a0 a f = - AstNoSimplify $ astLetDomainsInFun a0 a (unAstNoSimplify . f) - raddDynamic = undefined + AstNoSimplify $ astLetDomainsInFun + a0 a (unAstNoSimplify . f . noSimplifyDomains) rconstant = AstNoSimplify . fromPrimal -- exceptionally we do simplify AstConstant to avoid long boring chains @@ -1296,8 +1218,8 @@ instance AstSpan s => ShapedTensor (AstNoSimplifyS s) where sconst = AstNoSimplifyS . fromPrimalS . AstConstS sletDomainsIn a0 a f = - AstNoSimplifyS $ astLetDomainsInFunS a0 a (unAstNoSimplifyS . f) - saddDynamic = undefined + AstNoSimplifyS $ astLetDomainsInFunS + a0 a (unAstNoSimplifyS . f . noSimplifyDomains) sconstant = AstNoSimplifyS . fromPrimalS -- exceptionally we do simplify AstConstant to avoid long boring chains @@ -1307,33 +1229,39 @@ instance AstSpan s => ShapedTensor (AstNoSimplifyS s) where sScale s t = astDualPartS $ AstConstantS s * AstDS 0 t instance AstSpan s => ConvertTensor (AstNoSimplify s) (AstNoSimplifyS s) where - rfromD = AstNoSimplify . rfromD @(AstRanked s) rfromS = AstNoSimplify . rfromS @(AstRanked s) . unAstNoSimplifyS - dfromR = dfromR @(AstRanked s) . unAstNoSimplify - dfromS = dfromS @(AstRanked s) . unAstNoSimplifyS + dfromR = DynamicRanked + dfromS = DynamicShaped sfromR = AstNoSimplifyS . sfromR @(AstRanked s) . unAstNoSimplify - sfromD = AstNoSimplifyS . sfromD @(AstRanked s) - ddummy = ddummy @(AstRanked s) - dIsDummy = dIsDummy @(AstRanked s) - dshape = dshape @(AstRanked s) + dIsDummy DynamicRankedDummy{} = True + dIsDummy DynamicShapedDummy{} = True + dIsDummy _ = False + dshape = shapeDynamic instance AstSpan s => DomainsTensor (AstNoSimplify s) (AstNoSimplifyS s) where - dmkDomains = dmkDomains @(AstRanked s) - dunDomains = dunDomains @(AstRanked s) + dmkDomains domains = dmkDomains @(AstRanked s) (unNoSimplifyDomains domains) + dunDomains parameters0 doms = + noSimplifyDomains $ dunDomains @(AstRanked s) parameters0 doms rletInDomains u f = rletInDomains @(AstRanked s) (unAstNoSimplify u) (f . AstNoSimplify) sletInDomains u f = sletInDomains @(AstRanked s) (unAstNoSimplifyS u) (f . AstNoSimplifyS) - rrev = rrev @(AstRanked s) + rrev f parameters0 domains = + rrev @(AstRanked s) f parameters0 (unNoSimplifyDomains domains) rrevDt f parameters0 domains dt = - rrevDt @(AstRanked s) f parameters0 domains (unAstNoSimplify dt) + rrevDt @(AstRanked s) f parameters0 + (unNoSimplifyDomains domains) (unAstNoSimplify dt) rfwd f parameters0 domains ds = - AstNoSimplify $ rfwd @(AstRanked s) f parameters0 domains ds - srev = srev @(AstRanked s) + AstNoSimplify $ rfwd @(AstRanked s) f parameters0 + (unNoSimplifyDomains domains) (unNoSimplifyDomains ds) + srev f parameters0 domains = + srev @(AstRanked s) f parameters0 (unNoSimplifyDomains domains) srevDt f parameters0 domains dt = - srevDt @(AstRanked s) f parameters0 domains (unAstNoSimplifyS dt) + srevDt @(AstRanked s) f parameters0 + (unNoSimplifyDomains domains) (unAstNoSimplifyS dt) sfwd f parameters0 domains ds = - AstNoSimplifyS $ sfwd @(AstRanked s) f parameters0 domains ds + AstNoSimplifyS $ sfwd @(AstRanked s) f parameters0 + (unNoSimplifyDomains domains) (unNoSimplifyDomains ds) rfold f x0 as = AstNoSimplify $ rfold @(AstRanked s) f (unAstNoSimplify x0) (unAstNoSimplify as) @@ -1347,3 +1275,19 @@ instance AstSpan s => DomainsTensor (AstNoSimplify s) (AstNoSimplifyS s) where AstNoSimplifyS $ sfoldDer @(AstRanked s) f df rf (unAstNoSimplifyS x0) (unAstNoSimplifyS as) + +unNoSimplifyDomains :: Domains (AstNoSimplify s) -> Domains (AstRanked s) +unNoSimplifyDomains = + let f (DynamicRanked (AstNoSimplify t)) = DynamicRanked t + f (DynamicShaped (AstNoSimplifyS t)) = DynamicShaped t + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in V.map f + +noSimplifyDomains :: Domains (AstRanked s) -> Domains (AstNoSimplify s) +noSimplifyDomains = + let f (DynamicRanked t) = DynamicRanked $ AstNoSimplify t + f (DynamicShaped t) = DynamicShaped $ AstNoSimplifyS t + f (DynamicRankedDummy p1 p2) = DynamicRankedDummy p1 p2 + f (DynamicShapedDummy p1 p2) = DynamicShapedDummy p1 p2 + in V.map f diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index cec964f20..fa2246169 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -13,19 +13,21 @@ module HordeAd.Core.TensorClass ShapeInt, ShapeSh -- * The tensor classes , RankedTensor(..), ShapedTensor(..), ConvertTensor(..), DomainsTensor(..) + , raddDynamic, saddDynamic, rfromD, sfromD -- * The related constraints , ADReady, ADReadyR, ADReadyS, ADReadySmall, ADReadyBoth + -- * Concrete array instances auxiliary definitions + , DomainsOD, sizeDomainsOD, sameShapesDomainsOD, shapeDynamic + , odFromVar, odFromSh, odFromShS, fromDomainsR, fromDomainsS ) where import Prelude import qualified Data.Array.Convert -import qualified Data.Array.DynamicS as OD import Data.Array.Internal (valueOf) import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS -import Data.Bifunctor.Clown import Data.Bifunctor.Flip import Data.Function ((&)) import Data.Kind (Constraint, Type) @@ -34,7 +36,15 @@ import qualified Data.Strict.Vector as Data.Vector import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V import GHC.TypeLits - (KnownNat, OrderingI (..), cmpNat, type (+), type (-), type (<=)) + ( KnownNat + , Nat + , OrderingI (..) + , cmpNat + , sameNat + , type (+) + , type (-) + , type (<=) + ) import Numeric.LinearAlgebra (Numeric, Vector) import qualified Numeric.LinearAlgebra as LA import System.Random @@ -44,6 +54,8 @@ import Unsafe.Coerce (unsafeCoerce) import HordeAd.Core.Adaptor import HordeAd.Core.Ast import HordeAd.Core.Types +import HordeAd.Internal.OrthotopeOrphanInstances + (matchingRank, sameShape) import HordeAd.Internal.TensorOps import HordeAd.Util.ShapedList (ShapeSh, ShapedList (..), consShaped, shapedNat, unShapedNat) @@ -234,7 +246,7 @@ class ( Integral (IntOf ranked), CRanked ranked Num rletDomainsIn :: (KnownNat n, GoodScalar r) => DomainsOD -> DomainsOf ranked - -> (Domains (DynamicOf ranked) -> ranked r n) + -> (Domains ranked -> ranked r n) -> ranked r n -- ** No serviceable parts beyond this point ** -- @@ -253,12 +265,9 @@ class ( Integral (IntOf ranked), CRanked ranked Num rletWrap _l u = u rletUnwrap :: ranked r n -> (ADShare, ranked r n) rletUnwrap u = (emptyADShare, u) - raddDynamic :: forall r n. (GoodScalar r, KnownNat n) - => ranked r n -> DynamicExists (DynamicOf ranked) - -> DynamicExists (DynamicOf ranked) rregister :: (GoodScalar r, KnownNat n) - => ranked r n -> AstBindingsD (DynamicOf ranked) - -> (AstBindingsD (DynamicOf ranked), ranked r n) + => ranked r n -> AstBindingsD ranked + -> (AstBindingsD ranked, ranked r n) rregister r l = (l, r) -- Primal/dual things. @@ -277,6 +286,33 @@ class ( Integral (IntOf ranked), CRanked ranked Num -- TODO: if DualOf is supposed to be user-visible, we needed -- a better name for it; TangentOf? CotangentOf? SecondaryOf? +raddDynamic :: forall ranked r n. + ( RankedTensor ranked, ConvertTensor ranked (ShapedOf ranked) + , GoodScalar r, KnownNat n ) + => ranked r n -> DynamicTensor ranked + -> DynamicTensor ranked +raddDynamic r (DynamicRanked @r2 @n2 t) = case sameNat (Proxy @n2) + (Proxy @n) of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> DynamicRanked @r $ r + t + _ -> error "raddDynamic: type mismatch" + _ -> error "raddDynamic: rank mismatch" +raddDynamic r (DynamicShaped @r2 @sh2 t) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicRanked @r $ r + rfromS t + _ -> error "raddDynamic: type mismatch" + _ -> error "raddDynamic: rank mismatch" +raddDynamic r (DynamicRankedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicRanked (r :: ranked r2 (Sh.Rank sh2)) + _ -> error "raddDynamic: type mismatch" + _ -> error "raddDynamic: rank mismatch" +raddDynamic r (DynamicShapedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicRanked (r :: ranked r2 (Sh.Rank sh2)) + _ -> error "raddDynamic: type mismatch" + _ -> error "raddDynamic: rank mismatch" + -- * Shaped tensor class definition @@ -529,17 +565,14 @@ class ( Integral (IntOf shaped), CShaped shaped Num sletWrap _l u = u sletUnwrap :: shaped r sh -> (ADShare, shaped r sh) sletUnwrap u = (emptyADShare, u) - saddDynamic :: forall sh r. (GoodScalar r, Sh.Shape sh) - => shaped r sh -> DynamicExists (DynamicOf shaped) - -> DynamicExists (DynamicOf shaped) sregister :: (GoodScalar r, Sh.Shape sh) - => shaped r sh -> AstBindingsD (DynamicOf shaped) - -> (AstBindingsD (DynamicOf shaped), shaped r sh) + => shaped r sh -> AstBindingsD (RankedOf shaped) + -> (AstBindingsD (RankedOf shaped), shaped r sh) sregister r l = (l, r) sletDomainsIn :: Sh.Shape sh => DomainsOD -> DomainsOf shaped - -> (Domains (DynamicOf shaped) -> shaped r sh) + -> (Domains (RankedOf shaped) -> shaped r sh) -> shaped r sh -- Primal/dual things. @@ -554,35 +587,99 @@ class ( Integral (IntOf shaped), CShaped shaped Num sScale :: (GoodScalar r, Sh.Shape sh) => PrimalOf shaped r sh -> DualOf shaped r sh -> DualOf shaped r sh +saddDynamic :: forall shaped sh r. + ( ShapedTensor shaped, ConvertTensor (RankedOf shaped) shaped + , GoodScalar r, Sh.Shape sh + , ShapedOf (RankedOf shaped) ~ shaped ) + => shaped r sh -> DynamicTensor (RankedOf shaped) + -> DynamicTensor (RankedOf shaped) +saddDynamic r (DynamicRanked @r2 @n2 t) = case matchingRank @sh @n2 of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> DynamicShaped @r $ r + sfromR t + _ -> error "saddDynamic: type mismatch" + _ -> error "saddDynamic: rank mismatch" +saddDynamic r (DynamicShaped @r2 @sh2 t) = case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicShaped @r $ r + t + _ -> error "saddDynamic: type mismatch" + _ -> error "saddDynamic: shape mismatch" +saddDynamic r (DynamicRankedDummy @r2 @sh2 _ _) = case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicShaped (r :: shaped r2 sh2) + _ -> error "saddDynamic: type mismatch" + _ -> error "saddDynamic: shape mismatch" +saddDynamic r (DynamicShapedDummy @r2 @sh2 _ _) = case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> DynamicShaped (r :: shaped r2 sh2) + _ -> error "saddDynamic: type mismatch" + _ -> error "saddDynamic: shape mismatch" + -- * ConvertTensor and DomainsTensor class definitions -class ( DynamicOf ranked ~ DynamicOf shaped - , DynamicOf shaped ~ DynamicOf ranked ) - => ConvertTensor (ranked :: RankedTensorKind) - (shaped :: ShapedTensorKind) - | ranked -> shaped, shaped -> ranked where - rfromD :: (GoodScalar r, KnownNat n) - => DynamicOf ranked r -> ranked r n +class ConvertTensor (ranked :: RankedTensorKind) + (shaped :: ShapedTensorKind) + | ranked -> shaped, shaped -> ranked where rfromS :: (GoodScalar r, Sh.Shape sh) => shaped r sh -> ranked r (Sh.Rank sh) dfromR :: (GoodScalar r, KnownNat n) - => ranked r n -> DynamicOf ranked r + => ranked r n -> DynamicTensor ranked dfromS :: (GoodScalar r, Sh.Shape sh) - => shaped r sh -> DynamicOf shaped r + => shaped r sh -> DynamicTensor ranked sfromR :: (GoodScalar r, Sh.Shape sh, KnownNat (Sh.Rank sh)) => ranked r (Sh.Rank sh) -> shaped r sh - sfromD :: (GoodScalar r, Sh.Shape sh) - => DynamicOf shaped r -> shaped r sh - ddummy :: GoodScalar r => DynamicOf ranked r - dIsDummy :: DynamicOf ranked r -> Bool - dshape :: GoodScalar r => DynamicOf ranked r -> [Int] + dIsDummy :: DynamicTensor ranked -> Bool + dshape :: GoodScalar r => DynamicTensor ranked -> [Int] + +rfromD :: forall ranked r n. + ( ShapedTensor (ShapedOf ranked) + , ConvertTensor ranked (ShapedOf ranked) + , GoodScalar r, KnownNat n ) + => DynamicTensor ranked -> ranked r n +rfromD (DynamicRanked @r2 @n2 t) = case sameNat (Proxy @n2) (Proxy @n) of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> t + _ -> error "rfromD: type mismatch" + _ -> error "rfromD: rank mismatch" +rfromD (DynamicShaped @r2 @sh2 t) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> rfromS t + _ -> error "rfromD: type mismatch" + _ -> error "rfromD: rank mismatch" +rfromD (DynamicRankedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> rfromS @_ @_ @r2 @sh2 0 + _ -> error "rfromD: type mismatch" + _ -> error "rfromD: rank mismatch" +rfromD (DynamicShapedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> rfromS @ranked @_ @r2 @sh2 0 + _ -> error "rfromD: type mismatch" + _ -> error "rfromD: rank mismatch" + +sfromD :: forall shaped r sh. + ( ShapedTensor shaped, ConvertTensor (RankedOf shaped) shaped + , GoodScalar r, Sh.Shape sh + , ShapedOf (RankedOf shaped) ~ shaped ) + => DynamicTensor (RankedOf shaped) -> shaped r sh +sfromD (DynamicRanked @r2 @n2 t) = case matchingRank @sh @n2 of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> sfromR t + _ -> error "sfromD: type mismatch" + _ -> error "sfromD: rank mismatch" +sfromD (DynamicShaped @r2 @sh2 t) = case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> t + _ -> error "sfromD: type mismatch" + _ -> error "sfromD: shape mismatch" +sfromD DynamicRankedDummy{} = 0 +sfromD DynamicShapedDummy{} = 0 class DomainsTensor (ranked :: RankedTensorKind) (shaped :: ShapedTensorKind) | ranked -> shaped, shaped -> ranked where - dmkDomains :: Domains (DynamicOf ranked) -> DomainsOf ranked - dunDomains :: DomainsOD -> DomainsOf ranked -> Domains (DynamicOf ranked) + dmkDomains :: Domains ranked -> DomainsOf ranked + dunDomains :: DomainsOD -> DomainsOf ranked -> Domains ranked -- ^ Warning: this operation easily breaks sharing. rletInDomains :: (GoodScalar r, KnownNat n) => ranked r n @@ -599,38 +696,38 @@ class DomainsTensor (ranked :: RankedTensorKind) -- because otherwise in the ADVal instance one could put an illegal -- InputR there, confusing two levels of contangents. rrev :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (DynamicOf ranked) + -> Domains ranked -> DomainsOf ranked rrevDt :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (DynamicOf ranked) + -> Domains ranked -> ranked r n -- ^ incoming cotangent (dt) -> DomainsOf ranked rfwd :: (GoodScalar r, KnownNat n) - => (forall f. ADReady f => Domains (DynamicOf f) -> f r n) + => (forall f. ADReady f => Domains f -> f r n) -> DomainsOD - -> Domains (DynamicOf ranked) - -> Domains (DynamicOf ranked) -- ^ incoming tangent (ds) + -> Domains ranked + -> Domains ranked -- ^ incoming tangent (ds) -> ranked r n srev :: (GoodScalar r, Sh.Shape sh) - => (forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh) + => (forall f. ADReadyS f => Domains (RankedOf f) -> f r sh) -> DomainsOD - -> Domains (DynamicOf ranked) + -> Domains ranked -> DomainsOf ranked srevDt :: (GoodScalar r, Sh.Shape sh) - => (forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh) + => (forall f. ADReadyS f => Domains (RankedOf f) -> f r sh) -> DomainsOD - -> Domains (DynamicOf ranked) + -> Domains ranked -> shaped r sh -> DomainsOf ranked sfwd :: (GoodScalar r, Sh.Shape sh) - => (forall f. ADReadyS f => Domains (DynamicOf f) -> f r sh) + => (forall f. ADReadyS f => Domains (RankedOf f) -> f r sh) -> DomainsOD - -> Domains (DynamicOf ranked) - -> Domains (DynamicOf ranked) + -> Domains ranked + -> Domains ranked -> shaped r sh -- The type mentions ADReady, so it's hard to put this into RankedTensor, -- which doesn't know about ConvertTensor and DomainsTensor. @@ -698,7 +795,6 @@ type ADReadySmall ranked shaped = , ConvertTensor (PrimalOf ranked) (PrimalOf shaped) , CRanked ranked Show, CRanked (PrimalOf ranked) Show , CShaped shaped Show, CShaped (PrimalOf shaped) Show - , CDynamic (DynamicOf ranked) Show, Show (DomainsOf ranked) , DomainsOf ranked ~ DomainsOf shaped , DomainsOf shaped ~ DomainsOf ranked ) @@ -708,15 +804,92 @@ type ADReadyBoth ranked shaped = , DomainsTensor ranked shaped , DomainsTensor (PrimalOf ranked) (PrimalOf shaped) ) -type CDynamic :: (Type -> Type) -> (Type -> Constraint) -> Constraint -class (forall r20. GoodScalar r20 => c (dynamic r20)) - => CDynamic dynamic c where -instance (forall r20. GoodScalar r20 => c (dynamic r20)) - => CDynamic dynamic c where - -- * Instances for concrete arrays +type DomainsOD = Domains (Flip OR.Array) + +sizeDomainsOD :: DomainsOD -> Int +sizeDomainsOD = let f (DynamicRanked (Flip t)) = OR.size t + f (DynamicShaped (Flip t)) = OS.size t + f (DynamicRankedDummy _ proxy_sh) = Sh.sizeP proxy_sh + f (DynamicShapedDummy _ proxy_sh) = Sh.sizeP proxy_sh + in V.sum . V.map f + +shapeDynamic :: (RankedTensor ranked, ShapedTensor (ShapedOf ranked)) + => DynamicTensor ranked -> [Int] +shapeDynamic (DynamicRanked t) = shapeToList $ rshape t +shapeDynamic (DynamicShaped t) = ShapedList.sizedListToList $ sshape t +shapeDynamic (DynamicRankedDummy _ proxy_sh) = Sh.shapeP proxy_sh +shapeDynamic (DynamicShapedDummy _ proxy_sh) = Sh.shapeP proxy_sh + +-- TODO: also check scalars are same +sameShapesDomainsOD :: DomainsOD -> DomainsOD -> Bool +sameShapesDomainsOD v1 v2 = + let sameExShape t u = + shapeDynamic @(Flip OR.Array) t == shapeDynamic @(Flip OR.Array) u + in V.and $ V.zipWith sameExShape v1 v2 + +odFromVar :: AstDynamicVarName -> DynamicTensor (Flip OR.Array) +odFromVar (AstDynamicVarName @k @rD @shD _) = + case testEquality (typeRep @k) (typeRep @Nat) of + Just Refl -> DynamicRankedDummy @rD @shD Proxy Proxy + _ -> DynamicShapedDummy @rD @shD Proxy Proxy + +odFromSh :: forall r n. GoodScalar r + => ShapeInt n -> DynamicTensor (Flip OR.Array) +odFromSh sh = Sh.withShapeP (shapeToList sh) $ \proxySh -> + DynamicRankedDummy (Proxy @r) proxySh + +odFromShS :: forall r sh. (GoodScalar r, Sh.Shape sh) + => DynamicTensor (Flip OR.Array) +odFromShS = DynamicShapedDummy @r @sh Proxy Proxy + +fromDomainsR :: forall r n ranked. + (RankedTensor ranked, GoodScalar r, KnownNat n) + => Domains ranked + -> Maybe (ranked r n, Domains ranked) +fromDomainsR params = case V.uncons params of + Just (DynamicRanked @r2 @n2 t, rest) -> case sameNat (Proxy @n2) + (Proxy @n) of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> Just (t, rest) + _ -> error $ "fromDomainsR: type mismatch in " + ++ show (typeRep @r2, typeRep @r) + _ -> error "fromDomainsR: rank mismatch" + Just (DynamicShaped{}, _) -> error "fromDomainsR: ranked from shaped" + Just (DynamicRankedDummy @r2 @sh2 _ _, rest) -> case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> + let sh2 = listShapeToShape (Sh.shapeT @sh2) + in Just (rzero sh2 :: ranked r2 (Sh.Rank sh2), rest) + _ -> error "fromDomainsR: type mismatch" + _ -> error "fromDomainsR: shape mismatch" + Just (DynamicShapedDummy{}, _) -> error "fromDomainsR: ranked from shaped" + Nothing -> Nothing + +fromDomainsS :: forall r sh shaped + . ( ShapedTensor shaped, GoodScalar r, Sh.Shape sh + , ShapedOf (RankedOf shaped) ~ shaped ) + => Domains (RankedOf shaped) + -> Maybe (shaped r sh, Domains (RankedOf shaped)) +fromDomainsS params = case V.uncons params of + Just (DynamicRanked{}, _) -> error "fromDomainsS: shaped from ranked" + Just (DynamicShaped @r2 @sh2 t, rest) -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> Just (t, rest) + _ -> error "fromDomainsS: type mismatch" + _ -> error "fromDomainsS: shape mismatch" + Just (DynamicRankedDummy{}, _) -> error "fromDomainsS: shaped from ranked" + Just (DynamicShapedDummy @r2 @sh2 _ _, rest) -> case sameShape @sh2 @sh of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> + -- The dummy gets removed, so we verify its types before it does. + Just (0 :: shaped r2 sh2, rest) + _ -> error "fromDomainsS: type mismatch" + _ -> error "fromDomainsS: shape mismatch" + Nothing -> Nothing + type instance SimpleBoolOf (Flip OR.Array) = Bool instance EqF (Flip OR.Array) where @@ -732,20 +905,10 @@ instance OrdF (Flip OR.Array) where instance IfF (Flip OR.Array) where ifF (_, b) v w = if b then v else w -type instance RankedOf (Clown OD.Array) = Flip OR.Array - -type instance ShapedOf (Clown OD.Array) = Flip OS.Array - -type instance DynamicOf (Clown OD.Array) = OD.Array - -type instance DomainsOf (Clown OD.Array) = DomainsOD - type instance RankedOf (Flip OR.Array) = Flip OR.Array type instance ShapedOf (Flip OR.Array) = Flip OS.Array -type instance DynamicOf (Flip OR.Array) = OD.Array - type instance DomainsOf (Flip OR.Array) = DomainsOD type instance PrimalOf (Flip OR.Array) = Flip OR.Array @@ -798,15 +961,6 @@ instance RankedTensor (Flip OR.Array) where rconst = Flip rletDomainsIn _ = (&) - raddDynamic :: forall r n. (GoodScalar r, KnownNat n) - => Flip OR.Array r n -> DynamicExists OD.Array - -> DynamicExists OD.Array - raddDynamic r (DynamicExists @r2 d) = DynamicExists @r $ - if dIsDummy @(Flip OR.Array) d then dfromR r - else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> dfromR r + d - _ -> error "raddDynamic: type mismatch" - rconstant = id rprimalPart = id rdualPart _ = DummyDual @@ -814,21 +968,13 @@ instance RankedTensor (Flip OR.Array) where rScale _ _ = DummyDual instance (GoodScalar r, KnownNat n) - => AdaptableDomains OD.Array (Flip OR.Array r n) where + => AdaptableDomains (Flip OR.Array) (Flip OR.Array r n) where {-# SPECIALIZE instance KnownNat n - => AdaptableDomains OD.Array (Flip OR.Array Double n) #-} + => AdaptableDomains (Flip OR.Array) (Flip OR.Array Double n) #-} type Value (Flip OR.Array r n) = Flip OR.Array r n - toDomains a = V.singleton $ DynamicExists $ dfromR a - fromDomains aInit params = case V.uncons params of - Just (DynamicExists @r2 a, rest) -> - if isTensorDummyD a then Just (rzero (rshape aInit), rest) else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !aR = rfromD @(Flip OR.Array) @(Flip OS.Array) @r a - in Just (aR, rest) - _ -> error $ "fromDomains: type mismatch: " - ++ show (typeRep @r) ++ " " ++ show (typeRep @r2) - Nothing -> Nothing + toDomains = V.singleton . DynamicRanked + fromDomains _aInit params = fromDomainsR @r @n params instance ForgetShape (Flip OR.Array r n) where type NoShape (Flip OR.Array r n) = Flip OR.Array r n @@ -856,8 +1002,6 @@ type instance RankedOf (Flip OS.Array) = Flip OR.Array type instance ShapedOf (Flip OS.Array) = Flip OS.Array -type instance DynamicOf (Flip OS.Array) = OD.Array - type instance DomainsOf (Flip OS.Array) = DomainsOD type instance PrimalOf (Flip OS.Array) = Flip OS.Array @@ -915,14 +1059,6 @@ instance ShapedTensor (Flip OS.Array) where sdot1In u v = Flip $ tdot1InS (runFlip u) (runFlip v) sconst = Flip sletDomainsIn _ = (&) - saddDynamic :: forall r sh. (GoodScalar r, Sh.Shape sh) - => Flip OS.Array r sh -> DynamicExists OD.Array - -> DynamicExists OD.Array - saddDynamic r (DynamicExists @r2 d) = DynamicExists @r $ - if dIsDummy @(Flip OR.Array) d then dfromS r - else case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> dfromS r + d - _ -> error "saddDynamic: type mismatch" sconstant = id sprimalPart = id @@ -931,17 +1067,10 @@ instance ShapedTensor (Flip OS.Array) where sScale _ _ = DummyDual instance (GoodScalar r, Sh.Shape sh) - => AdaptableDomains OD.Array (Flip OS.Array r sh) where + => AdaptableDomains (Flip OR.Array) (Flip OS.Array r sh) where type Value (Flip OS.Array r sh) = Flip OS.Array r sh - toDomains a = V.singleton $ DynamicExists $ dfromS a - fromDomains _aInit params = case V.uncons params of - Just (DynamicExists @r2 a, rest) -> - if isTensorDummyD a then Just (0, rest) else - case testEquality (typeRep @r) (typeRep @r2) of - Just Refl -> let !aS = sfromD @(Flip OR.Array) @(Flip OS.Array) @r a - in Just (aS, rest) - _ -> error "fromDomains: type mismatch" - Nothing -> Nothing + toDomains = V.singleton . DynamicShaped + fromDomains _aInit params = fromDomainsS @r @sh @(Flip OS.Array) params instance Sh.Shape sh => ForgetShape (Flip OS.Array r sh) where @@ -982,12 +1111,11 @@ instance {-# OVERLAPS #-} {-# OVERLAPPING #-} -- The DomainsTensor instance requires ADVal instance, so it's given elsewhere. instance ConvertTensor (Flip OR.Array) (Flip OS.Array) where - rfromD = Flip . Data.Array.Convert.convert rfromS = Flip . Data.Array.Convert.convert . runFlip - dfromR = Data.Array.Convert.convert . runFlip - dfromS = Data.Array.Convert.convert . runFlip + dfromR = DynamicRanked + dfromS = DynamicShaped sfromR = Flip . Data.Array.Convert.convert . runFlip - sfromD = Flip . Data.Array.Convert.convert - ddummy = dummyTensorD - dIsDummy = isTensorDummyD - dshape = OD.shapeL + dIsDummy DynamicRankedDummy{} = True + dIsDummy DynamicShapedDummy{} = True + dIsDummy _ = False + dshape = shapeDynamic diff --git a/src/HordeAd/Core/Types.hs b/src/HordeAd/Core/Types.hs index 47e620eb1..d02729344 100644 --- a/src/HordeAd/Core/Types.hs +++ b/src/HordeAd/Core/Types.hs @@ -6,10 +6,9 @@ module HordeAd.Core.Types -- * Some fundamental constraints , GoodScalar, HasSingletonDict, Differentiable, IfDifferentiable(..) -- * Type definitions for dynamic tensors and tensor collections - , DynamicExists(..), DynamicTensor(..), CRanked, CShaped - , Domains, DomainsOD, sizeDomainsOD, sameShapesDomainsOD + , DynamicTensor(..), CRanked, CShaped, Domains -- * Type families that tensors will belong to - , RankedOf, ShapedOf, DynamicOf, DomainsOf, PrimalOf, DualOf, DummyDual(..) + , RankedOf, ShapedOf, DomainsOf, PrimalOf, DualOf, DummyDual(..) -- * Generic types of indexes used in tensor operations , IntOf, IndexOf, IntSh, IndexSh -- * Generic types of booleans used in tensor operations @@ -26,7 +25,6 @@ module HordeAd.Core.Types import Prelude import Control.DeepSeq (NFData (..)) -import qualified Data.Array.DynamicS as OD import qualified Data.Array.Shape as Sh import Data.Boolean (Boolean (..)) import Data.Int (Int64) @@ -36,7 +34,6 @@ import Data.List (foldl') import Data.Maybe (fromMaybe) import Data.Proxy (Proxy (Proxy)) import qualified Data.Strict.Vector as Data.Vector -import qualified Data.Vector.Generic as V import GHC.TypeLits (KnownNat, Nat, SomeNat (..), natVal, someNatVal) import Numeric.LinearAlgebra (Numeric, Vector) import System.IO.Unsafe (unsafePerformIO) @@ -99,29 +96,22 @@ instance IfDifferentiable Float where -- * Type definitions for dynamic tensors and tensor collections --- Warning: r is an existential variable, a proper specialization needs --- to be picked explicitly at runtime. -type role DynamicExists representational -data DynamicExists :: (Type -> Type) -> Type where - DynamicExists :: forall r dynamic. GoodScalar r - => dynamic r -> DynamicExists dynamic -deriving instance (forall r. GoodScalar r => Show (dynamic r)) - => Show (DynamicExists dynamic) -instance (forall r. NFData r => NFData (dynamic r)) - => NFData (DynamicExists dynamic) where - rnf (DynamicExists x) = rnf x - +-- For thousands of tensor parameters, orthotope's dynamic tensors +-- are faster than the datatype below and the special dummy values are faster +-- than ordinary zero values. However, the library has become complex enough +-- that simplicity is the bottlenet, not speed. +-- -- Warning: r is an existential variable, a proper specialization needs -- to be picked explicitly at runtime. type role DynamicTensor nominal data DynamicTensor (ranked :: RankedTensorKind) where - DynamicRanked :: forall r n ranked. (GoodScalar r, KnownNat n) + DynamicRanked :: (GoodScalar r, KnownNat n) => ranked r n -> DynamicTensor ranked - DynamicShaped :: forall r sh ranked. (GoodScalar r, Sh.Shape sh) + DynamicShaped :: (GoodScalar r, Sh.Shape sh) => ShapedOf ranked r sh -> DynamicTensor ranked - DynamicRankedDummy :: forall r sh ranked. (GoodScalar r, Sh.Shape sh) + DynamicRankedDummy :: (GoodScalar r, Sh.Shape sh) => Proxy r -> Proxy sh -> DynamicTensor ranked - DynamicShapedDummy :: forall r sh ranked. (GoodScalar r, Sh.Shape sh) + DynamicShapedDummy :: (GoodScalar r, Sh.Shape sh) => Proxy r -> Proxy sh -> DynamicTensor ranked deriving instance @@ -155,19 +145,7 @@ instance -- DomainsOf is used for that and the only reasons DomainsOf exists -- is to prevent mixing up the two (and complicating the definition -- below with errors in the AstDomainsLet case). -type Domains dynamic = Data.Vector.Vector (DynamicExists dynamic) - -type DomainsOD = Domains OD.Array - -sizeDomainsOD :: DomainsOD -> Int -sizeDomainsOD d = let f (DynamicExists t) = OD.size t - in V.sum (V.map f d) - -sameShapesDomainsOD :: DomainsOD -> DomainsOD -> Bool -sameShapesDomainsOD v1 v2 = - let sameExShape (DynamicExists arr1, DynamicExists arr2) = - OD.shapeL arr1 == OD.shapeL arr2 - in V.all sameExShape $ V.zip v1 v2 +type Domains ranked = Data.Vector.Vector (DynamicTensor ranked) -- * Type families that tensors will belong to @@ -177,8 +155,6 @@ type family RankedOf (f :: TensorKind k) :: RankedTensorKind type family ShapedOf (f :: TensorKind k) :: ShapedTensorKind -type family DynamicOf (f :: TensorKind k) :: Type -> Type - type family DomainsOf (f :: TensorKind k) :: Type type family PrimalOf (f :: TensorKind k) :: TensorKind k @@ -262,8 +238,8 @@ newtype AstVarId = AstVarId Int intToAstVarId :: Int -> AstVarId intToAstVarId = AstVarId -type AstBindingsD :: (Type -> Type) -> Type -type AstBindingsD dynamic = [(AstVarId, DynamicExists dynamic)] +type AstBindingsD (ranked :: RankedTensorKind) = + [(AstVarId, DynamicTensor ranked)] unsafeGlobalCounter :: Counter {-# NOINLINE unsafeGlobalCounter #-} @@ -284,17 +260,18 @@ unsafeGetFreshId = atomicAddCounter_ unsafeGlobalCounter 1 -- are rarely called and relatively cheap, so no picking specializations -- at runtime is needed. type role ADShareD nominal -type ADShareD :: (Type -> Type) -> Type -data ADShareD d = ADShareNil - | forall r. GoodScalar r - => ADShareCons Int AstVarId (d r) (ADShareD d) -deriving instance (forall r. GoodScalar r => Show (d r)) => Show (ADShareD d) +type ADShareD :: RankedTensorKind -> Type +data ADShareD ranked = + ADShareNil + | ADShareCons Int AstVarId (DynamicTensor ranked) (ADShareD ranked) +deriving instance (CRanked ranked Show, CShaped (ShapedOf ranked) Show) + => Show (ADShareD ranked) emptyADShare :: ADShareD d emptyADShare = ADShareNil -insertADShare :: forall r d. GoodScalar r - => AstVarId -> d r -> ADShareD d -> ADShareD d +insertADShare :: forall d. + AstVarId -> DynamicTensor d -> ADShareD d -> ADShareD d insertADShare !key !t !s = -- The Maybe over-engineering ensures that we never refresh an id -- unnecessarily. In theory, when merging alternating equal lists @@ -314,8 +291,7 @@ insertADShare !key !t !s = GT -> Just $ freshInsertADShare key t l2 in fromMaybe s (insertAD s) -freshInsertADShare :: GoodScalar r - => AstVarId -> d r -> ADShareD d -> ADShareD d +freshInsertADShare :: AstVarId -> DynamicTensor d -> ADShareD d -> ADShareD d {-# NOINLINE freshInsertADShare #-} freshInsertADShare !key !t !s = unsafePerformIO $ do id0 <- unsafeGetFreshId @@ -363,7 +339,7 @@ subtractADShare !s1 !s2 = else case compare key1 key2 of EQ -> subAD rest1 rest2 LT -> subAD l1 rest2 - GT -> (key1, DynamicExists t1) : subAD rest1 l2 + GT -> (key1, t1) : subAD rest1 l2 in subAD s1 s2 flattenADShare :: [ADShareD d] -> ADShareD d @@ -372,14 +348,13 @@ flattenADShare = foldl' mergeADShare emptyADShare assocsADShare :: ADShareD d -> AstBindingsD d {-# INLINE assocsADShare #-} -- help list fusion assocsADShare ADShareNil = [] -assocsADShare (ADShareCons _ key t rest) = - (key, DynamicExists t) : assocsADShare rest +assocsADShare (ADShareCons _ key t rest) = (key, t) : assocsADShare rest _lengthADShare :: Int -> ADShareD d -> Int _lengthADShare acc ADShareNil = acc _lengthADShare acc (ADShareCons _ _ _ rest) = _lengthADShare (acc + 1) rest -varInADShare :: (forall r. AstVarId -> d r -> Bool) +varInADShare :: (AstVarId -> DynamicTensor d -> Bool) -> AstVarId -> ADShareD d -> Bool {-# INLINE varInADShare #-} diff --git a/src/HordeAd/External/Optimizer.hs b/src/HordeAd/External/Optimizer.hs index 1e540cc2c..a74584cd4 100644 --- a/src/HordeAd/External/Optimizer.hs +++ b/src/HordeAd/External/Optimizer.hs @@ -7,13 +7,13 @@ module HordeAd.External.Optimizer import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import Data.Bifunctor.Flip import GHC.TypeLits (KnownNat) import HordeAd.Core.Delta (DualPart (..)) import HordeAd.Core.DualNumber +import HordeAd.Core.TensorADVal () import HordeAd.Core.TensorClass import HordeAd.Core.Types import HordeAd.External.OptimizerTools @@ -24,13 +24,13 @@ import HordeAd.External.OptimizerTools sgd :: forall n r a. (KnownNat n, GoodScalar r) => Double -> (a - -> Domains (DynamicOf (ADVal (Flip OR.Array))) + -> Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) r n) -> [a] -- ^ training data -> DomainsOD -- ^ initial parameters -> (DomainsOD, Flip OR.Array r n) sgd gamma f trainingData parameters0 = go trainingData parameters0 where - deltaInputs = generateDeltaInputsOD @(Flip OR.Array) parameters0 + deltaInputs = generateDeltaInputs @(Flip OR.Array) parameters0 go :: [a] -> DomainsOD -> (DomainsOD, Flip OR.Array r n) go [] parameters = (parameters, 0) go (a : rest) !parameters = @@ -42,33 +42,33 @@ sgd gamma f trainingData parameters0 = go trainingData parameters0 where else go rest parametersNew -- | An implementation of the Adam gradient descent. -sgdAdam :: forall f r a y. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , DualPart f, UnletGradient f, HasSingletonDict y, GoodScalar r - , DynamicOf f ~ OD.Array, DomainsOf f ~ DomainsOD, Num (f r y) ) - => (a -> Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> [a] - -> DomainsOD - -> StateAdam - -> (DomainsOD, StateAdam) +sgdAdam + :: forall f r a y. + ( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f)) + , DualPart f, UnletGradient f, HasSingletonDict y, GoodScalar r + , RankedOf f ~ Flip OR.Array, Num (f r y)) + => (a -> Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> [a] + -> DomainsOD + -> StateAdam + -> (DomainsOD, StateAdam) sgdAdam = sgdAdamArgs defaultArgsAdam -sgdAdamArgs :: forall f r a y. - ( DynamicOf f ~ DynamicOf (RankedOf f) - , ConvertTensor (RankedOf f) (ShapedOf f) - , DualPart f, UnletGradient f, HasSingletonDict y, GoodScalar r - , DynamicOf f ~ OD.Array, DomainsOf f ~ DomainsOD, Num (f r y) ) - => ArgsAdam - -> (a -> Domains (DynamicOf (ADVal f)) -> ADVal f r y) - -> [a] - -> DomainsOD - -> StateAdam - -> (DomainsOD, StateAdam) +sgdAdamArgs + :: forall f r a y. + ( RankedTensor (RankedOf f), RankedTensor (ADVal (RankedOf f)) + , DualPart f, UnletGradient f, GoodScalar r, HasSingletonDict y + , RankedOf f ~ Flip OR.Array, Num (f r y) ) + => ArgsAdam + -> (a -> Domains (ADVal (RankedOf f)) -> ADVal f r y) + -> [a] + -> DomainsOD + -> StateAdam + -> (DomainsOD, StateAdam) sgdAdamArgs argsAdam f trainingData !parameters0 !stateAdam0 = go trainingData parameters0 stateAdam0 where - deltaInputs = generateDeltaInputsOD parameters0 + deltaInputs = generateDeltaInputs parameters0 go :: [a] -> DomainsOD -> StateAdam -> (DomainsOD, StateAdam) go [] parameters stateAdam = (parameters, stateAdam) go (a : rest) !parameters !stateAdam = diff --git a/src/HordeAd/External/OptimizerTools.hs b/src/HordeAd/External/OptimizerTools.hs index 2e0eee96c..a38ef2d69 100644 --- a/src/HordeAd/External/OptimizerTools.hs +++ b/src/HordeAd/External/OptimizerTools.hs @@ -9,32 +9,49 @@ module HordeAd.External.OptimizerTools import Prelude -import qualified Data.Array.DynamicS as OD +import qualified Data.Array.RankedS as OR +import Data.Bifunctor.Flip +import Data.Proxy (Proxy (Proxy)) import Data.Type.Equality (testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V +import GHC.TypeLits (KnownNat, sameNat) import Numeric.LinearAlgebra (Numeric, Vector) import qualified Numeric.LinearAlgebra as LA import Type.Reflection (typeRep) +import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (liftVD2) -import HordeAd.Internal.TensorOps (isTensorDummyD) +import HordeAd.Internal.OrthotopeOrphanInstances updateWithGradient :: Double -> DomainsOD -> DomainsOD -> DomainsOD updateWithGradient gamma params gradient = let updateVector :: (Numeric r, Fractional r, Num (Vector r)) => Vector r -> Vector r -> Vector r updateVector i r = i - LA.scale (realToFrac gamma) r - updateR :: DynamicExists OD.Array -> DynamicExists OD.Array - -> DynamicExists OD.Array - updateR ei@(DynamicExists @r1 i) (DynamicExists @r2 r) = - if isTensorDummyD r -- eval didn't update it, would crash - then ei - else ifDifferentiable @r1 - (case testEquality (typeRep @r1) (typeRep @r2) of - Just Refl -> DynamicExists $ liftVD2 updateVector i r - _ -> error "updateWithGradient: type mismatch") - ei + updateR :: DynamicTensor (Flip OR.Array) -> DynamicTensor (Flip OR.Array) + -> DynamicTensor (Flip OR.Array) + updateR i r = case (i, r) of + (DynamicRanked @r1 @n1 t1, DynamicRanked @r2 @n2 t2) -> + ifDifferentiable @r1 + (case sameNat (Proxy @n1) (Proxy @n2) of + Just Refl -> case testEquality (typeRep @r1) (typeRep @r2) of + Just Refl -> + DynamicRanked $ Flip + $ liftVR2 updateVector (runFlip t1) (runFlip t2) + _ -> error "updateWithGradient: scalar mismatch" + _ -> error "updateWithGradient: rank mismatch") + i + (DynamicShaped @r1 @sh1 t1, DynamicShaped @r2 @sh2 t2) -> + ifDifferentiable @r1 + (case sameShape @sh1 @sh2 of + Just Refl -> case testEquality (typeRep @r1) (typeRep @r2) of + Just Refl -> + DynamicShaped $ Flip + $ liftVS2 updateVector (runFlip t1) (runFlip t2) + _ -> error "updateWithGradient: scalar mismatch" + _ -> error "updateWithGradient: rank mismatch") + i + _ -> i -- eval didn't update the gradient, save on computation in V.zipWith updateR params gradient {- @@ -47,13 +64,13 @@ minimumGradient :: (Ord r, Numeric r) => DomainsOD -> r minimumGradient (DomainsOD gradient0 gradientR) = min (if V.null gradient0 then 0 else LA.minElement gradient0) (if V.null gradientR then 0 - else V.minimum (V.map OD.minimumA gradientR)) + else V.minimum (V.map OR.minimumA gradientR)) maximumGradient :: (Ord r, Numeric r) => DomainsOD -> r maximumGradient (DomainsOD gradient0 gradientR) = max (if V.null gradient0 then 0 else LA.maxElement gradient0) (if V.null gradientR then 0 - else V.maximum (V.map OD.maximumA gradientR)) + else V.maximum (V.map OR.maximumA gradientR)) -} data ArgsAdam = ArgsAdam @@ -81,10 +98,14 @@ data StateAdam = StateAdam -- The arguments are just sample params0, for dimensions. zeroParameters :: DomainsOD -> DomainsOD -zeroParameters params = - V.map (\(DynamicExists @r a) -> DynamicExists @r - $ OD.constant (OD.shapeL a) 0) - params +zeroParameters = + let f (DynamicRanked @r @n t) = + let sh = rshape @(Flip OR.Array) t + in DynamicRanked @r @n $ rzero @(Flip OR.Array) sh + f (DynamicShaped @r @sh _) = DynamicShaped @r @sh 0 + f DynamicRankedDummy{} = error "zeroParameters: unexpected value" + f DynamicShapedDummy{} = error "zeroParameters: unexpected value" + in V.map f initialStateAdam :: DomainsOD -> StateAdam initialStateAdam parameters0 = @@ -95,28 +116,28 @@ initialStateAdam parameters0 = , vAdam = zeroP } --- TOOD: make sure this is not worse that OD.zipWith3A when transposing +-- TOOD: make sure this is not worse that OR.zipWith3A when transposing -- between each application or that we never encounter such situations -- -- | Application of a vector function on the flattened arrays elements. liftArray43 :: ( Numeric a, Numeric b, Numeric c, Numeric d - , Numeric x, Numeric y, Numeric z ) + , Numeric x, Numeric y, Numeric z, KnownNat n ) => (Vector a -> Vector b -> Vector c -> Vector d -> (Vector x, Vector y, Vector z)) - -> OD.Array a -> OD.Array b -> OD.Array c -> OD.Array d - -> (OD.Array x, OD.Array y, OD.Array z) + -> OR.Array n a -> OR.Array n b -> OR.Array n c -> OR.Array n d + -> (OR.Array n x, OR.Array n y, OR.Array n z) liftArray43 f m1 m2 m3 m4 = - let sz = OD.shapeL m1 - in if sz == OD.shapeL m2 && sz == OD.shapeL m3 && sz == OD.shapeL m4 - then let (vx, vy, vz) = f (OD.toVector m1) (OD.toVector m2) - (OD.toVector m3) (OD.toVector m4) - in ( OD.fromVector sz vx - , OD.fromVector sz vy - , OD.fromVector sz vz + let sz = OR.shapeL m1 + in if sz == OR.shapeL m2 && sz == OR.shapeL m3 && sz == OR.shapeL m4 + then let (vx, vy, vz) = f (OR.toVector m1) (OR.toVector m2) + (OR.toVector m3) (OR.toVector m4) + in ( OR.fromVector sz vx + , OR.fromVector sz vy + , OR.fromVector sz vz ) else error $ "nonconformant arrays in liftArray43: " - ++ show (OD.shapeL m1, OD.shapeL m2, OD.shapeL m3, OD.shapeL m4) + ++ show (OR.shapeL m1, OR.shapeL m2, OR.shapeL m3, OR.shapeL m4) updateWithGradientAdam :: ArgsAdam -> StateAdam -> DomainsOD -> DomainsOD @@ -145,24 +166,32 @@ updateWithGradientAdam ArgsAdam{..} StateAdam{tAdam, mAdam, vAdam} / (sqrt vANew + LA.scalar (realToFrac epsilon)) ) -- the @scalar@ is safe here; -- @addConstant@ would be better, but it's not exposed - updateR :: DynamicExists OD.Array -> DynamicExists OD.Array - -> DynamicExists OD.Array -> DynamicExists OD.Array - -> ( DynamicExists OD.Array - , DynamicExists OD.Array - , DynamicExists OD.Array ) - updateR emA@(DynamicExists @r1 mA) evA@(DynamicExists @r2 vA) - ep@(DynamicExists @r3 p) (DynamicExists @r4 g) = - if isTensorDummyD g -- eval didn't update it - then (emA, evA, ep) - else ifDifferentiable @r1 - (case ( testEquality (typeRep @r1) (typeRep @r2) - , testEquality (typeRep @r2) (typeRep @r3) - , testEquality (typeRep @r3) (typeRep @r4) ) of - (Just Refl, Just Refl, Just Refl) -> - let (od1, od2, od3) = liftArray43 updateVector mA vA p g - in (DynamicExists od1, DynamicExists od2, DynamicExists od3) + updateR :: DynamicTensor (Flip OR.Array) -> DynamicTensor (Flip OR.Array) + -> DynamicTensor (Flip OR.Array) -> DynamicTensor (Flip OR.Array) + -> ( DynamicTensor (Flip OR.Array) + , DynamicTensor (Flip OR.Array) + , DynamicTensor (Flip OR.Array) ) + updateR emA@(DynamicRanked @r1 @n1 mA) evA@(DynamicRanked @r2 @n2 vA) + ep@(DynamicRanked @r3 @n3 p) (DynamicRanked @r4 @n4 g) = + ifDifferentiable @r1 + (case ( sameNat (Proxy @n1) (Proxy @n2) + , sameNat (Proxy @n1) (Proxy @n3) + , sameNat (Proxy @n1) (Proxy @n4) + , testEquality (typeRep @r1) (typeRep @r2) + , testEquality (typeRep @r1) (typeRep @r3) + , testEquality (typeRep @r1) (typeRep @r4) ) of + ( Just Refl, Just Refl, Just Refl + ,Just Refl, Just Refl, Just Refl ) -> + let (od1, od2, od3) = + liftArray43 updateVector (runFlip mA) (runFlip vA) + (runFlip p) (runFlip g) + in ( DynamicRanked $ Flip od1 + , DynamicRanked $ Flip od2 + , DynamicRanked $ Flip od3 ) _ -> error "updateWithGradientAdam: type mismatch") (emA, evA, ep) + updateR emA evA ep _ = + (emA, evA, ep) -- eval didn't update the gradient, save on computation (!mAdamRNew, !vAdamRNew, !paramsRNew) = V.unzip3 $ V.zipWith4 updateR mAdamR vAdamR paramsR gradientR in ( paramsRNew diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index 087a22e70..4cd73f828 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -6,7 +6,6 @@ module TestAdaptorSimplified import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh import qualified Data.Array.ShapedS as OS @@ -15,14 +14,12 @@ import qualified Data.EnumMap.Strict as EM import Data.Int (Int64) import Data.List (foldl1') import qualified Data.Strict.IntMap as IM -import Data.Type.Equality (testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V import Foreign.C (CInt) import GHC.TypeLits (KnownNat) import Numeric.LinearAlgebra (Numeric, Vector) import Test.Tasty import Test.Tasty.HUnit hiding (assert) -import Type.Reflection (typeRep) import HordeAd import HordeAd.Core.AstEnv @@ -659,7 +656,7 @@ testListProdPP = do fT = shapedListProd let (artifactRev, deltas)= revArtifactAdapt True fT [1, 2, 3, 4] printGradient6SimpleS renames artifactRev - @?= "\\dret x2 x3 x4 x5 -> sletInDomains (x2 * x3) (\\x6 -> sletInDomains (x6 * x4) (\\x7 -> sletInDomains (x5 * dret) (\\x8 -> sletInDomains (x4 * x8) (\\x9 -> dmkDomains (fromList [dfromR (x3 * x9), dfromR (x2 * x9), dfromR (x6 * x8), dfromR (x7 * dret)])))))" + @?= "\\dret x2 x3 x4 x5 -> sletInDomains (x2 * x3) (\\x6 -> sletInDomains (x6 * x4) (\\x7 -> sletInDomains (x5 * dret) (\\x8 -> sletInDomains (x4 * x8) (\\x9 -> dmkDomains (fromList [dfromS (x3 * x9), dfromS (x2 * x9), dfromS (x6 * x8), dfromS (x7 * dret)])))))" printPrimal6SimpleS renames artifactRev @?= "\\x2 x3 x4 x5 -> slet (x2 * x3) (\\x6 -> slet (x6 * x4) (\\x7 -> x7 * x5))" printGradient6PrettyS renames (simplifyArtifactRevS artifactRev) @@ -667,7 +664,7 @@ testListProdPP = do printPrimal6PrettyS renames (simplifyArtifactRevS artifactRev) @?= "\\x2 x3 x4 x5 -> ((x2 * x3) * x4) * x5" show deltas - @?= "LetS 100000003 (AddS (ScaleS (AstRToS (AstVar [] (AstVarId 100000005))) (LetS 100000002 (AddS (ScaleS (AstRToS (AstVar [] (AstVarId 100000004))) (LetS 100000001 (AddS (ScaleS (AstRToS (AstVar [] (AstVarId 100000003))) (RToS (InputR [] (InputId 0)))) (ScaleS (AstRToS (AstVar [] (AstVarId 100000002))) (RToS (InputR [] (InputId 1))))))) (ScaleS (AstVarS (AstVarId 100000006)) (RToS (InputR [] (InputId 2))))))) (ScaleS (AstVarS (AstVarId 100000007)) (RToS (InputR [] (InputId 3)))))" + @?= "LetS 100000003 (AddS (ScaleS (AstVarS (AstVarId 100000005)) (LetS 100000002 (AddS (ScaleS (AstVarS (AstVarId 100000004)) (LetS 100000001 (AddS (ScaleS (AstVarS (AstVarId 100000003)) (InputS (InputId 0))) (ScaleS (AstVarS (AstVarId 100000002)) (InputS (InputId 1)))))) (ScaleS (AstVarS (AstVarId 100000006)) (InputS (InputId 2)))))) (ScaleS (AstVarS (AstVarId 100000007)) (InputS (InputId 3))))" rankedListProdr :: (RankedTensor ranked, GoodScalar r) => [ranked r 0] -> ranked r 0 @@ -939,7 +936,7 @@ testReluPP2 = do printPrimal6Pretty renames (simplifyArtifactRev artifactRev) @?= "\\v2 x3 -> rgather [5] (rconst (fromList [2] [0.0,1.0])) (\\[i5] -> [ifF (v2 ! [i5] * x3 <=. rconst 0.0) 0 1]) * (v2 * rreplicate 5 x3)" show deltas - @?= "LetR 100000009 (ScaleR (AstVar [5] (AstVarId 100000007)) (LetR 100000005 (AddR (ScaleR (AstReplicate 5 (AstVar [] (AstVarId 100000003))) (InputR [5] (InputId 0))) (ScaleR (AstVar [5] (AstVarId 100000002)) (LetR 100000004 (ReplicateR 5 (InputR [] (InputId 1))))))))" + @?= "LetR 100000009 (ScaleR (AstVar [5] (AstVarId 100000007)) (LetR 100000008 (AddR (ScaleR (AstReplicate 5 (AstVar [] (AstVarId 100000003))) (InputR [5] (InputId 0))) (ScaleR (AstVar [5] (AstVarId 100000002)) (LetR 100000007 (ReplicateR 5 (InputR [] (InputId 1))))))))" testReluSimpler :: Assertion testReluSimpler = do @@ -1095,7 +1092,7 @@ testReluSimplerPP4S2 = do printPrimal6PrettyS renames (simplifyArtifactRevS artifactRev) @?= "\\m2 x3 -> let m8 = m2 * sreshape (sreplicate x3) in sgather (sreplicate (sconst (fromList @[2] [0.0,1.0]))) (\\[i9, i10] -> [i9, ifF (m8 !$ [i9, i10] <=. sconst 0.0) 0 1]) * m8" show deltas - @?= "LetS 100000007 (ScaleS (AstVarS (AstVarId 100000012)) (LetS 100000003 (AddS (ScaleS (AstVarS (AstVarId 100000007)) (RToS (InputR [3,4] (InputId 0)))) (ScaleS (AstRToS (AstVar [3,4] (AstVarId 100000002))) (LetS 100000002 (ReshapeS (LetS 100000001 (ReplicateS (RToS (InputR [] (InputId 1)))))))))))" + @?= "LetS 100000007 (ScaleS (AstVarS (AstVarId 100000012)) (LetS 100000003 (AddS (ScaleS (AstVarS (AstVarId 100000007)) (InputS (InputId 0))) (ScaleS (AstVarS (AstVarId 100000002)) (LetS 100000002 (ReshapeS (LetS 100000001 (ReplicateS (InputS (InputId 1))))))))))" testReluSimpler4S :: Assertion testReluSimpler4S = do @@ -1971,33 +1968,19 @@ blowupTests = testGroup "Catastrophic blowup avoidance tests" fooRrev :: forall g a. (ADReady g, GoodScalar a, Differentiable a) => (a, a, a) -> (g a 0, g a 0, g a 0) fooRrev (x, y, z) = - let fromDynamicExists :: forall f. ADReady f - => DynamicExists (DynamicOf f) -> f a 0 - fromDynamicExists (DynamicExists @r d) - | Just Refl <- testEquality (typeRep @r) (typeRep @a) = rfromD d - | otherwise = error "fromDynamicExists: type mismatch" - fromDoms :: forall f. ADReady f - => Domains (DynamicOf f) -> (f a 0, f a 0, f a 0) - fromDoms v = ( fromDynamicExists $ v V.! 0 - , fromDynamicExists $ v V.! 1 - , fromDynamicExists $ v V.! 2 ) - fooDomains :: forall f. ADReady f - => Domains (DynamicOf f) -> f a 0 - fooDomains v = foo (fromDoms v) - toDynamicExists :: forall f. ADReady f => a -> DynamicExists (DynamicOf f) - toDynamicExists a = - DynamicExists $ dfromR $ rconst @f $ OR.scalar a - zero :: DynamicExists OD.Array - zero = toDynamicExists @(Flip OR.Array) (0 :: a) + let fDomains :: forall f. ADReady f => Domains f -> f a 0 + fDomains v = foo (rfromD $ v V.! 0, rfromD $ v V.! 1, rfromD $ v V.! 2) + sh = [] + zero = odFromSh @a @0 sh shapes = V.fromList [zero, zero, zero] - domsOf = - rrev @g - fooDomains - shapes - (V.fromList $ map (toDynamicExists @g) [x, y, z]) - in ( rletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 0) - , rletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 1) - , rletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 2) ) + domsOf = rrev @g fDomains shapes + (V.fromList + $ [ DynamicRanked $ rconst @g $ OR.scalar x + , DynamicRanked $ rconst @g $ OR.scalar y + , DynamicRanked $ rconst @g $ OR.scalar z ]) + in ( rletDomainsIn shapes domsOf (\v -> rfromD $ v V.! 0) + , rletDomainsIn shapes domsOf (\v -> rfromD $ v V.! 1) + , rletDomainsIn shapes domsOf (\v -> rfromD $ v V.! 2) ) testFooRrev :: Assertion testFooRrev = do @@ -2022,7 +2005,7 @@ testFooRrevPP2 :: Assertion testFooRrevPP2 = do let (a1, _, _) = fooRrev @(AstRanked FullSpan) @Double (1.1, 2.2, 3.3) printAstSimple IM.empty a1 - @?= "rletDomainsIn (rletInDomains (sin (rconst 2.2)) (\\x39 -> rletInDomains (rconst 1.1 * x39) (\\x40 -> rletInDomains (recip (rconst 3.3 * rconst 3.3 + x40 * x40)) (\\x41 -> rletInDomains (sin (rconst 2.2)) (\\x42 -> rletInDomains (rconst 1.1 * x42) (\\x43 -> rletInDomains (rreshape [] (rreplicate 1 (rconst 1.0))) (\\x44 -> rletInDomains (rconst 3.3 * x44) (\\x45 -> rletInDomains (negate (rconst 3.3 * x41) * x44) (\\x46 -> dmkDomains (fromList [dfromR (x39 * x46 + x42 * x45), dfromR (cos (rconst 2.2) * (rconst 1.1 * x46) + cos (rconst 2.2) * (rconst 1.1 * x45)), dfromR ((x40 * x41) * x44 + x43 * x44)])))))))))) (\\[x24, x25, x26] -> x24)" + @?= "rletDomainsIn (rletInDomains (sin (rconst 2.2)) (\\x27 -> rletInDomains (rconst 1.1 * x27) (\\x28 -> rletInDomains (recip (rconst 3.3 * rconst 3.3 + x28 * x28)) (\\x29 -> rletInDomains (sin (rconst 2.2)) (\\x30 -> rletInDomains (rconst 1.1 * x30) (\\x31 -> rletInDomains (rreshape [] (rreplicate 1 (rconst 1.0))) (\\x32 -> rletInDomains (rconst 3.3 * x32) (\\x33 -> rletInDomains (negate (rconst 3.3 * x29) * x32) (\\x34 -> dmkDomains (fromList [dfromR (x27 * x34 + x30 * x33), dfromR (cos (rconst 2.2) * (rconst 1.1 * x34) + cos (rconst 2.2) * (rconst 1.1 * x33)), dfromR ((x28 * x29) * x32 + x31 * x32)])))))))))) (\\[x24, x25, x26] -> x24)" testFooRrev3 :: Assertion testFooRrev3 = do diff --git a/test/simplified/TestHighRankSimplified.hs b/test/simplified/TestHighRankSimplified.hs index fc49a81ae..f540332fa 100644 --- a/test/simplified/TestHighRankSimplified.hs +++ b/test/simplified/TestHighRankSimplified.hs @@ -590,7 +590,7 @@ testConcatBuild3PP2 = do let (artifactRev, _) = revArtifactAdapt True t (Flip $ OR.fromList [3] [0.651,0.14,0.3414]) printGradient6Simple renames artifactRev - @?= "\\dret v2 -> dmkDomains (fromList [dfromR riota])" + @?= "\\dret v2 -> dmkDomains (fromList [dfromR 0])" printPrimal6Simple renames artifactRev @?= "\\v2 -> rfromIntegral (rgather [5,2] (rfromList [rreplicate 5 (rconst (fromList [2] [0,1])), quot (rtranspose [1,0] (rreplicate 2 (rconst (fromList [5] [0,1,2,3,4])))) (rreplicate 5 (rconst (fromList [2] [0,1]) + rreplicate 2 (rconst 1)))]) (\\[i7, i8] -> [ifF (i8 >=. quot i7 (1 + i8)) 0 1, i7, i8]))" printPrimal6Simple renames (simplifyArtifactRev artifactRev) diff --git a/test/simplified/TestMnistCNNR.hs b/test/simplified/TestMnistCNNR.hs index c63f6f18d..31cf636ad 100644 --- a/test/simplified/TestMnistCNNR.hs +++ b/test/simplified/TestMnistCNNR.hs @@ -9,7 +9,6 @@ module TestMnistCNNR import Prelude import Control.Monad (foldM, unless) -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.ShapedS as OS import Data.Bifunctor.Flip @@ -60,7 +59,7 @@ mnistTestCaseCNNA prefix epochs maxBatches kh kw c_out n_hidden , someNatVal $ toInteger n_hidden ) of ( Just (SomeNat @kh _), Just (SomeNat @kw _) ,Just (SomeNat @c_out _), Just (SomeNat @n_hidden _) ) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistCnnRanked2.ADCnnMnistParametersShaped (Flip OS.Array) SizeMnistHeight SizeMnistWidth kh kw c_out n_hidden r) @@ -87,7 +86,7 @@ mnistTestCaseCNNA prefix epochs maxBatches kh kw c_out n_hidden runBatch :: (DomainsOD, StateAdam) -> (Int, [MnistDataR r]) -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do - let f :: MnistDataBatchR r -> Domains (ADValClown OD.Array) + let f :: MnistDataBatchR r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyphR, labelR) adinputs = MnistCnnRanked2.convMnistLossFusedR @@ -159,7 +158,7 @@ mnistTestCaseCNNI prefix epochs maxBatches kh kw c_out n_hidden , someNatVal $ toInteger n_hidden ) of ( Just (SomeNat @kh _), Just (SomeNat @kw _) ,Just (SomeNat @c_out _), Just (SomeNat @n_hidden _) ) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistCnnRanked2.ADCnnMnistParametersShaped (Flip OS.Array) SizeMnistHeight SizeMnistWidth kh kw c_out n_hidden r) @@ -197,7 +196,7 @@ mnistTestCaseCNNI prefix epochs maxBatches kh kw c_out n_hidden runBatch :: (DomainsOD, StateAdam) -> (Int, [MnistDataR r]) -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do - let f :: MnistDataBatchR r -> Domains (ADValClown OD.Array) + let f :: MnistDataBatchR r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyph, label) varInputs = let env = foldr extendEnvD EM.empty @@ -278,7 +277,7 @@ mnistTestCaseCNNO prefix epochs maxBatches kh kw c_out n_hidden valsInitShaped = fst $ randomVals 0.4 (mkStdGen 44) domainsInit = toDomains valsInitShaped -- == toDomains valsInit valsInit :: MnistCnnRanked2.ADCnnMnistParameters ranked r - valsInit = shapedToRanked valsInitShaped + valsInit = forgetShape valsInitShaped name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show kh, show kw, show c_out, show n_hidden @@ -318,8 +317,8 @@ mnistTestCaseCNNO prefix epochs maxBatches kh kw c_out n_hidden -> (DomainsOD, StateAdam) go [] (parameters, stateAdam) = (parameters, stateAdam) go ((glyph, label) : rest) (!parameters, !stateAdam) = - let glyphD = DynamicExists $ dfromR @(Flip OR.Array) $ rconst glyph - labelD = DynamicExists $ dfromR @(Flip OR.Array) $ rconst label + let glyphD = DynamicRanked $ rconst glyph + labelD = DynamicRanked $ rconst label parametersAndInput = V.concat [parameters, V.fromList [glyphD, labelD]] gradientDomain = @@ -398,7 +397,7 @@ testCNNOPP = do $ AstReplicate sizeMnistHeightI 7 valsInit :: MnistCnnRanked2.ADCnnMnistParameters (Flip OR.Array) Double valsInit = - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistCnnRanked2.ADCnnMnistParametersShaped (Flip OS.Array) 4 4 -- see sizeMnistWidthI, etc. 1 1 1 1 Double) diff --git a/test/simplified/TestMnistFCNNR.hs b/test/simplified/TestMnistFCNNR.hs index 990a949a4..25d2119ea 100644 --- a/test/simplified/TestMnistFCNNR.hs +++ b/test/simplified/TestMnistFCNNR.hs @@ -7,7 +7,6 @@ module TestMnistFCNNR import Prelude import Control.Monad (foldM, unless) -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.ShapedS as OS import Data.Bifunctor.Flip @@ -62,11 +61,12 @@ mnistTestCase1VTA prefix epochs maxBatches widthHidden widthHidden2 gamma batchSize expected = let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2 params1Init = - imap (\i nPV -> OD.fromVector [nPV] - $ V.map realToFrac - $ LA.randomVector (44 + nPV + i) LA.Uniform nPV - - LA.scalar 0.5) - nParams1 + imap (\i nPV -> + DynamicRanked @r @1 $ Flip $ OR.fromVector [nPV] + $ V.map realToFrac + $ LA.randomVector (44 + nPV + i) LA.Uniform nPV + - LA.scalar 0.5) + nParams1 -- This is a very ugly and probably unavoidable boilerplate: -- we have to manually define a dummy value of type ADFcnnMnist1Parameters -- with the correct list lengths (vector lengths can be fake) @@ -74,7 +74,7 @@ mnistTestCase1VTA prefix epochs maxBatches widthHidden widthHidden2 -- avoided only with shapely typed tensors and scalars or when -- not using adaptors. emptyR = Flip $ OR.fromList [0] [] - domainsInit = V.fromList $ map (DynamicExists @r) params1Init + domainsInit = V.fromList params1Init valsInit :: MnistFcnnRanked1.ADFcnnMnist1Parameters ranked r valsInit = ( (replicate widthHidden emptyR, emptyR) , (replicate widthHidden2 emptyR, emptyR) @@ -83,7 +83,7 @@ mnistTestCase1VTA prefix epochs maxBatches widthHidden widthHidden2 ++ unwords [ show epochs, show maxBatches , show widthHidden, show widthHidden2 , show (length params1Init) - , show (sum (map OD.size params1Init)) + , show (sizeDomainsOD domainsInit) , show gamma ] ftest :: [MnistData r] -> DomainsOD -> r ftest = MnistFcnnRanked1.afcnnMnistTest1 valsInit widthHidden widthHidden2 @@ -98,7 +98,7 @@ mnistTestCase1VTA prefix epochs maxBatches widthHidden widthHidden2 -- should not print, in principle. let runBatch :: DomainsOD -> (Int, [MnistData r]) -> IO DomainsOD runBatch !domains (k, chunk) = do - let f :: MnistData r -> Domains (ADValClown OD.Array) + let f :: MnistData r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f mnist adinputs = MnistFcnnRanked1.afcnnMnistLoss1 @@ -157,13 +157,14 @@ mnistTestCase1VTI prefix epochs maxBatches widthHidden widthHidden2 gamma batchSize expected = let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2 params1Init = - imap (\i nPV -> OD.fromVector [nPV] - $ V.map realToFrac - $ LA.randomVector (44 + nPV + i) LA.Uniform nPV - - LA.scalar 0.5) - nParams1 + imap (\i nPV -> + DynamicRanked @r @1 $ Flip $ OR.fromVector [nPV] + $ V.map realToFrac + $ LA.randomVector (44 + nPV + i) LA.Uniform nPV + - LA.scalar 0.5) + nParams1 emptyR = Flip $ OR.fromList [0] [] - domainsInit = V.fromList $ map (DynamicExists @r) params1Init + domainsInit = V.fromList params1Init -- This is a very ugly and probably unavoidable boilerplate: -- we have to manually define a dummy value of type ADFcnnMnist1Parameters -- with the correct list lengths (vector lengths can be fake) @@ -178,7 +179,7 @@ mnistTestCase1VTI prefix epochs maxBatches widthHidden widthHidden2 ++ unwords [ show epochs, show maxBatches , show widthHidden, show widthHidden2 , show (length params1Init) - , show (sum (map OD.size params1Init)) + , show (sizeDomainsOD domainsInit) , show gamma ] ftest :: [MnistData r] -> DomainsOD -> r ftest = MnistFcnnRanked1.afcnnMnistTest1 valsInit widthHidden widthHidden2 @@ -202,7 +203,7 @@ mnistTestCase1VTI prefix epochs maxBatches widthHidden widthHidden2 -- should not print, in principle. let runBatch :: DomainsOD -> (Int, [MnistData r]) -> IO DomainsOD runBatch !domains (k, chunk) = do - let f :: MnistData r -> Domains (ADValClown OD.Array) + let f :: MnistData r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyph, label) varInputs = let env = foldr extendEnvD EM.empty @@ -268,13 +269,14 @@ mnistTestCase1VTO prefix epochs maxBatches widthHidden widthHidden2 gamma batchSize expected = let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2 params1Init = - imap (\i nPV -> OD.fromVector [nPV] - $ V.map realToFrac - $ LA.randomVector (44 + nPV + i) LA.Uniform nPV - - LA.scalar 0.5) - nParams1 + imap (\i nPV -> + DynamicRanked @r @1 $ Flip $ OR.fromVector [nPV] + $ V.map realToFrac + $ LA.randomVector (44 + nPV + i) LA.Uniform nPV + - LA.scalar 0.5) + nParams1 emptyR = Flip $ OR.fromList [0] [] - domainsInit = V.fromList $ map (DynamicExists @r) params1Init + domainsInit = V.fromList params1Init -- This is a very ugly and probably unavoidable boilerplate: -- we have to manually define a dummy value of type ADFcnnMnist1Parameters -- with the correct list lengths (vector lengths can be fake) @@ -289,7 +291,7 @@ mnistTestCase1VTO prefix epochs maxBatches widthHidden widthHidden2 ++ unwords [ show epochs, show maxBatches , show widthHidden, show widthHidden2 , show (length params1Init) - , show (sum (map OD.size params1Init)) + , show (sizeDomainsOD domainsInit) , show gamma ] ftest :: [MnistData r] -> DomainsOD -> r ftest = MnistFcnnRanked1.afcnnMnistTest1 valsInit widthHidden widthHidden2 @@ -319,10 +321,10 @@ mnistTestCase1VTO prefix epochs maxBatches widthHidden widthHidden2 go :: [MnistData r] -> DomainsOD -> DomainsOD go [] parameters = parameters go ((glyph, label) : rest) !parameters = - let glyphD = DynamicExists - $ OD.fromVector [sizeMnistGlyphInt] glyph - labelD = DynamicExists - $ OD.fromVector [sizeMnistLabelInt] label + let glyphD = DynamicRanked @r @1 + $ Flip $ OR.fromVector [sizeMnistGlyphInt] glyph + labelD = DynamicRanked @r @1 + $ Flip $ OR.fromVector [sizeMnistLabelInt] label parametersAndInput = V.concat [parameters, V.fromList [glyphD, labelD]] gradientDomain = @@ -394,7 +396,7 @@ mnistTestCase2VTA prefix epochs maxBatches widthHidden widthHidden2 Just (SomeNat @widthHidden _) -> case someNatVal $ toInteger widthHidden2 of Just (SomeNat @widthHidden2 _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped (Flip OS.Array) widthHidden widthHidden2 r) @@ -421,7 +423,7 @@ mnistTestCase2VTA prefix epochs maxBatches widthHidden widthHidden2 -- should not print, in principle. let runBatch :: DomainsOD -> (Int, [MnistData r]) -> IO DomainsOD runBatch !domains (k, chunk) = do - let f :: MnistData r -> Domains (ADValClown OD.Array) + let f :: MnistData r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f mnist adinputs = MnistFcnnRanked2.afcnnMnistLoss2 @@ -456,13 +458,13 @@ mnistTestCase2VTA prefix epochs maxBatches widthHidden widthHidden2 tensorADValMnistTests2 :: TestTree tensorADValMnistTests2 = testGroup "Ranked2 ADVal MNIST tests" - [ mnistTestCase2VTA "VTA 1 epoch, 1 batch" 1 1 300 100 0.02 5 + [ mnistTestCase2VTA "VTA2 1 epoch, 1 batch" 1 1 300 100 0.02 5 (0.8 :: Double) - , mnistTestCase2VTA "VTA artificial 1 2 3 4 5" 1 2 3 4 5 500 + , mnistTestCase2VTA "VTA2 artificial 1 2 3 4 5" 1 2 3 4 5 500 (0.89 :: Float) - , mnistTestCase2VTA "VTA artificial 5 4 3 2 1" 5 4 3 2 1 499 + , mnistTestCase2VTA "VTA2 artificial 5 4 3 2 1" 5 4 3 2 1 499 (0.8361723446893787 :: Double) - , mnistTestCase2VTA "VTA 1 epoch, 0 batch" 1 0 300 100 0.02 500 + , mnistTestCase2VTA "VTA2 1 epoch, 0 batch" 1 0 300 100 0.02 500 (1 :: Float) ] @@ -485,7 +487,7 @@ mnistTestCase2VTI prefix epochs maxBatches widthHidden widthHidden2 case someNatVal $ toInteger widthHidden2 of Nothing -> error "impossible someNatVal error" Just (SomeNat @widthHidden2 _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped (Flip OS.Array) widthHidden widthHidden2 r) @@ -518,7 +520,7 @@ mnistTestCase2VTI prefix epochs maxBatches widthHidden widthHidden2 -- should not print, in principle. let runBatch :: DomainsOD -> (Int, [MnistData r]) -> IO DomainsOD runBatch !domains (k, chunk) = do - let f :: MnistData r -> Domains (ADValClown OD.Array) + let f :: MnistData r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyph, label) varInputs = let env = foldr extendEnvD EM.empty @@ -560,13 +562,13 @@ mnistTestCase2VTI prefix epochs maxBatches widthHidden widthHidden2 tensorIntermediateMnistTests2 :: TestTree tensorIntermediateMnistTests2 = testGroup "Ranked2 Intermediate MNIST tests" - [ mnistTestCase2VTI "VTI 1 epoch, 1 batch" 1 1 300 100 0.02 500 + [ mnistTestCase2VTI "VTI2 1 epoch, 1 batch" 1 1 300 100 0.02 500 (0.534 :: Double) - , mnistTestCase2VTI "VTI artificial 1 2 3 4 5" 1 2 3 4 5 500 + , mnistTestCase2VTI "VTI2 artificial 1 2 3 4 5" 1 2 3 4 5 500 (0.884 :: Float) - , mnistTestCase2VTI "VTI artificial 5 4 3 2 1" 5 4 3 2 1 499 + , mnistTestCase2VTI "VTI2 artificial 5 4 3 2 1" 5 4 3 2 1 499 (0.7464929859719439 :: Double) - , mnistTestCase2VTI "VTI 1 epoch, 0 batch" 1 0 300 100 0.02 500 + , mnistTestCase2VTI "VTI2 1 epoch, 0 batch" 1 0 300 100 0.02 500 (1 :: Float) ] @@ -595,10 +597,10 @@ mnistTestCase2VTO prefix epochs maxBatches widthHidden widthHidden2 domainsInit = toDomains valsInitShaped -- == toDomains valsInit valsInit :: MnistFcnnRanked2.ADFcnnMnist2Parameters ranked r valsInit = - -- This almost works and I wouldn't need shapedToRanked, + -- This almost works and I wouldn't need forgetShape, -- but there is nowhere to get aInit from. -- parseDomains aInit domainsInit - shapedToRanked valsInitShaped + forgetShape valsInitShaped name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show widthHidden, show widthHidden2 @@ -632,10 +634,10 @@ mnistTestCase2VTO prefix epochs maxBatches widthHidden widthHidden2 go :: [MnistData r] -> DomainsOD -> DomainsOD go [] parameters = parameters go ((glyph, label) : rest) !parameters = - let glyphD = DynamicExists - $ OD.fromVector [sizeMnistGlyphInt] glyph - labelD = DynamicExists - $ OD.fromVector [sizeMnistLabelInt] label + let glyphD = DynamicRanked @r @1 + $ Flip $ OR.fromVector [sizeMnistGlyphInt] glyph + labelD = DynamicRanked @r @1 + $ Flip $ OR.fromVector [sizeMnistLabelInt] label parametersAndInput = V.concat [parameters, V.fromList [glyphD, labelD]] gradientDomain = @@ -677,13 +679,13 @@ mnistTestCase2VTO prefix epochs maxBatches widthHidden widthHidden2 tensorADOnceMnistTests2 :: TestTree tensorADOnceMnistTests2 = testGroup "Ranked2 Once MNIST tests" - [ mnistTestCase2VTO "VTO 1 epoch, 1 batch" 1 1 300 100 0.02 500 + [ mnistTestCase2VTO "VTO2 1 epoch, 1 batch" 1 1 300 100 0.02 500 (0.534 :: Double) - , mnistTestCase2VTO "VTO artificial 1 2 3 4 5" 1 2 3 4 5 500 + , mnistTestCase2VTO "VTO2 artificial 1 2 3 4 5" 1 2 3 4 5 500 (0.884 :: Float) - , mnistTestCase2VTO "VTO artificial 5 4 3 2 1" 5 4 3 2 1 499 + , mnistTestCase2VTO "VTO2 artificial 5 4 3 2 1" 5 4 3 2 1 499 (0.7945891783567134 :: Double) - , mnistTestCase2VTO "VTO 1 epoch, 0 batch" 1 0 300 100 0.02 500 + , mnistTestCase2VTO "VTO2 1 epoch, 0 batch" 1 0 300 100 0.02 500 (1 :: Float) ] diff --git a/test/simplified/TestMnistRNNR.hs b/test/simplified/TestMnistRNNR.hs index 202ec6dc6..b431a886c 100644 --- a/test/simplified/TestMnistRNNR.hs +++ b/test/simplified/TestMnistRNNR.hs @@ -15,7 +15,6 @@ module TestMnistRNNR import Prelude import Control.Monad (foldM, unless) -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.ShapedS as OS import Data.Bifunctor.Flip @@ -63,7 +62,7 @@ mnistTestCaseRNNA prefix epochs maxBatches width miniBatchSize totalBatchSize case someNatVal $ toInteger width of Nothing -> error "impossible someNatVal error" Just (SomeNat @width _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistRnnRanked2.ADRnnMnistParametersShaped (Flip OS.Array) width r) 0.4 (mkStdGen 44) @@ -87,7 +86,7 @@ mnistTestCaseRNNA prefix epochs maxBatches width miniBatchSize totalBatchSize runBatch :: (DomainsOD, StateAdam) -> (Int, [MnistDataR r]) -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do - let f :: MnistDataBatchR r -> Domains (ADValClown OD.Array) + let f :: MnistDataBatchR r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyphR, labelR) adinputs = MnistRnnRanked2.rnnMnistLossFusedR @@ -156,7 +155,7 @@ mnistTestCaseRNNI prefix epochs maxBatches width miniBatchSize totalBatchSize case someNatVal $ toInteger width of Nothing -> error "impossible someNatVal error" Just (SomeNat @width _) -> - shapedToRanked $ fst + forgetShape $ fst $ randomVals @(MnistRnnRanked2.ADRnnMnistParametersShaped (Flip OS.Array) width r) 0.4 (mkStdGen 44) @@ -191,7 +190,7 @@ mnistTestCaseRNNI prefix epochs maxBatches width miniBatchSize totalBatchSize runBatch :: (DomainsOD, StateAdam) -> (Int, [MnistDataR r]) -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do - let f :: MnistDataBatchR r -> Domains (ADValClown OD.Array) + let f :: MnistDataBatchR r -> Domains (ADVal (Flip OR.Array)) -> ADVal ranked r 0 f (glyph, label) varInputs = let env = foldr extendEnvD EM.empty @@ -267,7 +266,7 @@ mnistTestCaseRNNO prefix epochs maxBatches width miniBatchSize totalBatchSize valsInitShaped = fst $ randomVals 0.4 (mkStdGen 44) domainsInit = toDomains valsInitShaped -- == toDomains valsInit valsInit :: MnistRnnRanked2.ADRnnMnistParameters ranked r - valsInit = shapedToRanked valsInitShaped + valsInit = forgetShape valsInitShaped name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show width, show miniBatchSize @@ -306,8 +305,8 @@ mnistTestCaseRNNO prefix epochs maxBatches width miniBatchSize totalBatchSize -> (DomainsOD, StateAdam) go [] (parameters, stateAdam) = (parameters, stateAdam) go ((glyph, label) : rest) (!parameters, !stateAdam) = - let glyphD = DynamicExists $ dfromR @(Flip OR.Array) $ rconst glyph - labelD = DynamicExists $ dfromR @(Flip OR.Array) $ rconst label + let glyphD = DynamicRanked $ rconst glyph + labelD = DynamicRanked $ rconst label parametersAndInput = V.concat [parameters, V.fromList [glyphD, labelD]] gradientDomain = diff --git a/test/simplified/TestMnistRNNS.hs b/test/simplified/TestMnistRNNS.hs index 0fd5c516e..3fe984fea 100644 --- a/test/simplified/TestMnistRNNS.hs +++ b/test/simplified/TestMnistRNNS.hs @@ -9,7 +9,6 @@ import Prelude import Control.Exception.Assert.Sugar import Control.Monad (foldM, unless) import qualified Data.Array.Convert -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.ShapedS as OS import Data.Bifunctor.Flip @@ -82,7 +81,7 @@ mnistTestCaseRNNSA prefix epochs maxBatches width@SNat batch_size@SNat -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do let f :: MnistDataBatchS batch_size r - -> Domains (ADValClown OD.Array) + -> Domains (ADVal (Flip OR.Array)) -> ADVal shaped r '[] f (glyphS, labelS) adinputs = MnistRnnShaped2.rnnMnistLossFusedS @@ -130,11 +129,11 @@ mnistTestCaseRNNSA prefix epochs maxBatches width@SNat batch_size@SNat tensorADValMnistTestsRNNSA :: TestTree tensorADValMnistTestsRNNSA = testGroup "RNNS ADVal MNIST tests" [ mnistTestCaseRNNSA "RNNSA 1 epoch, 1 batch" 1 1 (SNat @128) (SNat @5) 50 - (0.8200000000000001 :: Double) + (0.8933333 :: Double) , mnistTestCaseRNNSA "RNNSA artificial 1 2 3 4 5" 2 3 (SNat @4) (SNat @5) 50 - (0.8933333 :: Float) + (0.9 :: Float) , mnistTestCaseRNNSA "RNNSA artificial 5 4 3 2 1" 5 4 (SNat @3) (SNat @2) 49 - (0.8928571428571429 :: Double) + (0.9336734693877551 :: Double) , mnistTestCaseRNNSA "RNNSA 1 epoch, 0 batch" 1 0 (SNat @128) (SNat @5) 50 (1.0 :: Float) ] @@ -191,7 +190,7 @@ mnistTestCaseRNNSI prefix epochs maxBatches width@SNat batch_size@SNat -> IO (DomainsOD, StateAdam) runBatch (!parameters, !stateAdam) (k, chunk) = do let f :: MnistDataBatchS batch_size r - -> Domains (ADValClown OD.Array) + -> Domains (ADVal (Flip OR.Array)) -> ADVal shaped r '[] f (glyph, label) varInputs = let env = foldr extendEnvD EM.empty @@ -312,8 +311,8 @@ mnistTestCaseRNNSO prefix epochs maxBatches width@SNat batch_size@SNat -> (DomainsOD, StateAdam) go [] (parameters, stateAdam) = (parameters, stateAdam) go ((glyph, label) : rest) (!parameters, !stateAdam) = - let glyphD = DynamicExists $ dfromS @(Flip OR.Array) $ sconst glyph - labelD = DynamicExists $ dfromS @(Flip OR.Array) $ sconst label + let glyphD = DynamicShaped $ sconst glyph + labelD = DynamicShaped $ sconst label parametersAndInput = V.concat [parameters, V.fromList [glyphD, labelD]] gradientDomain = diff --git a/test/tool/CrossTesting.hs b/test/tool/CrossTesting.hs index 417c58e03..24623968e 100644 --- a/test/tool/CrossTesting.hs +++ b/test/tool/CrossTesting.hs @@ -9,18 +9,14 @@ module CrossTesting import Prelude -import qualified Data.Array.DynamicS as OD import qualified Data.Array.RankedS as OR import qualified Data.Array.Shape as Sh -import qualified Data.Array.ShapedS as OS import Data.Bifunctor.Flip import qualified Data.EnumMap.Strict as EM -import Data.Type.Equality (testEquality, (:~:) (Refl)) import qualified Data.Vector.Generic as V import GHC.TypeLits (KnownNat) import Numeric.LinearAlgebra (Numeric) import Test.Tasty.HUnit hiding (assert) -import Type.Reflection (typeRep) import HordeAd.Core.Adaptor import HordeAd.Core.Ast @@ -57,13 +53,13 @@ rev' f vals = let value0 = f vals parameters = toDomains vals dt = Nothing - g :: Domains (ADValClown OD.Array) + g :: Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) r m g inputs = f $ parseDomains vals inputs (advalGrad, value1) = crevOnDomains True dt g parameters gradient1 = parseDomains vals advalGrad gradientRrev1 = rrev1 @(Flip OR.Array) @r @n @m f vals - g9 :: Domains (ADValClown (AstDynamic PrimalSpan)) + g9 :: Domains (ADVal (AstRanked PrimalSpan)) -> ADVal (AstRanked PrimalSpan) r m g9 inputs = f $ parseDomains vals inputs (advalGrad9, value9) = @@ -86,7 +82,7 @@ rev' f vals = => (f1 r m -> AstRanked PrimalSpan r m) -> (AstRanked PrimalSpan r n -> f1 r n) -> (AstRanked PrimalSpan r m -> AstRanked PrimalSpan r m) - -> Domains (ADValClown OD.Array) + -> Domains (ADVal (Flip OR.Array)) -> ADVal (Flip OR.Array) r m h fx1 fx2 gx inputs = hGeneral @(ADVal (Flip OR.Array)) fx1 fx2 gx (parseDomains vals inputs) @@ -133,7 +129,7 @@ rev' f vals = => (f1 r m -> AstRanked PrimalSpan r m) -> (AstRanked PrimalSpan r n -> f1 r n) -> (AstRanked PrimalSpan r m -> AstRanked PrimalSpan r m) - -> Domains (ADValClown (AstDynamic PrimalSpan)) + -> Domains (ADVal (AstRanked PrimalSpan)) -> ADVal (AstRanked PrimalSpan) r m hAst fx1 fx2 gx inputs = hGeneral @(ADVal (AstRanked PrimalSpan)) @@ -452,71 +448,35 @@ rrev1 :: forall g r n m r3. (ADReady g, GoodScalar r, GoodScalar r3, KnownNat n, KnownNat m) => (forall f. ADReady f => f r n -> f r3 m) -> g r n -> g r n rrev1 f u = - let fromDynamicExists :: forall f. ADReady f - => DynamicExists (DynamicOf f) -> f r n - fromDynamicExists (DynamicExists @r2 d) - | dIsDummy @f d = rzero (rshape u) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r) = rfromD d - | otherwise = - error $ "fromDynamicExists type mismatch: " - ++ show (typeRep @r2) ++ " /= " ++ show (typeRep @r) - fDomains :: forall f. ADReady f - => Domains (DynamicOf f) -> f r3 m - fDomains v = f (fromDynamicExists $ v V.! 0) - toDynamicExists :: forall f. ADReady f - => f r n -> DynamicExists (DynamicOf f) - toDynamicExists a = DynamicExists $ dfromR a - zero :: DynamicExists OD.Array - zero = toDynamicExists @(Flip OR.Array) (0 :: Flip OR.Array r n) + let fDomains :: forall f. ADReady f => Domains f -> f r3 m + fDomains v = f (rfromD $ v V.! 0) + sh = rshape u + zero = odFromSh @r @n sh shapes = V.fromList [zero] - domsOf = rrev @g fDomains shapes (V.singleton $ toDynamicExists @g u) - in rletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 0) + domsOf = rrev @g fDomains shapes (V.singleton $ DynamicRanked u) + in rletDomainsIn shapes domsOf (\v -> rfromD $ v V.! 0) rfwd1 :: forall g r n m r3. (ADReady g, GoodScalar r, GoodScalar r3, KnownNat n, KnownNat m) => (forall f. ADReady f => f r n -> f r3 m) -> g r n -> g r3 m rfwd1 f u = - let fromDynamicExists :: forall f. ADReady f - => DynamicExists (DynamicOf f) -> f r n - fromDynamicExists (DynamicExists @r2 d) - | dIsDummy @f d = rzero (rshape u) - | Just Refl <- testEquality (typeRep @r2) (typeRep @r) = rfromD d - | otherwise = - error $ "fromDynamicExists type mismatch: " - ++ show (typeRep @r2) ++ " /= " ++ show (typeRep @r) - fDomains :: forall f. ADReady f - => Domains (DynamicOf f) -> f r3 m - fDomains v = f (fromDynamicExists $ v V.! 0) - toDynamicExists :: forall f. ADReady f - => f r n -> DynamicExists (DynamicOf f) - toDynamicExists a = DynamicExists $ dfromR a - zero :: DynamicExists OD.Array - zero = toDynamicExists @(Flip OR.Array) (0 :: Flip OR.Array r n) + let fDomains :: forall f. ADReady f => Domains f -> f r3 m + fDomains v = f (rfromD $ v V.! 0) + sh = rshape u + zero = odFromSh @r @n sh shapes = V.fromList [zero] - in rfwd @g fDomains shapes (V.singleton $ toDynamicExists @g u) - (V.singleton $ toDynamicExists @g u) -- simple + in rfwd @g fDomains shapes (V.singleton $ DynamicRanked u) + (V.singleton $ DynamicRanked u) -- simple srev1 :: forall g r sh sh2 r3. (ADReadyS g, GoodScalar r, GoodScalar r3, Sh.Shape sh, Sh.Shape sh2) => (forall f. ADReadyS f => f r sh -> f r3 sh2) -> g r sh -> g r sh srev1 f u = - let fromDynamicExists :: forall f. ADReadyS f - => DynamicExists (DynamicOf f) -> f r sh - fromDynamicExists (DynamicExists @r2 d) - | dIsDummy @(RankedOf f) d = 0 - | Just Refl <- testEquality (typeRep @r2) (typeRep @r) = sfromD d - | otherwise = - error $ "fromDynamicExists type mismatch: " - ++ show (typeRep @r2) ++ " /= " ++ show (typeRep @r) - fDomains :: forall f. ADReadyS f - => Domains (DynamicOf f) -> f r3 sh2 - fDomains v = f (fromDynamicExists $ v V.! 0) - toDynamicExists :: forall f. ADReadyS f - => f r sh -> DynamicExists (DynamicOf f) - toDynamicExists a = DynamicExists $ dfromS a - zero :: DynamicExists OD.Array - zero = toDynamicExists @(Flip OS.Array) (0 :: Flip OS.Array r sh) + let fDomains :: forall f. ADReadyS f + => Domains (RankedOf f) -> f r3 sh2 + fDomains v = f (sfromD $ v V.! 0) + zero = odFromShS @r @sh shapes = V.fromList [zero] domsOf = srev @(RankedOf g) - fDomains shapes (V.singleton $ toDynamicExists @g u) - in sletDomainsIn shapes domsOf (\v -> fromDynamicExists $ v V.! 0) + fDomains shapes (V.singleton $ DynamicShaped u) + in sletDomainsIn shapes domsOf (\v -> sfromD $ v V.! 0)