Skip to content

Commit

Permalink
Trivialize nesting/unnesting in AST
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 31, 2025
1 parent 7b0f7c0 commit fda17f1
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 335 deletions.
26 changes: 6 additions & 20 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstUnzipS :: (TensorKind y, TensorKind z, KnownShS sh)
=> AstTensor ms s (TKS2 sh (TKProduct y z))
-> AstTensor ms s (TKProduct (TKS2 sh y) (TKS2 sh z))
AstNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKS2 (sh1 ++ sh2) x)
-> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x))
AstUnNestS :: (KnownShS sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKS2 sh1 (TKS2 sh2 x))
-> AstTensor ms s (TKS2 (sh1 ++ sh2) x)

-- Conversions
AstFromS :: forall y z ms s.
Expand All @@ -413,26 +419,6 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSFromX :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind r)
=> AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r)

-- Nesting/unnesting
AstXNestR :: (KnownShX sh1, KnownNat m, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x)
-> AstTensor ms s (TKX2 sh1 (TKR2 m x))
AstXNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x)
-> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x))
AstXNest :: (KnownShX sh1, KnownShX sh2, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ sh2) x)
-> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x))
AstXUnNestR :: (KnownShX sh1, KnownNat m, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKR2 m x))
-> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x)
AstXUnNestS :: (KnownShX sh1, KnownShS sh2, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKS2 sh2 x))
-> AstTensor ms s (TKX2 (sh1 ++ MapJust sh2) x)
AstXUnNest :: (KnownShX sh1, KnownShX sh2, TensorKind x)
=> AstTensor ms s (TKX2 sh1 (TKX2 sh2 x))
-> AstTensor ms s (TKX2 (sh1 ++ sh2) x)

-- Backend-specific primitives
AstReplicate0NS :: ShS sh -> STensorKindType x
-> AstTensor ms s (TKS2 '[] x)
Expand Down
18 changes: 4 additions & 14 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,14 @@ inlineAst memo v0 = case v0 of
Ast.AstReshapeS v -> second Ast.AstReshapeS (inlineAst memo v)
Ast.AstZipS v -> second Ast.AstZipS (inlineAst memo v)
Ast.AstUnzipS v -> second Ast.AstUnzipS (inlineAst memo v)
Ast.AstNestS v -> second Ast.AstNestS $ inlineAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ inlineAst memo v

Ast.AstFromS stkz v -> second (Ast.AstFromS stkz) $ inlineAst memo v
Ast.AstSFromK t -> second Ast.AstSFromK (inlineAst memo t)
Ast.AstSFromR v -> second Ast.AstSFromR $ inlineAst memo v
Ast.AstSFromX v -> second Ast.AstSFromX $ inlineAst memo v

Ast.AstXNestR v -> second Ast.AstXNestR $ inlineAst memo v
Ast.AstXNestS v -> second Ast.AstXNestS $ inlineAst memo v
Ast.AstXNest v -> second Ast.AstXNest $ inlineAst memo v
Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ inlineAst memo v
Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ inlineAst memo v
Ast.AstXUnNest v -> second Ast.AstXUnNest $ inlineAst memo v

Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (inlineAst memo v)
Ast.AstSum0S sh stk v ->
Expand Down Expand Up @@ -553,13 +548,8 @@ unshareAst memo = \case
Ast.AstSFromK t -> second Ast.AstSFromK (unshareAst memo t)
Ast.AstSFromR v -> second Ast.AstSFromR $ unshareAst memo v
Ast.AstSFromX v -> second Ast.AstSFromX $ unshareAst memo v

Ast.AstXNestR v -> second Ast.AstXNestR $ unshareAst memo v
Ast.AstXNestS v -> second Ast.AstXNestS $ unshareAst memo v
Ast.AstXNest v -> second Ast.AstXNest $ unshareAst memo v
Ast.AstXUnNestR v -> second Ast.AstXUnNestR $ unshareAst memo v
Ast.AstXUnNestS v -> second Ast.AstXUnNestS $ unshareAst memo v
Ast.AstXUnNest v -> second Ast.AstXUnNest $ unshareAst memo v
Ast.AstNestS v -> second Ast.AstNestS $ unshareAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ unshareAst memo v

Ast.AstReplicate0NS sh stk v ->
second (Ast.AstReplicate0NS sh stk) (unshareAst memo v)
Expand Down
9 changes: 2 additions & 7 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ interpretAst !env = \case
AstReshapeS v -> sreshape (interpretAst env v)
AstZipS v -> szip $ interpretAst env v
AstUnzipS v -> sunzip $ interpretAst env v
AstNestS v -> snest knownShS $ interpretAst env v
AstUnNestS v -> sunNest $ interpretAst env v

AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v))
, Dict <- lemTensorKindOfSTK stkz ->
Expand All @@ -383,13 +385,6 @@ interpretAst !env = \case
AstSFromR v -> sfromR $ interpretAst env v
AstSFromX v -> sfromX $ interpretAst env v

AstXNestR v -> xnestR knownShX $ interpretAst env v
AstXNestS v -> xnestS knownShX $ interpretAst env v
AstXNest v -> xnest knownShX $ interpretAst env v
AstXUnNestR v -> xunNestR $ interpretAst env v
AstXUnNestS v -> xunNestS $ interpretAst env v
AstXUnNest v -> xunNest $ interpretAst env v

AstReplicate0NS sh stk v | Dict <- lemTensorKindOfSTK stk
, SNat <- shsProduct sh ->
withKnownShS sh $
Expand Down
18 changes: 4 additions & 14 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,10 @@ printAstAux cfg d = \case
printPrefixOp printAst cfg d "sreshape" [v]
AstZipS v -> printPrefixOp printAst cfg d "szip" [v]
AstUnzipS v -> printPrefixOp printAst cfg d "sunzip" [v]
AstNestS @sh1 @sh2 v ->
withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $
printPrefixOp printAst cfg d "snestS" [v]
AstUnNestS v -> printPrefixOp printAst cfg d "sunNestS" [v]

AstFromS stkz v | Dict <- lemTensorKindOfSTK (ftkToStk (ftkAst v)) ->
case stkz of
Expand All @@ -453,20 +457,6 @@ printAstAux cfg d = \case
AstSFromR v -> printPrefixOp printAst cfg d "sfromR" [v]
AstSFromX v -> printPrefixOp printAst cfg d "sfromX" [v]

AstXNestR @sh1 @m v ->
withKnownShX (knownShX @sh1 `ssxAppend` ssxReplicate (SNat @m)) $
printPrefixOp printAst cfg d "xnestR" [v]
AstXNestS @sh1 @sh2 v ->
withKnownShX (knownShX @sh1
`ssxAppend` ssxFromShape (shCvtSX (knownShS @sh2))) $
printPrefixOp printAst cfg d "xnestS" [v]
AstXNest @sh1 @sh2 v ->
withKnownShX (knownShX @sh1 `ssxAppend` knownShX @sh2) $
printPrefixOp printAst cfg d "xnest" [v]
AstXUnNestR v -> printPrefixOp printAst cfg d "xunNestR" [v]
AstXUnNestS v -> printPrefixOp printAst cfg d "xunNestS" [v]
AstXUnNest v -> printPrefixOp printAst cfg d "xunNest" [v]

AstReplicate0NS _sh stk v | Dict <- lemTensorKindOfSTK stk ->
printPrefixOp printAst cfg d "sreplicate0N" [v]
AstSum0S sh stk v | Dict <- lemTensorKindOfSTK stk ->
Expand Down
Loading

0 comments on commit fda17f1

Please sign in to comment.