From a5472f11332e0c806fa6c644088a9eccc712f08a Mon Sep 17 00:00:00 2001 From: Mikolaj Konarski Date: Sun, 26 Jan 2025 21:59:24 +0100 Subject: [PATCH] Remove a constraint from sminIndex --- src/HordeAd/Core/Ast.hs | 6 ++---- src/HordeAd/Core/AstTools.hs | 5 +++-- src/HordeAd/Core/OpsADVal.hs | 9 ++++++--- src/HordeAd/Core/OpsAst.hs | 28 ++++++++++++++++++++-------- src/HordeAd/Core/OpsConcrete.hs | 10 +++++++--- src/HordeAd/Core/TensorClass.hs | 9 ++++----- 6 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index 3718c57d8..31ce29535 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -362,12 +362,10 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType => AstTensor ms s (TKS sh r1) -> AstTensor ms s (TKS sh r2) -- Shaped tensor operations - AstMinIndexS :: ( KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2 - , GoodScalar r2, KnownShS (Init (n ': sh)) ) + AstMinIndexS :: (KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2) => AstTensor ms PrimalSpan (TKS (n ': sh) r) -> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2) - AstMaxIndexS :: ( KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2 - , GoodScalar r2, KnownShS (Init (n ': sh)) ) + AstMaxIndexS :: (KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2) => AstTensor ms PrimalSpan (TKS (n ': sh) r) -> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2) AstIotaS :: (KnownNat n, GoodScalar r) diff --git a/src/HordeAd/Core/AstTools.hs b/src/HordeAd/Core/AstTools.hs index 0605a8231..89d974757 100644 --- a/src/HordeAd/Core/AstTools.hs +++ b/src/HordeAd/Core/AstTools.hs @@ -46,6 +46,7 @@ import Data.Array.Nested.Internal.Shape , shrRank , shrSize , shsAppend + , shsInit , shsPermutePrefix , shsRank , shsSize @@ -96,8 +97,8 @@ ftkAst t = case t of AstCastK{} -> FTKScalar AstFromIntegralK{} -> FTKScalar - AstMinIndexS{} -> FTKS knownShS FTKScalar - AstMaxIndexS{} -> FTKS knownShS FTKScalar + AstMinIndexS @sh @n _ -> FTKS (shsInit (knownShS @(n ': sh))) FTKScalar + AstMaxIndexS @sh @n _ -> FTKS (shsInit (knownShS @(n ': sh))) FTKScalar AstFloorS{} -> FTKS knownShS FTKScalar AstIotaS{} -> FTKS knownShS FTKScalar AstN1S{} -> FTKS knownShS FTKScalar diff --git a/src/HordeAd/Core/OpsADVal.hs b/src/HordeAd/Core/OpsADVal.hs index ce9c62363..23d864ae7 100644 --- a/src/HordeAd/Core/OpsADVal.hs +++ b/src/HordeAd/Core/OpsADVal.hs @@ -31,12 +31,13 @@ import Data.Array.Nested , IxX (..) , StaticShX(..) , ShX (..) + , ShS (..) , KnownShS (..) , KnownShX (..) , Rank ) import Data.Array.Nested qualified as Nested -import Data.Array.Nested.Internal.Shape (shCvtSX, withKnownShS, shsAppend) +import Data.Array.Nested.Internal.Shape (shsInit, shCvtSX, withKnownShS, shsAppend) import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape import HordeAd.Core.CarriersADVal @@ -236,10 +237,12 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) rfromK (D t d) = dDnotShared (rfromK t) (DeltaFromS $ DeltaSFromK d) - sminIndex (D u _) = + sminIndex @_ @_ @sh @n (D u _) = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ let v = sminIndex u in fromPrimalADVal v - smaxIndex (D u _) = + smaxIndex @_ @_ @sh @n (D u _) = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ let v = smaxIndex u in fromPrimalADVal v sfloor (D u _) = diff --git a/src/HordeAd/Core/OpsAst.hs b/src/HordeAd/Core/OpsAst.hs index 83196d545..233622b1d 100644 --- a/src/HordeAd/Core/OpsAst.hs +++ b/src/HordeAd/Core/OpsAst.hs @@ -416,8 +416,12 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where kfromR = astFromS stensorKind . astSFromR @'[] rfromK @r = astFromS (stensorKind @(TKR 0 r)) . astFromK - sminIndex = fromPrimal . AstMinIndexS . astSpanPrimal - smaxIndex = fromPrimal . AstMaxIndexS . astSpanPrimal + sminIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + fromPrimal . AstMinIndexS . astSpanPrimal $ a + smaxIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + fromPrimal . AstMaxIndexS . astSpanPrimal $ a sfloor = fromPrimal . AstFloorS . astSpanPrimal siota = fromPrimal $ AstIotaS @@ -1176,8 +1180,12 @@ instance AstSpan s => BaseTensor (AstRaw s) where xfromK @r = AstRaw . AstFromS (stensorKind @(TKX '[] r)) . AstSFromK . unAstRaw - sminIndex = AstRaw . fromPrimal . AstMinIndexS . astSpanPrimalRaw . unAstRaw - smaxIndex = AstRaw . fromPrimal . AstMaxIndexS . astSpanPrimalRaw . unAstRaw + sminIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + AstRaw . fromPrimal . AstMinIndexS . astSpanPrimalRaw . unAstRaw $ a + smaxIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + AstRaw . fromPrimal . AstMaxIndexS . astSpanPrimalRaw . unAstRaw $ a sfloor = AstRaw . fromPrimal . AstFloorS . astSpanPrimalRaw . unAstRaw siota = AstRaw . fromPrimal $ AstIotaS sindex v ix = AstRaw $ AstIndexS (unAstRaw v) (unAstRaw <$> ix) @@ -1913,10 +1921,14 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where xfromK @r = AstNoSimplify . AstFromS (stensorKind @(TKX '[] r)) . AstSFromK . unAstNoSimplify - sminIndex = AstNoSimplify . fromPrimal . AstMinIndexS - . astSpanPrimal . unAstNoSimplify - smaxIndex = AstNoSimplify . fromPrimal . AstMaxIndexS - . astSpanPrimal . unAstNoSimplify + sminIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + AstNoSimplify . fromPrimal . AstMinIndexS + . astSpanPrimal . unAstNoSimplify $ a + smaxIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + AstNoSimplify . fromPrimal . AstMaxIndexS + . astSpanPrimal . unAstNoSimplify $ a sfloor = AstNoSimplify . fromPrimal . AstFloorS . astSpanPrimal . unAstNoSimplify siota = AstNoSimplify . fromPrimal $ AstIotaS diff --git a/src/HordeAd/Core/OpsConcrete.hs b/src/HordeAd/Core/OpsConcrete.hs index c228c9bbf..2ab1b01b2 100644 --- a/src/HordeAd/Core/OpsConcrete.hs +++ b/src/HordeAd/Core/OpsConcrete.hs @@ -55,7 +55,7 @@ import Data.Array.Nested qualified as Nested import Data.Array.Nested.Internal.Mixed qualified as Nested.Internal.Mixed import Data.Array.Nested.Internal.Ranked qualified as Nested.Internal import Data.Array.Nested.Internal.Shape - (shrRank, shrSize, shsTail, withKnownShS, shrTail, shsAppend, shsProduct, shsSize) + (shsInit, shrRank, shrSize, shsTail, withKnownShS, shrTail, shsAppend, shsProduct, shsSize) import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape import Data.Array.Nested.Internal.Shaped qualified as Nested.Internal import Data.Array.Mixed.Types (Init) @@ -188,8 +188,12 @@ instance BaseTensor RepN where RepN $ liftVR (V.map (* Nested.runScalar (unRepN s))) (unRepN v) rdot1In u v = RepN $ Nested.rdot1Inner (unRepN u) (unRepN v) - sminIndex = RepN . tminIndexS . unRepN - smaxIndex = RepN . tmaxIndexS . unRepN + sminIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + RepN . tminIndexS . unRepN $ a + smaxIndex @_ @_ @sh @n a = + withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $ + RepN . tmaxIndexS . unRepN $ a sfloor = RepN . liftVS (V.map floor) . unRepN siota @n = let n = valueOf @n in RepN $ Nested.sfromList1 SNat diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index d6d66bf66..f530cc3e8 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -400,9 +400,9 @@ class ( Num (IntOf target) rlength v = case rshape v of ZSR -> error "rlength: impossible pattern needlessly required" k :$: _ -> k - rminIndex, rmaxIndex + rminIndex, rmaxIndex -- partial :: (GoodScalar r, GoodScalar r2, KnownNat n) - => target (TKR (1 + n) r) -> target (TKR n r2) -- partial + => target (TKR (1 + n) r) -> target (TKR n r2) rfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2, KnownNat n) => target (TKR n r) -> target (TKR n r2) riota :: GoodScalar r => Int -> target (TKR 1 r) -- from 0 to n - 1 @@ -697,9 +697,8 @@ class ( Num (IntOf target) slength :: forall r n sh. (TensorKind r, KnownNat n) => target (TKS2 (n ': sh) r) -> Int slength _ = valueOf @n - sminIndex, smaxIndex - :: ( GoodScalar r, GoodScalar r2, KnownShS sh, KnownNat n - , KnownShS (Init (n ': sh)) ) -- partial + sminIndex, smaxIndex -- partial + :: (GoodScalar r, GoodScalar r2, KnownShS sh, KnownNat n) => target (TKS (n ': sh) r) -> target (TKS (Init (n ': sh)) r2) sfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2, KnownShS sh) => target (TKS sh r) -> target (TKS sh r2)