Skip to content

Commit

Permalink
Add AD for snest and sunNest
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 20, 2024
1 parent ad2ee2f commit 95d1420
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,12 @@ data Delta :: Target -> TensorKindType -> Type where
-- TODO: this is a haddock for Gather1; fix.
CastS :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2, KnownShS sh)
=> Delta target (TKS sh r1) -> Delta target (TKS sh r2)
NestS :: (GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS (sh1 ++ sh2) r)
-> Delta target (TKS2 sh1 (TKS sh2 r))
UnNestS :: (GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS2 sh1 (TKS sh2 r))
-> Delta target (TKS (sh1 ++ sh2) r)
SFromR :: forall sh r target. (GoodScalar r, KnownShS sh, KnownNat (Rank sh))
=> Delta target (TKR (Rank sh) r)
-> Delta target (TKS sh r)
Expand Down Expand Up @@ -694,6 +700,8 @@ shapeDeltaFull = \case
ReshapeS{} -> FTKS knownShS FTKScalar
GatherS{} -> FTKS knownShS FTKScalar
CastS{} -> FTKS knownShS FTKScalar
NestS{} -> FTKS knownShS (FTKS knownShS FTKScalar)
UnNestS{} -> FTKS knownShS FTKScalar
SFromR{} -> FTKS knownShS FTKScalar
SFromH{} -> FTKS knownShS FTKScalar

Expand Down Expand Up @@ -1207,6 +1215,10 @@ evalSame !s !c = \case
CastS @r1 @_ @sh d ->
evalSRuntimeSpecialized s (toADTensorKindShared (stensorKind @(TKS sh r1))
$ scast c) d
NestS d ->
evalSame s (sunNest c) d
UnNestS d ->
evalSame s (snest knownShS c) d
SFromR @sh (RFromS @sh2 d) ->
case sameShape @sh @sh2 of
Just Refl -> evalSame s c d
Expand Down Expand Up @@ -1554,6 +1566,8 @@ fwdSame params s = \case
case sameTensorKind @(TKS sh r1) @(ADTensorKind (TKS sh r1)) of
Just Refl -> second scast $ fwdSame params s d
_ -> (s, repConstant 0 $ aDTensorKind $ shapeDeltaFull d0)
NestS d -> second (snest knownShS) $ fwdSame params s d
UnNestS d -> second sunNest $ fwdSame params s d
SFromR @sh (RFromS @sh2 d) ->
case sameShape @sh @sh2 of
Just Refl -> fwdSame params s d
Expand Down
3 changes: 3 additions & 0 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
let v = sfromIntegral u
in fromPrimalADVal v
sconcrete t = fromPrimalADVal (sconcrete t)
snest sh (D u u') | Dict <- Nested.Internal.Shape.shsKnownShS sh =
dD (snest sh u) (NestS u')
sunNest (D u u') = dD (sunNest u) (UnNestS u')
sfromR :: forall r sh. (GoodScalar r, KnownShS sh, KnownNat (Rank sh))
=> ADVal target (TKR (Rank sh) r) -> ADVal target (TKS sh r)
sfromR (D u u') = dDnotShared (sfromR u) (dSFromR u')
Expand Down

0 comments on commit 95d1420

Please sign in to comment.