Skip to content

Commit

Permalink
Remove a constraint from sminIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 26, 2025
1 parent ef81cb6 commit a5472f1
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 25 deletions.
6 changes: 2 additions & 4 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,10 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
=> AstTensor ms s (TKS sh r1) -> AstTensor ms s (TKS sh r2)

-- Shaped tensor operations
AstMinIndexS :: ( KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2
, GoodScalar r2, KnownShS (Init (n ': sh)) )
AstMinIndexS :: (KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS (n ': sh) r)
-> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2)
AstMaxIndexS :: ( KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2
, GoodScalar r2, KnownShS (Init (n ': sh)) )
AstMaxIndexS :: (KnownShS sh, KnownNat n, GoodScalar r, GoodScalar r2)
=> AstTensor ms PrimalSpan (TKS (n ': sh) r)
-> AstTensor ms PrimalSpan (TKS (Init (n ': sh)) r2)
AstIotaS :: (KnownNat n, GoodScalar r)
Expand Down
5 changes: 3 additions & 2 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import Data.Array.Nested.Internal.Shape
, shrRank
, shrSize
, shsAppend
, shsInit
, shsPermutePrefix
, shsRank
, shsSize
Expand Down Expand Up @@ -96,8 +97,8 @@ ftkAst t = case t of
AstCastK{} -> FTKScalar
AstFromIntegralK{} -> FTKScalar

AstMinIndexS{} -> FTKS knownShS FTKScalar
AstMaxIndexS{} -> FTKS knownShS FTKScalar
AstMinIndexS @sh @n _ -> FTKS (shsInit (knownShS @(n ': sh))) FTKScalar
AstMaxIndexS @sh @n _ -> FTKS (shsInit (knownShS @(n ': sh))) FTKScalar
AstFloorS{} -> FTKS knownShS FTKScalar
AstIotaS{} -> FTKS knownShS FTKScalar
AstN1S{} -> FTKS knownShS FTKScalar
Expand Down
9 changes: 6 additions & 3 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ import Data.Array.Nested
, IxX (..)
, StaticShX(..)
, ShX (..)
, ShS (..)
, KnownShS (..)
, KnownShX (..)
, Rank
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shCvtSX, withKnownShS, shsAppend)
import Data.Array.Nested.Internal.Shape (shsInit, shCvtSX, withKnownShS, shsAppend)
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape

import HordeAd.Core.CarriersADVal
Expand Down Expand Up @@ -236,10 +237,12 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
rfromK (D t d) =
dDnotShared (rfromK t) (DeltaFromS $ DeltaSFromK d)

sminIndex (D u _) =
sminIndex @_ @_ @sh @n (D u _) =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
let v = sminIndex u
in fromPrimalADVal v
smaxIndex (D u _) =
smaxIndex @_ @_ @sh @n (D u _) =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
let v = smaxIndex u
in fromPrimalADVal v
sfloor (D u _) =
Expand Down
28 changes: 20 additions & 8 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,12 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
kfromR = astFromS stensorKind . astSFromR @'[]
rfromK @r = astFromS (stensorKind @(TKR 0 r)) . astFromK

sminIndex = fromPrimal . AstMinIndexS . astSpanPrimal
smaxIndex = fromPrimal . AstMaxIndexS . astSpanPrimal
sminIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
fromPrimal . AstMinIndexS . astSpanPrimal $ a
smaxIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
fromPrimal . AstMaxIndexS . astSpanPrimal $ a
sfloor = fromPrimal . AstFloorS . astSpanPrimal

siota = fromPrimal $ AstIotaS
Expand Down Expand Up @@ -1176,8 +1180,12 @@ instance AstSpan s => BaseTensor (AstRaw s) where
xfromK @r = AstRaw . AstFromS (stensorKind @(TKX '[] r))
. AstSFromK . unAstRaw

sminIndex = AstRaw . fromPrimal . AstMinIndexS . astSpanPrimalRaw . unAstRaw
smaxIndex = AstRaw . fromPrimal . AstMaxIndexS . astSpanPrimalRaw . unAstRaw
sminIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
AstRaw . fromPrimal . AstMinIndexS . astSpanPrimalRaw . unAstRaw $ a
smaxIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
AstRaw . fromPrimal . AstMaxIndexS . astSpanPrimalRaw . unAstRaw $ a
sfloor = AstRaw . fromPrimal . AstFloorS . astSpanPrimalRaw . unAstRaw
siota = AstRaw . fromPrimal $ AstIotaS
sindex v ix = AstRaw $ AstIndexS (unAstRaw v) (unAstRaw <$> ix)
Expand Down Expand Up @@ -1913,10 +1921,14 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where
xfromK @r = AstNoSimplify . AstFromS (stensorKind @(TKX '[] r))
. AstSFromK . unAstNoSimplify

sminIndex = AstNoSimplify . fromPrimal . AstMinIndexS
. astSpanPrimal . unAstNoSimplify
smaxIndex = AstNoSimplify . fromPrimal . AstMaxIndexS
. astSpanPrimal . unAstNoSimplify
sminIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
AstNoSimplify . fromPrimal . AstMinIndexS
. astSpanPrimal . unAstNoSimplify $ a
smaxIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
AstNoSimplify . fromPrimal . AstMaxIndexS
. astSpanPrimal . unAstNoSimplify $ a
sfloor = AstNoSimplify . fromPrimal . AstFloorS
. astSpanPrimal . unAstNoSimplify
siota = AstNoSimplify . fromPrimal $ AstIotaS
Expand Down
10 changes: 7 additions & 3 deletions src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Mixed qualified as Nested.Internal.Mixed
import Data.Array.Nested.Internal.Ranked qualified as Nested.Internal
import Data.Array.Nested.Internal.Shape
(shrRank, shrSize, shsTail, withKnownShS, shrTail, shsAppend, shsProduct, shsSize)
(shsInit, shrRank, shrSize, shsTail, withKnownShS, shrTail, shsAppend, shsProduct, shsSize)
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape
import Data.Array.Nested.Internal.Shaped qualified as Nested.Internal
import Data.Array.Mixed.Types (Init)
Expand Down Expand Up @@ -188,8 +188,12 @@ instance BaseTensor RepN where
RepN $ liftVR (V.map (* Nested.runScalar (unRepN s))) (unRepN v)
rdot1In u v = RepN $ Nested.rdot1Inner (unRepN u) (unRepN v)

sminIndex = RepN . tminIndexS . unRepN
smaxIndex = RepN . tmaxIndexS . unRepN
sminIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
RepN . tminIndexS . unRepN $ a
smaxIndex @_ @_ @sh @n a =
withKnownShS (shsInit (SNat @n :$$ knownShS @sh)) $
RepN . tmaxIndexS . unRepN $ a
sfloor = RepN . liftVS (V.map floor) . unRepN
siota @n = let n = valueOf @n
in RepN $ Nested.sfromList1 SNat
Expand Down
9 changes: 4 additions & 5 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,9 @@ class ( Num (IntOf target)
rlength v = case rshape v of
ZSR -> error "rlength: impossible pattern needlessly required"
k :$: _ -> k
rminIndex, rmaxIndex
rminIndex, rmaxIndex -- partial
:: (GoodScalar r, GoodScalar r2, KnownNat n)
=> target (TKR (1 + n) r) -> target (TKR n r2) -- partial
=> target (TKR (1 + n) r) -> target (TKR n r2)
rfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2, KnownNat n)
=> target (TKR n r) -> target (TKR n r2)
riota :: GoodScalar r => Int -> target (TKR 1 r) -- from 0 to n - 1
Expand Down Expand Up @@ -697,9 +697,8 @@ class ( Num (IntOf target)
slength :: forall r n sh. (TensorKind r, KnownNat n)
=> target (TKS2 (n ': sh) r) -> Int
slength _ = valueOf @n
sminIndex, smaxIndex
:: ( GoodScalar r, GoodScalar r2, KnownShS sh, KnownNat n
, KnownShS (Init (n ': sh)) ) -- partial
sminIndex, smaxIndex -- partial
:: (GoodScalar r, GoodScalar r2, KnownShS sh, KnownNat n)
=> target (TKS (n ': sh) r) -> target (TKS (Init (n ': sh)) r2)
sfloor :: (GoodScalar r, RealFrac r, GoodScalar r2, Integral r2, KnownShS sh)
=> target (TKS sh r) -> target (TKS sh r2)
Expand Down

0 comments on commit a5472f1

Please sign in to comment.