diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index 987c9f106..883e353e8 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -42,7 +42,7 @@ import HordeAd.Core.HVector import HordeAd.Core.Types import HordeAd.Internal.OrthotopeOrphanInstances (IntegralF (..), RealFloatF (..), FlipS(..)) -import HordeAd.Util.ShapedList (IndexS, SizedListS (..)) +import HordeAd.Util.ShapedList (IndexS, SizedListS) import HordeAd.Util.SizedList -- * Basic type family instances diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index fe065f6d8..33f188973 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -96,7 +96,7 @@ import HordeAd.Internal.BackendConcrete import HordeAd.Internal.OrthotopeOrphanInstances (FlipS (..), MapSucc, trustMeThisIsAPermutation) import HordeAd.Util.ShapedList - (SizedListS (..), pattern (:.$), pattern ZIS) + (pattern (:.$), pattern (::$), pattern ZIS, pattern ZS) import qualified HordeAd.Util.ShapedList as ShapedList import HordeAd.Util.SizedList diff --git a/src/HordeAd/Core/AstVectorize.hs b/src/HordeAd/Core/AstVectorize.hs index bddac9b4d..6bfe2f489 100644 --- a/src/HordeAd/Core/AstVectorize.hs +++ b/src/HordeAd/Core/AstVectorize.hs @@ -50,7 +50,7 @@ import HordeAd.Core.Types import HordeAd.Internal.OrthotopeOrphanInstances (MapSucc, trustMeThisIsAPermutation) import HordeAd.Util.ShapedList - (SizedListS (..), pattern (:.$), pattern ZIS) + (pattern (:.$), pattern (::$), pattern ZIS, pattern ZS) import HordeAd.Util.SizedList -- * Vectorization of AstRanked diff --git a/src/HordeAd/Util/ShapedList.hs b/src/HordeAd/Util/ShapedList.hs index f43147101..5943f2483 100644 --- a/src/HordeAd/Util/ShapedList.hs +++ b/src/HordeAd/Util/ShapedList.hs @@ -1,13 +1,14 @@ -{-# LANGUAGE AllowAmbiguousTypes, DerivingStrategies, ViewPatterns #-} +{-# LANGUAGE AllowAmbiguousTypes, DerivingStrategies #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} {-# OPTIONS_GHC -fconstraint-solver-iterations=10000 #-} +{-# OPTIONS_GHC -Wno-orphans #-} -- | @[Nat]@-indexed lists to be used as is for lists of tensor variables, -- tensor shapes and tensor indexes. module HordeAd.Util.ShapedList ( -- * Shaped lists (sized, where size is shape) and their permutations IntSh, IndexSh - , SizedListS(..) + , SizedListS, pattern (::$), pattern ZS , consShaped, unconsContShaped , singletonSized, snocSized, appendSized , headSized, tailSized, takeSized, dropSized, splitAt_Sized @@ -42,6 +43,17 @@ import GHC.Exts (IsList (..)) import GHC.TypeLits (KnownNat, Nat, type (*)) import Unsafe.Coerce (unsafeCoerce) +import Data.Array.Nested + ( IxS (..) + , ListS + , StaticShapeS (..) + , pattern (:$$) + , pattern (:.$) + , pattern (::$) + , pattern ZIS + , pattern ZS + , pattern ZSS + ) import HordeAd.Core.Types import HordeAd.Util.SizedList (Permutation) import qualified HordeAd.Util.SizedList as SizedList @@ -60,28 +72,15 @@ type IntSh (f :: TensorType ty) (n :: Nat) = ShapedNat n (IntOf f) -- and up to evaluation. type IndexSh (f :: TensorType ty) (sh :: [Nat]) = IndexS sh (IntOf f) +-- | Lists indexed by shapes, that is, lists of the GHC @Nat@. +type SizedListS n i = ListS n i --- | Strict lists indexed by shapes, that is, lists of the GHC @Nat@. -infixr 3 ::$ -type role SizedListS nominal representational -data SizedListS (sh :: [Nat]) i where - ZS :: SizedListS '[] i - (::$) :: forall k sh {i}. (KnownNat k, KnownShape sh) - => i -> SizedListS sh i -> SizedListS (k : sh) i - -deriving instance Eq i => Eq (SizedListS sh i) - -deriving instance Ord i => Ord (SizedListS sh i) - +{- -- This is only lawful when OverloadedLists is enabled. -- However, it's much more readable when tracing and debugging. instance Show i => Show (SizedListS sh i) where showsPrec d l = showsPrec d (sizedToList l) - -deriving stock instance Functor (SizedListS sh) - -instance Foldable (SizedListS sh) where - foldr f z l = foldr f z (sizedToList l) +-} instance KnownShape sh => IsList (SizedListS sh i) where type Item (SizedListS sh i) = i @@ -222,40 +221,18 @@ shapedToSized = SizedList.listToSized . sizedToList -- * Tensor indexes as fully encapsulated shaped lists, with operations -type role IndexS nominal representational -newtype IndexS sh i = IndexS (SizedListS sh i) - deriving (Eq, Ord) +type IndexS sh i = IxS sh i +pattern IndexS :: forall {sh :: [Nat]} {i}. ListS sh i -> IxS sh i +pattern IndexS l = IxS l +{-# COMPLETE IndexS #-} + +{- -- This is only lawful when OverloadedLists is enabled. -- However, it's much more readable when tracing and debugging. instance Show i => Show (IndexS sh i) where showsPrec d (IndexS l) = showsPrec d l - -pattern ZIS :: forall sh i. () => sh ~ '[] => IndexS sh i -pattern ZIS = IndexS ZS - -infixr 3 :.$ -pattern (:.$) - :: forall {sh1} {i}. () - => forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1) - => i -> IndexS sh i -> IndexS sh1 i -pattern i :.$ shl <- (unconsIndex -> Just (UnconsIndexRes shl i)) - where i :.$ (IndexS shl) = IndexS (i ::$ shl) -{-# COMPLETE ZIS, (:.$) #-} - -type role UnconsIndexRes representational nominal -data UnconsIndexRes i sh1 = - forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1) - => UnconsIndexRes (IndexS sh i) i -unconsIndex :: IndexS sh1 i -> Maybe (UnconsIndexRes i sh1) -unconsIndex (IndexS shl) = case shl of - i ::$ shl' -> Just (UnconsIndexRes (IndexS shl') i) - ZS -> Nothing - -deriving newtype instance Functor (IndexS n) - -instance Foldable (IndexS n) where - foldr f z l = foldr f z (indexToList l) +-} instance KnownShape sh => IsList (IndexS sh i) where type Item (IndexS sh i) = i @@ -320,40 +297,18 @@ shapedToIndex = SizedList.listToIndex . indexToList -- * Tensor shapes as fully encapsulated shaped lists, with operations -type role ShapeS nominal representational -newtype ShapeS sh i = ShapeS (SizedListS sh i) - deriving (Eq, Ord) +type ShapeS sh i = StaticShapeS sh i + +pattern ShapeS :: forall {sh :: [Nat]} {i}. ListS sh i -> StaticShapeS sh i +pattern ShapeS l = StaticShapeS l +{-# COMPLETE ShapeS #-} +{- -- This is only lawful when OverloadedLists is enabled. -- However, it's much more readable when tracing and debugging. instance Show i => Show (ShapeS sh i) where showsPrec d (ShapeS l) = showsPrec d l - -pattern ZSS :: forall sh i. () => sh ~ '[] => ShapeS sh i -pattern ZSS = ShapeS ZS - -infixr 3 :$$ -pattern (:$$) - :: forall {sh1} {i}. () - => forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1) - => i -> ShapeS sh i -> ShapeS sh1 i -pattern i :$$ shl <- (unconsShape -> Just (UnconsShapeRes shl i)) - where i :$$ (ShapeS shl) = ShapeS (i ::$ shl) -{-# COMPLETE ZSS, (:$$) #-} - -type role UnconsShapeRes representational nominal -data UnconsShapeRes i sh1 = - forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1) - => UnconsShapeRes (ShapeS sh i) i -unconsShape :: ShapeS sh1 i -> Maybe (UnconsShapeRes i sh1) -unconsShape (ShapeS shl) = case shl of - i ::$ shl' -> Just (UnconsShapeRes (ShapeS shl') i) - ZS -> Nothing - -deriving newtype instance Functor (ShapeS n) - -instance Foldable (ShapeS n) where - foldr f z l = foldr f z (shapeToList l) +-} instance KnownShape sh => IsList (ShapeS sh i) where type Item (ShapeS sh i) = i @@ -382,7 +337,6 @@ listToShape = ShapeS . listToSized shapeToList :: ShapeS sh i -> [i] shapeToList (ShapeS l) = sizedToList l - -- * Operations involving both indexes and shapes -- | Given a multidimensional index, get the corresponding linear diff --git a/test/simplified/TestAdaptorSimplified.hs b/test/simplified/TestAdaptorSimplified.hs index 48af72acc..fb1db75c6 100644 --- a/test/simplified/TestAdaptorSimplified.hs +++ b/test/simplified/TestAdaptorSimplified.hs @@ -987,11 +987,11 @@ testReluSimplerPP4S2 = do reluT2 (t, r) = reluS (t * sreplicate0N r) let (artifactRev, _deltas) = revArtifactAdapt True reluT2 (FlipS $ OS.constant 128, 42) printArtifactPretty renames artifactRev - @?= "\\m12 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m11 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) ; m13 = m11 * m12 in [m6 * m13, ssum (sreshape (m1 * m13))]" + @?= "\\m11 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m10 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) ; m12 = m10 * m11 in [m6 * m12, ssum (sreshape (m1 * m12))]" printArtifactPrimalPretty renames artifactRev - @?= "\\m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m11 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) in [m11 * m7]" + @?= "\\m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m10 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) in [m10 * m7]" printArtifactPretty renames (simplifyArtifact artifactRev) - @?= "\\m12 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m13 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m1 !$ [i8, i9] * m6 !$ [i8, i9] <=. 0.0) 0 1]) * m12 in [m6 * m13, ssum (sreshape (m1 * m13))]" + @?= "\\m11 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m12 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m1 !$ [i8, i9] * m6 !$ [i8, i9] <=. 0.0) 0 1]) * m11 in [m6 * m12, ssum (sreshape (m1 * m12))]" printArtifactPrimalPretty renames (simplifyArtifact artifactRev) @?= "\\m1 x2 -> let m7 = m1 * sreshape (sreplicate x2) in [sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) * m7]"