Skip to content

Commit

Permalink
Remove the BaseTensor dictionary from calls to mnistTrainBench1VTO
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 28, 2025
1 parent 5fb0041 commit 62464f5
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 18 deletions.
20 changes: 18 additions & 2 deletions bench/common/BenchMnistTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ mnistTrainBench1VTO prefix widthHiddenInt widthHidden2Int
in do
let ftkData = FTKProduct (FTKR (sizeMnistGlyphInt :$: ZSR) FTKScalar)
(FTKR (sizeMnistLabelInt :$: ZSR) FTKScalar)
{- -- g is not enough to specialize to Double instead of to r,
-- despite the declaration of r ~ Double above
f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan)
widthHidden widthHidden2 r
Expand All @@ -165,6 +167,19 @@ mnistTrainBench1VTO prefix widthHiddenInt widthHidden2Int
MnistFcnnRanked1.afcnnMnistLoss1
widthHiddenSNat widthHidden2SNat
(glyphR, labelR) pars
g :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Double, ( AstTensor AstMethodLet FullSpan (TKR 1 Double), AstTensor AstMethodLet FullSpan (TKR 1 Double) ) ) -> AstTensor AstMethodLet FullSpan (TKScalar Double)
g = f
-}
f :: ( MnistFcnnRanked1.ADFcnnMnist1Parameters
(AstTensor AstMethodLet FullSpan)
widthHidden widthHidden2 Double
, ( AstTensor AstMethodLet FullSpan (TKR 1 Double)
, AstTensor AstMethodLet FullSpan (TKR 1 Double) ) )
-> AstTensor AstMethodLet FullSpan (TKScalar Double)
f = \ (pars, (glyphR, labelR)) ->
MnistFcnnRanked1.afcnnMnistLoss1 @_ @Double
widthHiddenSNat widthHidden2SNat
(glyphR, labelR) pars
(artRaw, _) = revArtifactAdapt False f (FTKProduct ftk ftkData)
art = simplifyArtifactGradient artRaw
go :: [MnistDataLinearR r]
Expand Down Expand Up @@ -375,8 +390,9 @@ mnistBGroup2VTO xs0 chunkLength =
, mnistTrainBench2VTO "500|150 " 0.02 chunkLength xs (targetInit, art)
]

-- This is expected to fail with -O0 and to pass with -O1.
-- This is expected to fail with -O0 and to pass with -O1 and -fpolymorphic-specialisation.
-- This prevents running benchmarks without optimization, which is a good thing.
inspect $ hasNoTypeClassesExcept 'mnistTrainBench2VTO [''(~), ''GoodScalar, ''Show, ''Num, ''Ord, ''Eq, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench2VTC [''(~), ''RealFrac, ''Nested.FloatElt, ''RealFloatF, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.Elt, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''BaseTensor, ''KnownNat, ''Nested.Storable, ''IntegralF, ''Nested.KnownShX, ''WithDict, ''Integral, ''AstSpan, ''Nested.KnownShS, ''Numeric, ''SplitGen, ''RandomGen, ''Fractional, ''Random, ''KnownSTK]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench1VTO [''(~), ''RealFrac, ''Nested.FloatElt, ''RealFloatF, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.Elt, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''BaseTensor, ''KnownNat, ''Nested.Storable, ''IntegralF, ''Nested.KnownShX, ''WithDict, ''Integral, ''AstSpan, ''Nested.KnownShS, ''Numeric, ''SplitGen, ''RandomGen, ''Fractional, ''Random, ''AdaptableTarget, ''Nested.KnownPerm, ''CommonTargetEqOrd, ''ConvertTensor, ''KnownSTK, ''Boolean, ''AllTargetShow, ''ShareTensor, ''LetTensor, ''RandomValue]
inspect $ hasNoTypeClassesExcept 'mnistTrainBench1VTO [''(~), ''RealFrac, ''Nested.FloatElt, ''RealFloatF, ''GoodScalar, ''Num, ''Show, ''Ord, ''Eq, ''Nested.Elt, ''Nested.PrimElt, ''Nested.KnownElt, ''Nested.NumElt, ''Typeable, ''IfDifferentiable, ''NFData, ''Default.Default, ''KnownNat, ''Nested.Storable, ''IntegralF, ''Nested.KnownShX, ''WithDict, ''Integral, ''AstSpan, ''Nested.KnownShS, ''Numeric, ''SplitGen, ''RandomGen, ''Fractional, ''Random, ''AdaptableTarget, ''Nested.KnownPerm, ''CommonTargetEqOrd, ''ConvertTensor, ''KnownSTK, ''Boolean, ''AllTargetShow, ''ShareTensor, ''LetTensor, ''RandomValue]
-- inspect $ coreOf 'mnistTrainBench1VTO
4 changes: 4 additions & 0 deletions example/MnistFcnnRanked1.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import GHC.TypeLits (KnownNat, Nat)
import Data.Array.Nested (ListR (..))
import Data.Array.Nested qualified as Nested

import HordeAd.Core.Ast
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
Expand Down Expand Up @@ -79,6 +80,9 @@ afcnnMnistLoss1 widthHidden widthHidden2 (datum, target) adparams =
let result = afcnnMnist1 logisticS softMax1S
widthHidden widthHidden2 (sfromR datum) adparams
in lossCrossEntropyV target result
-- {-# SPECIALIZE afcnnMnistLoss1 :: (GoodScalar r, Differentiable r) => SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 r), AstTensor AstMethodLet FullSpan (TKR 1 r)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 r -> AstTensor AstMethodLet FullSpan (TKScalar r) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 Double), AstTensor AstMethodLet FullSpan (TKR 1 Double)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Double -> AstTensor AstMethodLet FullSpan (TKScalar Double) #-}
{-# SPECIALIZE afcnnMnistLoss1 :: SNat widthHidden -> SNat widthHidden2 -> (AstTensor AstMethodLet FullSpan (TKR 1 Float), AstTensor AstMethodLet FullSpan (TKR 1 Float)) -> ADFcnnMnist1Parameters (AstTensor AstMethodLet FullSpan) widthHidden widthHidden2 Float -> AstTensor AstMethodLet FullSpan (TKScalar Float) #-}

-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
Expand Down
13 changes: 13 additions & 0 deletions src/HordeAd/Core/Adaptor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Ops
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.OpsConcrete ()
import HordeAd.Core.OpsAst ()
import HordeAd.Core.Ast

-- * Adaptor classes

Expand Down Expand Up @@ -109,6 +112,9 @@ instance AdaptableTarget target (target y) where
type X (target y) = y
toTarget = id
fromTarget t = t
{-# SPECIALIZE instance AdaptableTarget RepN (RepN (TKS sh Double)) #-}
{-# SPECIALIZE instance AdaptableTarget RepN (RepN (TKS sh Float)) #-}
-- a failed attempt to specialize without -fpolymorphic-specialisation

instance (BaseTensor target, BaseTensor (PrimalOf target), KnownSTK y)
=> DualNumberValue (target y) where
Expand Down Expand Up @@ -153,6 +159,8 @@ instance ( KnownShS sh, GoodScalar r, Fractional r, Random r
(g1, g2) = splitGen g
arr = createRandomVector (shsSize (knownShS @sh)) g1
in (arr, g2)
{-# SPECIALIZE instance KnownShS sh => RandomValue (RepN (TKS sh Double)) #-}
{-# SPECIALIZE instance KnownShS sh => RandomValue (RepN (TKS sh Float)) #-}


-- * Compound instances
Expand Down Expand Up @@ -217,6 +225,7 @@ instance (BaseTensor target, KnownNat n, AdaptableTarget target a)
a = fromTarget a1
rest = fromTarget rest1
in (a ::: rest)
{-# SPECIALIZE instance (KnownNat n, AdaptableTarget (AstTensor AstMethodLet FullSpan) a) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (ListR n a) #-}

instance TermValue a => TermValue (ListR n a) where
type Value (ListR n a) = ListR n (Value a)
Expand Down Expand Up @@ -258,6 +267,7 @@ instance ( BaseTensor target
a = fromTarget a1
b = fromTarget b1
in (a, b)
{-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b) #-}

instance (TermValue a, TermValue b) => TermValue (a, b) where
type Value (a, b) = (Value a, Value b)
Expand Down Expand Up @@ -299,6 +309,7 @@ instance ( BaseTensor target
b = fromTarget b1
c = fromTarget c1
in (a, b, c)
{-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c) #-}

instance (TermValue a, TermValue b, TermValue c)
=> TermValue (a, b, c) where
Expand Down Expand Up @@ -350,6 +361,7 @@ instance ( BaseTensor target
c = fromTarget c1
d = fromTarget d1
in (a, b, c, d)
{-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c, AdaptableTarget (AstTensor AstMethodLet FullSpan) d) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c, d) #-}

instance (TermValue a, TermValue b, TermValue c, TermValue d)
=> TermValue (a, b, c, d) where
Expand Down Expand Up @@ -413,6 +425,7 @@ instance ( BaseTensor target
d = fromTarget d1
e = fromTarget e1
in (a, b, c, d, e)
{-# SPECIALIZE instance (AdaptableTarget (AstTensor AstMethodLet FullSpan) a, AdaptableTarget (AstTensor AstMethodLet FullSpan) b, AdaptableTarget (AstTensor AstMethodLet FullSpan) c, AdaptableTarget (AstTensor AstMethodLet FullSpan) d, AdaptableTarget (AstTensor AstMethodLet FullSpan) e) => AdaptableTarget (AstTensor AstMethodLet FullSpan) (a, b, c, d, e) #-}

instance (TermValue a, TermValue b, TermValue c, TermValue d, TermValue e)
=> TermValue (a, b, c, d, e) where
Expand Down
18 changes: 9 additions & 9 deletions src/HordeAd/Core/Engine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind

-- An orphan needed to tweak dependencies to specialize better.
instance KnownSTK y
=> TermValue (AstTensor AstMethodLet FullSpan y) where
type Value (AstTensor AstMethodLet FullSpan y) = RepN y
fromValue t = tconcrete (tftkG (knownSTK @y) $ unRepN t) t


-- * Reverse derivative adaptors

-- VJP (vector-jacobian product) or Lop (left operations) are alternative
Expand Down Expand Up @@ -129,15 +136,8 @@ revArtifactAdapt hasDt f xftk =
g !hv = tlet hv $ \ !hvShared ->
f $ fromTarget hvShared
in revProduceArtifact hasDt g emptyEnv xftk
{- TODO
{-# SPECIALIZE revArtifactAdapt
:: ( KnownNat n
, AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals
, AdaptableTarget RepN (Value astvals)
, TermValue astvals )
=> Bool -> (astvals -> AstTensor AstMethodLet FullSpan n Double) -> FullTensorKind (X astvals)
-> (AstArtifactRev TKUntyped (TKR n Double), Delta (AstRaw PrimalSpan) (TKR n Double)) #-}
-}
{-# SPECIALIZE revArtifactAdapt :: forall astvals. AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals => Bool -> (astvals -> AstTensor AstMethodLet FullSpan (TKScalar Double)) -> FullTensorKind (X astvals) -> (AstArtifactRev (X astvals) (TKScalar Double), Delta (AstRaw PrimalSpan) (TKScalar Double)) #-}
{-# SPECIALIZE revArtifactAdapt :: forall astvals. AdaptableTarget (AstTensor AstMethodLet FullSpan) astvals => Bool -> (astvals -> AstTensor AstMethodLet FullSpan (TKScalar Float)) -> FullTensorKind (X astvals) -> (AstArtifactRev (X astvals) (TKScalar Float), Delta (AstRaw PrimalSpan) (TKScalar Float)) #-}

revProduceArtifactWithoutInterpretation
:: forall x z.
Expand Down
7 changes: 0 additions & 7 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import Data.Array.Nested.Internal.Shape (shCvtSX, shsProduct, shsRank, shrRank,
import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Nested qualified as Nested

import HordeAd.Core.Adaptor
import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId
Expand Down Expand Up @@ -142,12 +141,6 @@ fwdProduceArtifact f envInit xftk =

-- * AstTensor instances

instance KnownSTK y
=> TermValue (AstTensor AstMethodLet FullSpan y) where
type Value (AstTensor AstMethodLet FullSpan y) = RepN y
fromValue t =
fromPrimal $ astConcrete (tftkG (knownSTK @y) $ unRepN t) t

-- This is a vectorizing combinator that also simplifies
-- the terms touched during vectorization, but not any others.
-- Due to how the Ast instance of Tensor is defined above, vectorization
Expand Down

0 comments on commit 62464f5

Please sign in to comment.