Skip to content

Commit

Permalink
Replace orthotope's dynamic tensors and DynamicExists with DynamicTensor
Browse files Browse the repository at this point in the history
Extra several tests fail, but the simplification is already clear.
  • Loading branch information
Mikolaj committed Dec 27, 2023
1 parent f0f9c5a commit f7db2d6
Show file tree
Hide file tree
Showing 29 changed files with 1,470 additions and 1,728 deletions.
6 changes: 3 additions & 3 deletions bench/common/BenchMnistTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ mnistTrainBench2VTA extraPrefix chunkLength xs widthHidden widthHidden2
Just (SomeNat @widthHidden _) ->
case someNatVal $ toInteger widthHidden2 of
Just (SomeNat @widthHidden2 _) ->
shapedToRanked $ fst
forgetShape $ fst
$ randomVals
@(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
(Flip OS.Array) widthHidden widthHidden2 r)
Expand Down Expand Up @@ -240,7 +240,7 @@ mnistTestBench2VTA extraPrefix chunkLength xs widthHidden widthHidden2 = do
Just (SomeNat @widthHidden _) ->
case someNatVal $ toInteger widthHidden2 of
Just (SomeNat @widthHidden2 _) ->
shapedToRanked $ fst
forgetShape $ fst
$ randomVals
@(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
(Flip OS.Array) widthHidden widthHidden2 r)
Expand Down Expand Up @@ -296,7 +296,7 @@ mnistTrainBench2VTO extraPrefix chunkLength xs widthHidden widthHidden2
Just (SomeNat @widthHidden _) ->
case someNatVal $ toInteger widthHidden2 of
Just (SomeNat @widthHidden2 _) ->
shapedToRanked $ fst
forgetShape $ fst
$ randomVals
@(MnistFcnnRanked2.ADFcnnMnist2ParametersShaped
(Flip OS.Array) widthHidden widthHidden2 r)
Expand Down
6 changes: 3 additions & 3 deletions example/MnistData.hs
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ chunksOf n = go where
:: KnownNat y
=> Double
-> (MnistData Double
-> Domains (DynamicOf (ADVal (Flip OR.Array)))
-> Domains (ADVal (Flip OR.Array))
-> ADVal (Flip OR.Array) Double y)
-> [MnistData Double]
-> DomainsOD
-> (DomainsOD, Flip OR.Array Double y) #-}

{-# SPECIALIZE sgdAdam
:: KnownNat y
=> (MnistDataBatchR Double -> Domains (DynamicOf (ADVal (Flip OR.Array)))
=> (MnistDataBatchR Double -> Domains (ADVal (Flip OR.Array))
-> ADVal (Flip OR.Array) Double y)
-> [MnistDataBatchR Double]
-> DomainsOD
Expand All @@ -196,7 +196,7 @@ chunksOf n = go where
{-# SPECIALIZE sgdAdamArgs
:: KnownNat y
=> ArgsAdam
-> (MnistDataBatchR Double -> Domains (DynamicOf (ADVal (Flip OR.Array)))
-> (MnistDataBatchR Double -> Domains (ADVal (Flip OR.Array))
-> ADVal (Flip OR.Array) Double y)
-> [MnistDataBatchR Double]
-> DomainsOD
Expand Down
63 changes: 31 additions & 32 deletions src/HordeAd/Core/Adaptor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ module HordeAd.Core.Adaptor
import Prelude

import Control.Exception (assert)
import Data.Kind (Type)
import qualified Data.Strict.Vector as Data.Vector
import qualified Data.Vector.Generic as V
import System.Random
Expand All @@ -23,14 +22,14 @@ import HordeAd.Core.Types
-- * Adaptor classes

-- Inspired by adaptors from @tomjaguarpaw's branch.
class AdaptableDomains (dynamic :: Type -> Type) vals where
class AdaptableDomains (ranked :: RankedTensorKind) vals where
type Value vals -- ^ a helper type, with the same general shape,
-- but possibly more concrete, e.g., arrays instead of terms
toDomains :: vals -> Domains dynamic
toDomains :: vals -> Domains ranked
-- ^ represent a value of the domain of objective function
-- in a canonical, much less typed way common to all possible types
fromDomains :: Value vals -> Domains dynamic
-> Maybe (vals, Domains dynamic)
fromDomains :: Value vals -> Domains ranked
-> Maybe (vals, Domains ranked)
-- ^ recovers a value of the domain of objective function
-- from its canonical representation, using the general shape
-- recorded in a value of a more concrete type; the remainder
Expand All @@ -40,8 +39,8 @@ class AdaptableDomains (dynamic :: Type -> Type) vals where
-- there is no remainder. This is the main call of the recursive
-- procedure where @fromDomains@ calls itself for sub-values.
parseDomains
:: AdaptableDomains dynamic vals
=> Value vals -> Domains dynamic -> vals
:: AdaptableDomains ranked vals
=> Value vals -> Domains ranked -> vals
parseDomains aInit domains =
case fromDomains aInit domains of
Just (vals, rest) -> assert (V.null rest) vals
Expand All @@ -62,8 +61,8 @@ class RandomDomains vals where
-- * Basic Adaptor class instances

{- This is temporarily moved to TensorADVal in order to specialize manually
instance AdaptableDomains dynamic a
=> AdaptableDomains dynamic [a] where
instance AdaptableDomains ranked a
=> AdaptableDomains ranked [a] where
{-# SPECIALIZE instance
(KnownNat n, AdaptableDomains OD.Array (OR.Array n Double))
=> AdaptableDomains OD.Array
Expand Down Expand Up @@ -96,8 +95,8 @@ instance ForgetShape a
type NoShape [a] = [NoShape a]
forgetShape = map forgetShape

instance AdaptableDomains dynamic a
=> AdaptableDomains dynamic (Data.Vector.Vector a) where
instance AdaptableDomains ranked a
=> AdaptableDomains ranked (Data.Vector.Vector a) where
type Value (Data.Vector.Vector a) = Data.Vector.Vector (Value a)
toDomains = V.concatMap toDomains
fromDomains lInit source =
Expand All @@ -115,8 +114,8 @@ instance ForgetShape a
type NoShape (Data.Vector.Vector a) = Data.Vector.Vector (NoShape a)
forgetShape = V.map forgetShape

instance ( AdaptableDomains dynamic a
, AdaptableDomains dynamic b ) => AdaptableDomains dynamic (a, b) where
instance ( AdaptableDomains ranked a
, AdaptableDomains ranked b ) => AdaptableDomains ranked (a, b) where
type Value (a, b) = (Value a, Value b)
toDomains (a, b) =
let a1 = toDomains a
Expand All @@ -139,10 +138,10 @@ instance ( RandomDomains a
(v2, g2) = randomVals range g1
in ((v1, v2), g2)

instance ( AdaptableDomains dynamic a
, AdaptableDomains dynamic b
, AdaptableDomains dynamic c )
=> AdaptableDomains dynamic (a, b, c) where
instance ( AdaptableDomains ranked a
, AdaptableDomains ranked b
, AdaptableDomains ranked c )
=> AdaptableDomains ranked (a, b, c) where
type Value (a, b, c) = (Value a, Value b, Value c)
toDomains (a, b, c) =
let a1 = toDomains a
Expand Down Expand Up @@ -170,11 +169,11 @@ instance ( RandomDomains a
(v3, g3) = randomVals range g2
in ((v1, v2, v3), g3)

instance ( AdaptableDomains dynamic a
, AdaptableDomains dynamic b
, AdaptableDomains dynamic c
, AdaptableDomains dynamic d )
=> AdaptableDomains dynamic (a, b, c, d) where
instance ( AdaptableDomains ranked a
, AdaptableDomains ranked b
, AdaptableDomains ranked c
, AdaptableDomains ranked d )
=> AdaptableDomains ranked (a, b, c, d) where
type Value (a, b, c, d) = (Value a, Value b, Value c, Value d)
toDomains (a, b, c, d) =
let a1 = toDomains a
Expand Down Expand Up @@ -209,12 +208,12 @@ instance ( RandomDomains a
(v4, g4) = randomVals range g3
in ((v1, v2, v3, v4), g4)

instance ( AdaptableDomains dynamic a
, AdaptableDomains dynamic b
, AdaptableDomains dynamic c
, AdaptableDomains dynamic d
, AdaptableDomains dynamic e )
=> AdaptableDomains dynamic (a, b, c, d, e) where
instance ( AdaptableDomains ranked a
, AdaptableDomains ranked b
, AdaptableDomains ranked c
, AdaptableDomains ranked d
, AdaptableDomains ranked e )
=> AdaptableDomains ranked (a, b, c, d, e) where
type Value (a, b, c, d, e) = (Value a, Value b, Value c, Value d, Value e)
toDomains (a, b, c, d, e) =
let a1 = toDomains a
Expand Down Expand Up @@ -254,8 +253,8 @@ instance ( RandomDomains a
(v5, g5) = randomVals range g4
in ((v1, v2, v3, v4, v5), g5)

instance ( AdaptableDomains dynamic a, AdaptableDomains dynamic b )
=> AdaptableDomains dynamic (Either a b) where
instance ( AdaptableDomains ranked a, AdaptableDomains ranked b )
=> AdaptableDomains ranked (Either a b) where
type Value (Either a b) = Either (Value a) (Value b)
toDomains e = case e of
Left a -> toDomains a
Expand All @@ -275,8 +274,8 @@ instance ( ForgetShape a
Left a -> Left $ forgetShape a
Right b -> Right $ forgetShape b

instance AdaptableDomains dynamic a
=> AdaptableDomains dynamic (Maybe a) where
instance AdaptableDomains ranked a
=> AdaptableDomains ranked (Maybe a) where
type Value (Maybe a) = Maybe (Value a)
toDomains e = case e of
Nothing -> V.concat []
Expand Down
44 changes: 13 additions & 31 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module HordeAd.Core.Ast
, AstArtifactRev, AstArtifactFwd
, AstIndex, AstVarList, AstIndexS, AstVarListS
-- * ASTs
, AstRanked(..), AstShaped(..), AstDynamic(..), AstDomains(..)
, AstRanked(..), AstShaped(..), AstDynamic, AstDomains(..)
, AstBool(..), OpCodeNum1(..), OpCodeNum2(..), OpCode1(..), OpCode2(..)
, OpCodeIntegral2(..), OpCodeBool(..), OpCodeRel(..)
-- * Boolean definitions and instances
Expand All @@ -31,7 +31,6 @@ import Prelude hiding (foldl')
import qualified Data.Array.RankedS as OR
import qualified Data.Array.Shape as Sh
import qualified Data.Array.ShapedS as OS
import Data.Bifunctor.Clown
import Data.Bifunctor.Flip
import Data.Int (Int64)
import Data.Kind (Type)
Expand Down Expand Up @@ -88,21 +87,14 @@ sameAstSpan = case eqTypeRep (typeRep @s1) (typeRep @s2) of

-- * Basic type family instances

type instance RankedOf (Clown (AstDynamic s)) = AstRanked s
type instance ShapedOf (Clown (AstDynamic s)) = AstShaped s
type instance DynamicOf (Clown (AstDynamic s)) = AstDynamic s
type instance DomainsOf (Clown (AstDynamic s)) = AstDomains s

type instance RankedOf (AstRanked s) = AstRanked s
type instance ShapedOf (AstRanked s) = AstShaped s
type instance DynamicOf (AstRanked s) = AstDynamic s
type instance DomainsOf (AstRanked s) = AstDomains s
type instance PrimalOf (AstRanked s) = AstRanked PrimalSpan
type instance DualOf (AstRanked s) = AstRanked DualSpan

type instance RankedOf (AstShaped s) = AstRanked s
type instance ShapedOf (AstShaped s) = AstShaped s
type instance DynamicOf (AstShaped s) = AstDynamic s
type instance DomainsOf (AstShaped s) = AstDomains s
type instance PrimalOf (AstShaped s) = AstShaped PrimalSpan
type instance DualOf (AstShaped s) = AstShaped DualSpan
Expand Down Expand Up @@ -138,8 +130,8 @@ type family ConcreteOf f = result | result -> f where
ConcreteOf (AstRanked FullSpan) = Flip OR.Array
ConcreteOf (AstShaped FullSpan) = Flip OS.Array

type AstBindings = AstBindingsD (AstDynamic PrimalSpan)
type ADShare = ADShareD (AstDynamic PrimalSpan)
type AstBindings = AstBindingsD (AstRanked PrimalSpan)
type ADShare = ADShareD (AstRanked PrimalSpan)


-- * More and less typed variables and type synonyms containing them
Expand Down Expand Up @@ -297,8 +289,8 @@ data AstRanked :: AstSpanType -> RankedTensorKind where
-> AstRanked s2 r n
AstFwd :: (GoodScalar r, KnownNat n)
=> ([AstDynamicVarName], AstRanked s r n)
-> Domains (AstDynamic s)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> Domains (AstRanked s)
-> AstRanked s r n
AstFold :: forall rn rm n m s. (GoodScalar rm, KnownNat m)
=> ( AstVarName (AstRanked PrimalSpan) rn n
Expand Down Expand Up @@ -438,8 +430,8 @@ data AstShaped :: AstSpanType -> ShapedTensorKind where
-> AstShaped s2 r sh
AstFwdS :: (GoodScalar r, Sh.Shape sh)
=> ([AstDynamicVarName], AstShaped s r sh)
-> Domains (AstDynamic s)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> Domains (AstRanked s)
-> AstShaped s r sh
AstFoldS :: forall rn rm sh shm k s. (GoodScalar rm, Sh.Shape shm, KnownNat k)
=> ( AstVarName (AstShaped PrimalSpan) rn sh
Expand Down Expand Up @@ -468,18 +460,12 @@ data AstShaped :: AstSpanType -> ShapedTensorKind where

deriving instance (GoodScalar r, Sh.Shape sh) => Show (AstShaped s r sh)

type role AstDynamic nominal nominal
data AstDynamic :: AstSpanType -> Type -> Type where
AstRToD :: forall n r s. KnownNat n
=> AstRanked s r n -> AstDynamic s r
AstSToD :: forall sh r s. Sh.Shape sh
=> AstShaped s r sh -> AstDynamic s r
deriving instance GoodScalar r => Show (AstDynamic s r)
type AstDynamic (s :: AstSpanType) = DynamicTensor (AstRanked s)

type role AstDomains nominal
data AstDomains s where
-- There are existential variables inside DynamicExists here.
AstDomains :: Domains (AstDynamic s) -> AstDomains s
AstDomains :: Domains (AstRanked s) -> AstDomains s
-- This operation is why we need AstDomains and so DomainsOf.
-- If we kept a vector of terms instead, we'd need to let-bind in each
-- of the terms separately, duplicating the let-bound term.
Expand All @@ -496,24 +482,24 @@ data AstDomains s where
-> AstDomains s2
AstRev :: (GoodScalar r, KnownNat n)
=> ([AstDynamicVarName], AstRanked s r n)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> AstDomains s
-- ^ the function body can't have any free variables outside those
-- listed in the first component of the pair; this reflects
-- the quantification in 'rrev' and prevents cotangent confusion;
-- the same holds for the similar operations below
AstRevDt :: (GoodScalar r, KnownNat n)
=> ([AstDynamicVarName], AstRanked s r n)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> AstRanked s r n
-> AstDomains s
AstRevS :: (GoodScalar r, Sh.Shape sh)
=> ([AstDynamicVarName], AstShaped s r sh)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> AstDomains s
AstRevDtS :: (GoodScalar r, Sh.Shape sh)
=> ([AstDynamicVarName], AstShaped s r sh)
-> Domains (AstDynamic s)
-> Domains (AstRanked s)
-> AstShaped s r sh
-> AstDomains s

Expand Down Expand Up @@ -877,25 +863,21 @@ maxF u v = ifF (u >=. v) u v

type instance RankedOf (AstNoVectorize s) = AstNoVectorize s
type instance ShapedOf (AstNoVectorize s) = AstNoVectorizeS s
type instance DynamicOf (AstNoVectorize s) = AstDynamic s
type instance DomainsOf (AstNoVectorize s) = AstDomains s
type instance PrimalOf (AstNoVectorize s) = AstRanked PrimalSpan
type instance DualOf (AstNoVectorize s) = AstRanked DualSpan
type instance RankedOf (AstNoVectorizeS s) = AstNoVectorize s
type instance ShapedOf (AstNoVectorizeS s) = AstNoVectorizeS s
type instance DynamicOf (AstNoVectorizeS s) = AstDynamic s
type instance DomainsOf (AstNoVectorizeS s) = AstDomains s
type instance PrimalOf (AstNoVectorizeS s) = AstShaped PrimalSpan
type instance DualOf (AstNoVectorizeS s) = AstShaped DualSpan
type instance RankedOf (AstNoSimplify s) = AstNoSimplify s
type instance ShapedOf (AstNoSimplify s) = AstNoSimplifyS s
type instance DynamicOf (AstNoSimplify s) = AstDynamic s
type instance DomainsOf (AstNoSimplify s) = AstDomains s
type instance PrimalOf (AstNoSimplify s) = AstRanked PrimalSpan
type instance DualOf (AstNoSimplify s) = AstRanked DualSpan
type instance RankedOf (AstNoSimplifyS s) = AstNoSimplify s
type instance ShapedOf (AstNoSimplifyS s) = AstNoSimplifyS s
type instance DynamicOf (AstNoSimplifyS s) = AstDynamic s
type instance DomainsOf (AstNoSimplifyS s) = AstDomains s
type instance PrimalOf (AstNoSimplifyS s) = AstShaped PrimalSpan
type instance DualOf (AstNoSimplifyS s) = AstShaped DualSpan
Expand Down
Loading

0 comments on commit f7db2d6

Please sign in to comment.