Skip to content

Commit

Permalink
Interpret trivial gathers and scatters better
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 2, 2024
1 parent 4c5ceea commit 6030107
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
22 changes: 15 additions & 7 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,13 @@ import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Mixed.Shape (pattern (:.%), pattern ZIX)
import Data.Array.Nested
( IxS (..)
( IxR (..)
, IxS (..)
, KnownShS (..)
, ListR (..)
, ListS (..)
, Rank
, ShR (..)
, pattern (:$:)
, pattern (:.$)
, pattern (:.:)
, pattern ZIR
, pattern ZIS
, pattern ZSR
, type (++)
)
import Data.Array.Nested qualified as Nested
Expand All @@ -55,6 +52,7 @@ import HordeAd.Core.HVector
import HordeAd.Core.HVectorOps
import HordeAd.Core.TensorClass
import HordeAd.Core.Types
import HordeAd.Util.SizedList

interpretAstPrimalRuntimeSpecialized
:: forall target n r.
Expand Down Expand Up @@ -511,6 +509,8 @@ interpretAst !env = \case
-- TODO: recognize when sum0 may be used instead, which is much cheaper
-- or should I do that in Delta instead? no, because tsum0R
-- is cheaper, too
AstScatter sh v (ZR, ix) ->
roneHot (takeShape sh) (interpretAst env v) (interpretAstPrimal env <$> ix)
AstScatter sh v (vars, ix) ->
let t1 = interpretAst env v
f2 = interpretLambdaIndexToIndex interpretAstPrimal env (vars, ix)
Expand All @@ -537,6 +537,8 @@ interpretAst !env = \case
AstReshape sh (AstLet var v (AstReplicate k t)) ->
interpretAst env (AstLet var v (AstReshape sh (AstReplicate k t)))
AstReshape sh v -> rreshape sh (interpretAst env v)
AstGather _ v (ZR, ix) ->
rindex (interpretAst env v) (interpretAstPrimal env <$> ix)
AstGather sh AstIotaR (vars, i :.: ZIR) ->
rbuild sh (interpretLambdaIndex interpretAst env
(vars, fromPrimal @s $ AstFromIntegralR $ AstRFromS $ AstFromScalar i))
Expand Down Expand Up @@ -758,6 +760,9 @@ interpretAst !env = \case
-- TODO: recognize when sum0 may be used instead, which is much cheaper
-- or should I do that in Delta instead? no, because tsum0R
-- is cheaper, too
AstScatterS @_ @p @sh v (ZS, ix) ->
gcastWith (unsafeCoerce Refl :: Take p sh ++ Drop p sh :~: sh)
soneHot (interpretAst env v) (interpretAstPrimal env <$> ix)
AstScatterS v (vars, ix) ->
let t1 = interpretAst env v
f2 = interpretLambdaIndexToIndexS interpretAstPrimal env (vars, ix)
Expand All @@ -781,6 +786,9 @@ interpretAst !env = \case
AstReverseS v -> sreverse (interpretAst env v)
AstTransposeS perm v -> stranspose perm $ interpretAst env v
AstReshapeS v -> sreshape (interpretAst env v)
AstGatherS @_ @p @sh v (ZS, ix) ->
gcastWith (unsafeCoerce Refl :: Take p sh ++ Drop p sh :~: sh)
sindex (interpretAst env v) (interpretAstPrimal env <$> ix)
AstGatherS @sh2 @p @sh @r AstIotaS (vars, i :.$ ZIS) ->
gcastWith (unsafeCoerce Refl :: Take (Rank sh2) sh2 :~: sh2)
$ gcastWith (unsafeCoerce Refl :: Drop (Rank sh2) sh2 :~: '[])
Expand Down
12 changes: 6 additions & 6 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,13 +1013,13 @@ testReluSimplerPP4s2 = do
reluT2 (t, r) = reluS (t * sreplicate0N r)
let (artifactRev, _deltas) = revArtifactAdapt True reluT2 (srepl 128, srepl 42)
printArtifactPretty renames artifactRev
@?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) ; m12 = m10 * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1 * m12)))"
@?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) ; m12 = m10 * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1 * m12)))"
printArtifactPrimalPretty renames artifactRev
@?= "\\x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) in m10 * m6"
@?= "\\x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m6 = tproject1 m1 * m5 ; m10 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) in m10 * m6"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m12 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (tproject1 m1 !$ [i7, i8] * m5 !$ [i7, i8] <=. sscalar 0.0) 0 1]) * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1) * sreshape m12))"
@?= "\\m11 x1 -> let m5 = sreshape (sreplicate (tproject2 m1)) ; m12 = sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (tproject1 m1 !$ [i8, i9] * m5 !$ [i8, i9] <=. sscalar 0.0) 0 1]) * m11 in tpair (m5 * m12, ssum (sreshape (tproject1 m1) * sreshape m12))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> let m6 = tproject1 m1 * sreshape (sreplicate (tproject2 m1)) in sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i7, i8] -> [i7, ifF (m6 !$ [i7, i8] <=. sscalar 0.0) 0 1]) * m6"
@?= "\\x1 -> let m6 = tproject1 m1 * sreshape (sreplicate (tproject2 m1)) in sgather (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m6 !$ [i8, i9] <=. sscalar 0.0) 0 1]) * m6"

testReluSimpler4s :: Assertion
testReluSimpler4s = do
Expand Down Expand Up @@ -2113,6 +2113,6 @@ testConcatBuild3PP2 = do
printArtifactSimple renames artifactRev
@?= "\\m8 x1 -> rreshape [3] (rreplicate 3 (rscalar 0.0))"
printArtifactPrimalSimple renames artifactRev
@?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i5, i6] -> [i5, i6, ifF (i6 >=. quotF i5 (1 + i6)) 0 1])))"
@?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i6, i7] -> [i6, i7, ifF (i7 >=. quotF i6 (1 + i7)) 0 1])))"
printArtifactPrimalSimple renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i5, i6] -> [i5, i6, ifF (i6 >=. quotF i5 (1 + i6)) 0 1])))"
@?= "\\x1 -> rfromIntegral (rfromS (sgather (stranspose (sfromVector (fromList [sreplicate siota, quotF (stranspose (sreplicate siota)) (sreplicate (tconcrete (FTKS [2] FTKScalar) (sfromListLinear [2] [0,0]) + siota + sreplicate (sscalar 1)))]))) (\\[i6, i7] -> [i6, i7, ifF (i7 >=. quotF i6 (1 + i7)) 0 1])))"
6 changes: 4 additions & 2 deletions test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ testCNNOPP2 :: Assertion
testCNNOPP2 = do
resetVarCounter
printAstPretty IM.empty maxPool2dUnpadded2
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [rreplicate 1 (rreplicate 1 1), 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. rreplicate 1 (rreplicate 1 1)) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [1, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. 1) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"

maxPool2dUnpadded2
:: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
Expand Down Expand Up @@ -717,7 +717,9 @@ testCNNOPP4 = do
afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph
printAstPretty IM.empty afcnn2T
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. rreplicate 1 (rreplicate 1 1) &&* 1 >. rreplicate 1 (rreplicate 1 1)) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. 1 &&* 1 >. 1) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
printAstPretty IM.empty (simplifyInline afcnn2T)
@?= "rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 1 (rscalar 0.0))))"

maxPool2dUnpadded4
:: (ADReady target, GoodScalar r)
Expand Down

0 comments on commit 6030107

Please sign in to comment.