Skip to content

Commit

Permalink
Fix Delta eval for ScanR
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 1, 2024
1 parent 12cf49f commit 9163b79
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 8 deletions.
25 changes: 19 additions & 6 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,10 @@ buildFinMaps s0 deltaDt =
-> [(ranked r n1, ranked r n1, ranked rm m)]
-> (ranked r n1, [ranked rm m])
rg = mapAccumR $ \cr (cx, x, a) -> rf (cr + cx) (x, a)
(cx0, cas) = rg (rzero $ rshape x0) (zip3 cxs (init p) las)
(cx0, cas) = assert (length cxs == length p) $
rg (cxs !! 0) (zip3 (drop 1 cxs) (init p) las)
s2 = evalR sShared cx0 x0'
in evalR s2 (rfromList cas) as' -}
in evalR s2 (rfromList cas) as'
ScanR @rm @m @_ @_ @n1 f x0 as _df rf x0' as' -> -- n1 ~ n - 1
let cxs :: [ranked r n1]
cxs = runravelToList cShared
Expand All @@ -929,14 +930,25 @@ buildFinMaps s0 deltaDt =
las = runravelToList as
crs :: [ranked r n1]
crs = scanr (\(cx, x, a) cr -> fst $ rf (cr + cx) (x, a))
(rzero $ rshape x0) (zip3 cxs (init p) las)
(cxs !! 0) (zip3 (drop 1 cxs) (init p) las)
rg :: [ranked r n1] -> [ranked r n1] -> [ranked r n1]
-> [ranked rm m]
-> [ranked rm m]
rg = zipWith4 (\cr cx x a -> snd $ rf (cr + cx) (x, a))
cas = rg (drop 1 crs) cxs (init p) las
cas = rg (drop 1 crs) (drop 1 cxs) (init p) las
s2 = evalR sShared (crs !! 0) x0'
in evalR s2 (rfromList cas) as'
in evalR s2 (rfromList cas) as' -}
ScanR f x0 as _df rf x0' as' ->
let g (asPrefix, as'Prefix) = FoldR f x0 asPrefix _df rf x0' as'Prefix
-- starting from 0 would be better, but I'm
-- getting "tfromListR: shape ambiguity, no arguments"
initsViaSlice t = map (\k -> rslice @ranked 0 k t)
[1..rlength t]
initsViaSliceD t = map (\k -> SliceR 0 k t)
[1..lengthDelta @ranked t]
d = FromListR
$ x0' : map g (zip (initsViaSlice as) (initsViaSliceD as'))
in evalR s c d
{- Scan2R @rm @m @rp @p @_ @_ @n1 f x0 as bs _df rf x0' as' bs' ->
let cxs :: [ranked r n1]
cxs = runravelToList cShared
Expand Down Expand Up @@ -1286,7 +1298,8 @@ buildDerivative dimR deltaDt params = do
let lcas = runravelToList cas
las = runravelToList as
p = scanl' f x0 las
return $! rfromList $ scanl' df cx0 (zip3 lcas (init p) las)
return $! rfromList $ assert (length lcas == length las) $
scanl' df cx0 (zip3 lcas (init p) las)
ScanDR f x0 as df _rf x0' as' -> do
cx0 <- evalR x0'
let evalRDynamicRanked
Expand Down
220 changes: 220 additions & 0 deletions test/simplified/TestRevFwdFold.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ testTrees =
, testCase "4Sin0Rrev5S" testSin0Rrev5S
, testCase "4Sin0RrevPP5S" testSin0RrevPP5S
, testCase "4Sin0Fold0" testSin0Fold0
, testCase "4Sin0Fold0ForComparison" testSin0Fold0ForComparison
, testCase "4Sin0Fold1" testSin0Fold1
, testCase "4Sin0Fold2" testSin0Fold2
, testCase "4Sin0FoldForComparison" testSin0FoldForComparison
Expand Down Expand Up @@ -81,6 +82,26 @@ testTrees =
, testCase "4Sin0Fold8Sfwd2" testSin0Fold8Sfwd2
, testCase "4Sin0Fold5Sfwd" testSin0Fold5Sfwd
, testCase "4Sin0Fold5Sfwds" testSin0Fold5Sfwds
, testCase "4Sin0Scan0" testSin0Scan0
, testCase "4Sin0Scan1" testSin0Scan1
, testCase "4Sin0Scan1ForComparison" testSin0Scan1ForComparison
, testCase "4Sin0Scan2" testSin0Scan2
, testCase "4Sin0Scan3" testSin0Scan3
, testCase "4Sin0Scan4" testSin0Scan4
, testCase "4Sin0Scan5" testSin0Scan5
, testCase "4Sin0Scan6" testSin0Scan6
, testCase "4Sin0Scan7" testSin0Scan7
, testCase "4Sin0Scan8" testSin0Scan8
, testCase "4Sin0Scan8rev" testSin0Scan8rev
, testCase "4Sin0Scan8rev2" testSin0Scan8rev2
, testCase "4Sin0Scan1RevPP" testSin0Scan1RevPP
, testCase "4Sin0Scan1RevPPForComparison" testSin0Scan1RevPPForComparison
, testCase "4Sin0Scan0fwd" testSin0Scan0fwd
, testCase "4Sin0Scan1fwd" testSin0Scan1fwd
, testCase "4Sin0Scan1FwdForComparison" testSin0Scan1FwdForComparison
, testCase "4Sin0ScanFwdPP" testSin0ScanFwdPP
, testCase "4Sin0Scan8fwd" testSin0Scan8fwd
, testCase "4Sin0Scan8fwd2" testSin0Scan8fwd2
]

foo :: RealFloat a => (a, a, a) -> a
Expand Down Expand Up @@ -294,6 +315,14 @@ testSin0Fold0 = do
x0 (rzero @f @Double (0 :$ ZS))
in f) 1.1)

testSin0Fold0ForComparison :: Assertion
testSin0Fold0ForComparison = do
assertEqualUpToEpsilon' 1e-10
(1.0 :: OR.Array 0 Double)
(rev' (let f :: forall f. f Double 0 -> f Double 0
f = id
in f) 1.1)

testSin0Fold1 :: Assertion
testSin0Fold1 = do
assertEqualUpToEpsilon' 1e-10
Expand Down Expand Up @@ -662,3 +691,194 @@ testSin0Fold5Sfwds = do
(sreplicate @f @2
(sreplicate @f @5 a0)))
in f) 1.1 1.1)

testSin0Scan0 :: Assertion
testSin0Scan0 = do
assertEqualUpToEpsilon' 1e-10
1
(rev' (let f :: forall f. ADReady f => f Double 0 -> f Double 1
f x0 = rscan (\x _a -> sin x)
x0 (rzero @f @Double (0 :$ ZS))
in f) 1.1)

testSin0Scan1 :: Assertion
testSin0Scan1 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1,1] [1.4535961214255773] :: OR.Array 5 Double)
(rev' (\x0 -> rscan (\x _a -> sin x)
x0 (rconst (OR.constant @Double @1 [1] 42)))
(rreplicate0N [1,1,1,1,1] 1.1))

testSin0Scan1ForComparison :: Assertion
testSin0Scan1ForComparison = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1,1] [1.4535961214255773] :: OR.Array 5 Double)
(rev' (\x0 -> rfromList [x0, sin x0])
(rreplicate0N [1,1,1,1,1] 1.1))

testSin0Scan2 :: Assertion
testSin0Scan2 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1,1] [2.2207726343670955] :: OR.Array 5 Double)
(rev' (\x0 -> rscan (\x _a -> sin x)
x0 (rconst (OR.constant @Double @1 [5] 42)))
(rreplicate0N [1,1,1,1,1] 1.1))

testSin0Scan3 :: Assertion
testSin0Scan3 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1,1] [1.360788364276732] :: OR.Array 5 Double)
(rev' (\a0 -> rscan (\_x a -> sin a)
(rreplicate0N [1,1,1,1,1] 84)
(rreplicate 3 a0)) (rreplicate0N [1,1,1,1,1] 1.1))

testSin0Scan4 :: Assertion
testSin0Scan4 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1,1] [-0.4458209450295252] :: OR.Array 5 Double)
(rev' (\a0 -> rscan (\x a -> atan2 (sin x) (sin a))
(rreplicate0N [1,1,1,1,1] 2 * a0)
(rreplicate 3 a0)) (rreplicate0N [1,1,1,1,1] 1.1))

testSin0Scan5 :: Assertion
testSin0Scan5 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1,1] [4.126141830000979] :: OR.Array 4 Double)
(rev' (\a0 -> rscan (\x a -> rsum
$ atan2 (sin $ rreplicate 5 x)
(rsum $ sin $ rsum
$ rtr $ rreplicate 7 a))
(rreplicate0N [1,1,1,1] 2 * a0)
(rreplicate 3 (rreplicate 2 (rreplicate 5 a0))))
(rreplicate0N [1,1,1,1] 1.1))

testSin0Scan6 :: Assertion
testSin0Scan6 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1] [12] :: OR.Array 2 Double)
(rev' (\a0 -> rscan (\x a -> rtr
$ rtr x + rreplicate 1 (rreplicate 2 a))
(rreplicate 2 (rreplicate 1 a0))
(rreplicate 2 a0)) (rreplicate0N [1,1] 1.1))

testSin0Scan7 :: Assertion
testSin0Scan7 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1] [310] :: OR.Array 2 Double)
(rev' (\a0 -> rscan (\x _a -> rtr $ rreplicate 5
$ (rsum (rtr x)))
(rreplicate 2 (rreplicate 5 a0))
(rreplicate 2 a0)) (rreplicate0N [1,1] 1.1))

testSin0Scan8 :: Assertion
testSin0Scan8 = do
assertEqualUpToEpsilon' 1e-10
(OR.fromList [1,1,1] [9.532987357352765] :: OR.Array 3 Double)
(rev' (\a0 -> rscan (\x a -> rtr $ rreplicate 5
$ atan2 (rsum (rtr $ sin x))
(rreplicate 2
$ sin (rsum $ rreplicate 7 a)))
(rreplicate 2 (rreplicate 5 (rreplicate0N [1,1,1] 2 * a0)))
(rreplicate 3 a0)) (rreplicate0N [1,1,1] 1.1))

testSin0Scan8rev :: Assertion
testSin0Scan8rev = do
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [] [9.53298735735276])
(rrev1 @(Flip OR.Array) @Double @0 @3
(\a0 -> rscan (\x a -> rtr $ rreplicate 5
$ atan2 (rsum (rtr $ sin x))
(rreplicate 2
$ sin (rsum $ rreplicate 7 a)))
(rreplicate 2 (rreplicate 5 (2 * a0)))
(rreplicate 3 a0)) 1.1)

testSin0Scan8rev2 :: Assertion
testSin0Scan8rev2 = do
let h = rrev1 @(ADVal (Flip OR.Array)) @Double @0 @3
(\a0 -> rscan (\x a -> rtr $ rreplicate 5
$ atan2 (rsum (rtr $ sin x))
(rreplicate 2
$ sin (rsum $ rreplicate 7 a)))
(rreplicate 2 (rreplicate 5 (2 * a0)))
(rreplicate 3 a0))
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [] [285.9579482947575])
(crev h 1.1)

testSin0Scan1RevPP :: Assertion
testSin0Scan1RevPP = do
resetVarCounter
let a1 = rrev1 @(AstRanked FullSpan) @Double @0 @1
(\x0 -> rscan (\x _a -> sin x) x0
(rconst (OR.constant @Double @1 [1] 42))) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "cos (rconst 1.1) * rconst 1.0 + rconst 1.0"

testSin0Scan1RevPPForComparison :: Assertion
testSin0Scan1RevPPForComparison = do
resetVarCounter
let a1 = rrev1 @(AstRanked FullSpan) @Double @0 @1
(\x0 -> rfromList [x0, sin x0]) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "cos (rconst 1.1) * rconst 1.0 + rconst 1.0"

testSin0Scan0fwd :: Assertion
testSin0Scan0fwd = do
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [1] [1.1])
(rfwd1 @(Flip OR.Array) @Double @0 @1
(let f :: forall f. ADReady f => f Double 0 -> f Double 1
f x0 = rscan (\x _a -> sin x)
x0 (rzero @f @Double (0 :$ ZS))
in f) 1.1)

testSin0Scan1fwd :: Assertion
testSin0Scan1fwd = do
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [2] [1.1,0.4989557335681351])
(rfwd1 @(Flip OR.Array) @Double @0 @1
(\x0 -> rscan (\x _a -> sin x)
x0 (rconst (OR.constant @Double @1 [1] 42)))
1.1)

testSin0Scan1FwdForComparison :: Assertion
testSin0Scan1FwdForComparison = do
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [2] [1.1,0.4989557335681351])
(rfwd1 @(Flip OR.Array) @Double @0 @1
(\x0 -> rfromList [x0, sin x0]) 1.1)

testSin0ScanFwdPP :: Assertion
testSin0ScanFwdPP = do
resetVarCounter
let a1 = rfwd1 @(AstRanked FullSpan) @Double @0 @1
(\x0 -> rscan (\x _a -> sin x) x0
(rconst (OR.constant @Double @1 [1] 42))) 1.1
printAstPretty IM.empty (simplifyAst6 a1)
@?= "rfromList [rconst 1.1, rconst 1.1 * cos (rconst 1.1)]"

testSin0Scan8fwd :: Assertion
testSin0Scan8fwd = do
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [4,2,5] [2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,2.2,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.6450465372542022,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.2642905982717151,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279,-0.242034255165279])
(rfwd1 @(Flip OR.Array) @Double @0 @3
(\a0 -> rscan (\x a -> rtr $ rreplicate 5
$ atan2 (rsum (rtr $ sin x))
(rreplicate 2
$ sin (rsum $ rreplicate 7 a)))
(rreplicate 2 (rreplicate 5 (2 * a0)))
(rreplicate 3 a0)) 1.1)

testSin0Scan8fwd2 :: Assertion
testSin0Scan8fwd2 = do
let h = rfwd1 @(ADVal (Flip OR.Array)) @Double @0 @3
(\a0 -> rscan (\x a -> rtr $ rreplicate 5
$ atan2 (rsum (rtr $ sin x))
(rreplicate 2
$ sin (rsum $ rreplicate 7 a)))
(rreplicate 2 (rreplicate 5 (2 * a0)))
(rreplicate 3 a0))
assertEqualUpToEpsilon 1e-10
(Flip $ OR.fromList [] [324.086730481586])
(crev h 1.1)
4 changes: 2 additions & 2 deletions test/tool/CrossTesting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ assertEqualUpToEpsilon'
assertEqualUpToEpsilonWithMark "Derivatives" errMargin cderivative derivative
-- The formula for comparing derivative and gradient is due to @awf
-- at https://github.com/Mikolaj/horde-ad/issues/15#issuecomment-1063251319
assertEqualUpToEpsilonWithMark "Forward vs reverse"
1e-5 (rsum0 derivative) (rdot0 expected vals)
assertEqualUpToEpsilonWithMark "Reverse vs forward"
1e-5 (rdot0 expected vals) (rsum0 derivative)
-- No Eq instance, so let's compare the text.
show (simplifyAst6 $ simplifyAst6 astVectSimp)
@?= show (simplifyAst6 astVectSimp) -- more simplification is needed
Expand Down

0 comments on commit 9163b79

Please sign in to comment.