Skip to content

Commit

Permalink
Re-enable most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 25, 2025
1 parent 7a8cf98 commit 1799d18
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 47 deletions.
31 changes: 13 additions & 18 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import Prelude

import Data.Int (Int64)
import Data.IntMap.Strict qualified as IM
import Data.List (foldl1')
import Data.List.NonEmpty qualified as NonEmpty
import Foreign.C (CInt)
import GHC.Exts (IsList (..))
Expand Down Expand Up @@ -74,10 +73,10 @@ testTrees =
, testCase "2fooPP" testFooPP
, testCase "2fooLet" testFooLet
, testCase "2fooLetPP" testFooLetPP
-- TODO: , testCase "2listProdPP" testListProdPP
, testCase "2listProdPP" testListProdPP
, testCase "2listProdrPP" testListProdrPP
, testCase "2listProdrLongPP" testListProdrLongPP
-- TODO , testCase "2listProd" testListProd
, testCase "2listProd" testListProd
, testCase "2listProdr" testListProdr
, testCase "2listSumrPP" testListSumrPP
, testCase "2listSum2rPP" testListSum2rPP
Expand Down Expand Up @@ -615,27 +614,25 @@ testFooLetPP = do
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (let z = sfromR (tproject1 (tproject1 x1)) * sin (sfromR (tproject2 (tproject1 x1))) in atan2F (sfromR (tproject2 x1)) z + sfromR (tproject2 x1) * z)"

_shapedListProd :: (BaseTensor target, GoodScalar r)
=> [target (TKS '[] r)] -> target (TKS '[] r)
_shapedListProd = foldl1' (*)
shapedListProd :: forall k target r. (BaseTensor target, GoodScalar r)
=> ListR k (target (TKS '[] r)) -> target (TKS '[] r)
shapedListProd = foldr1 (*)

{- TODO: this requires a better AdaptableHVector instance for [a]
testListProdPP :: Assertion
testListProdPP = do
resetVarCounter >> resetIdCounter
let renames = IM.empty
fT :: [AstTensor AstMethodLet FullSpan (TKS '[] Double)] -> AstTensor AstMethodLet FullSpan (TKS '[] Double)
fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKS '[] Double)) -> AstTensor AstMethodLet FullSpan (TKS '[] Double)
fT = shapedListProd
let (artifactRev, _deltas) = revArtifactAdapt True fT [srepl 1, srepl 2, srepl 3, srepl 4]
let (artifactRev, _deltas) = revArtifactAdapt True fT (fromList $ [srepl 1, srepl 2, srepl 3, srepl 4])
printArtifactSimple renames artifactRev
@?= "\\x17 x20 x21 x22 x23 -> tlet (x20 * x21) (\\x12 -> tlet (x12 * x22) (\\x15 -> tlet (x23 * x17) (\\x18 -> tlet (x22 * x18) (\\x19 -> dmkHVector (fromList [DynamicShaped (x21 * x19), DynamicShaped (x20 * x19), DynamicShaped (x12 * x18), DynamicShaped (x15 * x17)])))))"
@?= "\\x4 x1 -> tlet (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))) (\\x2 -> tlet (tproject1 (tproject2 x1) * x2) (\\x3 -> tlet (tproject1 x1 * x4) (\\x5 -> tlet (tproject1 (tproject2 x1) * x5) (\\x6 -> tpair (x3 * x4, tpair (x2 * x5, tpair (tproject1 (tproject2 (tproject2 (tproject2 x1))) * x6, tpair (tproject1 (tproject2 (tproject2 x1)) * x6, Z0))))))))"
printArtifactPrimalSimple renames artifactRev
@?= "\\x56 x57 x58 x59 -> tlet (x56 * x57) (\\x12 -> tlet (x12 * x58) (\\x15 -> x15 * x59))"
@?= "\\x1 -> tlet (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))) (\\x2 -> tlet (tproject1 (tproject2 x1) * x2) (\\x3 -> tproject1 x1 * x3))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\x17 x124 x125 x126 x127 -> let x12 = x124 * x125 ; x18 = x127 * x17 ; x19 = x126 * x18 in [x125 * x19, x124 * x19, x12 * x18, (x12 * x126) * x17]"
@?= "\\x4 x1 -> let x2 = tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1))) ; x5 = tproject1 x1 * x4 ; x6 = tproject1 (tproject2 x1) * x5 in tpair ((tproject1 (tproject2 x1) * x2) * x4, tpair (x2 * x5, tpair (tproject1 (tproject2 (tproject2 (tproject2 x1))) * x6, tpair (tproject1 (tproject2 (tproject2 x1)) * x6, Z0))))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x160 x161 x162 x163 -> ((x160 * x161) * x162) * x163"
-}
@?= "\\x1 -> tproject1 x1 * (tproject1 (tproject2 x1) * (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))))"

rankedListProdr :: forall k target r. (BaseTensor target, GoodScalar r)
=> ListR k (target (TKR 0 r)) -> target (TKR 0 r)
Expand Down Expand Up @@ -670,14 +667,12 @@ testListProdrLongPP = do
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\x1 -> rfromS (sfromR (tproject1 x1) * (sfromR (tproject1 (tproject2 x1)) * (sfromR (tproject1 (tproject2 (tproject2 x1))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 x1)))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 x1))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))))) * sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))))))))))))))))))"

{- TODO: this requires a better AdaptableHVector instance for [a]
testListProd :: Assertion
testListProd = do
assertEqualUpToEpsilon 1e-10
[srepl 24, srepl 12, srepl 8, srepl 6]
(rev @_ @(TKS '[] Double)
shapedListProd [srepl 1, srepl 2, srepl 3, srepl 4])
-}
(shapedListProd @4) [srepl 1, srepl 2, srepl 3, srepl 4])

testListProdr :: Assertion
testListProdr = do
Expand All @@ -694,7 +689,7 @@ testListSumrPP :: Assertion
testListSumrPP = do
resetVarCounter >> resetIdCounter
let renames = IM.empty
fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKR 0 Double)) -> AstTensor AstMethodLet FullSpan (TKR 0 Double)
fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKR 0 Double)) -> AstTensor AstMethodLet FullSpan (TKR 0 Double)
fT = rankedListSumr
let (artifactRev, deltas) = revArtifactAdapt True fT [rscalar 1, rscalar 2, rscalar 3, rscalar 4]
printArtifactPretty renames (simplifyArtifact artifactRev)
Expand Down
18 changes: 0 additions & 18 deletions test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -551,24 +551,6 @@ test_disparitySmall = do

-- * PP Tests

{- This probably needs some exotic instance of AdaptableHVector, so should be removed:
testConv2dUnpaddedPP :: Assertion
testConv2dUnpaddedPP = do
resetVarCounter
let f :: HVector (AstGeneric AstMethodLet FullSpan) -> AAstTensor AstMethodLet FullSpan (TKR 4 Double)
f v = conv2dUnpadded (rfromD $ rankedHVector v V.! 0) (rfromD $ rankedHVector v V.! 1)
g :: Double -> RepN (TKR 4 Double)
g x = Nested.rfromOrthotope SNat $ OR.fromList [2,2,2,2] $ replicate 16 x
(artifactRev, _) =
revArtifactAdapt
True
f
(V.fromList [ DynamicRanked @Double @4 (g 1.1)
, DynamicRanked @Double @4 (g 2.3) ])
printArtifactPretty IM.empty (simplifyArtifact artifactRev)
@?= "\\u61 u175 u176 -> [rscatter [2,2,2,2] (rscatter [2,2,1,2,2,2] (rsum (rsum (rsum (rgather [2,2,2,2,1,2,2,2] (rtranspose [0,4,1,2,3] (rreplicate 2 (rreshape [2,2,2,8] (rgather [2,2,2,1,2,2,2] (rfromVector (fromList [rgather [2,2,2,1,2,2,2] u176 (\\[i37, i38, i39, i40, i41, i42, i43] -> [i37 + i40, i41, i38 + i42, i39 + i43]), rreplicate 2 (rreplicate 2 (rreplicate 2 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rreplicate 2 0.0))))))])) (\\[i44, i45, i46, i47, i48, i49, i50] -> [ifF ((0 <=. i44 + i47 &&* 2 >. i44 + i47) &&* ((0 <=. i48 &&* 2 >. i48) &&* ((0 <=. i45 + i49 &&* 2 >. i45 + i49) &&* (0 <=. i46 + i50 &&* 2 >. i46 + i50)))) 0 1, i44, i45, i46, i47, i48, i49, i50])))) * rtranspose [2,0,1] (rreplicate 8 u61)) (\\[i148, i149, i150, i151, i152, i153, i154, i155] -> [remF (quotF (i155 + 2 * i154 + 4 * i153 + 8 * i151 + 8 * i152) 8) 2, remF (i155 + 2 * i154 + 4 * i153 + 8 * i151 + 8 * i152) 8, i148, i149, i150]))))) (\\[i62, i63, i64, i65, i66] -> [ifF ((0 <=. i62 + i63 &&* 2 >. i62 + i63) &&* ((0 <=. i64 &&* 2 >. i64) &&* ((0 <=. i65 &&* 2 >. i65) &&* (0 <=. i66 &&* 2 >. i66)))) 0 1, i62, i63, i64, i65, i66]) ! [0]) (\\[i68, i69] -> [i68 + i69]), rscatter [2,2,2,2] (rscatter [2,2,2,2,1,2,2,2] (rsum (rgather [2,2,2,2,1,2,2,2] (rtranspose [0,1,2,4,3] (rreplicate 2 (rreplicate 2 (rreplicate 2 (rreshape [2,8] (rgather [2,1,2,2,2] (rfromVector (fromList [rgather [2,1,2,2,2] u175 (\\[i52, i53] -> [i52 + i53]), rreplicate 2 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rreplicate 2 0.0))))])) (\\[i54, i55, i56, i57, i58] -> [ifF ((0 <=. i54 + i55 &&* 2 >. i54 + i55) &&* ((0 <=. i56 &&* 2 >. i56) &&* ((0 <=. i57 &&* 2 >. i57) &&* (0 <=. i58 &&* 2 >. i58)))) 0 1, i54, i55, i56, i57, i58])))))) * rtranspose [1,3,4,0,2] (rreplicate 8 u61)) (\\[i163, i164, i165, i166, i167, i168, i169, i170] -> [remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 32) 2, remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 16) 2, remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 8) 2, remF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 8, i163]))) (\\[i70, i71, i72, i73, i74, i75, i76] -> [ifF ((0 <=. i70 + i73 &&* 2 >. i70 + i73) &&* ((0 <=. i74 &&* 2 >. i74) &&* ((0 <=. i71 + i75 &&* 2 >. i71 + i75) &&* (0 <=. i72 + i76 &&* 2 >. i72 + i76)))) 0 1, i70, i71, i72, i73, i74, i75, i76]) ! [0]) (\\[i78, i79, i80, i81, i82, i83, i84] -> [i78 + i81, i82, i79 + i83, i80 + i84])]"
-}

testConv2dUnpadded2PP :: Assertion
testConv2dUnpadded2PP = do
resetVarCounter
Expand Down
28 changes: 17 additions & 11 deletions test/simplified/TestHighRankSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ import Test.Tasty
import Test.Tasty.HUnit hiding (assert)

import Data.Array.Nested
(IShR, IxR (..), IxS (..), KnownShS (..), ListR (..), Rank, ShR (..))
( IShR
, IxR (..)
, IxS (..)
, KnownShS (..)
, ListR (..)
, Rank
, ShR (..)
, rfromListLinear
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shrTail)

Expand All @@ -25,8 +33,8 @@ import EqEpsilon
testTrees :: [TestTree]
testTrees =
[ testCase "3foo" testFoo
-- , testCase "3bar" testBar
-- , testCase "3barS" testBarS
, testCase "3bar" testBar
, testCase "3barS" testBarS
, testCase "3fooD T Double [1.1, 2.2, 3.3]" testFooD
, testCase "3fooBuild0" testFooBuild0
, testCase "3fooBuildOut" testFooBuildOut
Expand Down Expand Up @@ -98,18 +106,16 @@ barF (x, y) =
let w = fooF (x, y, x) * sin y
in atan2F x w + y * w

-- Numerically unstable ATM.
_testBar :: Assertion
_testBar =
testBar :: Assertion
testBar =
assertEqualUpToEpsilon 1e-5
(ringestData [3, 1, 2, 2, 1, 2, 2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1917,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], ringestData [3, 1, 2, 2, 1, 2, 2] [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34987.0,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
(RepN $ rfromListLinear [3,1,2,2,1,2,2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596],RepN $ rfromListLinear [3,1,2,2,1,2,2] [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
(crev (bar @(ADVal RepN (TKR 7 Float))) (t48, t48))

-- Numerically unstable ATM.
_testBarS :: Assertion
_testBarS =
testBarS :: Assertion
testBarS =
assertEqualUpToEpsilon 1e-5
(sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
(sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34986.992,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
(crev (barF @(ADVal RepN (TKS '[3, 1, 2, 2, 1, 2, 2] Float))) (sfromR t48, sfromR t48))

-- A dual-number and list-based version of a function that goes
Expand Down

0 comments on commit 1799d18

Please sign in to comment.