From 1079c10453abb635e32ba8fd5a712f9ae3090002 Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Fri, 31 Jan 2025 12:25:57 +0100 Subject: [PATCH] Trivialize nesting/unnesting in AST --- src/HordeAd/Core/Ast.hs | 28 +-- src/HordeAd/Core/AstInline.hs | 18 +- src/HordeAd/Core/AstInterpret.hs | 11 +- src/HordeAd/Core/AstPrettyPrint.hs | 31 +-- src/HordeAd/Core/AstSimplify.hs | 305 +++++++---------------------- src/HordeAd/Core/AstTools.hs | 50 +---- src/HordeAd/Core/AstVectorize.hs | 30 +-- src/HordeAd/Core/OpsAst.hs | 204 +++++++++++++++++-- 8 files changed, 302 insertions(+), 375 deletions(-) diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index e0f37eab6..ade09ef45 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -46,9 +46,7 @@ import Data.Array.Nested , KnownShX , ListR , ListS (..) - , MapJust , Rank - , Replicate , ShS (..) , type (++) ) @@ -402,6 +400,12 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType AstUnzipS :: (TensorKind y, TensorKind z, KnownShS sh) => AstTensor ms s (TKS2 sh (TKProduct y z)) -> AstTensor ms s (TKProduct (TKS2 sh y) (TKS2 sh z)) + AstNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x) + => AstTensor ms s (TKS2 (sh1 ++ sh2) x) + -> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x)) + AstUnNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x) + => AstTensor ms s (TKS2 sh1 (TKS2 sh2 x)) + -> AstTensor ms s (TKS2 (sh1 ++ sh2) x) -- Conversions AstFromS :: forall y z ms s. @@ -413,26 +417,6 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType AstSFromX :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind r) => AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r) - -- Nesting/unnesting - AstXNestR :: (KnownShX sh1, KnownNat m, TensorKind x) - => AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x) - -> AstTensor ms s (TKX2 sh1 (TKR2 m x)) - AstXNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x) - => AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x) - -> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x)) - AstXNest :: (KnownShX sh1, KnownShX sh2, TensorKind x) - => AstTensor ms s (TKX2 (sh1 ++ sh2) x) - -> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x)) - AstXUnNestR :: (KnownShX sh1, KnownNat m, TensorKind x) - => AstTensor ms s (TKX2 sh1 (TKR2 m x)) - -> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x) - AstXUnNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x) - => AstTensor ms s (TKX2 sh1 (TKS2 sh2 x)) - -> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x) - AstXUnNest :: (KnownShX sh1, KnownShX sh2, TensorKind x) - => AstTensor ms s (TKX2 sh1 (TKX2 sh2 x)) - -> AstTensor ms s (TKX2 (sh1 ++ sh2) x) - -- Backend-specific primitives AstReplicate0NS :: ShS sh -> STensorKindType x -> AstTensor ms s (TKS2 '[] x) diff --git a/src/HordeAd/Core/AstInline.hs b/src/HordeAd/Core/AstInline.hs index 27d17decf..7962458fd 100644 --- a/src/HordeAd/Core/AstInline.hs +++ b/src/HordeAd/Core/AstInline.hs @@ -265,19 +265,14 @@ inlineAst memo v0 = case v0 of Ast.AstReshapeS v -> second Ast.AstReshapeS (inlineAst memo v) Ast.AstZipS v -> second Ast.AstZipS (inlineAst memo v) Ast.AstUnzipS v -> second Ast.AstUnzipS (inlineAst memo v) + Ast.AstNestS v -> second Ast.AstNestS $ inlineAst memo v + Ast.AstUnNestS v -> second Ast.AstUnNestS $ inlineAst memo v Ast.AstFromS stkz v -> second (Ast.AstFromS stkz) $ inlineAst memo v Ast.AstSFromK t -> second Ast.AstSFromK (inlineAst memo t) Ast.AstSFromR v -> second Ast.AstSFromR $ inlineAst memo v Ast.AstSFromX v -> second Ast.AstSFromX $ inlineAst memo v - Ast.AstXNestR v -> second Ast.AstXNestR $ inlineAst memo v - Ast.AstXNestS v -> second Ast.AstXNestS $ inlineAst memo v - Ast.AstXNest v -> second Ast.AstXNest $ inlineAst memo v - Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ inlineAst memo v - Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ inlineAst memo v - Ast.AstXUnNest v -> second Ast.AstXUnNest $ inlineAst memo v - Ast.AstReplicate0NS sh stk v -> second (Ast.AstReplicate0NS sh stk) (inlineAst memo v) Ast.AstSum0S sh stk v -> @@ -553,13 +548,8 @@ unshareAst memo = \case Ast.AstSFromK t -> second Ast.AstSFromK (unshareAst memo t) Ast.AstSFromR v -> second Ast.AstSFromR $ unshareAst memo v Ast.AstSFromX v -> second Ast.AstSFromX $ unshareAst memo v - - Ast.AstXNestR v -> second Ast.AstXNestR $ unshareAst memo v - Ast.AstXNestS v -> second Ast.AstXNestS $ unshareAst memo v - Ast.AstXNest v -> second Ast.AstXNest $ unshareAst memo v - Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ unshareAst memo v - Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ unshareAst memo v - Ast.AstXUnNest v -> second Ast.AstXUnNest $ unshareAst memo v + Ast.AstNestS v -> second Ast.AstNestS $ unshareAst memo v + Ast.AstUnNestS v -> second Ast.AstUnNestS $ unshareAst memo v Ast.AstReplicate0NS sh stk v -> second (Ast.AstReplicate0NS sh stk) (unshareAst memo v) diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index c9009c2ea..9ae30668c 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -29,7 +29,7 @@ import GHC.TypeLits (KnownNat) import Type.Reflection (Typeable, typeRep) import Data.Array.Mixed.Shape (withKnownShX) -import Data.Array.Nested (KnownShS (..), KnownShX (..), ListS (..)) +import Data.Array.Nested (KnownShS (..), ListS (..)) import Data.Array.Nested.Internal.Shape (shsAppend, shsProduct, withKnownShS) import HordeAd.Core.Ast @@ -375,6 +375,8 @@ interpretAst !env = \case AstReshapeS v -> sreshape (interpretAst env v) AstZipS v -> szip $ interpretAst env v AstUnzipS v -> sunzip $ interpretAst env v + AstNestS v -> snest knownShS $ interpretAst env v + AstUnNestS v -> sunNest $ interpretAst env v AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) , Dict <- lemTensorKindOfSTK stkz -> @@ -383,13 +385,6 @@ interpretAst !env = \case AstSFromR v -> sfromR $ interpretAst env v AstSFromX v -> sfromX $ interpretAst env v - AstXNestR v -> xnestR knownShX $ interpretAst env v - AstXNestS v -> xnestS knownShX $ interpretAst env v - AstXNest v -> xnest knownShX $ interpretAst env v - AstXUnNestR v -> xunNestR $ interpretAst env v - AstXUnNestS v -> xunNestS $ interpretAst env v - AstXUnNest v -> xunNest $ interpretAst env v - AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk , SNat <- shsProduct sh -> withKnownShS sh $ diff --git a/src/HordeAd/Core/AstPrettyPrint.hs b/src/HordeAd/Core/AstPrettyPrint.hs index a52002c61..19a86bf46 100644 --- a/src/HordeAd/Core/AstPrettyPrint.hs +++ b/src/HordeAd/Core/AstPrettyPrint.hs @@ -21,14 +21,11 @@ import Data.Vector.Generic qualified as V import GHC.Exts (IsList (..)) import GHC.TypeLits (fromSNat) -import Data.Array.Mixed.Shape - (ssxAppend, ssxFromShape, ssxReplicate, withKnownShX) -import Data.Array.Mixed.Shape qualified as X +import Data.Array.Mixed.Shape (StaticShX (..), listxRank) import Data.Array.Nested - (KnownShS (..), KnownShX (..), ListS (..), ShR (..), ShS (..), ShX (..)) + (KnownShS (..), ListS (..), ShR (..), ShS (..), ShX (..)) import Data.Array.Nested qualified as Nested -import Data.Array.Nested.Internal.Shape - (shCvtSX, shsAppend, shsRank, withKnownShS) +import Data.Array.Nested.Internal.Shape (shsAppend, shsRank, withKnownShS) import HordeAd.Core.Ast import HordeAd.Core.AstTools @@ -83,8 +80,8 @@ printAstVar cfg var = rankTensorKind (STKScalar _) = 0 rankTensorKind (STKR snat _) = fromInteger $ fromSNat snat rankTensorKind (STKS sh _) = fromInteger $ fromSNat $ shsRank sh - rankTensorKind (STKX (X.StaticShX l) _) = - fromInteger $ fromSNat $ X.listxRank l + rankTensorKind (STKX (StaticShX l) _) = + fromInteger $ fromSNat $ listxRank l rankTensorKind (STKProduct @y1 @z1 sy sz) = rankTensorKind @y1 sy `max` rankTensorKind @z1 sz n = rankTensorKind (stensorKind @y) @@ -442,6 +439,10 @@ printAstAux cfg d = \case printPrefixOp printAst cfg d "sreshape" [v] AstZipS v -> printPrefixOp printAst cfg d "szip" [v] AstUnzipS v -> printPrefixOp printAst cfg d "sunzip" [v] + AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + printPrefixOp printAst cfg d "snestS" [v] + AstUnNestS v -> printPrefixOp printAst cfg d "sunNestS" [v] AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> case stkz of @@ -453,20 +454,6 @@ printAstAux cfg d = \case AstSFromR v -> printPrefixOp printAst cfg d "sfromR" [v] AstSFromX v -> printPrefixOp printAst cfg d "sfromX" [v] - AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - printPrefixOp printAst cfg d "xnestR" [v] - AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - printPrefixOp printAst cfg d "xnestS" [v] - AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - printPrefixOp printAst cfg d "xnest" [v] - AstXUnNestR v -> printPrefixOp printAst cfg d "xunNestR" [v] - AstXUnNestS v -> printPrefixOp printAst cfg d "xunNestS" [v] - AstXUnNest v -> printPrefixOp printAst cfg d "xunNest" [v] - AstReplicate0NS _sh stk v | Dict <- lemTensorKindOfSTK stk -> printPrefixOp printAst cfg d "sreplicate0N" [v] AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk -> diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index 1811828fa..c21867d6a 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -32,11 +32,10 @@ module HordeAd.Core.AstSimplify , astIndexStepS, astScatterS, astGatherStepS , astAppendS, astSliceS, astReverseS, astTransposeS, astReshapeS + , astNestS, astUnNestS , astFromS, astSFromK, astSFromR, astSFromX - , astXNestR, astXNestS, astXNest, astXUnNestR, astXUnNestS, astXUnNest - -- * Helper combinators , astLetFun, astReplicate0NS -- * A cheap simplification of only the topmost nodes @@ -86,18 +85,15 @@ import Type.Reflection (typeRep) import Data.Array.Mixed.Lemmas import Data.Array.Mixed.Permutation (DropLen, Perm (..), TakeLen, permInverse) import Data.Array.Mixed.Permutation qualified as Permutation -import Data.Array.Mixed.Shape - (ssxAppend, ssxFromShape, ssxReplicate, withKnownShX) +import Data.Array.Mixed.Shape (ssxFromShape, withKnownShX) import Data.Array.Mixed.Types (Init, Last, Tail, unsafeCoerceRefl) import Data.Array.Nested ( IxS (..) , KnownShS (..) , KnownShX (..) , ListS (..) - , MapJust , Product , Rank - , Replicate , ShR (..) , ShS (..) , type (++) @@ -974,6 +970,10 @@ astPrimalPart t = case t of Ast.AstReshapeS v -> astReshapeS (astPrimalPart v) Ast.AstZipS v -> Ast.AstZipS (astPrimalPart v) Ast.AstUnzipS v -> Ast.AstUnzipS (astPrimalPart v) + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ astPrimalPart v + Ast.AstUnNestS v -> astUnNestS $ astPrimalPart v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz $ astPrimalPart v @@ -982,20 +982,6 @@ astPrimalPart t = case t of Ast.AstSFromR{} -> Ast.AstPrimalPart t Ast.AstSFromX{} -> Ast.AstPrimalPart t - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ astPrimalPart v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ astPrimalPart v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ astPrimalPart v - Ast.AstXUnNestR v -> astXUnNestR $ astPrimalPart v - Ast.AstXUnNestS v -> astXUnNestS $ astPrimalPart v - Ast.AstXUnNest v -> astXUnNest $ astPrimalPart v - -- These should not appear in this context unless via wacky tests. Ast.AstReplicate0NS{} -> Ast.AstPrimalPart t Ast.AstSum0S{} -> Ast.AstPrimalPart t @@ -1070,6 +1056,10 @@ astDualPart t = case t of Ast.AstReshapeS v -> astReshapeS (astDualPart v) Ast.AstZipS v -> Ast.AstZipS (astDualPart v) Ast.AstUnzipS v -> Ast.AstUnzipS (astDualPart v) + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ astDualPart v + Ast.AstUnNestS v -> astUnNestS $ astDualPart v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz $ astDualPart v @@ -1078,20 +1068,6 @@ astDualPart t = case t of Ast.AstSFromR{} -> Ast.AstDualPart t Ast.AstSFromX{} -> Ast.AstDualPart t - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ astDualPart v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ astDualPart v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ astDualPart v - Ast.AstXUnNestR v -> astXUnNestR $ astDualPart v - Ast.AstXUnNestS v -> astXUnNestS $ astDualPart v - Ast.AstXUnNest v -> astXUnNest $ astDualPart v - -- These should not appear in this context unless via wacky tests. Ast.AstReplicate0NS{} -> Ast.AstDualPart t Ast.AstSum0S{} -> Ast.AstDualPart t @@ -1421,6 +1397,8 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 @shm1 i1 rest1) astIndex @(Permutation.PermutePrefix perm shm) v ix2 Ast.AstReshapeS v -> astIndex (astReshapeAsGatherS knobs v) ix Ast.AstZipS _ -> Ast.AstIndexS v0 ix + Ast.AstNestS _ -> Ast.AstIndexS v0 ix + Ast.AstUnNestS _ -> Ast.AstIndexS v0 ix Ast.AstFromS stkz v -> case sameSTK (ftkToStk (ftkAst v)) stkz of Just Refl -> astIndex v ix -- rare, usually simplifies away earlier @@ -1625,7 +1603,8 @@ astGatherKnobsS knobs v0 (vars0, ix0) = astGatherCase v4 (_, ZIS) = astReplicateNS @shm' @shn' v4 -- not really possible astGatherCase v4 ( vars4 , ix4@((:.$) @p1' @shp1' i4 rest4) ) - | Dict <- shsKnownShS (knownShS @shm' `shsAppend` knownShS @shn') + | Dict <- shsKnownShS (knownShS @shm' + `shsAppend` knownShS @shn') , Dict <- sixKnown rest4 = case v4 of Ast.AstProject1{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) Ast.AstProject2{} -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) @@ -1963,6 +1942,8 @@ astGatherKnobsS knobs v0 (vars0, ix0) = then astGather @shm' @shn' @shp' (astReshapeAsGatherS knobs v) (vars4, ix4) else Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) Ast.AstZipS _v -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) + Ast.AstNestS _v -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) + Ast.AstUnNestS _v -> Ast.AstGatherS @shm' @shn' @shp' v4 (vars4, ix4) Ast.AstFromS stkz v -> case sameSTK (ftkToStk (ftkAst v)) stkz of Just Refl -> astGatherCase @shm' @shn' @shp' v (vars4, ix4) @@ -2207,6 +2188,47 @@ astReshapeS = \case Just Refl -> v _ -> Ast.AstReshapeS v +astNestS + :: forall sh1 sh2 x ms s. + (TensorKind x, KnownShS sh1, KnownShS sh2, AstSpan s) + => AstTensor ms s (TKS2 (sh1 ++ sh2) x) + -> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x)) +astNestS t = case t of + Ast.AstCond b v1 v2 -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstCond b (astNestS v1) (astNestS v2) -- TODO: ?? + Ast.AstLet var u2 d2 -> -- TODO: good idea? + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astLet var u2 (astNestS d2) + Ast.AstFromPrimal u -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstFromPrimal $ astNestS u + Ast.AstFromDual u -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstFromDual $ astNestS u + _ -> Ast.AstNestS t + +astUnNestS + :: forall sh1 sh2 x ms s. + (TensorKind x, KnownShS sh1, KnownShS sh2, AstSpan s) + => AstTensor ms s (TKS2 sh1 (TKS2 sh2 x)) + -> AstTensor ms s (TKS2 (sh1 ++ sh2) x) +astUnNestS t = case t of + Ast.AstCond b v1 v2 -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstCond b (astUnNestS v1) (astUnNestS v2) -- TODO: ?? + Ast.AstLet var u2 d2 -> -- TODO: good idea? + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astLet var u2 (astUnNestS d2) + Ast.AstFromPrimal u -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstFromPrimal $ astUnNestS u + Ast.AstFromDual u -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + Ast.AstFromDual $ astUnNestS u +-- Ast.AstNestS u -> u + _ -> Ast.AstUnNestS t + astFromS :: forall y z s. STensorKindType z -> AstTensor AstMethodLet s y -> AstTensor AstMethodLet s z @@ -2353,138 +2375,6 @@ astSFromX (Ast.AstFromS _ v) = case sameSTK (ftkToStk (ftkAst v)) _ -> error "astSFromX: different shapes in SFromX(FromS)" astSFromX v = Ast.AstSFromX v -astXNestR - :: forall sh1 m x ms s. - (TensorKind x, KnownShX sh1, KnownNat m, AstSpan s) - => AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x) - -> AstTensor ms s (TKX2 sh1 (TKR2 m x)) -astXNestR t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstCond b (astXNestR v1) (astXNestR v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astLet var u2 (astXNestR d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstFromPrimal $ astXNestR u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstFromDual $ astXNestR u --- TODO: when sh agrees: Ast.AstUnNestS u -> u - _ -> Ast.AstXNestR t - -astXNestS - :: forall sh1 sh2 x ms s. - (TensorKind x, KnownShX sh1, KnownShS sh2, AstSpan s) - => AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x) - -> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x)) -astXNestS t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstCond b (astXNestS v1) (astXNestS v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astLet var u2 (astXNestS d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstFromPrimal $ astXNestS u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstFromDual $ astXNestS u - _ -> Ast.AstXNestS t - -astXNest - :: forall sh1 sh2 x ms s. - (TensorKind x, KnownShX sh1, KnownShX sh2, AstSpan s) - => AstTensor ms s (TKX2 (sh1 ++ sh2) x) - -> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x)) -astXNest t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstCond b (astXNest v1) (astXNest v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astLet var u2 (astXNest d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstFromPrimal $ astXNest u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstFromDual $ astXNest u - _ -> Ast.AstXNest t - -astXUnNestR - :: forall sh1 m x ms s. - (TensorKind x, KnownShX sh1, KnownNat m, AstSpan s) - => AstTensor ms s (TKX2 sh1 (TKR2 m x)) - -> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x) -astXUnNestR t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstCond b (astXUnNestR v1) (astXUnNestR v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astLet var u2 (astXUnNestR d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstFromPrimal $ astXUnNestR u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - Ast.AstFromDual $ astXUnNestR u --- Ast.AstNestS u -> u - _ -> Ast.AstXUnNestR t - -astXUnNestS - :: forall sh1 sh2 x ms s. - (TensorKind x, KnownShX sh1, KnownShS sh2, AstSpan s) - => AstTensor ms s (TKX2 sh1 (TKS2 sh2 x)) - -> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x) -astXUnNestS t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstCond b (astXUnNestS v1) (astXUnNestS v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astLet var u2 (astXUnNestS d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstFromPrimal $ astXUnNestS u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - Ast.AstFromDual $ astXUnNestS u --- Ast.AstNestS u -> u - _ -> Ast.AstXUnNestS t - -astXUnNest - :: forall sh1 sh2 x ms s. - (TensorKind x, KnownShX sh1, KnownShX sh2, AstSpan s) - => AstTensor ms s (TKX2 sh1 (TKX2 sh2 x)) - -> AstTensor ms s (TKX2 (sh1 ++ sh2) x) -astXUnNest t = case t of - Ast.AstCond b v1 v2 -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstCond b (astXUnNest v1) (astXUnNest v2) -- TODO: ?? - Ast.AstLet var u2 d2 -> -- TODO: good idea? - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astLet var u2 (astXUnNest d2) - Ast.AstFromPrimal u -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstFromPrimal $ astXUnNest u - Ast.AstFromDual u -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - Ast.AstFromDual $ astXUnNest u --- Ast.AstNestS u -> u - _ -> Ast.AstXUnNest t - -- * Helper combinators @@ -2628,19 +2518,14 @@ astNonIndexStep t = case t of Ast.AstReshapeS v -> astReshapeS v Ast.AstZipS _ -> t Ast.AstUnzipS _ -> t + Ast.AstNestS v -> astNestS v + Ast.AstUnNestS v -> astUnNestS v Ast.AstFromS stkz v -> astFromS stkz v Ast.AstSFromK u -> astSFromK $ astNonIndexStep u Ast.AstSFromR v -> astSFromR v Ast.AstSFromX v -> astSFromX v - Ast.AstXNestR v -> astXNestR v - Ast.AstXNestS v -> astXNestS v - Ast.AstXNest v -> astXNest v - Ast.AstXUnNestR v -> astXUnNestR v - Ast.AstXUnNestS v -> astXUnNestS v - Ast.AstXUnNest v -> astXUnNest v - -- These should not appear here unless via wacky tests. Ast.AstReplicate0NS{} -> t Ast.AstSum0S{} -> t @@ -2814,6 +2699,10 @@ expandAst t = case t of -- this is expensive but the only way to guarantee full simplification Ast.AstZipS v -> Ast.AstZipS (expandAst v) Ast.AstUnzipS v -> Ast.AstUnzipS (expandAst v) + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ expandAst v + Ast.AstUnNestS v -> astUnNestS $ expandAst v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz $ expandAst v @@ -2821,20 +2710,6 @@ expandAst t = case t of Ast.AstSFromR v -> astSFromR $ expandAst v Ast.AstSFromX v -> astSFromX $ expandAst v - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ expandAst v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ expandAst v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ expandAst v - Ast.AstXUnNestR v -> astXUnNestR $ expandAst v - Ast.AstXUnNestS v -> astXUnNestS $ expandAst v - Ast.AstXUnNest v -> astXUnNest $ expandAst v - -- These should not appear in this context unless via wacky tests. Ast.AstReplicate0NS{} -> t Ast.AstSum0S{} -> t @@ -2974,6 +2849,10 @@ simplifyAst t = case t of Ast.AstReshapeS v -> astReshapeS $ simplifyAst v Ast.AstZipS v -> Ast.AstZipS (simplifyAst v) Ast.AstUnzipS v -> Ast.AstUnzipS (simplifyAst v) + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ simplifyAst v + Ast.AstUnNestS v -> astUnNestS $ simplifyAst v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz $ simplifyAst v @@ -2981,20 +2860,6 @@ simplifyAst t = case t of Ast.AstSFromR v -> astSFromR $ simplifyAst v Ast.AstSFromX v -> astSFromX $ simplifyAst v - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ simplifyAst v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ simplifyAst v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ simplifyAst v - Ast.AstXUnNestR v -> astXUnNestR $ simplifyAst v - Ast.AstXUnNestS v -> astXUnNestS $ simplifyAst v - Ast.AstXUnNest v -> astXUnNest $ simplifyAst v - -- These should not appear in this context unless via wacky tests. Ast.AstReplicate0NS{} -> t Ast.AstSum0S{} -> t @@ -3402,6 +3267,10 @@ contractAst t = case t of Ast.AstReshapeS v -> astReshapeS $ contractAst v Ast.AstZipS v -> Ast.AstZipS (contractAst v) Ast.AstUnzipS v -> Ast.AstUnzipS (contractAst v) + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ contractAst v + Ast.AstUnNestS v -> astUnNestS $ contractAst v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz $ contractAst v @@ -3409,20 +3278,6 @@ contractAst t = case t of Ast.AstSFromR v -> astSFromR $ contractAst v Ast.AstSFromX v -> astSFromX $ contractAst v - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ contractAst v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ contractAst v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ contractAst v - Ast.AstXUnNestR v -> astXUnNestR $ contractAst v - Ast.AstXUnNestS v -> astXUnNestS $ contractAst v - Ast.AstXUnNest v -> astXUnNest $ contractAst v - -- These should not appear in this context unless via wacky tests. Ast.AstReplicate0NS{} -> t Ast.AstSum0S{} -> t @@ -3910,6 +3765,10 @@ substitute1Ast i var v1 = case v1 of Ast.AstReshapeS v -> astReshapeS <$> substitute1Ast i var v Ast.AstZipS v -> Ast.AstZipS <$> substitute1Ast i var v Ast.AstUnzipS v -> Ast.AstUnzipS <$> substitute1Ast i var v + Ast.AstNestS @sh1 @sh2 v -> + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS <$> substitute1Ast i var v + Ast.AstUnNestS v -> astUnNestS <$> substitute1Ast i var v Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> astFromS stkz <$> substitute1Ast i var v @@ -3917,20 +3776,6 @@ substitute1Ast i var v1 = case v1 of Ast.AstSFromR v -> astSFromR <$> substitute1Ast i var v Ast.AstSFromX v -> astSFromX <$> substitute1Ast i var v - Ast.AstXNestR @sh1 @m v -> - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR <$> substitute1Ast i var v - Ast.AstXNestS @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS <$> substitute1Ast i var v - Ast.AstXNest @sh1 @sh2 v -> - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest <$> substitute1Ast i var v - Ast.AstXUnNestR v -> astXUnNestR <$> substitute1Ast i var v - Ast.AstXUnNestS v -> astXUnNestS <$> substitute1Ast i var v - Ast.AstXUnNest v -> astXUnNest <$> substitute1Ast i var v - Ast.AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk -> Ast.AstReplicate0NS sh stk <$> substitute1Ast i var v Ast.AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk -> diff --git a/src/HordeAd/Core/AstTools.hs b/src/HordeAd/Core/AstTools.hs index cd0593542..99fd50a76 100644 --- a/src/HordeAd/Core/AstTools.hs +++ b/src/HordeAd/Core/AstTools.hs @@ -18,27 +18,16 @@ module HordeAd.Core.AstTools import Prelude hiding (foldl') import Control.Exception.Assert.Sugar -import Data.Proxy (Proxy (Proxy)) import Data.Type.Equality (testEquality, (:~:) (Refl)) import Data.Vector.Generic qualified as V import GHC.TypeLits (KnownNat, sameNat) import Type.Reflection (typeRep) import Data.Array.Mixed.Shape - ( KnownShX (..) - , shxAppend - , shxDropSSX - , shxSize - , shxTakeSSX - , ssxFromShape - , withKnownShX - ) -import Data.Array.Nested (KnownShS (..), MapJust, Rank, Replicate, ShS (..)) + (KnownShX (..), shxSize, ssxFromShape, withKnownShX) +import Data.Array.Nested (KnownShS (..), Rank, ShS (..)) import Data.Array.Nested.Internal.Shape - ( shCvtRX - , shCvtSX - , shCvtXR' - , shrRank + ( shrRank , shrSize , shsAppend , shsInit @@ -140,6 +129,10 @@ ftkAst t = case t of FTKProduct (FTKS sh y) (FTKS _ z) -> FTKS sh (FTKProduct y z) AstUnzipS v -> case ftkAst v of FTKS sh (FTKProduct y z) -> FTKProduct (FTKS sh y) (FTKS sh z) + AstNestS @sh1 @sh2 v -> case ftkAst v of + FTKS _ x -> FTKS (knownShS @sh1) (FTKS (knownShS @sh2) x) + AstUnNestS @sh1 @sh2 v -> case ftkAst v of + FTKS _ (FTKS _ x) -> FTKS (knownShS @sh1 `shsAppend` knownShS @sh2) x AstFromS stkz v -> let fromS :: FullTensorKind y2 -> STensorKindType z2 -> FullTensorKind z2 @@ -172,26 +165,6 @@ ftkAst t = case t of AstSFromX v -> case ftkAst v of FTKX _ x -> FTKS knownShS x - AstXNestR @sh1 @m v -> case ftkAst v of - FTKX sh x -> FTKX (shxTakeSSX (Proxy @(Replicate m Nothing)) - sh (knownShX @sh1)) - (FTKR (shCvtXR' (shxDropSSX sh (knownShX @sh1))) x) - AstXNestS @sh1 @sh2 v -> case ftkAst v of - FTKX sh x -> FTKX (shxTakeSSX (Proxy @(MapJust sh2)) sh (knownShX @sh1)) - (FTKS knownShS x) - AstXNest @sh1 @sh2 v -> case ftkAst v of - FTKX sh x -> FTKX (shxTakeSSX (Proxy @sh2) sh (knownShX @sh1)) - (FTKX (shxDropSSX sh (knownShX @sh1)) x) - AstXUnNestR v -> case ftkAst v of - FTKX sh1 (FTKR sh2 x) -> - FTKX (sh1 `shxAppend` shCvtRX sh2) x - AstXUnNestS v -> case ftkAst v of - FTKX sh1 (FTKS sh2 x) -> - FTKX (sh1 `shxAppend` shCvtSX sh2) x - AstXUnNest v -> case ftkAst v of - FTKX sh1 (FTKX sh2 x) -> - FTKX (sh1 `shxAppend` sh2) x - AstReplicate0NS sh _ v -> case ftkAst v of FTKS _ x -> FTKS sh x AstSum0S _ _ v -> case ftkAst v of @@ -270,19 +243,14 @@ varInAst var = \case AstReshapeS v -> varInAst var v AstZipS v -> varInAst var v AstUnzipS v -> varInAst var v + AstNestS v -> varInAst var v + AstUnNestS v -> varInAst var v AstFromS _ v -> varInAst var v AstSFromK t -> varInAst var t AstSFromR v -> varInAst var v AstSFromX v -> varInAst var v - AstXNestR v -> varInAst var v - AstXNestS v -> varInAst var v - AstXNest v -> varInAst var v - AstXUnNestR v -> varInAst var v - AstXUnNestS v -> varInAst var v - AstXUnNest v -> varInAst var v - AstReplicate0NS _ _ v -> varInAst var v AstSum0S _ _ v -> varInAst var v AstDot0S _ u v -> varInAst var u || varInAst var v diff --git a/src/HordeAd/Core/AstVectorize.hs b/src/HordeAd/Core/AstVectorize.hs index 30934e2a9..83da2571f 100644 --- a/src/HordeAd/Core/AstVectorize.hs +++ b/src/HordeAd/Core/AstVectorize.hs @@ -22,8 +22,7 @@ import System.IO (Handle, hFlush, hPutStrLn, stderr, stdout) import System.IO.Unsafe (unsafePerformIO) import Data.Array.Mixed.Permutation qualified as Permutation -import Data.Array.Mixed.Shape - (ssxAppend, ssxFromShape, ssxReplicate, withKnownShX) +import Data.Array.Mixed.Shape (ssxFromShape, withKnownShX) import Data.Array.Mixed.Types (unsafeCoerceRefl) import Data.Array.Nested ( IShX @@ -36,14 +35,7 @@ import Data.Array.Nested , type (++) ) import Data.Array.Nested.Internal.Shape - ( shCvtSX - , shrRank - , shsAppend - , shsLength - , shsPermutePrefix - , shsRank - , withKnownShS - ) + (shrRank, shsAppend, shsLength, shsPermutePrefix, shsRank, withKnownShS) import HordeAd.Core.Ast (AstTensor) import HordeAd.Core.Ast hiding (AstBool (..), AstTensor (..)) @@ -383,6 +375,10 @@ build1V snat@SNat (var, v0) = Ast.AstZipS $ build1V snat (var, v) Ast.AstUnzipS v -> traceRule $ Ast.AstUnzipS $ build1V snat (var, v) + Ast.AstNestS @sh1 @sh2 v -> traceRule $ + withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ + astNestS $ build1V snat (var, v) + Ast.AstUnNestS v -> traceRule $ astUnNestS $ build1V snat (var, v) Ast.AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) -> traceRule $ @@ -391,20 +387,6 @@ build1V snat@SNat (var, v0) = Ast.AstSFromR v -> traceRule $ astSFromR $ build1V snat (var, v) Ast.AstSFromX v -> traceRule $ astSFromX $ build1V snat (var, v) - Ast.AstXNestR @sh1 @m v -> traceRule $ - withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $ - astXNestR $ build1V snat (var, v) - Ast.AstXNestS @sh1 @sh2 v -> traceRule $ - withKnownShX (knownShX @sh1 - `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $ - astXNestS $ build1V snat (var, v) - Ast.AstXNest @sh1 @sh2 v -> traceRule $ - withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $ - astXNest $ build1V snat (var, v) - Ast.AstXUnNestR v -> traceRule $ astXUnNestR $ build1V snat (var, v) - Ast.AstXUnNestS v -> traceRule $ astXUnNestS $ build1V snat (var, v) - Ast.AstXUnNest v -> traceRule $ astXUnNest $ build1V snat (var, v) - Ast.AstReplicate0NS{} -> error "build1V: term not accessible from user API" Ast.AstSum0S{} -> error "build1V: term not accessible from user API" Ast.AstDot0S{} -> error "build1V: term not accessible from user API" diff --git a/src/HordeAd/Core/OpsAst.hs b/src/HordeAd/Core/OpsAst.hs index 9f07bbfbc..487c61eb6 100644 --- a/src/HordeAd/Core/OpsAst.hs +++ b/src/HordeAd/Core/OpsAst.hs @@ -24,8 +24,8 @@ import Unsafe.Coerce (unsafeCoerce) import Data.Array.Nested (type (++), Rank, KnownShS (..), KnownShX (..), ShX (..), ShS (..)) import Data.Array.Mixed.Types (Init, unsafeCoerceRefl) -import Data.Array.Mixed.Shape (shxInit, IShX, ssxFromShape, withKnownShX) -import Data.Array.Nested.Internal.Shape (shsProduct, shsRank, shsPermutePrefix, shrRank, shsInit, withKnownShS) +import Data.Array.Mixed.Shape (ssxAppend, ssxReplicate, shxInit, IShX, ssxFromShape, withKnownShX) +import Data.Array.Nested.Internal.Shape (shCvtSX, shsProduct, shsRank, shsPermutePrefix, shrRank, shsInit, withKnownShS) import Data.Array.Mixed.Permutation qualified as Permutation import HordeAd.Core.Adaptor @@ -653,12 +653,100 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where sfromX = astSFromX -- Nesting/unnesting - xnestR sh = withKnownShX sh $ astXNestR - xnestS sh = withKnownShX sh $ astXNestS - xnest sh = withKnownShX sh $ astXNest - xunNestR = astXUnNestR - xunNestS = astXUnNestS - xunNest = astXUnNest + xnestR @sh1' @m @x sh1' a = case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + withKnownShS (dropShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ Drop (Rank sh1') sh1sh2 + :~: sh1sh2) $ + (unsafeCoerce + :: AstTensor AstMethodLet s + (TKX2 sh1' (TKS2 (Drop (Rank sh1') sh1sh2) x)) + -> AstTensor AstMethodLet s (TKX2 sh1' (TKR2 m x))) + $ astFromS @(TKS2 (Take (Rank sh1') sh1sh2) + (TKS2 (Drop (Rank sh1') sh1sh2) x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ astNestS @(Take (Rank sh1') sh1sh2) @(Drop (Rank sh1') sh1sh2) + $ astSFromX @sh1sh2 a + xnestS @sh1' @sh2 @x sh1' a = case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ sh2 + :~: sh1sh2) $ + astFromS @(TKS2 (Take (Rank sh1') sh1sh2) (TKS2 sh2 x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ astNestS @(Take (Rank sh1') sh1sh2) @sh2 + $ astSFromX @sh1sh2 a + xnest @sh1' @sh2' @x sh1' a = case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + withKnownShS (dropShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ Drop (Rank sh1') sh1sh2 + :~: sh1sh2) $ + (unsafeCoerce + :: AstTensor AstMethodLet s + (TKX2 sh1' (TKS2 (Drop (Rank sh1') sh1sh2) x)) + -> AstTensor AstMethodLet s (TKX2 sh1' (TKX2 sh2' x))) + $ astFromS @(TKS2 (Take (Rank sh1') sh1sh2) + (TKS2 (Drop (Rank sh1') sh1sh2) x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ astNestS @(Take (Rank sh1') sh1sh2) @(Drop (Rank sh1') sh1sh2) + $ astSFromX @sh1sh2 a + xunNestR @sh1' @m @x a = case ftkAst a of + FTKX sh1' y -> case y of + FTKR sh2' x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withCastRS sh2' $ \(sh2 :: ShS sh2) -> + withKnownShS sh1 $ + withKnownShS sh2 $ + astFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' `ssxAppend` ssxReplicate (SNat @m)) + (ftkToStk x)) + $ astUnNestS @sh1 @sh2 + $ astSFromX @sh1 + $ (unsafeCoerce + :: AstTensor AstMethodLet s (TKX2 sh1' (TKR2 m x)) + -> AstTensor AstMethodLet s (TKX2 sh1' (TKS2 sh2 x))) + a + xunNestS @_ @sh2 @x a = case ftkAst a of + FTKX sh1' y -> case y of + FTKS _ x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withKnownShS sh1 $ + astFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' + `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) + (ftkToStk x)) + $ astUnNestS @sh1 @sh2 + $ astSFromX @sh1 a + xunNest @sh1' @sh2' @x a = case ftkAst a of + FTKX sh1' y -> case y of + FTKX sh2' x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withCastXS sh2' $ \(sh2 :: ShS sh2) -> + withKnownShS sh1 $ + withKnownShS sh2 $ + astFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' `ssxAppend` (knownShX @sh2')) + (ftkToStk x)) + $ astUnNestS @sh1 @sh2 + $ astSFromX @sh1 + $ (unsafeCoerce + :: AstTensor AstMethodLet s (TKX2 sh1' (TKX2 sh2' x)) + -> AstTensor AstMethodLet s (TKX2 sh1' (TKS2 sh2 x))) + a -- General operations that don't require LetTensor nor ShareTensor tftk _stk = ftkAst @@ -1241,12 +1329,100 @@ instance AstSpan s => BaseTensor (AstRaw s) where xfromS @_ @sh' @x = AstRaw . AstFromS (stensorKind @(TKX2 sh' x)) . unAstRaw -- Nesting/unnesting - xnestR sh = withKnownShX sh $ AstRaw . AstXNestR . unAstRaw - xnestS sh = withKnownShX sh $ AstRaw . AstXNestS . unAstRaw - xnest sh = withKnownShX sh $ AstRaw . AstXNest . unAstRaw - xunNestR = AstRaw . AstXUnNestR . unAstRaw - xunNestS = AstRaw . AstXUnNestS . unAstRaw - xunNest = AstRaw . AstXUnNest . unAstRaw + xnestR @sh1' @m @x sh1' (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + withKnownShS (dropShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ Drop (Rank sh1') sh1sh2 + :~: sh1sh2) $ + (unsafeCoerce + :: AstTensor AstMethodShare s + (TKX2 sh1' (TKS2 (Drop (Rank sh1') sh1sh2) x)) + -> AstTensor AstMethodShare s (TKX2 sh1' (TKR2 m x))) + $ AstFromS @(TKS2 (Take (Rank sh1') sh1sh2) + (TKS2 (Drop (Rank sh1') sh1sh2) x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ AstNestS @(Take (Rank sh1') sh1sh2) @(Drop (Rank sh1') sh1sh2) + $ AstSFromX @sh1sh2 a + xnestS @sh1' @sh2 @x sh1' (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ sh2 + :~: sh1sh2) $ + AstFromS @(TKS2 (Take (Rank sh1') sh1sh2) (TKS2 sh2 x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ AstNestS @(Take (Rank sh1') sh1sh2) @sh2 + $ AstSFromX @sh1sh2 a + xnest @sh1' @sh2' @x sh1' (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1sh2' x | SNat <- ssxRank sh1' -> + withCastXS sh1sh2' $ \(sh1sh2 :: ShS sh1sh2) -> + withKnownShX (ssxFromShape sh1sh2') $ + withKnownShS sh1sh2 $ + withKnownShS (takeShS @(Rank sh1') sh1sh2) $ + withKnownShS (dropShS @(Rank sh1') sh1sh2) $ + gcastWith (unsafeCoerceRefl + :: Take (Rank sh1') sh1sh2 ++ Drop (Rank sh1') sh1sh2 + :~: sh1sh2) $ + (unsafeCoerce + :: AstTensor AstMethodShare s + (TKX2 sh1' (TKS2 (Drop (Rank sh1') sh1sh2) x)) + -> AstTensor AstMethodShare s (TKX2 sh1' (TKX2 sh2' x))) + $ AstFromS @(TKS2 (Take (Rank sh1') sh1sh2) + (TKS2 (Drop (Rank sh1') sh1sh2) x)) + (STKX sh1' (STKS knownShS (ftkToStk x))) + $ AstNestS @(Take (Rank sh1') sh1sh2) @(Drop (Rank sh1') sh1sh2) + $ AstSFromX @sh1sh2 a + xunNestR @sh1' @m @x (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1' y -> case y of + FTKR sh2' x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withCastRS sh2' $ \(sh2 :: ShS sh2) -> + withKnownShS sh1 $ + withKnownShS sh2 $ + AstFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' `ssxAppend` ssxReplicate (SNat @m)) + (ftkToStk x)) + $ AstUnNestS @sh1 @sh2 + $ AstSFromX @sh1 + $ (unsafeCoerce + :: AstTensor AstMethodShare s (TKX2 sh1' (TKR2 m x)) + -> AstTensor AstMethodShare s (TKX2 sh1' (TKS2 sh2 x))) + a + xunNestS @_ @sh2 @x (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1' y -> case y of + FTKS _ x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withKnownShS sh1 $ + AstFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' + `ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) + (ftkToStk x)) + $ AstUnNestS @sh1 @sh2 + $ AstSFromX @sh1 a + xunNest @sh1' @sh2' @x (AstRaw a) = AstRaw $ case ftkAst a of + FTKX sh1' y -> case y of + FTKX sh2' x -> + withCastXS sh1' $ \(sh1 :: ShS sh1) -> + withCastXS sh2' $ \(sh2 :: ShS sh2) -> + withKnownShS sh1 $ + withKnownShS sh2 $ + AstFromS @(TKS2 (sh1 ++ sh2) x) + (STKX (ssxFromShape sh1' `ssxAppend` (knownShX @sh2')) + (ftkToStk x)) + $ AstUnNestS @sh1 @sh2 + $ AstSFromX @sh1 + $ (unsafeCoerce + :: AstTensor AstMethodShare s (TKX2 sh1' (TKX2 sh2' x)) + -> AstTensor AstMethodShare s (TKX2 sh1' (TKS2 sh2 x))) + a -- General operations that don't require LetTensor nor ShareTensor tftk _stk = ftkAst . unAstRaw