Skip to content

Commit

Permalink
Harden rfromD
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 27, 2023
1 parent bf61c42 commit b489e66
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 49 deletions.
40 changes: 18 additions & 22 deletions src/HordeAd/Core/TensorADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ instance ( Dual shaped ~ DeltaS ranked shaped
sD ast (Pair (Clown (Const l)) delta) =
let (l2, r) = sletUnwrap ast
in dD (l `mergeADShare` l2) r delta
sScale ast (Pair (Clown (Const l)) delta) =
sScale ast (Pair (Clown (Const l)) delta) =
let (l2, r) = sletUnwrap ast
in Pair (Clown (Const (l `mergeADShare` l2))) (dScale r delta)

Expand Down Expand Up @@ -411,12 +411,14 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> ranked rn n
df cx (ca, x, a) =
fst $ cfwdOnDomains (V.fromList [DynamicRanked x, DynamicRanked a])
g (V.fromList [DynamicRanked cx, DynamicRanked ca])
g
(V.fromList [DynamicRanked cx, DynamicRanked ca])
rf :: ranked rn n -> (ranked rn n, ranked rm m)
-> (ranked rn n, ranked rm m)
rf dt (x, a) =
domsToPair $ dunDomains @ranked domsOD $ fst
$ crevOnDomains (Just dt) g (V.fromList [DynamicRanked x, DynamicRanked a])
$ crevOnDomains (Just dt) g
(V.fromList [DynamicRanked x, DynamicRanked a])
in D (l1 `mergeADShare` l2)
(rfold @ranked f x0 as)
(FoldR f x0 as df rf x0' as')
Expand All @@ -437,22 +439,18 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
let shn = rshape x0
shm = tailShape $ rshape as
domsOD = V.fromList [odFromSh @rn shn, odFromSh @rm shm]
domsToPair :: forall f. ADReady f
=> Domains f -> (f rn n, f rm m)
domsToPair doms = (rfromD $ doms V.! 0, rfromD $ doms V.! 1)
-- Note that this function, and similarly @f@ and @rf@ instantiated
-- and passed to FoldR, is not a function on dual numbers.
df :: ranked rn n -> (ranked rm m, ranked rn n, ranked rm m)
-> ranked rn n
df cx (ca, x, a) = df0 cx ca x a
rf :: ranked rn n -> (ranked rn n, ranked rm m)
-> (ranked rn n, ranked rm m)
rf cx (x, a) =
let res = rf0 cx x a -- non-explicit sharing, so helps little
in ( rletDomainsIn
domsOD res
(\doms -> rfromD $ doms V.! 0)
, rletDomainsIn
domsOD res
(\doms -> rfromD $ doms V.! 1)
)
rf cx (x, a) = -- TODO: add explicit sharing
domsToPair $ dunDomains domsOD $ rf0 cx x a
in D (l1 `mergeADShare` l2)
(rfold @ranked f x0 as)
(FoldR f x0 as df rf x0' as')
Expand All @@ -473,12 +471,14 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> shaped rn sh
df cx (ca, x, a) =
fst $ cfwdOnDomains (V.fromList [DynamicShaped x, DynamicShaped a])
g (V.fromList [DynamicShaped cx, DynamicShaped ca])
g
(V.fromList [DynamicShaped cx, DynamicShaped ca])
rf :: shaped rn sh -> (shaped rn sh, shaped rm shm)
-> (shaped rn sh, shaped rm shm)
rf dt (x, a) =
domsToPair $ dunDomains @ranked domsOD $ fst
$ crevOnDomains (Just dt) g (V.fromList [DynamicShaped x, DynamicShaped a])
$ crevOnDomains (Just dt) g
(V.fromList [DynamicShaped x, DynamicShaped a])
in D (l1 `mergeADShare` l2)
(sfold @ranked f x0 as)
(FoldS f x0 as df rf x0' as')
Expand All @@ -496,6 +496,9 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
-> ADVal shaped rn sh
sfoldDer f df0 rf0 (D l1 x0 x0') (D l2 as as') =
let domsOD = V.fromList [odFromShS @rn @sh, odFromShS @rm @shm]
domsToPair :: forall f. ADReadyS f
=> Domains (RankedOf f) -> (f rn sh, f rm shm)
domsToPair doms = (sfromD $ doms V.! 0, sfromD $ doms V.! 1)
-- Note that this function, and similarly @f@ and @rf@ instantiated
-- and passed to FoldR, is not a function on dual numbers.
df :: shaped rn sh -> (shaped rm shm, shaped rn sh, shaped rm shm)
Expand All @@ -504,14 +507,7 @@ instance ( ADReady ranked, ADReadySmall (ADVal ranked) (ADVal shaped)
rf :: shaped rn sh -> (shaped rn sh, shaped rm shm)
-> (shaped rn sh, shaped rm shm)
rf cx (x, a) =
let res = rf0 cx x a -- non-explicit sharing, so helps little
in ( sletDomainsIn
domsOD res
(\doms -> sfromD $ doms V.! 0)
, sletDomainsIn
domsOD res
(\doms -> sfromD $ doms V.! 1)
)
domsToPair $ dunDomains domsOD $ rf0 cx x a
in D (l1 `mergeADShare` l2)
(sfold @ranked f x0 as)
(FoldS f x0 as df rf x0' as')
Expand Down
14 changes: 4 additions & 10 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,6 @@ instance (GoodScalar r, Sh.Shape sh)

-- * Reverse and forward derivative stages instances

-- TODO: it's not clear if the instance should be of Clown OD.Array or of
-- DomainsOD, for which we already have unletAstDomains6, etc.;
-- let's wait until we have rev as a function of Tensor class in case
-- that affects rev and/or Delta
--instance DerivativeStages @() (Clown OD.Array) where
-- revEvalArtifact = undefined
-- revProduceArtifact = undefined

instance DerivativeStages (AstRanked FullSpan) where
forwardPassByInterpretation
:: (GoodScalar r, KnownNat n)
Expand Down Expand Up @@ -155,7 +147,8 @@ instance DerivativeStages (AstRanked FullSpan) where

fwdArtifactFromForwardPass
:: forall r n. (GoodScalar r, KnownNat n)
=> TensorToken (AstRanked FullSpan) -> (Domains (AstRanked PrimalSpan)
=> TensorToken (AstRanked FullSpan)
-> (Domains (AstRanked PrimalSpan)
-> [AstDynamicVarName]
-> Domains (AstRanked FullSpan)
-> ADVal (AstRanked PrimalSpan) r n)
Expand Down Expand Up @@ -252,7 +245,8 @@ instance DerivativeStages (AstShaped FullSpan) where

fwdArtifactFromForwardPass
:: forall r sh. (GoodScalar r, Sh.Shape sh)
=> TensorToken (AstShaped FullSpan) -> (Domains (AstRanked PrimalSpan)
=> TensorToken (AstShaped FullSpan)
-> (Domains (AstRanked PrimalSpan)
-> [AstDynamicVarName]
-> Domains (AstRanked FullSpan)
-> ADVal (AstShaped PrimalSpan) r sh)
Expand Down
20 changes: 4 additions & 16 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -609,38 +609,26 @@ rfromD (DynamicRanked @r2 @n2 t) = case sameNat (Proxy @n2) (Proxy @n) of
Just Refl -> t
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: rank mismatch"
rfromD (DynamicShaped @r2 @sh2 t) = case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS t
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: rank mismatch"
rfromD DynamicShaped{} = error "rfromD: unexpected DynamicShaped"
rfromD (DynamicRankedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS @_ @r2 @sh2 0
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: rank mismatch"
rfromD (DynamicShapedDummy @r2 @sh2 _ _) = case matchingRank @sh2 @n of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> rfromS @ranked @r2 @sh2 0
_ -> error "rfromD: type mismatch"
_ -> error "rfromD: rank mismatch"
rfromD DynamicShapedDummy{} = error "rfromD: unexpected DynamicShapedDummy"

sfromD :: forall shaped r sh.
( ShapedTensor shaped
, GoodScalar r, Sh.Shape sh
, ShapedOf (RankedOf shaped) ~ shaped )
=> DynamicTensor (RankedOf shaped) -> shaped r sh
sfromD (DynamicRanked @r2 @n2 t) = case matchingRank @sh @n2 of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> sfromR t
_ -> error "sfromD: type mismatch"
_ -> error "sfromD: rank mismatch"
sfromD DynamicRanked{} = error "sfromD: unexpected DynamicRanked"
sfromD (DynamicShaped @r2 @sh2 t) = case sameShape @sh2 @sh of
Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of
Just Refl -> t
_ -> error "sfromD: type mismatch"
_ -> error "sfromD: shape mismatch"
sfromD DynamicRankedDummy{} = 0
sfromD DynamicRankedDummy{} = error "sfromD: unexpected DynamicRankedDummy"
sfromD DynamicShapedDummy{} = 0


Expand Down
2 changes: 1 addition & 1 deletion test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2431,7 +2431,7 @@ testSin0Fold18SrevPP = do
(sreplicate @_ @2 a0)
in rfromS . f . sfromR) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "sconst 2.0 * ssum (ssum (sletDomainsIn (let x68 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v69 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v70 = sreplicate (sin x68) ; v71 = recip (v69 * v69 + sconst (fromList @[2] [0.0,0.0]) + v70 * v70 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v79 = ssum (stranspose (sletDomainsIn (let x74 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v75 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v76 = sreplicate (sin x74) ; v77 = recip (v75 * v75 + sconst (fromList @[2] [0.0,0.0]) + v76 * v76 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v78 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v76 * v77) * v78)), ssum (sreplicate (cos x74 * ssum (negate (v75 * v77) * v78))))) (\\[m72, x73] -> m72))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v70 * v71) * v79)), ssum (sreplicate (cos x68 * ssum (negate (v69 * v71) * v79))))) (\\[m66, x67] -> m66))) + ssum (sfromList [sletDomainsIn (let x86 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v87 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v88 = sreplicate (sin x86) ; v89 = recip (v87 * v87 + sconst (fromList @[2] [0.0,0.0]) + v88 * v88 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v97 = ssum (stranspose (sletDomainsIn (let x92 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v93 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v94 = sreplicate (sin x92) ; v95 = recip (v93 * v93 + sconst (fromList @[2] [0.0,0.0]) + v94 * v94 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v96 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v94 * v95) * v96)), ssum (sreplicate (cos x92 * ssum (negate (v93 * v95) * v96))))) (\\[m90, x91] -> m90))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v88 * v89) * v97)), ssum (sreplicate (cos x86 * ssum (negate (v87 * v89) * v97))))) (\\[m84, x85] -> x85), sletDomainsIn (let x102 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v103 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v104 = sreplicate (sin x102) ; v105 = recip (v103 * v103 + sconst (fromList @[2] [0.0,0.0]) + v104 * v104 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v106 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v104 * v105) * v106)), ssum (sreplicate (cos x102 * ssum (negate (v103 * v105) * v106))))) (\\[m100, x101] -> x101)])"
@?= "sconst 2.0 * ssum (ssum (sletDomainsIn (let x72 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v73 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v74 = sreplicate (sin x72) ; v75 = recip (v73 * v73 + sconst (fromList @[2] [0.0,0.0]) + v74 * v74 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v83 = ssum (stranspose (sletDomainsIn (let x78 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v79 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v80 = sreplicate (sin x78) ; v81 = recip (v79 * v79 + sconst (fromList @[2] [0.0,0.0]) + v80 * v80 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v82 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v80 * v81) * v82)), ssum (sreplicate (cos x78 * ssum (negate (v79 * v81) * v82))))) (\\[m76, x77] -> m76))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v74 * v75) * v83)), ssum (sreplicate (cos x72 * ssum (negate (v73 * v75) * v83))))) (\\[m70, x71] -> m70))) + ssum (sfromList [sletDomainsIn (let x88 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v89 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v90 = sreplicate (sin x88) ; v91 = recip (v89 * v89 + sconst (fromList @[2] [0.0,0.0]) + v90 * v90 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v99 = ssum (stranspose (sletDomainsIn (let x94 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v95 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v96 = sreplicate (sin x94) ; v97 = recip (v95 * v95 + sconst (fromList @[2] [0.0,0.0]) + v96 * v96 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v98 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v96 * v97) * v98)), ssum (sreplicate (cos x94 * ssum (negate (v95 * v97) * v98))))) (\\[m92, x93] -> m92))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v90 * v91) * v99)), ssum (sreplicate (cos x88 * ssum (negate (v89 * v91) * v99))))) (\\[m86, x87] -> x87), sletDomainsIn (let x102 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v103 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v104 = sreplicate (sin x102) ; v105 = recip (v103 * v103 + sconst (fromList @[2] [0.0,0.0]) + v104 * v104 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v106 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v104 * v105) * v106)), ssum (sreplicate (cos x102 * ssum (negate (v103 * v105) * v106))))) (\\[m100, x101] -> x101)])"

testSin0Fold8fwd :: Assertion
testSin0Fold8fwd = do
Expand Down

0 comments on commit b489e66

Please sign in to comment.