Skip to content

Commit

Permalink
Add snest and sunNest to AST
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 20, 2024
1 parent 8f010e4 commit ad2ee2f
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 6 deletions.
8 changes: 8 additions & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType -> Type wh
=> Nested.Shaped sh r -> AstTensor ms PrimalSpan (TKS sh r)
AstProjectS :: (GoodScalar r, KnownShS sh)
=> AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKS sh r)
AstNestS :: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> AstTensor ms s (TKS (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
AstUnNestS :: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
-> AstTensor ms s (TKS (sh1 ++ sh2) r)
AstSFromR :: (KnownShS sh, KnownNat (Rank sh), GoodScalar r)
=> AstTensor ms s (TKR (Rank sh) r) -> AstTensor ms s (TKS sh r)

Expand Down
4 changes: 4 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ inlineAst memo v0 = case v0 of
Ast.AstProjectS l p ->
let (memo1, l2) = inlineAst memo l
in (memo1, Ast.AstProjectS l2 p)
Ast.AstNestS v -> second Ast.AstNestS $ inlineAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ inlineAst memo v
Ast.AstSFromR v -> second Ast.AstSFromR $ inlineAst memo v

Ast.AstMinIndexX a -> second Ast.AstMinIndexX $ inlineAst memo a
Expand Down Expand Up @@ -608,6 +610,8 @@ unshareAst memo = \case
Ast.AstProjectS l p ->
let (memo1, l2) = unshareAst memo l
in (memo1, Ast.AstProjectS l2 p)
Ast.AstNestS v -> second Ast.AstNestS $ unshareAst memo v
Ast.AstUnNestS v -> second Ast.AstUnNestS $ unshareAst memo v
Ast.AstSFromR v -> second Ast.AstSFromR $ unshareAst memo v

Ast.AstMinIndexX a -> second Ast.AstMinIndexX $ unshareAst memo a
Expand Down
2 changes: 2 additions & 0 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ interpretAst !env = \case
let lt = interpretAst env l
in tlet @_ @TKUntyped lt
(\lw -> sfromD $ dunHVector lw V.! p)
AstNestS v -> snest knownShS $ interpretAst env v
AstUnNestS v -> sunNest $ interpretAst env v
AstSFromR v -> sfromR $ interpretAst env v

AstMinIndexX _v -> error "TODO"
Expand Down
2 changes: 2 additions & 0 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ printAstAux cfg d = \case
. printAst cfg 11 l
. showString " "
. shows p
AstNestS v -> printPrefixOp printAst cfg d "snest" [v]
AstUnNestS v -> printPrefixOp printAst cfg d "sunNest" [v]
AstSFromR v -> printPrefixOp printAst cfg d "sfromR" [v]

AstMkHVector l ->
Expand Down
44 changes: 43 additions & 1 deletion src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ module HordeAd.Core.AstSimplify
, astReverse, astReverseS
, astTranspose, astTransposeS, astReshape, astReshapeS
, astCast, astCastS, astFromIntegral, astFromIntegralS
, astProject1, astProject2, astProjectR, astProjectS, astRFromS, astSFromR
, astProject1, astProject2, astProjectR, astProjectS, astNestS, astUnNestS
, astRFromS, astSFromR
, astPrimalPart, astDualPart
, astLetHVectorIn, astHApply, astLetFun
-- * The simplifying bottom-up pass
Expand Down Expand Up @@ -355,6 +356,8 @@ astNonIndexStep t = case t of
Ast.AstFromIntegralS v -> astFromIntegralS v
AstConcreteS{} -> t
Ast.AstProjectS l p -> astProjectS l p
Ast.AstNestS v -> astNestS v
Ast.AstUnNestS v -> astUnNestS v
Ast.AstSFromR v -> astSFromR v
_ -> t -- TODO

Expand Down Expand Up @@ -814,6 +817,9 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstProjectS{} -> Ast.AstIndexS v0 ix
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars l (astIndexRec v ix)
-- TODO: generalize AstIndexS? Ast.AstNestS v -> astNestS (astIndexRec v ix)
-- TODO: hard: Ast.AstUnNestS v -> astUnNestS (astIndexRec v ix)
Ast.AstUnNestS _ -> Ast.AstIndexS v0 ix
Ast.AstSFromR t ->
withListSh (Proxy @shn) $ \_ ->
withListSh (Proxy @shm) $ \_ ->
Expand Down Expand Up @@ -2098,6 +2104,32 @@ astProjectS l p = case l of
Ast.AstCond b v1 v2 -> Ast.AstCond b (astProjectS v1 p) (astProjectS v2 p)
_ -> Ast.AstProjectS l p

astNestS
:: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
astNestS t = case t of
Ast.AstLet var u2 d2 -> -- TODO: good idea?
astLet var u2 (astNestS d2)
Ast.AstFromPrimal u -> Ast.AstFromPrimal $ astNestS u
Ast.AstCond b v1 v2 -> Ast.AstCond b (astNestS v1) (astNestS v2) -- TODO: ??
-- TODO: when sh agrees: Ast.AstUnNestS u -> u
_ -> Ast.AstNestS t

astUnNestS
:: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
-> AstTensor ms s (TKS (sh1 ++ sh2) r)
astUnNestS t = case t of
Ast.AstLet var u2 d2 -> -- TODO: good idea?
astLet var u2 (astUnNestS d2)
Ast.AstFromPrimal u -> Ast.AstFromPrimal $ astUnNestS u
Ast.AstCond b v1 v2 -> Ast.AstCond b (astUnNestS v1) (astUnNestS v2) -- TODO: ??
Ast.AstNestS u -> u
_ -> Ast.AstUnNestS t

astRFromS :: forall sh s r. (GoodScalar r, KnownShS sh)
=> AstTensor AstMethodLet s (TKS sh r) -> AstTensor AstMethodLet s (TKR (Rank sh) r)
astRFromS (AstConcreteS t) =
Expand Down Expand Up @@ -2180,6 +2212,8 @@ astPrimalPart t = case t of
Ast.AstGatherS v (vars, ix) -> astGatherS (astPrimalPart v) (vars, ix)
Ast.AstCastS v -> astCastS $ astPrimalPart v
Ast.AstProjectS l p -> astProjectS (astPrimalPart l) p
Ast.AstNestS v -> astNestS $ astPrimalPart v
Ast.AstUnNestS v -> astUnNestS $ astPrimalPart v
Ast.AstSFromR v -> astSFromR $ astPrimalPart v

Ast.AstMkHVector{} -> Ast.AstPrimalPart t -- TODO
Expand Down Expand Up @@ -2254,6 +2288,8 @@ astDualPart t = case t of
Ast.AstGatherS v (vars, ix) -> astGatherS (astDualPart v) (vars, ix)
Ast.AstCastS v -> astCastS $ astDualPart v
Ast.AstProjectS l p -> astProjectS (astDualPart l) p
Ast.AstNestS v -> astNestS $ astDualPart v
Ast.AstUnNestS v -> astUnNestS $ astDualPart v
Ast.AstSFromR v -> astSFromR $ astDualPart v

Ast.AstMkHVector{} -> Ast.AstDualPart t -- TODO
Expand Down Expand Up @@ -2522,6 +2558,8 @@ simplifyAst t = case t of
Ast.AstFromIntegralS v -> astFromIntegralS $ simplifyAst v
AstConcreteS{} -> t
Ast.AstProjectS l p -> astProjectS (simplifyAst l) p
Ast.AstNestS v -> astNestS $ simplifyAst v
Ast.AstUnNestS v -> astUnNestS $ simplifyAst v
Ast.AstSFromR v -> astSFromR $ simplifyAst v

Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map simplifyAstDynamic l
Expand Down Expand Up @@ -2741,6 +2779,8 @@ expandAst t = case t of
Ast.AstFromIntegralS v -> astFromIntegralS $ expandAst v
AstConcreteS{} -> t
Ast.AstProjectS l p -> astProjectS (expandAst l) p
Ast.AstNestS v -> astNestS $ expandAst v
Ast.AstUnNestS v -> astUnNestS $ expandAst v
Ast.AstSFromR v -> astSFromR $ expandAst v

Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map expandAstDynamic l
Expand Down Expand Up @@ -3279,6 +3319,8 @@ substitute1Ast i var v1 = case v1 of
case substitute1Ast i var l of
Nothing -> Nothing
ml -> Just $ astProjectS (fromMaybe l ml) p
Ast.AstNestS v -> astNestS <$> substitute1Ast i var v
Ast.AstUnNestS v -> astUnNestS <$> substitute1Ast i var v
Ast.AstSFromR v -> astSFromR <$> substitute1Ast i var v

Ast.AstMkHVector args ->
Expand Down
4 changes: 4 additions & 0 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ ftkAst t = case t of
AstFromIntegralS{} -> FTKS knownShS FTKScalar
AstConcreteS{} -> FTKS knownShS FTKScalar
AstProjectS{} -> FTKS knownShS FTKScalar
AstNestS{} -> FTKS knownShS (FTKS knownShS FTKScalar)
AstUnNestS{} -> FTKS knownShS FTKScalar
AstSFromR{} -> FTKS knownShS FTKScalar

AstMkHVector v ->
Expand Down Expand Up @@ -253,6 +255,8 @@ varInAst var = \case
AstFromIntegralS a -> varInAst var a
AstConcreteS{} -> False
AstProjectS l _p -> varInAst var l -- conservative
AstNestS v -> varInAst var v
AstUnNestS v -> varInAst var v
AstSFromR v -> varInAst var v

AstMinIndexX a -> varInAst var a
Expand Down
2 changes: 2 additions & 0 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ build1V snat@SNat (var, v00) =

Ast.AstProjectS l p ->
astProjectS (build1VOccurenceUnknown snat (var, l)) p
Ast.AstNestS v -> astNestS $ build1V snat (var, v)
Ast.AstUnNestS v -> astUnNestS $ build1V snat (var, v)
Ast.AstSFromR v -> astSFromR $ build1V snat (var, v)

Ast.AstMkHVector l -> traceRule $
Expand Down
12 changes: 12 additions & 0 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Data.Vector.Generic qualified as V
import GHC.TypeLits (KnownNat, Nat)

import Data.Array.Nested (IShR, KnownShS (..))
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape

import HordeAd.Core.Adaptor
import HordeAd.Core.Ast
Expand Down Expand Up @@ -426,6 +427,8 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
scast = astCastS
sfromIntegral = fromPrimal . astFromIntegralS . astSpanPrimal
sconcrete = fromPrimal . AstConcreteS
snest sh | Dict <- Nested.Internal.Shape.shsKnownShS sh = astNestS
sunNest = astUnNestS
sfromR = astSFromR

sfromPrimal = fromPrimal
Expand Down Expand Up @@ -683,6 +686,9 @@ instance AstSpan s => BaseTensor (AstRaw s) where
sfromIntegral = AstRaw . fromPrimal . AstFromIntegralS
. astSpanPrimalRaw . unAstRaw
sconcrete = AstRaw . fromPrimal . AstConcreteS
snest sh | Dict <- Nested.Internal.Shape.shsKnownShS sh =
AstRaw . AstNestS . unAstRaw
sunNest = AstRaw . AstUnNestS . unAstRaw
sfromR = AstRaw . AstSFromR . unAstRaw

sfromPrimal = AstRaw . fromPrimal . unAstRaw
Expand Down Expand Up @@ -910,6 +916,9 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where
scast = AstNoVectorize . scast . unAstNoVectorize
sfromIntegral = AstNoVectorize . sfromIntegral . unAstNoVectorize
sconcrete = AstNoVectorize . sconcrete
snest sh | Dict <- Nested.Internal.Shape.shsKnownShS sh =
AstNoVectorize . astNestS . unAstNoVectorize
sunNest = AstNoVectorize . astUnNestS . unAstNoVectorize
sfromR = AstNoVectorize . sfromR . unAstNoVectorize
sfromPrimal = AstNoVectorize . sfromPrimal . unAstNoVectorize
sprimalPart = AstNoVectorize . sprimalPart . unAstNoVectorize
Expand Down Expand Up @@ -1140,6 +1149,9 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where
sfromIntegral = AstNoSimplify . fromPrimal . AstFromIntegralS
. astSpanPrimal . unAstNoSimplify
sconcrete = AstNoSimplify . fromPrimal . AstConcreteS
snest sh | Dict <- Nested.Internal.Shape.shsKnownShS sh =
AstNoSimplify . AstNestS . unAstNoSimplify
sunNest = AstNoSimplify . AstUnNestS . unAstNoSimplify
sfromR = AstNoSimplify . AstSFromR . unAstNoSimplify
sfromPrimal = AstNoSimplify . fromPrimal . unAstNoSimplify
-- exceptionally we do simplify AstFromPrimal to avoid long boring chains
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ instance BaseTensor RepN where
scast = RepN . tcastS . unRepN
sfromIntegral = RepN . tfromIntegralS . unRepN
sconcrete = RepN
sfromR = RepN . flip Nested.rcastToShaped knownShS . unRepN
snest shs t = RepN $ Nested.snest shs $ unRepN t
sunNest t = RepN $ Nested.sunNest $ unRepN t
sfromR = RepN . flip Nested.rcastToShaped knownShS . unRepN

sscaleByScalar s v =
RepN $ tscaleByScalarS (tunScalarS $ unRepN s) (unRepN v)
Expand Down
10 changes: 6 additions & 4 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,14 @@ class ( Num (IntOf target)
sfromIntegral :: (GoodScalar r1, Integral r1, GoodScalar r2, KnownShS sh)
=> target (TKS sh r1) -> target (TKS sh r2)
sconcrete :: (GoodScalar r, KnownShS sh) => Nested.Shaped sh r -> target (TKS sh r)
snest :: forall sh1 sh2 r.
(GoodScalar r, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> ShS sh1 -> target (TKS (sh1 ++ sh2) r) -> target (TKS2 sh1 (TKS sh2 r))
sunNest :: forall sh1 sh2 r.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> target (TKS2 sh1 (TKS sh2 r)) -> target (TKS (sh1 ++ sh2) r)
sfromR :: (GoodScalar r, KnownShS sh, KnownNat (Rank sh))
=> target (TKR (Rank sh) r) -> target (TKS sh r)
snest :: forall sh sh' r. GoodScalar r
=> ShS sh -> target (TKS (sh ++ sh') r) -> target (TKS2 sh (TKS sh' r))
sunNest :: forall sh sh' r. GoodScalar r
=> target (TKS2 sh (TKS sh' r)) -> target (TKS (sh ++ sh') r)

-- ** No serviceable parts beyond this point ** --

Expand Down

0 comments on commit ad2ee2f

Please sign in to comment.