Skip to content

Commit

Permalink
Port MnistFcnnRanked1 from lists to ListR
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 23, 2025
1 parent deff8c2 commit d9416c5
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 265 deletions.
15 changes: 11 additions & 4 deletions bench/common/BenchMnistTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ import MnistData
import MnistFcnnRanked1 qualified
import MnistFcnnRanked2 qualified

{-
afcnnMnistLen1 :: Int -> Int -> [Int]
afcnnMnistLen1 widthHidden widthHidden2 =
replicate widthHidden sizeMnistGlyphInt ++ [widthHidden]
++ replicate widthHidden2 widthHidden ++ [widthHidden2]
++ replicate sizeMnistLabelInt widthHidden2 ++ [sizeMnistLabelInt]
-- * Using lists of vectors, which is rank 1
-- POPL differentiation, straight via the ADVal instance of RankedTensor,
Expand All @@ -35,7 +42,7 @@ mnistTrainBench1VTA :: forall target r. (target ~ RepN, r ~ Double)
-> Benchmark
mnistTrainBench1VTA extraPrefix chunkLength xs widthHidden widthHidden2
gamma = do
let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2
let nParams1 = afcnnMnistLen1 widthHidden widthHidden2
params1Init =
imap (\i nPV ->
DynamicRanked @r @1 $ RepN $ Nested.rfromVector (nPV :$: ZSR)
Expand Down Expand Up @@ -66,7 +73,7 @@ mnistTestBench1VTA :: forall target r. (target ~ RepN, r ~ Double)
=> String -> Int -> [MnistData r] -> Int -> Int
-> Benchmark
mnistTestBench1VTA extraPrefix chunkLength xs widthHidden widthHidden2 = do
let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2
let nParams1 = afcnnMnistLen1 widthHidden widthHidden2
params1Init =
imap (\i nPV ->
DynamicRanked @r @1 $ RepN $ Nested.rfromVector (nPV :$: ZSR)
Expand Down Expand Up @@ -117,7 +124,7 @@ mnistTrainBench1VTO :: forall target r. (target ~ RepN, r ~ Double)
-> Benchmark
mnistTrainBench1VTO extraPrefix chunkLength testData widthHidden widthHidden2
gamma = do
let nParams1 = MnistFcnnRanked1.afcnnMnistLen1 widthHidden widthHidden2
let nParams1 = afcnnMnistLen1 widthHidden widthHidden2
params1Init =
imap (\i nPV ->
DynamicRanked @r @1 $ RepN $ Nested.rfromVector (nPV :$: ZSR)
Expand Down Expand Up @@ -196,7 +203,7 @@ mnistBGroup1VTO xs0 chunkLength =
-- another common width
, mnistTrainBench1VTO "500|150 " chunkLength xs 500 150 0.02
]

-}

-- * Using matrices, which is rank 2

Expand Down
92 changes: 43 additions & 49 deletions example/MnistFcnnRanked1.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,41 @@ module MnistFcnnRanked1 where

import Prelude

import Control.Exception (assert)
import Data.List.NonEmpty qualified as NonEmpty
import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..), inline)
import GHC.TypeLits (KnownNat, Nat)

import Data.Array.Nested (ListR (..))
import Data.Array.Nested qualified as Nested

import HordeAd.Core.Adaptor
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.TensorKind
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.External.CommonRankedOps
import MnistData

afcnnMnistLen1 :: Int -> Int -> [Int]
afcnnMnistLen1 widthHidden widthHidden2 =
replicate widthHidden sizeMnistGlyphInt ++ [widthHidden]
++ replicate widthHidden2 widthHidden ++ [widthHidden2]
++ replicate sizeMnistLabelInt widthHidden2 ++ [sizeMnistLabelInt]

-- | The differentiable type of all trainable parameters of this nn.
type ADFcnnMnist1Parameters (target :: Target) r =
( ( [target (TKR 1 r)] -- ^ @widthHidden@ copies, length @sizeMnistGlyphInt@
, target (TKR 1 r) ) -- ^ length @widthHidden@
, ( [target (TKR 1 r)] -- ^ @widthHidden2@ copies, length @widthHidden@
, target (TKR 1 r) ) -- ^ length @widthHidden2@
, ( [target (TKR 1 r)] -- ^ @sizeMnistLabelInt@ copies, length @widthHidden2@
, target (TKR 1 r) ) -- ^ length @sizeMnistLabelInt@
type ADFcnnMnist1Parameters
(target :: Target) (widthHidden :: Nat) (widthHidden2 :: Nat) r =
( ( ListR widthHidden (target (TKS '[SizeMnistGlyph] r))
, target (TKS '[widthHidden] r) )
, ( ListR widthHidden2 (target (TKS '[widthHidden] r))
, target (TKS '[widthHidden2] r) )
, ( ListR SizeMnistLabel (target (TKS '[widthHidden2] r))
, target (TKS '[SizeMnistLabel] r) )
)

listMatmul1
:: forall target r.
(BaseTensor target, LetTensor target, GoodScalar r)
=> target (TKR 1 r) -> [target (TKR 1 r)]
-> target (TKR 1 r)
:: forall target r w1 w2.
( BaseTensor target, LetTensor target, GoodScalar r
, KnownNat w1, KnownNat w2 )
=> target (TKS '[w1] r) -> [target (TKS '[w1] r)]
-> target (TKS '[w2] r)
listMatmul1 x0 weights = tlet x0 $ \x ->
let f :: target (TKR 1 r) -> target (TKR 0 r)
f v = v `rdot0` x
in rfromList $ NonEmpty.fromList $ map f weights
let f :: target (TKS '[w1] r) -> target (TKS '[] r)
f v = v `sdot0` x
in sfromList $ NonEmpty.fromList $ map f weights

-- | Fully connected neural network for the MNIST digit classification task.
-- There are two hidden layers and both use the same activation function.
Expand All @@ -57,37 +52,37 @@ listMatmul1 x0 weights = tlet x0 $ \x ->
afcnnMnist1 :: (ADReady target, GoodScalar r)
=> (target (TKR 1 r) -> target (TKR 1 r))
-> (target (TKR 1 r) -> target (TKR 1 r))
-> Int -> Int
-> target (TKR 1 r)
-> ADFcnnMnist1Parameters target r
-> SNat widthHidden -> SNat widthHidden2
-> target (TKS '[SizeMnistGlyph] r)
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
afcnnMnist1 factivationHidden factivationOutput widthHidden widthHidden2
afcnnMnist1 factivationHidden factivationOutput SNat SNat
datum ((hidden, bias), (hidden2, bias2), (readout, biasr)) =
let !_A = assert (sizeMnistGlyphInt == rlength datum
&& length hidden == widthHidden
&& length hidden2 == widthHidden2) ()
-- TODO: disabled for tests: && length readout == sizeMnistLabelInt) ()
hiddenLayer1 = listMatmul1 datum hidden + bias
nonlinearLayer1 = factivationHidden hiddenLayer1
hiddenLayer2 = listMatmul1 nonlinearLayer1 hidden2 + bias2
nonlinearLayer2 = factivationHidden hiddenLayer2
outputLayer = listMatmul1 nonlinearLayer2 readout + biasr
in factivationOutput outputLayer
let hiddenLayer1 = listMatmul1 datum (toList hidden) + bias
nonlinearLayer1 = sfromR $ factivationHidden $ rfromS hiddenLayer1
hiddenLayer2 = listMatmul1 nonlinearLayer1 (toList hidden2) + bias2
nonlinearLayer2 = sfromR $ factivationHidden $ rfromS hiddenLayer2
outputLayer = listMatmul1 nonlinearLayer2 (toList readout) + biasr
in factivationOutput $ rfromS outputLayer

-- | The neural network applied to concrete activation functions
-- and composed with the appropriate loss function.
afcnnMnistLoss1TensorData
:: (ADReady target, GoodScalar r, Differentiable r)
=> Int -> Int -> (target (TKR 1 r), target (TKR 1 r)) -> ADFcnnMnist1Parameters target r
=> SNat widthHidden -> SNat widthHidden2
-> (target (TKR 1 r), target (TKR 1 r))
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 0 r)
afcnnMnistLoss1TensorData widthHidden widthHidden2 (datum, target) adparams =
let result = inline afcnnMnist1 logistic softMax1
widthHidden widthHidden2 datum adparams
widthHidden widthHidden2 (sfromR datum) adparams
in lossCrossEntropyV target result

afcnnMnistLoss1
:: (ADReady target, GoodScalar r, Differentiable r)
=> Int -> Int -> MnistData r -> ADFcnnMnist1Parameters target r
=> SNat widthHidden -> SNat widthHidden2
-> MnistData r
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 0 r)
afcnnMnistLoss1 widthHidden widthHidden2 (datum, target) =
let datum1 = rconcrete $ Nested.rfromVector (fromList [sizeMnistGlyphInt]) datum
Expand All @@ -97,23 +92,22 @@ afcnnMnistLoss1 widthHidden widthHidden2 (datum, target) =
-- | A function testing the neural network given testing set of inputs
-- and the trained parameters.
afcnnMnistTest1
:: forall target r.
:: forall target widthHidden widthHidden2 r.
(target ~ RepN, GoodScalar r, Differentiable r)
=> ADFcnnMnist1Parameters target r
-> Int -> Int
=> SNat widthHidden -> SNat widthHidden2
-> [MnistData r]
-> HVector RepN
-> ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> r
afcnnMnistTest1 _ _ _ [] _ = 0
afcnnMnistTest1 valsInit widthHidden widthHidden2 dataList testParams =
afcnnMnistTest1 _ _ [] _ = 0
afcnnMnistTest1 widthHidden widthHidden2 dataList testParams =
let matchesLabels :: MnistData r -> Bool
matchesLabels (glyph, label) =
let glyph1 = rconcrete $ Nested.rfromVector (fromList [sizeMnistGlyphInt]) glyph
nn :: ADFcnnMnist1Parameters target r
nn :: ADFcnnMnist1Parameters target widthHidden widthHidden2 r
-> target (TKR 1 r)
nn = inline afcnnMnist1 logistic softMax1
widthHidden widthHidden2 glyph1
v = Nested.rtoVector $ unRepN $ nn $ unAsHVector $ parseHVector (AsHVector valsInit) (dmkHVector testParams)
widthHidden widthHidden2 (sfromR glyph1)
v = Nested.rtoVector $ unRepN $ nn testParams
in V.maxIndex v == V.maxIndex label
in fromIntegral (length (filter matchesLabels dataList))
/ fromIntegral (length dataList)
1 change: 0 additions & 1 deletion horde-ad.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ library testCommonLibrary
, ghc-typelits-knownnat
, ghc-typelits-natnormalise
, hmatrix
, ilist
, ox-arrays
, random
, tasty >= 1.0
Expand Down
9 changes: 5 additions & 4 deletions src/HordeAd/Core/Adaptor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ module HordeAd.Core.Adaptor
( AdaptableHVector(..), parseHVector, parseHVectorAD
, TermValue(..), DualNumberValue(..)
, ForgetShape(..), RandomHVector(..)
, AsHVector(..)
, AsHVector(..), stkOfListR
) where

import Prelude

import Control.Exception (assert)
import Control.Exception.Assert.Sugar
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (Proxy))
import Data.Strict.Vector qualified as Data.Vector
Expand Down Expand Up @@ -61,11 +61,12 @@ class AdaptableHVector (target :: Target) vals where
-- procedure where @fromHVector@ calls itself recursively for sub-values
-- across mutliple instances.
parseHVector
:: (TensorKind (X vals), AdaptableHVector target vals, BaseTensor target)
:: ( TensorKind (X vals), AdaptableHVector target vals, BaseTensor target
, Show (target (X vals)) )
=> vals -> target (X vals) -> vals
parseHVector aInit hVector =
case fromHVector aInit hVector of
Just (vals, mrest) -> assert (maybe True nullRep mrest) vals
Just (vals, mrest) -> assert (maybe True nullRep mrest `blame` mrest) vals
Nothing -> error "parseHVector: truncated product of tensors"

parseHVectorAD
Expand Down
7 changes: 5 additions & 2 deletions src/HordeAd/Core/TensorKind.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
module HordeAd.Core.TensorKind
( -- * Singletons
STensorKindType(..), TensorKind(..)
, lemTensorKindOfSTK, sameTensorKind, sameSTK
, withTensorKind, lemTensorKindOfSTK, sameTensorKind, sameSTK
, stkUnit, buildSTK, razeSTK, aDSTK
, lemTensorKindOfBuild, lemTensorKindOfAD, lemBuildOfAD
, FullTensorKind(..), ftkToStk
Expand Down Expand Up @@ -36,7 +36,7 @@ import Data.Proxy (Proxy (Proxy))
import Data.Strict.Vector qualified as Data.Vector
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..))
import GHC.Exts (IsList (..), withDict)
import GHC.TypeLits (KnownNat, OrderingI (..), cmpNat, type (+))
import Type.Reflection (TypeRep, typeRep)
import Unsafe.Coerce (unsafeCoerce)
Expand Down Expand Up @@ -104,6 +104,9 @@ instance (TensorKind y, TensorKind z)
instance TensorKind TKUntyped where
stensorKind = STKUntyped

withTensorKind :: forall y r. STensorKindType y -> (TensorKind y => r) -> r
withTensorKind = withDict @(TensorKind y)

lemTensorKindOfSTK :: STensorKindType y -> Dict TensorKind y
lemTensorKindOfSTK = \case
STKScalar _ -> Dict
Expand Down
Loading

0 comments on commit d9416c5

Please sign in to comment.