diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index 963091ddd..2d509ef2f 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -9,7 +9,6 @@ import Prelude import Data.Int (Int64) import Data.IntMap.Strict qualified as IM -import Data.List (foldl1') import Data.List.NonEmpty qualified as NonEmpty import Foreign.C (CInt) import GHC.Exts (IsList (..)) @@ -74,10 +73,10 @@ testTrees = , testCase "2fooPP" testFooPP , testCase "2fooLet" testFooLet , testCase "2fooLetPP" testFooLetPP --- TODO: , testCase "2listProdPP" testListProdPP + , testCase "2listProdPP" testListProdPP , testCase "2listProdrPP" testListProdrPP , testCase "2listProdrLongPP" testListProdrLongPP --- TODO , testCase "2listProd" testListProd + , testCase "2listProd" testListProd , testCase "2listProdr" testListProdr , testCase "2listSumrPP" testListSumrPP , testCase "2listSum2rPP" testListSum2rPP @@ -615,27 +614,25 @@ testFooLetPP = do printArtifactPrimalPretty renames (simplifyArtifact artifactRev) @?= "\\x1 -> rfromS (let z = sfromR (tproject1 (tproject1 x1)) * sin (sfromR (tproject2 (tproject1 x1))) in atan2F (sfromR (tproject2 x1)) z + sfromR (tproject2 x1) * z)" -_shapedListProd :: (BaseTensor target, GoodScalar r) - => [target (TKS '[] r)] -> target (TKS '[] r) -_shapedListProd = foldl1' (*) +shapedListProd :: forall k target r. (BaseTensor target, GoodScalar r) + => ListR k (target (TKS '[] r)) -> target (TKS '[] r) +shapedListProd = foldr1 (*) -{- TODO: this requires a better AdaptableHVector instance for [a] testListProdPP :: Assertion testListProdPP = do resetVarCounter >> resetIdCounter let renames = IM.empty - fT :: [AstTensor AstMethodLet FullSpan (TKS '[] Double)] -> AstTensor AstMethodLet FullSpan (TKS '[] Double) + fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKS '[] Double)) -> AstTensor AstMethodLet FullSpan (TKS '[] Double) fT = shapedListProd - let (artifactRev, _deltas) = revArtifactAdapt True fT [srepl 1, srepl 2, srepl 3, srepl 4] + let (artifactRev, _deltas) = revArtifactAdapt True fT (fromList $ [srepl 1, srepl 2, srepl 3, srepl 4]) printArtifactSimple renames artifactRev - @?= "\\x17 x20 x21 x22 x23 -> tlet (x20 * x21) (\\x12 -> tlet (x12 * x22) (\\x15 -> tlet (x23 * x17) (\\x18 -> tlet (x22 * x18) (\\x19 -> dmkHVector (fromList [DynamicShaped (x21 * x19), DynamicShaped (x20 * x19), DynamicShaped (x12 * x18), DynamicShaped (x15 * x17)])))))" + @?= "\\x4 x1 -> tlet (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))) (\\x2 -> tlet (tproject1 (tproject2 x1) * x2) (\\x3 -> tlet (tproject1 x1 * x4) (\\x5 -> tlet (tproject1 (tproject2 x1) * x5) (\\x6 -> tpair (x3 * x4, tpair (x2 * x5, tpair (tproject1 (tproject2 (tproject2 (tproject2 x1))) * x6, tpair (tproject1 (tproject2 (tproject2 x1)) * x6, Z0))))))))" printArtifactPrimalSimple renames artifactRev - @?= "\\x56 x57 x58 x59 -> tlet (x56 * x57) (\\x12 -> tlet (x12 * x58) (\\x15 -> x15 * x59))" + @?= "\\x1 -> tlet (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))) (\\x2 -> tlet (tproject1 (tproject2 x1) * x2) (\\x3 -> tproject1 x1 * x3))" printArtifactPretty renames (simplifyArtifact artifactRev) - @?= "\\x17 x124 x125 x126 x127 -> let x12 = x124 * x125 ; x18 = x127 * x17 ; x19 = x126 * x18 in [x125 * x19, x124 * x19, x12 * x18, (x12 * x126) * x17]" + @?= "\\x4 x1 -> let x2 = tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1))) ; x5 = tproject1 x1 * x4 ; x6 = tproject1 (tproject2 x1) * x5 in tpair ((tproject1 (tproject2 x1) * x2) * x4, tpair (x2 * x5, tpair (tproject1 (tproject2 (tproject2 (tproject2 x1))) * x6, tpair (tproject1 (tproject2 (tproject2 x1)) * x6, Z0))))" printArtifactPrimalPretty renames (simplifyArtifact artifactRev) - @?= "\\x160 x161 x162 x163 -> ((x160 * x161) * x162) * x163" --} + @?= "\\x1 -> tproject1 x1 * (tproject1 (tproject2 x1) * (tproject1 (tproject2 (tproject2 x1)) * tproject1 (tproject2 (tproject2 (tproject2 x1)))))" rankedListProdr :: forall k target r. (BaseTensor target, GoodScalar r) => ListR k (target (TKR 0 r)) -> target (TKR 0 r) @@ -670,14 +667,12 @@ testListProdrLongPP = do printArtifactPrimalPretty renames (simplifyArtifact artifactRev) @?= "\\x1 -> rfromS (sfromR (tproject1 x1) * (sfromR (tproject1 (tproject2 x1)) * (sfromR (tproject1 (tproject2 (tproject2 x1))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 x1)))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 x1))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1))))))))))) * (sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))))) * sfromR (tproject1 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 (tproject2 x1)))))))))))))))))))))))))" -{- TODO: this requires a better AdaptableHVector instance for [a] testListProd :: Assertion testListProd = do assertEqualUpToEpsilon 1e-10 [srepl 24, srepl 12, srepl 8, srepl 6] (rev @_ @(TKS '[] Double) - shapedListProd [srepl 1, srepl 2, srepl 3, srepl 4]) --} + (shapedListProd @4) [srepl 1, srepl 2, srepl 3, srepl 4]) testListProdr :: Assertion testListProdr = do @@ -694,7 +689,7 @@ testListSumrPP :: Assertion testListSumrPP = do resetVarCounter >> resetIdCounter let renames = IM.empty - fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKR 0 Double)) -> AstTensor AstMethodLet FullSpan (TKR 0 Double) + fT :: ListR 4 (AstTensor AstMethodLet FullSpan (TKR 0 Double)) -> AstTensor AstMethodLet FullSpan (TKR 0 Double) fT = rankedListSumr let (artifactRev, deltas) = revArtifactAdapt True fT [rscalar 1, rscalar 2, rscalar 3, rscalar 4] printArtifactPretty renames (simplifyArtifact artifactRev) diff --git a/test/simplified/TestConvSimplified.hs b/test/simplified/TestConvSimplified.hs index 5edf47b8a..9e1ee1500 100644 --- a/test/simplified/TestConvSimplified.hs +++ b/test/simplified/TestConvSimplified.hs @@ -551,24 +551,6 @@ test_disparitySmall = do -- * PP Tests -{- This probably needs some exotic instance of AdaptableHVector, so should be removed: -testConv2dUnpaddedPP :: Assertion -testConv2dUnpaddedPP = do - resetVarCounter - let f :: HVector (AstGeneric AstMethodLet FullSpan) -> AAstTensor AstMethodLet FullSpan (TKR 4 Double) - f v = conv2dUnpadded (rfromD $ rankedHVector v V.! 0) (rfromD $ rankedHVector v V.! 1) - g :: Double -> RepN (TKR 4 Double) - g x = Nested.rfromOrthotope SNat $ OR.fromList [2,2,2,2] $ replicate 16 x - (artifactRev, _) = - revArtifactAdapt - True - f - (V.fromList [ DynamicRanked @Double @4 (g 1.1) - , DynamicRanked @Double @4 (g 2.3) ]) - printArtifactPretty IM.empty (simplifyArtifact artifactRev) - @?= "\\u61 u175 u176 -> [rscatter [2,2,2,2] (rscatter [2,2,1,2,2,2] (rsum (rsum (rsum (rgather [2,2,2,2,1,2,2,2] (rtranspose [0,4,1,2,3] (rreplicate 2 (rreshape [2,2,2,8] (rgather [2,2,2,1,2,2,2] (rfromVector (fromList [rgather [2,2,2,1,2,2,2] u176 (\\[i37, i38, i39, i40, i41, i42, i43] -> [i37 + i40, i41, i38 + i42, i39 + i43]), rreplicate 2 (rreplicate 2 (rreplicate 2 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rreplicate 2 0.0))))))])) (\\[i44, i45, i46, i47, i48, i49, i50] -> [ifF ((0 <=. i44 + i47 &&* 2 >. i44 + i47) &&* ((0 <=. i48 &&* 2 >. i48) &&* ((0 <=. i45 + i49 &&* 2 >. i45 + i49) &&* (0 <=. i46 + i50 &&* 2 >. i46 + i50)))) 0 1, i44, i45, i46, i47, i48, i49, i50])))) * rtranspose [2,0,1] (rreplicate 8 u61)) (\\[i148, i149, i150, i151, i152, i153, i154, i155] -> [remF (quotF (i155 + 2 * i154 + 4 * i153 + 8 * i151 + 8 * i152) 8) 2, remF (i155 + 2 * i154 + 4 * i153 + 8 * i151 + 8 * i152) 8, i148, i149, i150]))))) (\\[i62, i63, i64, i65, i66] -> [ifF ((0 <=. i62 + i63 &&* 2 >. i62 + i63) &&* ((0 <=. i64 &&* 2 >. i64) &&* ((0 <=. i65 &&* 2 >. i65) &&* (0 <=. i66 &&* 2 >. i66)))) 0 1, i62, i63, i64, i65, i66]) ! [0]) (\\[i68, i69] -> [i68 + i69]), rscatter [2,2,2,2] (rscatter [2,2,2,2,1,2,2,2] (rsum (rgather [2,2,2,2,1,2,2,2] (rtranspose [0,1,2,4,3] (rreplicate 2 (rreplicate 2 (rreplicate 2 (rreshape [2,8] (rgather [2,1,2,2,2] (rfromVector (fromList [rgather [2,1,2,2,2] u175 (\\[i52, i53] -> [i52 + i53]), rreplicate 2 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rreplicate 2 0.0))))])) (\\[i54, i55, i56, i57, i58] -> [ifF ((0 <=. i54 + i55 &&* 2 >. i54 + i55) &&* ((0 <=. i56 &&* 2 >. i56) &&* ((0 <=. i57 &&* 2 >. i57) &&* (0 <=. i58 &&* 2 >. i58)))) 0 1, i54, i55, i56, i57, i58])))))) * rtranspose [1,3,4,0,2] (rreplicate 8 u61)) (\\[i163, i164, i165, i166, i167, i168, i169, i170] -> [remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 32) 2, remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 16) 2, remF (quotF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 8) 2, remF (i170 + 2 * i169 + 4 * i168 + 8 * i167 + 8 * i166 + 32 * i164 + 16 * i165) 8, i163]))) (\\[i70, i71, i72, i73, i74, i75, i76] -> [ifF ((0 <=. i70 + i73 &&* 2 >. i70 + i73) &&* ((0 <=. i74 &&* 2 >. i74) &&* ((0 <=. i71 + i75 &&* 2 >. i71 + i75) &&* (0 <=. i72 + i76 &&* 2 >. i72 + i76)))) 0 1, i70, i71, i72, i73, i74, i75, i76]) ! [0]) (\\[i78, i79, i80, i81, i82, i83, i84] -> [i78 + i81, i82, i79 + i83, i80 + i84])]" --} - testConv2dUnpadded2PP :: Assertion testConv2dUnpadded2PP = do resetVarCounter diff --git a/test/simplified/TestHighRankSimplified.hs b/test/simplified/TestHighRankSimplified.hs index 23b5e4c27..8e75be2ad 100644 --- a/test/simplified/TestHighRankSimplified.hs +++ b/test/simplified/TestHighRankSimplified.hs @@ -13,7 +13,15 @@ import Test.Tasty import Test.Tasty.HUnit hiding (assert) import Data.Array.Nested - (IShR, IxR (..), IxS (..), KnownShS (..), ListR (..), Rank, ShR (..)) + ( IShR + , IxR (..) + , IxS (..) + , KnownShS (..) + , ListR (..) + , Rank + , ShR (..) + , rfromListLinear + ) import Data.Array.Nested qualified as Nested import Data.Array.Nested.Internal.Shape (shrTail) @@ -25,8 +33,8 @@ import EqEpsilon testTrees :: [TestTree] testTrees = [ testCase "3foo" testFoo --- , testCase "3bar" testBar --- , testCase "3barS" testBarS + , testCase "3bar" testBar + , testCase "3barS" testBarS , testCase "3fooD T Double [1.1, 2.2, 3.3]" testFooD , testCase "3fooBuild0" testFooBuild0 , testCase "3fooBuildOut" testFooBuildOut @@ -98,18 +106,16 @@ barF (x, y) = let w = fooF (x, y, x) * sin y in atan2F x w + y * w --- Numerically unstable ATM. -_testBar :: Assertion -_testBar = +testBar :: Assertion +testBar = assertEqualUpToEpsilon 1e-5 - (ringestData [3, 1, 2, 2, 1, 2, 2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1917,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], ringestData [3, 1, 2, 2, 1, 2, 2] [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34987.0,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) + (RepN $ rfromListLinear [3,1,2,2,1,2,2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596],RepN $ rfromListLinear [3,1,2,2,1,2,2] [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) (crev (bar @(ADVal RepN (TKR 7 Float))) (t48, t48)) --- Numerically unstable ATM. -_testBarS :: Assertion -_testBarS = +testBarS :: Assertion +testBarS = assertEqualUpToEpsilon 1e-5 - (sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) + (sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34986.992,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) (crev (barF @(ADVal RepN (TKS '[3, 1, 2, 2, 1, 2, 2] Float))) (sfromR t48, sfromR t48)) -- A dual-number and list-based version of a function that goes