Skip to content

Commit

Permalink
Trivialize nesting/unnesting in AST
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 31, 2025
1 parent 7b0f7c0 commit 1079c10
Show file tree
Hide file tree
Showing 8 changed files with 302 additions and 375 deletions.
28 changes: 6 additions & 22 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ import Data.Array.Nested
, KnownShX
, ListR
, ListS (..)
, MapJust
, Rank
, Replicate
, ShS (..)
, type (++)
)
Expand Down Expand Up @@ -402,6 +400,12 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstUnzipS :: (TensorKind y, TensorKind z, KnownShS sh)
=> AstTensor ms s (TKS2 sh (TKProduct y z))
-> AstTensor ms s (TKProduct (TKS2 sh y) (TKS2 sh z))
AstNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKS2 (sh1 ++ sh2) x)
-> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x))
AstUnNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x))
-> AstTensor ms s (TKS2 (sh1 ++ sh2) x)

-- Conversions
AstFromS :: forall y z ms s.
Expand All @@ -413,26 +417,6 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSFromX :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind r)
=> AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r)

-- Nesting/unnesting
AstXNestR :: (KnownShX sh1, KnownNat m, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x)
-> AstTensor ms s (TKX2 sh1 (TKR2 m x))
AstXNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x)
-> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x))
AstXNest :: (KnownShX sh1, KnownShX sh2, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ sh2) x)
-> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x))
AstXUnNestR :: (KnownShX sh1, KnownNat m, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKR2 m x))
-> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x)
AstXUnNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x))
-> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x)
AstXUnNest :: (KnownShX sh1, KnownShX sh2, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x))
-> AstTensor ms s (TKX2 (sh1 ++ sh2) x)

-- Backend-specific primitives
AstReplicate0NS :: ShS sh -> STensorKindType x
-> AstTensor ms s (TKS2 '[] x)
Expand Down
18 changes: 4 additions & 14 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,14 @@ inlineAst memo v0 = case v0 of
Ast.AstReshapeS v -> second Ast.AstReshapeS (inlineAst memo v)
Ast.AstZipS v -> second Ast.AstZipS (inlineAst memo v)
Ast.AstUnzipS v -> second Ast.AstUnzipS (inlineAst memo v)
Ast.AstNestS v -> second Ast.AstNestS $ inlineAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ inlineAst memo v

Ast.AstFromS stkz v -> second (Ast.AstFromS stkz) $ inlineAst memo v
Ast.AstSFromK t -> second Ast.AstSFromK (inlineAst memo t)
Ast.AstSFromR v -> second Ast.AstSFromR $ inlineAst memo v
Ast.AstSFromX v -> second Ast.AstSFromX $ inlineAst memo v

Ast.AstXNestR v -> second Ast.AstXNestR $ inlineAst memo v
Ast.AstXNestS v -> second Ast.AstXNestS $ inlineAst memo v
Ast.AstXNest v -> second Ast.AstXNest $ inlineAst memo v
Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ inlineAst memo v
Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ inlineAst memo v
Ast.AstXUnNest v -> second Ast.AstXUnNest $ inlineAst memo v

Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (inlineAst memo v)
Ast.AstSum0S sh stk v ->
Expand Down Expand Up @@ -553,13 +548,8 @@ unshareAst memo = \case
Ast.AstSFromK t -> second Ast.AstSFromK (unshareAst memo t)
Ast.AstSFromR v -> second Ast.AstSFromR $ unshareAst memo v
Ast.AstSFromX v -> second Ast.AstSFromX $ unshareAst memo v

Ast.AstXNestR v -> second Ast.AstXNestR $ unshareAst memo v
Ast.AstXNestS v -> second Ast.AstXNestS $ unshareAst memo v
Ast.AstXNest v -> second Ast.AstXNest $ unshareAst memo v
Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ unshareAst memo v
Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ unshareAst memo v
Ast.AstXUnNest v -> second Ast.AstXUnNest $ unshareAst memo v
Ast.AstNestS v -> second Ast.AstNestS $ unshareAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ unshareAst memo v

Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (unshareAst memo v)
Expand Down
11 changes: 3 additions & 8 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GHC.TypeLits (KnownNat)
import Type.Reflection (Typeable, typeRep)

import Data.Array.Mixed.Shape (withKnownShX)
import Data.Array.Nested (KnownShS (..), KnownShX (..), ListS (..))
import Data.Array.Nested (KnownShS (..), ListS (..))
import Data.Array.Nested.Internal.Shape (shsAppend, shsProduct, withKnownShS)

import HordeAd.Core.Ast
Expand Down Expand Up @@ -375,6 +375,8 @@ interpretAst !env = \case
AstReshapeS v -> sreshape (interpretAst env v)
AstZipS v -> szip $ interpretAst env v
AstUnzipS v -> sunzip $ interpretAst env v
AstNestS v -> snest knownShS $ interpretAst env v
AstUnNestS v -> sunNest $ interpretAst env v

AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v))
, Dict <- lemTensorKindOfSTK stkz ->
Expand All @@ -383,13 +385,6 @@ interpretAst !env = \case
AstSFromR v -> sfromR $ interpretAst env v
AstSFromX v -> sfromX $ interpretAst env v

AstXNestR v -> xnestR knownShX $ interpretAst env v
AstXNestS v -> xnestS knownShX $ interpretAst env v
AstXNest v -> xnest knownShX $ interpretAst env v
AstXUnNestR v -> xunNestR $ interpretAst env v
AstXUnNestS v -> xunNestS $ interpretAst env v
AstXUnNest v -> xunNest $ interpretAst env v

AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk
, SNat <- shsProduct sh ->
withKnownShS sh $
Expand Down
31 changes: 9 additions & 22 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@ import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..))
import GHC.TypeLits (fromSNat)

import Data.Array.Mixed.Shape
(ssxAppend, ssxFromShape, ssxReplicate, withKnownShX)
import Data.Array.Mixed.Shape qualified as X
import Data.Array.Mixed.Shape (StaticShX (..), listxRank)
import Data.Array.Nested
(KnownShS (..), KnownShX (..), ListS (..), ShR (..), ShS (..), ShX (..))
(KnownShS (..), ListS (..), ShR (..), ShS (..), ShX (..))
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape
(shCvtSX, shsAppend, shsRank, withKnownShS)
import Data.Array.Nested.Internal.Shape (shsAppend, shsRank, withKnownShS)

import HordeAd.Core.Ast
import HordeAd.Core.AstTools
Expand Down Expand Up @@ -83,8 +80,8 @@ printAstVar cfg var =
rankTensorKind (STKScalar _) = 0
rankTensorKind (STKR snat _) = fromInteger $ fromSNat snat
rankTensorKind (STKS sh _) = fromInteger $ fromSNat $ shsRank sh
rankTensorKind (STKX (X.StaticShX l) _) =
fromInteger $ fromSNat $ X.listxRank l
rankTensorKind (STKX (StaticShX l) _) =
fromInteger $ fromSNat $ listxRank l
rankTensorKind (STKProduct @y1 @z1 sy sz) =
rankTensorKind @y1 sy `max` rankTensorKind @z1 sz
n = rankTensorKind (stensorKind @y)
Expand Down Expand Up @@ -442,6 +439,10 @@ printAstAux cfg d = \case
printPrefixOp printAst cfg d "sreshape" [v]
AstZipS v -> printPrefixOp printAst cfg d "szip" [v]
AstUnzipS v -> printPrefixOp printAst cfg d "sunzip" [v]
AstNestS @sh1 @sh2 v ->
withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $
printPrefixOp printAst cfg d "snestS" [v]
AstUnNestS v -> printPrefixOp printAst cfg d "sunNestS" [v]

AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) ->
case stkz of
Expand All @@ -453,20 +454,6 @@ printAstAux cfg d = \case
AstSFromR v -> printPrefixOp printAst cfg d "sfromR" [v]
AstSFromX v -> printPrefixOp printAst cfg d "sfromX" [v]

AstXNestR @sh1 @m v ->
withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $
printPrefixOp printAst cfg d "xnestR" [v]
AstXNestS @sh1 @sh2 v ->
withKnownShX (knownShX @sh1
`ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $
printPrefixOp printAst cfg d "xnestS" [v]
AstXNest @sh1 @sh2 v ->
withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $
printPrefixOp printAst cfg d "xnest" [v]
AstXUnNestR v -> printPrefixOp printAst cfg d "xunNestR" [v]
AstXUnNestS v -> printPrefixOp printAst cfg d "xunNestS" [v]
AstXUnNest v -> printPrefixOp printAst cfg d "xunNest" [v]

AstReplicate0NS _sh stk v | Dict <- lemTensorKindOfSTK stk ->
printPrefixOp printAst cfg d "sreplicate0N" [v]
AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk ->
Expand Down
Loading

0 comments on commit 1079c10

Please sign in to comment.