Skip to content

Commit

Permalink
Remove lemKnownNatRank*
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 21, 2024
1 parent c5f6565 commit b47b169
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 34 deletions.
5 changes: 3 additions & 2 deletions src/HordeAd/Core/AstFreshId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import GHC.TypeLits (KnownNat, Nat)
import System.IO.Unsafe (unsafePerformIO)

import Data.Array.Nested (KnownShS (..), Rank)
import Data.Array.Nested.Internal.Shape (shsRank)

import HordeAd.Core.Ast
import HordeAd.Core.CarriersAst
Expand Down Expand Up @@ -155,7 +156,7 @@ funToAstRevIO ftk | Dict <- lemTensorKindOfSTK (ftkToStk ftk) = do
-> IO ( AstDynamic AstMethodShare PrimalSpan
, AstDynamic AstMethodLet FullSpan )
f i (DynamicRankedDummy @r @sh _ _)
| Dict <- lemKnownNatRankS (knownShS @sh) = do
| SNat <- shsRank (knownShS @sh) = do
return
( DynamicRanked @r @(Rank sh)
(AstProjectR astVarPrimal i)
Expand Down Expand Up @@ -219,7 +220,7 @@ funToAstFwdIO ftk | Dict <- lemTensorKindOfSTK (ftkToStk ftk)
, AstDynamic AstMethodShare PrimalSpan
, AstDynamic AstMethodLet FullSpan )
f i (DynamicRankedDummy @r @sh _ _)
| Dict <- lemKnownNatRankS (knownShS @sh) = do
| SNat <- shsRank (knownShS @sh) = do
return
( DynamicRanked @r @(Rank sh)
(AstProjectR astVarPrimalD i)
Expand Down
13 changes: 7 additions & 6 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shCvtSX, shsAppend)
import Data.Array.Nested.Internal.Shape (shCvtSX, shsAppend, shsRank)
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape

import HordeAd.Core.Ast
Expand All @@ -121,6 +121,7 @@ import HordeAd.Core.OpsConcrete ()
import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Util.ShapedList (ssxRank)
import HordeAd.Util.ShapedList qualified as ShapedList
import HordeAd.Util.SizedList

Expand Down Expand Up @@ -2236,12 +2237,12 @@ astRFromS :: forall sh s r. (TensorKind r, KnownShS sh)
=> AstTensor AstMethodLet s (TKS2 sh r)
-> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
astRFromS (AstConcrete ftk t)
| Dict <- lemKnownNatRankS (knownShS @sh) = case ftk of
| SNat <- shsRank (knownShS @sh) = case ftk of
FTKS _ x ->
let u = rfromS t
in AstConcrete (FTKR (rshape u) x) u
astRFromS (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankS (knownShS @sh) =
| SNat <- shsRank (knownShS @sh) =
Ast.AstFromPrimal $ astRFromS v
astRFromS (Ast.AstSFromR v) = v -- no information lost, so no checks
astRFromS v = Ast.AstRFromS v
Expand All @@ -2250,12 +2251,12 @@ astRFromX :: forall sh s r. (TensorKind r, KnownShX sh)
=> AstTensor AstMethodLet s (TKX2 sh r)
-> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
astRFromX (AstConcrete ftk t)
| Dict <- lemKnownNatRankX (knownShX @sh) = case ftk of
| SNat <- ssxRank (knownShX @sh) = case ftk of
FTKX _ x ->
let u = rfromX t
in AstConcrete (FTKR (rshape u) x) u
astRFromX (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankX (knownShX @sh) =
| SNat <- ssxRank (knownShX @sh) =
Ast.AstFromPrimal $ astRFromX v
astRFromX (Ast.AstXFromR v) = v -- no information lost, so no checks
astRFromX v = Ast.AstRFromX v
Expand Down Expand Up @@ -2760,7 +2761,7 @@ astLetHVectorIn vars l v = case v of
-> AstTensor AstMethodLet s2 z
mkLet i (AstDynamicVarName @ty @r3 @sh3 varId)
| Just Refl <- testEquality (typeRep @ty) (typeRep @Nat)
, Dict <- lemKnownNatRankS (knownShS @sh3) =
, SNat <- shsRank (knownShS @sh3) =
astLet (mkAstVarName @s @(TKR (Rank sh3) r3) varId)
(astProjectR l i)
| otherwise =
Expand Down
9 changes: 5 additions & 4 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ import Data.Array.Mixed.Shape
import Data.Array.Nested
(IShR, KnownShS (..), MapJust, Replicate, ShR (..), ShX (..))
import Data.Array.Nested.Internal.Shape
(shCvtRX, shCvtSX, shCvtXR', shrSize, shsSize)
(shCvtRX, shCvtSX, shCvtXR', shrSize, shsRank, shsSize)
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape

import HordeAd.Core.Ast
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Util.ShapedList (ssxRank)
import HordeAd.Util.SizedList

-- * Shape calculation
Expand Down Expand Up @@ -196,17 +197,17 @@ ftkAst t = case t of
FTKX sh (FTKProduct y z) -> FTKProduct (FTKX sh y) (FTKX sh z)

AstRFromS @sh v
| Dict <- lemKnownNatRankS (knownShS @sh) -> case ftkAst v of
| SNat <- shsRank (knownShS @sh) -> case ftkAst v of
FTKS _ x -> FTKR (fromList $ shapeT @sh) x
AstRFromX @sh v
| Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of
| SNat <- ssxRank (knownShX @sh) -> case ftkAst v of
FTKX shx x -> FTKR (fromList $ toList shx) x
AstSFromR v -> case ftkAst v of
FTKR _ x -> FTKS knownShS x
AstSFromX v -> case ftkAst v of
FTKX _ x -> FTKS knownShS x
AstXFromR @sh v
| Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of
| SNat <- ssxRank (knownShX @sh) -> case ftkAst v of
FTKR shr x -> FTKX (fromList $ toList shr) x
AstXFromS v -> case ftkAst v of
FTKS sh x -> FTKX (fromList $ toList sh) x
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shCvtSX, shrRank)
import Data.Array.Nested.Internal.Shape (shCvtSX, shrRank, shsRank)

import HordeAd.Core.Ast (AstTensor)
import HordeAd.Core.Ast hiding (AstBool (..), AstTensor (..))
Expand Down Expand Up @@ -604,7 +604,7 @@ build1VIndex snat@SNat (var, v0, ix@(_ :.: _)) =
astTrS :: forall n m sh s r.
(KnownNat n, KnownNat m, KnownShS sh, TensorKind r, AstSpan s)
=> AstTensor AstMethodLet s (TKS2 (n ': m ': sh) r) -> AstTensor AstMethodLet s (TKS2 (m ': n ': sh) r)
astTrS | Dict <- lemKnownNatRankS (knownShS @sh) =
astTrS | SNat <- shsRank (knownShS @sh) =
astTransposeS (Permutation.makePerm @'[1, 0])
astTrX :: forall n m sh s r.
-- (KnownNat n, KnownNat m, KnownShX sh, GoodScalar r, AstSpan s)
Expand Down
15 changes: 8 additions & 7 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ import HordeAd.Core.HVectorOps
import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Util.ShapedList (ssxRank)
import HordeAd.Util.SizedList

type IMap target = DEnumMap (InputId target) (RepM target)
Expand Down Expand Up @@ -815,17 +816,17 @@ shapeDeltaFull = \case
FTKX sh (FTKProduct y z) -> FTKProduct (FTKX sh y) (FTKX sh z)

RFromS @sh d
| Dict <- lemKnownNatRankS (knownShS @sh) -> case shapeDeltaFull d of
| SNat <- shsRank (knownShS @sh) -> case shapeDeltaFull d of
FTKS _ x -> FTKR (fromList $ shapeT @sh) x
RFromX @sh d
| Dict <- lemKnownNatRankX (knownShX @sh) -> case shapeDeltaFull d of
| SNat <- ssxRank (knownShX @sh) -> case shapeDeltaFull d of
FTKX shx x -> FTKR (fromList $ toList shx) x
SFromR d -> case shapeDeltaFull d of
FTKR _ x -> FTKS knownShS x
SFromX d -> case shapeDeltaFull d of
FTKX _ x -> FTKS knownShS x
XFromR @sh d
| Dict <- lemKnownNatRankX (knownShX @sh) -> case shapeDeltaFull d of
| SNat <- ssxRank (knownShX @sh) -> case shapeDeltaFull d of
FTKR shr x -> FTKX (fromList $ toList shr) x
XFromS d -> case shapeDeltaFull d of
FTKS sh x -> FTKX (fromList $ toList sh) x
Expand Down Expand Up @@ -1384,9 +1385,9 @@ evalSame !s !c = \case
evalSame s (xzip c) d

RFromS (SFromR d) -> evalSame s c d -- no information lost, so no checks
RFromS @sh d | Dict <- lemKnownNatRankS (knownShS @sh) ->
RFromS @sh d | SNat <- shsRank (knownShS @sh) ->
evalSame s (sfromR c) d
RFromX @sh d | Dict <- lemKnownNatRankX (knownShX @sh) ->
RFromX @sh d | SNat <- ssxRank (knownShX @sh) ->
evalSame s (xfromR c) d
SFromR @sh (RFromS @sh2 d) ->
case sameShape @sh @sh2 of
Expand All @@ -1401,7 +1402,7 @@ evalSame !s !c = \case
SFromX d ->
evalSame s (xfromS c) d
-- impossible, shapes may differ: XFromS (SFromX d) -> evalSame s c d
XFromR @sh d | Dict <- lemKnownNatRankX (knownShX @sh) ->
XFromR @sh d | SNat <- ssxRank (knownShX @sh) ->
evalSame s (rfromX c) d
XFromS @sh d ->
evalSame s (sfromX c) d
Expand Down Expand Up @@ -1770,7 +1771,7 @@ fwdSame params s = \case
Just Refl -> fwdSame params s d
_ -> error "fwdSame: different shapes in SFromR(RFromS)"
SFromR d -> second sfromR $ fwdSame params s d
XFromR @sh d | Dict <- lemKnownNatRankX (knownShX @sh) ->
XFromR @sh d | SNat <- ssxRank (knownShX @sh) ->
second xfromR $ fwdSame params s d
XFromS d -> second xfromS $ fwdSame params s d
SFromX @sh (XFromS @sh2 d) ->
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/HVectorOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested.Internal.Shape
(shCvtRX, shCvtSX, shrAppend, shrRank, shsAppend)
(shCvtRX, shCvtSX, shrAppend, shrRank, shsAppend, shsRank)

import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
Expand Down Expand Up @@ -89,7 +89,7 @@ soneHot :: forall r sh1 sh2 target.
=> target (TKS2 sh2 r) -> IxSOf target sh1
-> target (TKS2 (sh1 ++ sh2) r)
soneHot v ix = case stensorKind @r of
STKScalar{} | Dict <- lemKnownNatRankS (knownShS @sh1) ->
STKScalar{} | SNat <- shsRank (knownShS @sh1) ->
gcastWith (unsafeCoerce Refl :: Take (Rank sh1) (sh1 ++ sh2) :~: sh1) $
gcastWith (unsafeCoerce Refl :: Drop (Rank sh1) (sh1 ++ sh2) :~: sh2) $
sscatter @_ @_ @'[] @(Rank sh1) v (const ix)
Expand Down
3 changes: 2 additions & 1 deletion src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import HordeAd.Core.HVectorOps
import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Util.ShapedList (ssxRank)
import HordeAd.Util.ShapedList qualified as ShapedList
import HordeAd.Util.SizedList

Expand Down Expand Up @@ -511,7 +512,7 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
dSFromX d = SFromX d
xfromR :: forall sh r. (KnownShX sh, TensorKind r)
=> ADVal target (TKR2 (Rank sh) r) -> ADVal target (TKX2 sh r)
xfromR (D u u') | Dict <- lemKnownNatRankX (knownShX @sh) =
xfromR (D u u') | SNat <- ssxRank (knownShX @sh) =
dDnotShared (xfromR u) (XFromR u')
xfromS (D u u') = dDnotShared (xfromS u) (XFromS u')

Expand Down
10 changes: 1 addition & 9 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module HordeAd.Core.Types
, withKnownShS, withKnownShX
, sshapeKnown, slistKnown, sixKnown, knownShR
, shapeT, shapeP, sizeT, sizeP
, withShapeP, sameShape, matchingRank, lemKnownNatRankS, lemKnownNatRankX
, withShapeP, sameShape, matchingRank
, Dict(..), PermC, trustMeThisIsAPermutation
, Take, Drop, Last, Init
-- * Kinds of the functors that determine the structure of a tensor type
Expand Down Expand Up @@ -141,14 +141,6 @@ matchingRank =
then Just (unsafeCoerce Refl :: Rank sh1 :~: n2)
else Nothing

lemKnownNatRankS :: ShS sh -> Dict KnownNat (Rank sh)
lemKnownNatRankS ZSS = Dict
lemKnownNatRankS (_ :$$ sh) | Dict <- lemKnownNatRankS sh = Dict

lemKnownNatRankX :: StaticShX sh -> Dict KnownNat (Rank sh)
lemKnownNatRankX ZKX = Dict
lemKnownNatRankX (_ :!% sh) | Dict <- lemKnownNatRankX sh = Dict

class Permutation.IsPermutation is => PermC is
instance Permutation.IsPermutation is => PermC is

Expand Down
6 changes: 5 additions & 1 deletion src/HordeAd/Util/ShapedList.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module HordeAd.Util.ShapedList
-- * Operations involving both indexes and shapes
, toLinearIdx, fromLinearIdx
, permutePrefixIndex
, ssxRank
) where

import Prelude
Expand All @@ -36,7 +37,7 @@ import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat)

import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape (IShX)
import Data.Array.Mixed.Shape (IShX, StaticShX (..), listxRank)
import Data.Array.Nested
( IxR
, IxS (..)
Expand Down Expand Up @@ -261,3 +262,6 @@ permutePrefixSized p ix =
permutePrefixIndex :: forall sh sh2 i. (KnownShS sh, KnownShS sh2)
=> Permutation.PermR -> IxS sh i -> IxS sh2 i
permutePrefixIndex p (IxS ix) = IxS $ permutePrefixSized p ix

ssxRank :: StaticShX sh -> SNat (Rank sh)
ssxRank (StaticShX l) = listxRank l

0 comments on commit b47b169

Please sign in to comment.