From 603010715addd905fe9d5cd1cfe12d05513d0b6b Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Mon, 2 Dec 2024 07:37:44 +0100 Subject: [PATCH] Interpret trivial gathers and scatters better --- src/HordeAd/Core/AstInterpret.hs | 22 +++++++++++++++------- test/simplified/TestAdaptorSimplified.hs | 12 ++++++------ test/simplified/TestConvSimplified.hs | 6 ++++-- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index ab96436b9..742fbb024 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -31,16 +31,13 @@ import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.Shape (pattern (:.%), pattern ZIX) import Data.Array.Nested - ( IxS (..) + ( IxR (..) + , IxS (..) , KnownShS (..) + , ListR (..) + , ListS (..) , Rank , ShR (..) - , pattern (:$:) - , pattern (:.$) - , pattern (:.:) - , pattern ZIR - , pattern ZIS - , pattern ZSR , type (++) ) import Data.Array.Nested qualified as Nested @@ -55,6 +52,7 @@ import HordeAd.Core.HVector import HordeAd.Core.HVectorOps import HordeAd.Core.TensorClass import HordeAd.Core.Types +import HordeAd.Util.SizedList interpretAstPrimalRuntimeSpecialized :: forall target n r. @@ -511,6 +509,8 @@ interpretAst !env = \case -- TODO: recognize when sum0 may be used instead, which is much cheaper -- or should I do that in Delta instead? no, because tsum0R -- is cheaper, too + AstScatter sh v (ZR, ix) -> + roneHot (takeShape sh) (interpretAst env v) (interpretAstPrimal env <$> ix) AstScatter sh v (vars, ix) -> let t1 = interpretAst env v f2 = interpretLambdaIndexToIndex interpretAstPrimal env (vars, ix) @@ -537,6 +537,8 @@ interpretAst !env = \case AstReshape sh (AstLet var v (AstReplicate k t)) -> interpretAst env (AstLet var v (AstReshape sh (AstReplicate k t))) AstReshape sh v -> rreshape sh (interpretAst env v) + AstGather _ v (ZR, ix) -> + rindex (interpretAst env v) (interpretAstPrimal env <$> ix) AstGather sh AstIotaR (vars, i :.: ZIR) -> rbuild sh (interpretLambdaIndex interpretAst env (vars, fromPrimal @s $ AstFromIntegralR $ AstRFromS $ AstFromScalar i)) @@ -758,6 +760,9 @@ interpretAst !env = \case -- TODO: recognize when sum0 may be used instead, which is much cheaper -- or should I do that in Delta instead? no, because tsum0R -- is cheaper, too + AstScatterS @_ @p @sh v (ZS, ix) -> + gcastWith (unsafeCoerce Refl :: Take p sh ++ Drop p sh :~: sh) + soneHot (interpretAst env v) (interpretAstPrimal env <$> ix) AstScatterS v (vars, ix) -> let t1 = interpretAst env v f2 = interpretLambdaIndexToIndexS interpretAstPrimal env (vars, ix) @@ -781,6 +786,9 @@ interpretAst !env = \case AstReverseS v -> sreverse (interpretAst env v) AstTransposeS perm v -> stranspose perm $ interpretAst env v AstReshapeS v -> sreshape (interpretAst env v) + AstGatherS @_ @p @sh v (ZS, ix) -> + gcastWith (unsafeCoerce Refl :: Take p sh ++ Drop p sh :~: sh) + sindex (interpretAst env v) (interpretAstPrimal env <$> ix) AstGatherS @sh2 @p @sh @r AstIotaS (vars, i :.$ ZIS) -> gcastWith (unsafeCoerce Refl :: Take (Rank sh2) sh2 :~: sh2) $ gcastWith (unsafeCoerce Refl :: Drop (Rank sh2) sh2 :~: '[]) diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index 2c8cf9ee2..6042550f1 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -1013,13 +1013,13 @@ testReluSimplerPP4s2 = do reluT2 (t, r) = reluS (t * sreplicate0N r) let (artifactRev, _deltas) = revArtifactAdapt True reluT2 (srepl 128, srepl 42) printArtifactPretty renames artifactRev - @?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) ; m12 = m10 * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1 * m12)))" + @?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) ; m12 = m10 * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1 * m12)))" printArtifactPrimalPretty renames artifactRev - @?= "\\x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) in m10 * m6" + @?= "\\x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) in m10 * m6" printArtifactPretty renames (simplifyArtifact artifactRev) - @?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m12 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (tproject1 m1 !$ [i7, i8] * m5 !$ [i7, i8] <=. sscalar 0.0) 0 1]) * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1) * sreshape m12))" + @?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m12 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (tproject1 m1 !$ [i8, i9] * m5 !$ [i8, i9] <=. sscalar 0.0) 0 1]) * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1) * sreshape m12))" printArtifactPrimalPretty renames (simplifyArtifact artifactRev) - @?= "\\x1 -> let m6 = tproject1 m1 * sreshape (sreplicate (tproject2 m1)) in sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) * m6" + @?= "\\x1 -> let m6 = tproject1 m1 * sreshape (sreplicate (tproject2 m1)) in sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) * m6" testReluSimpler4s :: Assertion testReluSimpler4s = do @@ -2113,6 +2113,6 @@ testConcatBuild3PP2 = do printArtifactSimple renames artifactRev @?= "\\m8 x1 -> rreshape [3] (rreplicate 3 (rscalar 0.0))" printArtifactPrimalSimple renames artifactRev - @?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i5, i6] -> [i5, i6, ifF (i6 >=. quotF i5 (1 + i6)) 0 1])))" + @?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i6, i7] -> [i6, i7, ifF (i7 >=. quotF i6 (1 + i7)) 0 1])))" printArtifactPrimalSimple renames (simplifyArtifact artifactRev) - @?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i5, i6] -> [i5, i6, ifF (i6 >=. quotF i5 (1 + i6)) 0 1])))" + @?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i6, i7] -> [i6, i7, ifF (i7 >=. quotF i6 (1 + i7)) 0 1])))" diff --git a/test/simplified/TestConvSimplified.hs b/test/simplified/TestConvSimplified.hs index e35ae0db4..9df086875 100644 --- a/test/simplified/TestConvSimplified.hs +++ b/test/simplified/TestConvSimplified.hs @@ -615,7 +615,7 @@ testCNNOPP2 :: Assertion testCNNOPP2 = do resetVarCounter printAstPretty IM.empty maxPool2dUnpadded2 - @?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [rreplicate 1 (rreplicate 1 1), 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. rreplicate 1 (rreplicate 1 1)) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))" + @?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [1, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. 1) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))" maxPool2dUnpadded2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double) @@ -717,7 +717,9 @@ testCNNOPP4 = do afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph printAstPretty IM.empty afcnn2T - @?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. rreplicate 1 (rreplicate 1 1) &&* 1 >. rreplicate 1 (rreplicate 1 1)) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))" + @?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. 1 &&* 1 >. 1) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))" + printAstPretty IM.empty (simplifyInline afcnn2T) + @?= "rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 1 (rscalar 0.0))))" maxPool2dUnpadded4 :: (ADReady target, GoodScalar r)