Skip to content

Commit

Permalink
Use ox-arrays ListS
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Apr 27, 2024
1 parent fd81bd5 commit 092624d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 84 deletions.
2 changes: 1 addition & 1 deletion src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import HordeAd.Core.HVector
import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances
(IntegralF (..), RealFloatF (..), FlipS(..))
import HordeAd.Util.ShapedList (IndexS, SizedListS (..))
import HordeAd.Util.ShapedList (IndexS, SizedListS)
import HordeAd.Util.SizedList

-- * Basic type family instances
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ import HordeAd.Internal.BackendConcrete
import HordeAd.Internal.OrthotopeOrphanInstances
(FlipS (..), MapSucc, trustMeThisIsAPermutation)
import HordeAd.Util.ShapedList
(SizedListS (..), pattern (:.$), pattern ZIS)
(pattern (:.$), pattern (::$), pattern ZIS, pattern ZS)
import qualified HordeAd.Util.ShapedList as ShapedList
import HordeAd.Util.SizedList

Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ import HordeAd.Core.Types
import HordeAd.Internal.OrthotopeOrphanInstances
(MapSucc, trustMeThisIsAPermutation)
import HordeAd.Util.ShapedList
(SizedListS (..), pattern (:.$), pattern ZIS)
(pattern (:.$), pattern (::$), pattern ZIS, pattern ZS)
import HordeAd.Util.SizedList

-- * Vectorization of AstRanked
Expand Down
110 changes: 32 additions & 78 deletions src/HordeAd/Util/ShapedList.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
{-# LANGUAGE AllowAmbiguousTypes, DerivingStrategies, ViewPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes, DerivingStrategies #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=10000 #-}
{-# OPTIONS_GHC -Wno-orphans #-}
-- | @[Nat]@-indexed lists to be used as is for lists of tensor variables,
-- tensor shapes and tensor indexes.
module HordeAd.Util.ShapedList
( -- * Shaped lists (sized, where size is shape) and their permutations
IntSh, IndexSh
, SizedListS(..)
, SizedListS, pattern (::$), pattern ZS
, consShaped, unconsContShaped
, singletonSized, snocSized, appendSized
, headSized, tailSized, takeSized, dropSized, splitAt_Sized
Expand Down Expand Up @@ -42,6 +43,17 @@ import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat, Nat, type (*))
import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Nested
( IxS (..)
, ListS
, StaticShapeS (..)
, pattern (:$$)
, pattern (:.$)
, pattern (::$)
, pattern ZIS
, pattern ZS
, pattern ZSS
)
import HordeAd.Core.Types
import HordeAd.Util.SizedList (Permutation)
import qualified HordeAd.Util.SizedList as SizedList
Expand All @@ -60,28 +72,15 @@ type IntSh (f :: TensorType ty) (n :: Nat) = ShapedNat n (IntOf f)
-- and up to evaluation.
type IndexSh (f :: TensorType ty) (sh :: [Nat]) = IndexS sh (IntOf f)

-- | Lists indexed by shapes, that is, lists of the GHC @Nat@.
type SizedListS n i = ListS n i

-- | Strict lists indexed by shapes, that is, lists of the GHC @Nat@.
infixr 3 ::$
type role SizedListS nominal representational
data SizedListS (sh :: [Nat]) i where
ZS :: SizedListS '[] i
(::$) :: forall k sh {i}. (KnownNat k, KnownShape sh)
=> i -> SizedListS sh i -> SizedListS (k : sh) i

deriving instance Eq i => Eq (SizedListS sh i)

deriving instance Ord i => Ord (SizedListS sh i)

{-
-- This is only lawful when OverloadedLists is enabled.
-- However, it's much more readable when tracing and debugging.
instance Show i => Show (SizedListS sh i) where
showsPrec d l = showsPrec d (sizedToList l)

deriving stock instance Functor (SizedListS sh)

instance Foldable (SizedListS sh) where
foldr f z l = foldr f z (sizedToList l)
-}

instance KnownShape sh => IsList (SizedListS sh i) where
type Item (SizedListS sh i) = i
Expand Down Expand Up @@ -222,40 +221,18 @@ shapedToSized = SizedList.listToSized . sizedToList

-- * Tensor indexes as fully encapsulated shaped lists, with operations

type role IndexS nominal representational
newtype IndexS sh i = IndexS (SizedListS sh i)
deriving (Eq, Ord)
type IndexS sh i = IxS sh i

pattern IndexS :: forall {sh :: [Nat]} {i}. ListS sh i -> IxS sh i
pattern IndexS l = IxS l
{-# COMPLETE IndexS #-}

{-
-- This is only lawful when OverloadedLists is enabled.
-- However, it's much more readable when tracing and debugging.
instance Show i => Show (IndexS sh i) where
showsPrec d (IndexS l) = showsPrec d l

pattern ZIS :: forall sh i. () => sh ~ '[] => IndexS sh i
pattern ZIS = IndexS ZS

infixr 3 :.$
pattern (:.$)
:: forall {sh1} {i}. ()
=> forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1)
=> i -> IndexS sh i -> IndexS sh1 i
pattern i :.$ shl <- (unconsIndex -> Just (UnconsIndexRes shl i))
where i :.$ (IndexS shl) = IndexS (i ::$ shl)
{-# COMPLETE ZIS, (:.$) #-}

type role UnconsIndexRes representational nominal
data UnconsIndexRes i sh1 =
forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1)
=> UnconsIndexRes (IndexS sh i) i
unconsIndex :: IndexS sh1 i -> Maybe (UnconsIndexRes i sh1)
unconsIndex (IndexS shl) = case shl of
i ::$ shl' -> Just (UnconsIndexRes (IndexS shl') i)
ZS -> Nothing

deriving newtype instance Functor (IndexS n)

instance Foldable (IndexS n) where
foldr f z l = foldr f z (indexToList l)
-}

instance KnownShape sh => IsList (IndexS sh i) where
type Item (IndexS sh i) = i
Expand Down Expand Up @@ -320,40 +297,18 @@ shapedToIndex = SizedList.listToIndex . indexToList

-- * Tensor shapes as fully encapsulated shaped lists, with operations

type role ShapeS nominal representational
newtype ShapeS sh i = ShapeS (SizedListS sh i)
deriving (Eq, Ord)
type ShapeS sh i = StaticShapeS sh i

pattern ShapeS :: forall {sh :: [Nat]} {i}. ListS sh i -> StaticShapeS sh i
pattern ShapeS l = StaticShapeS l
{-# COMPLETE ShapeS #-}

{-
-- This is only lawful when OverloadedLists is enabled.
-- However, it's much more readable when tracing and debugging.
instance Show i => Show (ShapeS sh i) where
showsPrec d (ShapeS l) = showsPrec d l

pattern ZSS :: forall sh i. () => sh ~ '[] => ShapeS sh i
pattern ZSS = ShapeS ZS

infixr 3 :$$
pattern (:$$)
:: forall {sh1} {i}. ()
=> forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1)
=> i -> ShapeS sh i -> ShapeS sh1 i
pattern i :$$ shl <- (unconsShape -> Just (UnconsShapeRes shl i))
where i :$$ (ShapeS shl) = ShapeS (i ::$ shl)
{-# COMPLETE ZSS, (:$$) #-}

type role UnconsShapeRes representational nominal
data UnconsShapeRes i sh1 =
forall k sh. (KnownNat k, KnownShape sh, (k : sh) ~ sh1)
=> UnconsShapeRes (ShapeS sh i) i
unconsShape :: ShapeS sh1 i -> Maybe (UnconsShapeRes i sh1)
unconsShape (ShapeS shl) = case shl of
i ::$ shl' -> Just (UnconsShapeRes (ShapeS shl') i)
ZS -> Nothing

deriving newtype instance Functor (ShapeS n)

instance Foldable (ShapeS n) where
foldr f z l = foldr f z (shapeToList l)
-}

instance KnownShape sh => IsList (ShapeS sh i) where
type Item (ShapeS sh i) = i
Expand Down Expand Up @@ -382,7 +337,6 @@ listToShape = ShapeS . listToSized
shapeToList :: ShapeS sh i -> [i]
shapeToList (ShapeS l) = sizedToList l


-- * Operations involving both indexes and shapes

-- | Given a multidimensional index, get the corresponding linear
Expand Down
6 changes: 3 additions & 3 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -987,11 +987,11 @@ testReluSimplerPP4S2 = do
reluT2 (t, r) = reluS (t * sreplicate0N r)
let (artifactRev, _deltas) = revArtifactAdapt True reluT2 (FlipS $ OS.constant 128, 42)
printArtifactPretty renames artifactRev
@?= "\\m12 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m11 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) ; m13 = m11 * m12 in [m6 * m13, ssum (sreshape (m1 * m13))]"
@?= "\\m11 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m10 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) ; m12 = m10 * m11 in [m6 * m12, ssum (sreshape (m1 * m12))]"
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m11 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) in [m11 * m7]"
@?= "\\m1 x2 -> let m6 = sreshape (sreplicate x2) ; m7 = m1 * m6 ; m10 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) in [m10 * m7]"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m12 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m13 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m1 !$ [i8, i9] * m6 !$ [i8, i9] <=. 0.0) 0 1]) * m12 in [m6 * m13, ssum (sreshape (m1 * m13))]"
@?= "\\m11 m1 x2 -> let m6 = sreshape (sreplicate x2) ; m12 = sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m1 !$ [i8, i9] * m6 !$ [i8, i9] <=. 0.0) 0 1]) * m11 in [m6 * m12, ssum (sreshape (m1 * m12))]"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\m1 x2 -> let m7 = m1 * sreshape (sreplicate x2) in [sgather (sreplicate (sconst @[2] (fromList @[2] [0.0,1.0]))) (\\[i8, i9] -> [i8, ifF (m7 !$ [i8, i9] <=. 0.0) 0 1]) * m7]"

Expand Down

0 comments on commit 092624d

Please sign in to comment.