From b5854791a1a18b65f26a4045c51070d497f60b5d Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Sun, 24 Dec 2023 15:57:20 +0100 Subject: [PATCH] Simplify AstLetDomainsIn a little --- src/HordeAd/Core/AstInterpret.hs | 28 +++++++- src/HordeAd/Core/AstSimplify.hs | 85 ++++++++++++++++++++++-- src/HordeAd/Core/TensorAst.hs | 2 +- src/HordeAd/Core/TensorClass.hs | 2 +- test/simplified/TestAdaptorSimplified.hs | 15 +++-- test/tool/CrossTesting.hs | 3 +- 6 files changed, 118 insertions(+), 17 deletions(-) diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index 74ab39e99..0e0ba8a31 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -38,7 +38,7 @@ import HordeAd.Core.AstSimplify import HordeAd.Core.AstTools import HordeAd.Core.TensorClass import HordeAd.Core.Types -import HordeAd.Internal.OrthotopeOrphanInstances (sameShape) +import HordeAd.Internal.OrthotopeOrphanInstances (matchingRank, sameShape) import HordeAd.Util.ShapedList (ShapedList (..)) import HordeAd.Util.SizedIndex @@ -127,7 +127,19 @@ interpretAst !env = \case `blame` (sh, rshape t, var, t, env)) t _ -> error "interpretAst: type mismatch" _ -> error "interpretAst: wrong shape in environment" - Just{} -> error "interpretAst: wrong tensor kind in environment" + -- To impose such checks, we'd need to switch from OD tensors + -- to existential OR/OS tensors so that we can inspect + -- which it is and then seed Delta evaluation maps with that. + -- Just{} -> error "interpretAst: wrong tensor kind in environment" + Just (AstEnvElemS @sh2 @r2 t) -> case shapeToList sh == Sh.shapeT @sh2 of + True -> case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> rfromS @_ @_ @r2 @sh2 t + _ -> error "interpretAst: type mismatch" + _ -> error "interpretAst: wrong rank" + False -> error $ "interpretAst: wrong shape in environment" + `showFailure` + (sh, Sh.shapeT @sh2, var, t, env) Nothing -> error $ "interpretAst: unknown variable " ++ show var ++ " in environment " ++ show env AstLet var u v -> @@ -679,7 +691,17 @@ interpretAstS !env = \case Nothing -> error $ "interpretAstS: wrong shape in environment" `showFailure` (Sh.shapeT @sh, Sh.shapeT @sh2, var, t, env) - Just{} -> error "interpretAstS: wrong tensor kind in environment" + -- To impose such checks, we'd need to switch from OD tensors + -- to existential OR/OS tensors so that we can inspect + -- which it is and then seed Delta evaluation maps with that. + -- Just{} -> error "interpretAstS: wrong tensor kind in environment" + Just (AstEnvElemR @n2 @r2 t) -> case matchingRank @sh @n2 of + Just Refl -> case testEquality (typeRep @r) (typeRep @r2) of + Just Refl -> assert (Sh.shapeT @sh == shapeToList (rshape t) + `blame` (Sh.shapeT @sh, rshape t, var, t, env)) + $ sfromR @_ @_ @r2 @sh t + _ -> error "interpretAstS: type mismatch" + _ -> error "interpretAstS: wrong shape in environment" Nothing -> error $ "interpretAstS: unknown variable " ++ show var AstLetS var u v -> -- We assume there are no nested lets with the same variable. diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index 7cf4dc933..403c30890 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -1651,18 +1651,70 @@ astLetInDomainsS var u v | astIsSmallS True u = astLetInDomainsS var u v = Ast.AstLetInDomainsS var u v astLetDomainsIn - :: forall n s s2 r. (AstSpan s, KnownNat n) + :: forall n s s2 r. (AstSpan s, GoodScalar r, KnownNat n) => [AstDynamicVarName] -> AstDomains s -> AstRanked s2 r n -> AstRanked s2 r n -astLetDomainsIn vars l v = Ast.AstLetDomainsIn vars l v +astLetDomainsIn vars l v = + let sh = shapeAst v + in Sh.withShapeP (shapeToList sh) $ \proxy -> case proxy of + Proxy @sh | Just Refl <- matchingRank @sh @n -> case l of + Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomains + let f :: (AstDynamicVarName, DynamicExists (AstDynamic s)) + -> AstRanked s2 r n + -> AstRanked s2 r n + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicExists @r4 (Ast.AstRToD @n4 v3) ) + acc + | Just Refl <- matchingRank @sh3 @n4 + -- To impose such checks, we'd need to switch from OD tensors + -- to existential OR/OS tensors so that we can inspect + -- which it is and then seed Delta evaluation maps with that. + -- , Just Refl <- testEquality (typeRep @k) (typeRep @Nat) + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstLet (AstVarName varId) v3 acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicExists @r4 (Ast.AstSToD @sh4 v3) ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstSToR @sh + $ Ast.AstLetS (AstVarName varId) v3 $ Ast.AstRToS acc + f _ _ = error "astLetDomainsIn: corrupted arguments" + in foldr f v (zip vars (V.toList l3)) + _ -> Ast.AstLetDomainsIn vars l v + _ -> error "astLetDomainsIn: wrong rank of the argument" astLetDomainsInS :: forall sh s s2 r. (AstSpan s, Sh.Shape sh) => [AstDynamicVarName] -> AstDomains s -> AstShaped s2 r sh -> AstShaped s2 r sh -astLetDomainsInS vars l v = Ast.AstLetDomainsInS vars l v +astLetDomainsInS vars l v = + case someNatVal $ toInteger (length (Sh.shapeT @sh)) of + Just (SomeNat @n _) -> gcastWith (unsafeCoerce Refl :: n :~: Sh.Rank sh) + $ case l of + Ast.AstDomains l3 -> -- TODO: other cases: collect AstLetInDomainsS + let f :: (AstDynamicVarName, DynamicExists (AstDynamic s)) + -> AstShaped s2 r sh + -> AstShaped s2 r sh + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicExists @r4 (Ast.AstRToD @n4 v3) ) + acc + | Just Refl <- matchingRank @sh3 @n4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstRToS @sh + $ Ast.AstLet (AstVarName varId) v3 $ Ast.AstSToR acc + f ( AstDynamicVarName @_ @r3 @sh3 (AstVarName varId) + , DynamicExists @r4 (Ast.AstSToD @sh4 v3) ) + acc + | Just Refl <- sameShape @sh3 @sh4 + , Just Refl <- testEquality (typeRep @r3) (typeRep @r4) = + Ast.AstLetS (AstVarName varId) v3 acc + f _ _ = error "astLetDomainsInS: corrupted arguments" + in foldr f v (zip vars (V.toList l3)) + _ -> Ast.AstLetDomainsInS vars l v + _ -> error "astLetDomainsInS: impossible someNatVal" -- * The simplifying bottom-up pass @@ -2211,7 +2263,19 @@ substitute1Ast i var v1 = case v1 of _ -> error "substitute1Ast: scalar" _ -> error "substitute1Ast: rank" _ -> error "substitute1Ast: span" - _ -> error "substitute1Ast: type" + -- To impose such checks, we'd need to switch from OD tensors + -- to existential OR/OS tensors so that we can inspect + -- which it is and then seed Delta evaluation maps with that. + -- _ -> error "substitute1Ast: type" + SubstitutionPayloadShaped @_ @_ @sh2 t -> case sameAstSpan @s @s2 of + Just Refl -> case shapeToList sh == Sh.shapeT @sh2 of + True -> case matchingRank @sh2 @n of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> Just $ astSToR t + _ -> error "substitute1Ast: scalar" + _ -> error "substitute1Ast: rank" + False -> error "substitute1Ast: shape" + _ -> error "substitute1Ast: span" else Nothing Ast.AstLet var2 u v -> case (substitute1Ast i var u, substitute1Ast i var v) of @@ -2464,7 +2528,18 @@ substitute1AstS i var = \case _ -> error "substitute1AstS: scalar" _ -> error "substitute1AstS: shape" _ -> error "substitute1Ast: span" - _ -> error "substitute1AstS: type" + -- To impose such checks, we'd need to switch from OD tensors + -- to existential OR/OS tensors so that we can inspect + -- which it is and then seed Delta evaluation maps with that. + -- _ -> error "substitute1AstS: type" + SubstitutionPayloadRanked @_ @_ @m t -> case sameAstSpan @s @s2 of + Just Refl -> case matchingRank @sh @m of + Just Refl -> case testEquality (typeRep @r2) (typeRep @r) of + Just Refl -> assert (Sh.shapeT @sh == shapeToList (shapeAst t)) + $ Just $ astRToS t + _ -> error "substitute1Ast: scalar" + _ -> error "substitute1Ast: rank" + _ -> error "substitute1Ast: span" else Nothing Ast.AstLetS var2 u v -> case (substitute1AstS i var u, substitute1AstS i var v) of diff --git a/src/HordeAd/Core/TensorAst.hs b/src/HordeAd/Core/TensorAst.hs index 06884c4aa..1afc77a21 100644 --- a/src/HordeAd/Core/TensorAst.hs +++ b/src/HordeAd/Core/TensorAst.hs @@ -468,7 +468,7 @@ isTensorDummyAst t = case t of -- TODO: move the impure part to AstFreshId astLetDomainsInFun - :: forall n s r. (AstSpan s, KnownNat n) + :: forall n s r. (AstSpan s, GoodScalar r, KnownNat n) => DomainsOD -> AstDomains s -> (Domains (AstDynamic s) -> AstRanked s r n) -> AstRanked s r n {-# NOINLINE astLetDomainsInFun #-} diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index 4e4d6b5b1..687cd4923 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -237,7 +237,7 @@ class ( Integral (IntOf ranked), CRanked ranked Num rzero :: (GoodScalar r, KnownNat n) => ShapeInt n -> ranked r n rzero sh = rreplicate0N sh 0 - rletDomainsIn :: KnownNat n + rletDomainsIn :: (KnownNat n, GoodScalar r) => DomainsOD -> DomainsOf ranked -> (Domains (DynamicOf ranked) -> ranked r n) diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index 3e6b53d29..c014a13e9 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -2045,13 +2045,14 @@ testSin0RrevPP1 = do resetVarCounter let a1 = rrev1 @(AstRanked FullSpan) @Double @0 @0 sin 1.1 printAstPretty IM.empty a1 - @?= "rletDomainsIn (cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0))) (\\[dret] -> dret)" + @?= "let dret = cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0)) in dret" testSin0RrevPP2 :: Assertion testSin0RrevPP2 = do + resetVarCounter let a1 = rrev1 @(AstRanked FullSpan) @Double @0 @0 sin 1.1 printAstSimple IM.empty a1 - @?= "rletDomainsIn (dmkDomains (fromList [dfromR (cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0)))])) (\\[dret] -> dret)" + @?= "rlet (cos (rconst 1.1) * rreshape [] (rreplicate 1 (rconst 1.0))) (\\dret -> dret)" testSin0Rrev3 :: Assertion testSin0Rrev3 = do @@ -2070,7 +2071,7 @@ testSin0RrevPP4 :: Assertion testSin0RrevPP4 = do let a1 = (rrev1 sin . rrev1 @(AstRanked FullSpan) @Double @0 @0 sin) 1.1 printAstPretty IM.empty (simplifyAst6 a1) - @?= "rletDomainsIn (cos (rletDomainsIn (cos (rconst 1.1) * rconst 1.0) (\\[dret] -> dret)) * rconst 1.0) (\\[x4] -> x4)" + @?= "cos (cos (rconst 1.1) * rconst 1.0) * rconst 1.0" testSin0Rrev5 :: Assertion testSin0Rrev5 = do @@ -2080,9 +2081,10 @@ testSin0Rrev5 = do testSin0RrevPP5 :: Assertion testSin0RrevPP5 = do + resetVarCounter let a1 = rrev1 @(AstRanked FullSpan) @Double @0 @0 (rrev1 sin) 1.1 printAstPretty IM.empty (simplifyAst6 a1) - @?= "rletDomainsIn (negate (sin (rconst 1.1)) * (rconst 1.0 * rconst 1.0)) (\\[x7] -> x7)" + @?= "let dret = negate (sin (rconst 1.1)) * (rconst 1.0 * rconst 1.0) in dret" testSin0Rrev3' :: Assertion testSin0Rrev3' = do @@ -2172,9 +2174,10 @@ testSin0Rrev5S = do testSin0RrevPP5S :: Assertion testSin0RrevPP5S = do + resetVarCounter let a1 = srev1 @(AstShaped FullSpan) @Double @'[] @'[] (srev1 sin) 1.1 printAstPrettyS IM.empty (simplifyAst6S a1) - @?= "sletDomainsIn (negate (sin (sconst 1.1)) * sconst 1.0) (\\[x645] -> x645)" + @?= "let dret = negate (sin (sconst 1.1)) * sconst 1.0 in dret" testSin0Fold0 :: Assertion testSin0Fold0 = do @@ -2445,7 +2448,7 @@ testSin0Fold18SrevPP = do (sreplicate @_ @2 a0) in rfromS . f . sfromR) 1.1 printAstPretty IM.empty (simplifyAst6 a1) - @?= "rletDomainsIn (sconst 2.0 * ssum (ssum (sletDomainsIn (let x68 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v69 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v70 = sreplicate (sin x68) ; v71 = recip (v69 * v69 + sconst (fromList @[2] [0.0,0.0]) + v70 * v70 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v79 = ssum (stranspose (sletDomainsIn (let x74 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v75 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v76 = sreplicate (sin x74) ; v77 = recip (v75 * v75 + sconst (fromList @[2] [0.0,0.0]) + v76 * v76 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v78 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v76 * v77) * v78)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x74 * ssum (negate (v75 * v77) * v78))) + rconst 0.0)) (\\[m72, x73] -> m72))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v70 * v71) * v79)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x68 * ssum (negate (v69 * v71) * v79))) + rconst 0.0)) (\\[m66, x67] -> m66))) + ssum (sfromList [sletDomainsIn (let x90 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v91 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v92 = sreplicate (sin x90) ; v93 = recip (v91 * v91 + sconst (fromList @[2] [0.0,0.0]) + v92 * v92 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v101 = ssum (stranspose (sletDomainsIn (let x96 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v97 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v98 = sreplicate (sin x96) ; v99 = recip (v97 * v97 + sconst (fromList @[2] [0.0,0.0]) + v98 * v98 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v100 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v98 * v99) * v100)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x96 * ssum (negate (v97 * v99) * v100))) + rconst 0.0)) (\\[m94, x95] -> m94))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v92 * v93) * v101)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x90 * ssum (negate (v91 * v93) * v101))) + rconst 0.0)) (\\[m88, x89] -> x89), sletDomainsIn (let x110 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v111 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v112 = sreplicate (sin x110) ; v113 = recip (v111 * v111 + sconst (fromList @[2] [0.0,0.0]) + v112 * v112 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v114 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v112 * v113) * v114)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x110 * ssum (negate (v111 * v113) * v114))) + rconst 0.0)) (\\[m108, x109] -> x109)])) (\\[dret] -> dret)" + @?= "sconst 2.0 * ssum (ssum (sletDomainsIn (let x68 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v69 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v70 = sreplicate (sin x68) ; v71 = recip (v69 * v69 + sconst (fromList @[2] [0.0,0.0]) + v70 * v70 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v79 = ssum (stranspose (sletDomainsIn (let x74 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v75 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v76 = sreplicate (sin x74) ; v77 = recip (v75 * v75 + sconst (fromList @[2] [0.0,0.0]) + v76 * v76 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v78 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v76 * v77) * v78)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x74 * ssum (negate (v75 * v77) * v78))) + rconst 0.0)) (\\[m72, x73] -> m72))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v70 * v71) * v79)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x68 * ssum (negate (v69 * v71) * v79))) + rconst 0.0)) (\\[m66, x67] -> m66))) + ssum (sfromList [sletDomainsIn (let x90 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])) ; v91 = ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))))) ; v92 = sreplicate (sin x90) ; v93 = recip (v91 * v91 + sconst (fromList @[2] [0.0,0.0]) + v92 * v92 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v101 = ssum (stranspose (sletDomainsIn (let x96 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v97 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v98 = sreplicate (sin x96) ; v99 = recip (v97 * v97 + sconst (fromList @[2] [0.0,0.0]) + v98 * v98 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v100 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v98 * v99) * v100)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x96 * ssum (negate (v97 * v99) * v100))) + rconst 0.0)) (\\[m94, x95] -> m94))) in (cos (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1)))) * stranspose (sreplicate ((v92 * v93) * v101)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x90 * ssum (negate (v91 * v93) * v101))) + rconst 0.0)) (\\[m88, x89] -> x89), sletDomainsIn (let x110 = ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [1])) ; v111 = ssum (stranspose (sin (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))))) ; v112 = sreplicate (sin x110) ; v113 = recip (v111 * v111 + sconst (fromList @[2] [0.0,0.0]) + v112 * v112 + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0]) + sconst (fromList @[2] [0.0,0.0])) ; v114 = sconstant (ssum (stranspose (rreplicate 2 (rreplicate 5 (rconst 1.0))))) in (cos (stranspose (sreplicate (atan2 (ssum (stranspose (sin (sreplicate (sreplicate (sconst 2.0 * sconstant (rconst 1.1))))))) (sreplicate (sin (ssum (sreplicate (sconstant (sreplicate (rconst 1.1)) !$ [0])))))))) * stranspose (sreplicate ((v112 * v113) * v114)) + rconstant (rreplicate 2 (rreplicate 5 (rconst 0.0))), ssum (sreplicate (cos x110 * ssum (negate (v111 * v113) * v114))) + rconst 0.0)) (\\[m108, x109] -> x109)])" testSin0Fold8fwd :: Assertion testSin0Fold8fwd = do diff --git a/test/tool/CrossTesting.hs b/test/tool/CrossTesting.hs index 7c5013c35..417c58e03 100644 --- a/test/tool/CrossTesting.hs +++ b/test/tool/CrossTesting.hs @@ -334,7 +334,8 @@ assertEqualUpToEpsilon' assertEqualUpToEpsilonWithMark "Forward vs reverse" 1e-5 (rsum0 derivative) (rdot0 expected vals) -- No Eq instance, so let's compare the text. - show (simplifyAst6 astVectSimp) @?= show astVectSimp + show (simplifyAst6 $ simplifyAst6 astVectSimp) + @?= show (simplifyAst6 astVectSimp) -- more simplification is needed show (simplifyAst6 astSimp) @?= show astSimp assertEqualUpToEpsilonShort