Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 27, 2025
1 parent c6ac296 commit 487bfdf
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions test/simplified/TestGatherSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ testGatherSimpPP1 :: Assertion
testGatherSimpPP1 = do
resetVarCounter
let !t1 = gatherNested1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 282
length (show t1) @?= 229
resetVarCounter
let !t2 = gather1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 282
length (show t2) @?= 229
length (show (simplifyInlineContract @(TKR 1 Float) t1))
@?= length (show (simplifyInlineContract @(TKR 1 Float) t2))

Expand Down Expand Up @@ -214,12 +214,12 @@ testGatherSimpPP2 :: Assertion
testGatherSimpPP2 = do
resetVarCounter
let !t1 = gatherNested2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 481
length (show t1) @?= 398
resetVarCounter
let !t2 = gather2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 411
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 411
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 411
length (show t2) @?= 338
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 338
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 338

gatherNested12 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -281,12 +281,12 @@ testGatherSimpPP12 :: Assertion
testGatherSimpPP12 = do
resetVarCounter
let !t1 = gatherNested12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 481
length (show t1) @?= 398
resetVarCounter
let !t2 = gather12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 411
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 411
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 411
length (show t2) @?= 338
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 338
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 338

gatherReshape22 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -319,13 +319,13 @@ testGatherSimpPP22 :: Assertion
testGatherSimpPP22 = do
resetVarCounter
let !t1 = gatherReshape22 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [6, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 126
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 126
length (show t1) @?= 103
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 103
resetVarCounter
let !t2 = rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @2 @2 [2, 6]
$ AstVar (mkAstVarName (FTKR [6, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 126
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 126
length (show t2) @?= 103
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 103

testGatherSimpPP23 :: Assertion
testGatherSimpPP23 = do
Expand All @@ -334,15 +334,15 @@ testGatherSimpPP23 = do
gatherReshape22 @(AstTensor AstMethodLet PrimalSpan)
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (mkAstVarName (FTKR [6, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 335
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 335
length (show t1) @?= 312
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 312
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @2 @2 [2, 6]
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (mkAstVarName (FTKR [6, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 335
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 335
length (show t2) @?= 312
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 312

-- Depending on if and how transpose it desugared, this may or may not result
-- in dozens of nested gathers that should vanish after simplification.
Expand Down Expand Up @@ -454,31 +454,31 @@ testGatherSimpPP33 = do
resetVarCounter
let !t1 = gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 1031
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 830
length (show t1) @?= 992
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 791
resetVarCounter
let !t2 = (\t -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @10 [8, 16] t))
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 749
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 509
length (show t2) @?= 710
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 470

testGatherSimpPP34 :: Assertion
testGatherSimpPP34 = do
resetVarCounter
let !t1 = (\t -> rbuild1 4 (\i ->
gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan) (t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 1585
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 1585
length (show t1) @?= 1546
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 1546
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
(\t' -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @10 [8, 16] t'))
(t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 1082
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 1082
length (show t2) @?= 1043
length (show (simplifyInlineContract @(TKR 3 Float) t2)) @?= 1043

-- scatters instead of gathers

Expand Down Expand Up @@ -538,12 +538,12 @@ testScatterSimpPP1 :: Assertion
testScatterSimpPP1 = do
resetVarCounter
let !t1 = scatterNested1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 394
length (show t1) @?= 341
resetVarCounter
let !t2 = scatter1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 481
length (show (simplifyInlineContract @(TKR 1 Float) t1)) @?= 394
length (show (simplifyInlineContract @(TKR 1 Float) t2)) @?= 481
length (show t2) @?= 418
length (show (simplifyInlineContract @(TKR 1 Float) t1)) @?= 341
length (show (simplifyInlineContract @(TKR 1 Float) t2)) @?= 418

scatterNested2 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -604,12 +604,12 @@ testScatterSimpPP2 :: Assertion
testScatterSimpPP2 = do
resetVarCounter
let !t1 = scatterNested2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 1182
length (show t1) @?= 1019
resetVarCounter
let !t2 = scatter2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 765
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 1182
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 765
length (show t2) @?= 642
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 1019
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 642

scatterNested12 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -672,12 +672,12 @@ testScatterSimpPP12 :: Assertion
testScatterSimpPP12 = do
resetVarCounter
let !t1 = scatterNested12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 1017
length (show t1) @?= 874
resetVarCounter
let !t2 = scatter12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (mkAstVarName (FTKR [7, 2] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 765
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 1017
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 765
length (show t2) @?= 642
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 874
length (show (simplifyInlineContract @(TKR 2 Float) t2)) @?= 642

foo :: RealFloatF a => (a,a,a) -> a
foo (x,y,z) =
Expand Down Expand Up @@ -713,10 +713,10 @@ testReluSimpPP = do
resetVarCounter
let !t1 = barRelu10xSlower @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (mkAstVarName (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 19502
length (show (simplifyInlineContract @(TKR 10 Float) t1)) @?= 19502
length (show t1) @?= 17454
length (show (simplifyInlineContract @(TKR 10 Float) t1)) @?= 17454
resetVarCounter
let !t2 = barRelu @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (mkAstVarName (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) . intToAstVarId $ 100000000)
length (show t2) @?= 12334
length (show (simplifyInlineContract @(TKR 10 Float) t2)) @?= 19502
length (show t2) @?= 10286
length (show (simplifyInlineContract @(TKR 10 Float) t2)) @?= 17454

0 comments on commit 487bfdf

Please sign in to comment.