Skip to content

Commit

Permalink
Simplify and fix vectorization of AstReplicate
Browse files Browse the repository at this point in the history
Mikolaj committed Dec 15, 2024
1 parent 1a376fb commit bce8ba6
Showing 4 changed files with 10 additions and 43 deletions.
40 changes: 4 additions & 36 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
@@ -237,42 +237,10 @@ build1V snat@SNat (var, v0) =
STKProduct{} -> error "TODO"
STKUntyped -> error "TODO"
_ -> error "TODO"
Ast.AstReplicate @y2 snat2@(SNat @k2) v -> traceRule $
let repl2Stk :: forall z.
STensorKindType z
-> AstTensor AstMethodLet s (BuildTensorKind k z)
-> AstTensor AstMethodLet s (BuildTensorKind k
(BuildTensorKind k2 z))
repl2Stk stk u = case stk of
STKScalar{} -> u
STKR SNat STKScalar{} -> astTr $ astReplicate snat2 u
STKS sh STKScalar{} -> withKnownShS sh $ astTrS $ astReplicate snat2 u
STKX sh STKScalar{} -> withKnownShX sh $ astTrX $ astReplicate snat2 u
STKProduct @z1 @z2 stk1 stk2
| (Dict, Dict) <- lemTensorKind1OfBuild snat stk1
, Dict <- lemTensorKindOfBuild snat2 stk1
, (Dict, Dict) <- lemTensorKind1OfBuild
snat (stensorKind @(BuildTensorKind k2 z1))
, (Dict, Dict) <- lemTensorKind1OfBuild snat stk2
, Dict <- lemTensorKindOfBuild snat2 stk2
, (Dict, Dict) <- lemTensorKind1OfBuild
snat (stensorKind @(BuildTensorKind k2 z2)) ->
astLetFun u $ \ !uShared ->
let (u1, u2) = (astProject1 uShared, astProject2 uShared)
in astPair (repl2Stk stk1 u1) (repl2Stk stk2 u2)
STKUntyped ->
astTrAstHVector
$ fun1DToAst (shapeAstHVector u) $ \ !vars !asts ->
astLetHVectorIn
vars
u
(Ast.AstMkHVector
$ replicate1HVectorF
(\k3 -> withSNat k3 $ \snat3 -> astReplicate snat3)
(astReplicate SNat)
snat2 asts)
_ -> error "TODO"
in repl2Stk (stensorKind @y2) (build1V snat (var, v))
Ast.AstReplicate @y2 snat2@(SNat @k2) v
| Dict <- lemTensorKindOfBuild snat (stensorKind @y2) -> traceRule $
astTrGeneral @k2 (stensorKind @y2) (astReplicate snat2
$ build1V snat (var, v))
Ast.AstBuild1 snat2 (var2, v2) ->
build1VOccurenceUnknown snat (var, build1VOccurenceUnknown snat2 (var2, v2))
-- happens only when testing and mixing different pipelines
1 change: 0 additions & 1 deletion src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@ module HordeAd.Core.TensorClass
import Prelude

import Data.Kind (Constraint, Type)
import Data.List (foldl')
import Data.List.NonEmpty (NonEmpty)
import Data.List.NonEmpty qualified as NonEmpty
import Data.Proxy (Proxy (Proxy))
4 changes: 2 additions & 2 deletions test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
@@ -615,7 +615,7 @@ testCNNOPP2 :: Assertion
testCNNOPP2 = do
resetVarCounter
printAstPretty IM.empty maxPool2dUnpadded2
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [1, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. 1) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [let x26 = 0 + 1 in rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [i26, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. i26) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"

maxPool2dUnpadded2
:: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
@@ -717,7 +717,7 @@ testCNNOPP4 = do
afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph
printAstPretty IM.empty afcnn2T
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. 1 &&* 1 >. 1) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [let x21 = 0 + 1 in rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. i21 &&* 1 >. i21) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
printAstPretty IM.empty (simplifyInline afcnn2T)
@?= "rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 1 (rscalar 0.0))))"

8 changes: 4 additions & 4 deletions test/simplified/TestMnistCNNR.hs

Large diffs are not rendered by default.

0 comments on commit bce8ba6

Please sign in to comment.