Skip to content

Commit

Permalink
Use lemRankMapJust instead of unsafeCoerceRefl
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Jan 29, 2025
1 parent b47d13d commit a28f29a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSFromX :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind r)
=> AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r)

-- Nesting/unnesting
AstXNestR :: (KnownShX sh1, KnownNat m, TensorKind x)
=> AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x)
-> AstTensor ms s (TKX2 sh1 (TKR2 m x))
Expand Down
13 changes: 7 additions & 6 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import Data.Array.Nested
, type (++)
)
import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Lemmas
import Data.Array.Nested.Internal.Shape
( shCvtSX
, shrAppend
Expand Down Expand Up @@ -1398,7 +1399,7 @@ class ( Num (IntOf target)
=> ShS sh1 -> target (TKX2 (MapJust sh1 ++ Replicate m Nothing) x)
-> target (TKS2 sh1 (TKR2 m x))
snestR sh1 =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust sh1) $
withKnownShS sh1 $
withKnownShX (ssxFromShape (shCvtSX sh1)) $
sfromX . xnestR (ssxFromShape (shCvtSX sh1))
Expand All @@ -1407,7 +1408,7 @@ class ( Num (IntOf target)
=> ShS sh1 -> target (TKS2 (sh1 ++ sh2) x)
-> target (TKS2 sh1 (TKS2 sh2 x))
snest sh1 =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust sh1) $
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1 ++ MapJust sh2)
:~: Rank (sh1 ++ sh2)) $
withKnownShS sh1 $
Expand All @@ -1421,7 +1422,7 @@ class ( Num (IntOf target)
=> ShS sh1 -> target (TKX2 (MapJust sh1 ++ sh2) x)
-> target (TKS2 sh1 (TKX2 sh2 x))
snestX sh1 =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust sh1) $
withKnownShS sh1 $
withKnownShX (ssxFromShape (shCvtSX sh1)) $
sfromX . xnest (ssxFromShape (shCvtSX sh1))
Expand Down Expand Up @@ -1471,14 +1472,14 @@ class ( Num (IntOf target)
=> target (TKS2 sh1 (TKR2 m x))
-> target (TKX2 (MapJust sh1 ++ Replicate m Nothing) x)
sunNestR =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust (knownShS @sh1)) $
withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1))) $
xunNestR . xfromS @_ @_ @(MapJust sh1)
sunNest :: forall sh1 sh2 x.
(TensorKind x, KnownShS sh1, KnownShS sh2)
=> target (TKS2 sh1 (TKS2 sh2 x)) -> target (TKS2 (sh1 ++ sh2) x)
sunNest =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust (knownShS @sh1)) $
gcastWith (unsafeCoerceRefl
:: Rank (MapJust sh1 ++ MapJust sh2) :~: Rank (sh1 ++ sh2)) $
withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $
Expand All @@ -1491,7 +1492,7 @@ class ( Num (IntOf target)
=> target (TKS2 sh1 (TKX2 sh2 x))
-> target (TKX2 (MapJust sh1 ++ sh2) x)
sunNestX =
gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $
gcastWith (lemRankMapJust (knownShS @sh1)) $
withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1))) $
withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1))
`ssxAppend` knownShX @sh2) $
Expand Down
24 changes: 6 additions & 18 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -653,15 +653,9 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
sfromX = astSFromX

-- Nesting/unnesting
xnestR sh =
withKnownShX sh $
astXNestR
xnestS sh =
withKnownShX sh $
astXNestS
xnest sh =
withKnownShX sh $
astXNest
xnestR sh = withKnownShX sh $ astXNestR
xnestS sh = withKnownShX sh $ astXNestS
xnest sh = withKnownShX sh $ astXNest
xunNestR = astXUnNestR
xunNestS = astXUnNestS
xunNest = astXUnNest
Expand Down Expand Up @@ -1247,15 +1241,9 @@ instance AstSpan s => BaseTensor (AstRaw s) where
xfromS @_ @sh' @x = AstRaw . AstFromS (stensorKind @(TKX2 sh' x)) . unAstRaw

-- Nesting/unnesting
xnestR sh =
withKnownShX sh $
AstRaw . AstXNestR . unAstRaw
xnestS sh =
withKnownShX sh $
AstRaw . AstXNestS . unAstRaw
xnest sh =
withKnownShX sh $
AstRaw . AstXNest . unAstRaw
xnestR sh = withKnownShX sh $ AstRaw . AstXNestR . unAstRaw
xnestS sh = withKnownShX sh $ AstRaw . AstXNestS . unAstRaw
xnest sh = withKnownShX sh $ AstRaw . AstXNest . unAstRaw
xunNestR = AstRaw . AstXUnNestR . unAstRaw
xunNestS = AstRaw . AstXUnNestS . unAstRaw
xunNest = AstRaw . AstXUnNest . unAstRaw
Expand Down

0 comments on commit a28f29a

Please sign in to comment.