diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index a34bba625..e0f37eab6 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -413,6 +413,7 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType AstSFromX :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind r) => AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r) + -- Nesting/unnesting AstXNestR :: (KnownShX sh1, KnownNat m, TensorKind x) => AstTensor ms s (TKX2 (sh1 ++ Replicate m Nothing) x) -> AstTensor ms s (TKX2 sh1 (TKR2 m x)) diff --git a/src/HordeAd/Core/Ops.hs b/src/HordeAd/Core/Ops.hs index 059b16472..2fc26a862 100644 --- a/src/HordeAd/Core/Ops.hs +++ b/src/HordeAd/Core/Ops.hs @@ -68,6 +68,7 @@ import Data.Array.Nested , type (++) ) import Data.Array.Nested qualified as Nested +import Data.Array.Nested.Internal.Lemmas import Data.Array.Nested.Internal.Shape ( shCvtSX , shrAppend @@ -1398,7 +1399,7 @@ class ( Num (IntOf target) => ShS sh1 -> target (TKX2 (MapJust sh1 ++ Replicate m Nothing) x) -> target (TKS2 sh1 (TKR2 m x)) snestR sh1 = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust sh1) $ withKnownShS sh1 $ withKnownShX (ssxFromShape (shCvtSX sh1)) $ sfromX . xnestR (ssxFromShape (shCvtSX sh1)) @@ -1407,7 +1408,7 @@ class ( Num (IntOf target) => ShS sh1 -> target (TKS2 (sh1 ++ sh2) x) -> target (TKS2 sh1 (TKS2 sh2 x)) snest sh1 = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust sh1) $ gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1 ++ MapJust sh2) :~: Rank (sh1 ++ sh2)) $ withKnownShS sh1 $ @@ -1421,7 +1422,7 @@ class ( Num (IntOf target) => ShS sh1 -> target (TKX2 (MapJust sh1 ++ sh2) x) -> target (TKS2 sh1 (TKX2 sh2 x)) snestX sh1 = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust sh1) $ withKnownShS sh1 $ withKnownShX (ssxFromShape (shCvtSX sh1)) $ sfromX . xnest (ssxFromShape (shCvtSX sh1)) @@ -1471,14 +1472,14 @@ class ( Num (IntOf target) => target (TKS2 sh1 (TKR2 m x)) -> target (TKX2 (MapJust sh1 ++ Replicate m Nothing) x) sunNestR = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust (knownShS @sh1)) $ withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1))) $ xunNestR . xfromS @_ @_ @(MapJust sh1) sunNest :: forall sh1 sh2 x. (TensorKind x, KnownShS sh1, KnownShS sh2) => target (TKS2 sh1 (TKS2 sh2 x)) -> target (TKS2 (sh1 ++ sh2) x) sunNest = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust (knownShS @sh1)) $ gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1 ++ MapJust sh2) :~: Rank (sh1 ++ sh2)) $ withKnownShS (knownShS @sh1 `shsAppend` knownShS @sh2) $ @@ -1491,7 +1492,7 @@ class ( Num (IntOf target) => target (TKS2 sh1 (TKX2 sh2 x)) -> target (TKX2 (MapJust sh1 ++ sh2) x) sunNestX = - gcastWith (unsafeCoerceRefl :: Rank (MapJust sh1) :~: Rank sh1) $ + gcastWith (lemRankMapJust (knownShS @sh1)) $ withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1))) $ withKnownShX (ssxFromShape (shCvtSX (knownShS @sh1)) `ssxAppend` knownShX @sh2) $ diff --git a/src/HordeAd/Core/OpsAst.hs b/src/HordeAd/Core/OpsAst.hs index ca077c653..0c00df098 100644 --- a/src/HordeAd/Core/OpsAst.hs +++ b/src/HordeAd/Core/OpsAst.hs @@ -653,15 +653,9 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where sfromX = astSFromX -- Nesting/unnesting - xnestR sh = - withKnownShX sh $ - astXNestR - xnestS sh = - withKnownShX sh $ - astXNestS - xnest sh = - withKnownShX sh $ - astXNest + xnestR sh = withKnownShX sh $ astXNestR + xnestS sh = withKnownShX sh $ astXNestS + xnest sh = withKnownShX sh $ astXNest xunNestR = astXUnNestR xunNestS = astXUnNestS xunNest = astXUnNest @@ -1247,15 +1241,9 @@ instance AstSpan s => BaseTensor (AstRaw s) where xfromS @_ @sh' @x = AstRaw . AstFromS (stensorKind @(TKX2 sh' x)) . unAstRaw -- Nesting/unnesting - xnestR sh = - withKnownShX sh $ - AstRaw . AstXNestR . unAstRaw - xnestS sh = - withKnownShX sh $ - AstRaw . AstXNestS . unAstRaw - xnest sh = - withKnownShX sh $ - AstRaw . AstXNest . unAstRaw + xnestR sh = withKnownShX sh $ AstRaw . AstXNestR . unAstRaw + xnestS sh = withKnownShX sh $ AstRaw . AstXNestS . unAstRaw + xnest sh = withKnownShX sh $ AstRaw . AstXNest . unAstRaw xunNestR = AstRaw . AstXUnNestR . unAstRaw xunNestS = AstRaw . AstXUnNestS . unAstRaw xunNest = AstRaw . AstXUnNest . unAstRaw