Skip to content

Commit

Permalink
Fix tests by making BuildTensorKind identity on scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 29, 2024
1 parent 8c76466 commit 6db3aa1
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 24 deletions.
9 changes: 4 additions & 5 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -419,20 +419,20 @@ interpretAst !env = \case
(AstSum (AstN2R TimesOp (AstTranspose tperm t)
(AstTranspose uperm u))))
AstSum (AstN2R TimesOp (AstTranspose tperm t)
(AstLet varu vu (AstTranspose uperm u))) ->
(AstLet varu vu (AstTranspose uperm u))) ->
interpretAst env
(AstLet varu vu
(AstSum (AstN2R TimesOp (AstTranspose tperm t)
(AstTranspose uperm u))))
AstSum (AstN2R TimesOp (AstLet vart vt (AstTranspose tperm t))
(AstLet varu vu (AstTranspose uperm u))) ->
(AstLet varu vu (AstTranspose uperm u))) ->
interpretAst env
(AstLet vart vt (AstLet varu vu
(AstSum (AstN2R TimesOp (AstTranspose tperm t)
(AstTranspose uperm u)))))
(AstTranspose uperm u)))))
AstSum @n @r
v@(AstN2R TimesOp (AstTranspose tperm (AstReplicate @yt _tk t))
(AstTranspose uperm (AstReplicate @yu _uk u)))
(AstTranspose uperm (AstReplicate @yu _uk u)))
| Just Refl <- sameNat (Proxy @n) (Proxy @2) ->
case (stensorKind @yt, stensorKind @yu) of
(STKR{}, STKR{}) ->
Expand Down Expand Up @@ -487,7 +487,6 @@ interpretAst !env = \case
-- ttr
-- $ interpretMatmul2 (AstTranspose [1, 0] u) (AstTranspose [1, 0] t)
_ -> rsum $ interpretAst env v
_ -> error "interpretAst: type family BuildTensorKind stuck at TKScalar"
AstSum @n (AstN2R TimesOp t u)
| Just Refl <- sameNat (Proxy @n) (Proxy @0) ->
let t1 = interpretAst env t
Expand Down
33 changes: 28 additions & 5 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ build1V snat@SNat (var, v0) =
| Dict <- lemTensorKindOfBuild snat (stensorKind @x)
, Dict <- lemTensorKindOfBuild snat (stensorKind @y) -> traceRule $
astProject2 (build1V snat (var, t))
Ast.AstVar _ var2 ->
Ast.AstVar _ var2 -> traceRule $
if varNameToAstVarId var2 == varNameToAstVarId var
then error "build1V: AstVar: building over scalars is undefined"
then case isTensorInt v0 of
Just Refl -> 0 -- TODO: ???
_ -> error "build1V: build variable is not an index variable"
else error "build1V: AstVar can't contain other free variables"
Ast.AstPrimalPart v
| Dict <- lemTensorKindOfBuild snat (stensorKind @y) -> traceRule $
Expand Down Expand Up @@ -239,6 +241,7 @@ build1V snat@SNat (var, v0) =
-> AstTensor AstMethodLet s (BuildTensorKind k
(BuildTensorKind k2 z))
repl2Stk stk u = case stk of
STKScalar{} -> u
STKR SNat STKScalar{} -> astTr $ astReplicate snat2 u
STKS sh STKScalar{} -> withKnownShS sh $ astTrS $ astReplicate snat2 u
STKX sh STKScalar{} -> withKnownShX sh $ astTrX $ astReplicate snat2 u
Expand Down Expand Up @@ -287,22 +290,41 @@ build1V snat@SNat (var, v0) =
Ast.AstIota ->
error "build1V: AstIota can't have free index variables"

Ast.AstN1 opCode u -> traceRule $
Ast.AstN1 opCode (build1V snat (var, u))
Ast.AstN2 opCode u v -> traceRule $
Ast.AstN2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
-- we permit duplicated bindings, because they can't easily
-- be substituted into one another unlike. e.g., inside a let,
-- which may get inlined
Ast.AstR1 opCode u -> traceRule $
Ast.AstR1 opCode (build1V snat (var, u))
Ast.AstR2 opCode u v -> traceRule $
Ast.AstR2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstI2 opCode u v -> traceRule $
Ast.AstI2 opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
Ast.AstSumOfList args -> traceRule $
astSumOfList $ map (\v -> build1VOccurenceUnknown snat (var, v)) args

Ast.AstN1R opCode u -> traceRule $
Ast.AstN1R opCode (build1V snat (var, u))
Ast.AstN2R opCode u v -> traceRule $
Ast.AstN2R opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
(build1VOccurenceUnknown snat (var, v))
-- we permit duplicated bindings, because they can't easily
-- be substituted into one another unlike. e.g., inside a let,
-- which may get inlined
Ast.AstR1R opCode u -> traceRule $
Ast.AstR1R opCode (build1V snat (var, u))
Ast.AstR2R opCode u v -> traceRule $
Ast.AstR2R opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
(build1VOccurenceUnknown snat (var, v))
Ast.AstI2R opCode u v -> traceRule $
Ast.AstI2R opCode (build1VOccurenceUnknown snat (var, u))
(build1VOccurenceUnknown snat (var, v))
(build1VOccurenceUnknown snat (var, v))
Ast.AstSumOfListR args -> traceRule $
astSumOfListR $ map (\v -> build1VOccurenceUnknown snat (var, v)) args

Expand Down Expand Up @@ -711,6 +733,7 @@ substProjRep snat@SNat var ftk2 var1 v
-> FullTensorKind y4
-> AstTensor AstMethodLet s2 y4
projection prVar = \case
FTKScalar -> prVar
FTKR sh FTKScalar | SNat <- shrRank sh ->
Ast.AstIndex prVar (Ast.AstIntVar var :.: ZIR)
FTKS sh FTKScalar -> withKnownShS sh
Expand Down
5 changes: 3 additions & 2 deletions src/HordeAd/Core/HVector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ lemTensorKindOfF = \case
buildFullTensorKind :: SNat k -> FullTensorKind y
-> FullTensorKind (BuildTensorKind k y)
buildFullTensorKind snat@SNat = \case
FTKScalar ->
error "buildFullTensorKind: type family BuildTensorKind stuck at TKScalar"
FTKScalar -> FTKScalar
-- TODO? FTKScalar ->
-- error "buildFullTensorKind: type family BuildTensorKind stuck at TKScalar"
FTKR sh x -> FTKR (sNatValue snat :$: sh) x
FTKS sh x -> FTKS (snat :$$ sh) x
FTKX sh x -> FTKX (SKnown snat :$% sh) x
Expand Down
4 changes: 2 additions & 2 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ class LetTensor (target :: Target) where
-> target z
-> target (BuildTensorKind k z)
treplicate snat@SNat stk u = case stk of
STKScalar{} ->
error "treplicate: type family BuildTensorKind stuck at TKScalar"
STKScalar{} -> u
-- TODO? error "treplicate: type family BuildTensorKind stuck at TKScalar"
STKR SNat STKScalar{} -> rreplicate (sNatValue snat) u
STKS sh STKScalar{} -> withKnownShS sh $ sreplicate u
-- TODO: STKS sh (STKS _ STKScalar{}) -> withKnownShS sh $ sreplicate u
Expand Down
6 changes: 4 additions & 2 deletions src/HordeAd/Core/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ sameTK y1 y2 = case (y1, y2) of
_ -> Nothing

type family BuildTensorKind k tk where
BuildTensorKind k (TKScalar r) = TKScalar r -- TODO: say why on Earth
BuildTensorKind k (TKR2 n r) = TKR2 (1 + n) r
BuildTensorKind k (TKS2 sh r) = TKS2 (k : sh) r
BuildTensorKind k (TKX2 sh r) = TKX2 (Just k : sh) r
Expand All @@ -278,8 +279,9 @@ type family BuildTensorKind k tk where
lemTensorKindOfBuild :: SNat k -> STensorKindType y
-> Dict TensorKind (BuildTensorKind k y)
lemTensorKindOfBuild snat@SNat = \case
STKScalar{} ->
error "lemTensorKindOfBuild: type family BuildTensorKind stuck at TKScalar"
STKScalar{} -> Dict
-- TODO? STKScalar{} ->
-- error "lemTensorKindOfBuild: type family BuildTensorKind stuck at TKScalar"
STKR SNat x -> case lemTensorKindOfS x of
Dict -> Dict
STKS sh x -> case lemTensorKindOfS x of
Expand Down
4 changes: 2 additions & 2 deletions test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ testCNNOPP2 :: Assertion
testCNNOPP2 = do
resetVarCounter
printAstPretty IM.empty maxPool2dUnpadded2
@?= "rreplicate 1 (rreplicate 1 (let w52 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [let w26 = stranspose (sreplicate (sreplicate (sreplicate (sreplicate (sreplicate (sscalar 1) + siota))))) ; w27 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (sreplicate (sscalar 2) * siota)) + sreplicate siota)))) ; w28 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (sreplicate (sscalar 2) * siota)) + sreplicate siota)))) in rgather [1,1,1,2,2] (rfromVector (fromList [let w9 = sgather (sgather (sgather (sgather (sgather w26 (\\[i69] -> [i69])) (\\[i73, i57] -> [i73, i57])) (\\[i74, i61, i46] -> [i74, i61, i46])) (\\[i75, i62, i50, i38] -> [i75, i62, i50, i38])) (\\[i76, i63, i51, i42, i32] -> [i76, i63, i51, i42, i32]) in rgather [1,1,1,2,2] (tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0])) (\\[i72, i60, i49, i41, i34] -> [w9 !$ [i72, i60, i49, i41, i34], 0, w27 !$ [i72, i60, i49, i41, i34], w28 !$ [i72, i60, i49, i41, i34]]), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i70, i58, i47, i39, i33] -> [ifF (1 >. w26 !$ [i70, i58, i47, i39, i33]) 0 1, i70, i58, i47, i39, i33]), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i66, i54, i43] -> [ifF (1 >. 1 + i43) 0 1, i66, i54, i43]))) in rgather [1,1] w52 (\\[i65, i53] -> [i65, i53, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [rtranspose [1,2,0] (rreplicate 1 (let x27 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [rreplicate 1 (rreplicate 1 1), 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. rreplicate 1 (rreplicate 1 1)) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"

maxPool2dUnpadded2
:: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
Expand Down Expand Up @@ -717,7 +717,7 @@ testCNNOPP4 = do
afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph
printAstPretty IM.empty afcnn2T
@?= "rreplicate 1 (rreplicate 1 (let w41 = rgather [1,1,1,1,2,2] (rfromVector (fromList [let w21 = stranspose (sreplicate (sreplicate (sreplicate (sreplicate (sreplicate (sscalar 1) + siota))))) ; w20 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (sreplicate (sscalar 2) * siota)) + sreplicate siota)))) ; w12 = stranspose (sreplicate (sreplicate (sreplicate (stranspose (sreplicate (sreplicate (sscalar 2) * siota)) + sreplicate siota)))) in rgather [1,1,1,1,2,2] (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0])) (\\[i54, i47, i36, i31, i30, i25] -> [ifF ((0 <=. w21 !$ [i54, i47, i36, i30, i25] &&* 1 >. w21 !$ [i54, i47, i36, i30, i25]) &&* ((0 <=. w20 !$ [i54, i47, i36, i30, i25] &&* 2 >. w20 !$ [i54, i47, i36, i30, i25]) &&* (0 <=. w12 !$ [i54, i47, i36, i30, i25] &&* 2 >. w12 !$ [i54, i47, i36, i30, i25]))) 0 1]), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i31, i26, i22] -> [i26, i22]))))])) (\\[i50, i43, i35, i32, i33, i34] -> [ifF ((0 <=. 1 + i35 &&* 1 >. 1 + i35) &&* ((0 <=. 1 + i32 &&* 1 >. 1 + i32) &&* ((0 <=. 2 * i50 + i33 &&* 2 >. 2 * i50 + i33) &&* (0 <=. 2 * i43 + i34 &&* 2 >. 2 * i43 + i34)))) 0 1, i50, i43, i35, i32, i33, i34]) in rgather [1,1] w41 (\\[i49, i42] -> [i49, i42, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = rreplicate 1 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = rreplicate 1 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. rreplicate 1 (rreplicate 1 1) &&* 1 >. rreplicate 1 (rreplicate 1 1)) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"

maxPool2dUnpadded4
:: (ADReady target, GoodScalar r)
Expand Down
Loading

0 comments on commit 6db3aa1

Please sign in to comment.