diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index c8590d5ce..550e9f4d1 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -116,9 +116,10 @@ import HordeAd.Core.AstFreshId import HordeAd.Core.AstTools import HordeAd.Core.CarriersConcrete import HordeAd.Core.HVectorOps +import HordeAd.Core.OpsConcrete () +import HordeAd.Core.TensorClass import HordeAd.Core.TensorKind import HordeAd.Core.Types -import HordeAd.Internal.BackendOX import HordeAd.Util.ShapedList qualified as ShapedList import HordeAd.Util.SizedList @@ -598,14 +599,14 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) = Ast.AstCastR t -> astCastR $ astIndexKnobsR knobs t ix Ast.AstFromIntegralR v -> astFromIntegralR $ astIndexKnobsR knobs v ix AstConcrete (FTKR _ x) t -> - let unConst :: AstInt AstMethodLet -> Maybe [Int64] - -> Maybe [Int64] - unConst (AstConcrete _ (RepN i)) (Just l) = Just $ i : l + let unConst :: AstInt AstMethodLet -> Maybe [IntOf RepN] + -> Maybe [IntOf RepN] + unConst (AstConcrete _ i) (Just l) = Just $ i : l unConst _ _ = Nothing in case foldr unConst (Just []) ix of Just ixInt -> - let u = tindexZR (unRepN t) $ listToIndex ixInt - in AstConcrete (FTKR (Nested.rshape u) x) $ RepN u + let u = rindex t (fromList ixInt) + in AstConcrete (FTKR (rshape u) x) u -- TODO: we'd need mapM for Index to keep this rank-typed Nothing -> Ast.AstIndex v0 ix Ast.AstProjectR{} -> Ast.AstIndex v0 ix @@ -854,14 +855,13 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) | Ast.AstCastS t -> astCastS $ astIndexKnobsS knobs t ix Ast.AstFromIntegralS v -> astFromIntegralS $ astIndexKnobsS knobs v ix AstConcrete (FTKS _ x) t -> - let unConst :: AstInt AstMethodLet -> Maybe [Int64] - -> Maybe [Int64] - unConst (AstConcrete _ (RepN i)) (Just l) = Just $ i : l + let unConst :: AstInt AstMethodLet -> Maybe [IntOf RepN] + -> Maybe [IntOf RepN] + unConst (AstConcrete _ i) (Just l) = Just $ i : l unConst _ _ = Nothing in case foldr unConst (Just []) ix of Just ixInt -> AstConcrete (FTKS knownShS x) - $ RepN $ tindexZS (unRepN t) - $ ShapedList.listToIndex @shm ixInt + $ sindex @_ @_ @shm t (fromList ixInt) -- TODO: we'd need mapM for Index to keep this rank-typed Nothing -> Ast.AstIndexS v0 ix Ast.AstProjectS{} -> Ast.AstIndexS v0 ix @@ -1523,7 +1523,7 @@ astSum t0 = case shapeAst t0 of Ast.AstSlice i 1 v -> astIndexR v (fromIntegral i :.: ZIR) Ast.AstReverse v -> astSum v AstConcrete (FTKR sh FTKScalar) t -> - AstConcrete (FTKR (Nested.Internal.Shape.shrTail sh) FTKScalar) $ RepN $ tsumR $ unRepN t + AstConcrete (FTKR (Nested.Internal.Shape.shrTail sh) FTKScalar) $ rsum t Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astSum v _ -> Ast.AstSum t0 @@ -1557,7 +1557,7 @@ astSumS t0 = case sameNat (Proxy @n) (Proxy @0) of Ast.AstSliceS @i @k v | Just Refl <- sameNat (Proxy @k) (Proxy @1) -> astIndexS v (valueOf @i :.$ ZIS) Ast.AstReverseS v -> astSumS v - AstConcrete _ t -> AstConcrete (FTKS knownShS FTKScalar) $ RepN $ Nested.ssumOuter1 $ unRepN t + AstConcrete _ t -> AstConcrete (FTKS knownShS FTKScalar) $ ssum t Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astSumS v _ -> Ast.AstSumS t0 @@ -1609,13 +1609,13 @@ astFromVector :: forall s r n. (KnownNat n, TensorKind2 r, AstSpan s) astFromVector v | V.length v == 1 = astReplicate (SNat @1) (v V.! 0) astFromVector l | Just Refl <- sameAstSpan @s @PrimalSpan = let unConst :: AstTensor AstMethodLet PrimalSpan (TKR2 n r) - -> Maybe (FullTensorKind r, Nested.Ranked n (RepORArray r)) - unConst (AstConcrete (FTKR _ x) (RepN t)) = Just (x, t) + -> Maybe (FullTensorKind r, RepN (TKR2 n r)) + unConst (AstConcrete (FTKR _ x) t) = Just (x, t) unConst _ = Nothing in case V.mapM unConst l of Just l4 | Just ((x, _), _) <- V.uncons l4 -> let l3 = V.map snd l4 - in AstConcrete (FTKR (V.length l :$: Nested.rshape (l3 V.! 0)) x) $ RepN $ tfromVectorR l3 + in AstConcrete (FTKR (V.length l :$: rshape (l3 V.! 0)) x) $ rfromVector l3 _ -> Ast.AstFromVector l astFromVector l | Just Refl <- sameAstSpan @s @FullSpan = let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKR2 n r) @@ -1635,13 +1635,13 @@ astFromVectorS :: forall s r n sh. astFromVectorS v | V.length v == 1 = astReplicate SNat (v V.! 0) astFromVectorS l | Just Refl <- sameAstSpan @s @PrimalSpan = let unConst :: AstTensor AstMethodLet PrimalSpan (TKS2 sh r) - -> Maybe (FullTensorKind r, Nested.Shaped sh (RepORArray r)) - unConst (AstConcrete (FTKS _ x) (RepN t)) = Just (x, t) + -> Maybe (FullTensorKind r, RepN (TKS2 sh r)) + unConst (AstConcrete (FTKS _ x) t) = Just (x, t) unConst _ = Nothing in case V.mapM unConst l of Just l4 | Just ((x, _), _) <- V.uncons l4 -> let l3 = V.map snd l4 - in AstConcrete (FTKS knownShS x) $ RepN $ tfromVectorS l3 + in AstConcrete (FTKS knownShS x) $ sfromVector l3 _ -> Ast.AstFromVectorS l astFromVectorS l | Just Refl <- sameAstSpan @s @FullSpan = let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKS2 sh r) @@ -1660,12 +1660,13 @@ astFromVectorX :: forall s r n sh. -> AstTensor AstMethodLet s (TKX (Just n ': sh) r) astFromVectorX v | V.length v == 1 = astReplicate SNat (v V.! 0) astFromVectorX l | Just Refl <- sameAstSpan @s @PrimalSpan = - let unConst :: AstTensor AstMethodLet PrimalSpan (TKX sh r) -> Maybe (Nested.Mixed sh r) - unConst (AstConcrete _ (RepN t)) = Just t + let unConst :: AstTensor AstMethodLet PrimalSpan (TKX sh r) + -> Maybe (RepN (TKX sh r)) + unConst (AstConcrete _ t) = Just t unConst _ = Nothing in case V.mapM unConst l of Just l3 | V.length l3 >= 1 -> - AstConcrete (FTKX (SKnown (SNat @n) :$% Nested.mshape (l3 V.! 0)) FTKScalar) $ RepN $ tfromVectorX l3 + AstConcrete (FTKX (SKnown (SNat @n) :$% xshape (l3 V.! 0)) FTKScalar) $ xfromVector l3 _ -> Ast.AstFromVectorX l astFromVectorX l | Just Refl <- sameAstSpan @s @FullSpan = let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKX sh r) @@ -1734,7 +1735,7 @@ astReplicateNS v = astReplicate0N :: forall n s r. (GoodScalar r, AstSpan s) => IShR n -> r -> AstTensor AstMethodLet s (TKR n r) -astReplicate0N sh = astReplicate0NT sh . fromPrimal . AstConcrete (FTKR ZSR FTKScalar) . RepN . Nested.rscalar +astReplicate0N sh = astReplicate0NT sh . fromPrimal . AstConcrete (FTKR ZSR FTKScalar) . rscalar astReplicate0NS :: forall shn s r. (KnownShS shn, GoodScalar r, AstSpan s) => r -> AstTensor AstMethodLet s (TKS shn r) @@ -1742,7 +1743,7 @@ astReplicate0NS = let go :: ShS sh' -> AstTensor AstMethodLet s (TKS '[] r) -> AstTensor AstMethodLet s (TKS sh' r) go ZSS v = v go ((:$$) SNat sh') v | Dict <- sshapeKnown sh' = astReplicate SNat $ go sh' v - in go (knownShS @shn) . fromPrimal . AstConcrete (FTKS ZSS FTKScalar) . RepN . Nested.sscalar + in go (knownShS @shn) . fromPrimal . AstConcrete (FTKS ZSS FTKScalar) . sscalar astReplicate0NT :: forall n s r. (GoodScalar r, AstSpan s) => IShR n -> AstTensor AstMethodLet s (TKR 0 r) -> AstTensor AstMethodLet s (TKR n r) @@ -1758,7 +1759,7 @@ astAppend :: (KnownNat n, GoodScalar r, AstSpan s) -> AstTensor AstMethodLet s (TKR (1 + n) r) astAppend (AstConcrete (FTKR (ulen :$: sh) FTKScalar) u) (AstConcrete (FTKR (vlen :$: _) FTKScalar) v) = - AstConcrete (FTKR (ulen + vlen :$: sh) FTKScalar) $ RepN $ tappendR (unRepN u) (unRepN v) + AstConcrete (FTKR (ulen + vlen :$: sh) FTKScalar) $ rappend u v astAppend (Ast.AstFromPrimal u) (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astAppend u v astAppend (Ast.AstFromVector l1) (Ast.AstFromVector l2) = @@ -1769,7 +1770,7 @@ astAppendS :: (KnownNat m, KnownNat n, KnownShS sh, GoodScalar r, AstSpan s) => AstTensor AstMethodLet s (TKS (m ': sh) r) -> AstTensor AstMethodLet s (TKS (n ': sh) r) -> AstTensor AstMethodLet s (TKS ((m + n) ': sh) r) astAppendS (AstConcrete _ u) (AstConcrete _ v) = - AstConcrete (FTKS knownShS FTKScalar) $ RepN $ tappendS (unRepN u) (unRepN v) + AstConcrete (FTKS knownShS FTKScalar) $ sappend u v astAppendS (Ast.AstFromPrimal u) (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astAppendS u v astAppendS (Ast.AstFromVectorS l1) (Ast.AstFromVectorS l2) = @@ -1780,7 +1781,7 @@ astSlice :: forall k s r. (KnownNat k, GoodScalar r, AstSpan s) => Int -> Int -> AstTensor AstMethodLet s (TKR (1 + k) r) -> AstTensor AstMethodLet s (TKR (1 + k) r) astSlice i n (AstConcrete (FTKR (_ :$: sh) FTKScalar) t) = - AstConcrete (FTKR (n :$: sh) FTKScalar) $ RepN $ tsliceR i n (unRepN t) + AstConcrete (FTKR (n :$: sh) FTKScalar) $ rslice i n t astSlice i n (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSlice i n v astSlice 0 n v | n == lengthAst v = v astSlice _i n (Ast.AstReplicate @y2 _ v) = case stensorKind @y2 of @@ -1817,7 +1818,7 @@ astSliceS :: forall i n k sh s r. => AstTensor AstMethodLet s (TKS (i + n + k ': sh) r) -> AstTensor AstMethodLet s (TKS (n ': sh) r) astSliceS (AstConcrete _ t) = - AstConcrete (FTKS knownShS FTKScalar) $ RepN $ tsliceS @i @n (unRepN t) + AstConcrete (FTKS knownShS FTKScalar) $ sslice (Proxy @i) (Proxy @n) t astSliceS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSliceS @i @n v astSliceS v | Just Refl <- sameNat (Proxy @i) (Proxy @0) , Just Refl <- sameNat (Proxy @k) (Proxy @0) = v @@ -1857,7 +1858,7 @@ astSliceS v = Ast.AstSliceS @i v astReverse :: forall n s r. (KnownNat n, TensorKind2 r, AstSpan s) => AstTensor AstMethodLet s (TKR2 (1 + n) r) -> AstTensor AstMethodLet s (TKR2 (1 + n) r) -astReverse (AstConcrete ftk t) = AstConcrete ftk $ RepN $ treverseR (unRepN t) +astReverse (AstConcrete ftk t) = AstConcrete ftk $ rreverse t astReverse (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astReverse v astReverse (Ast.AstReplicate k v) = astReplicate k v astReverse (Ast.AstFromVector l) = astFromVector $ V.reverse l @@ -1872,7 +1873,7 @@ astReverse v = Ast.AstReverse v astReverseS :: forall n sh s r. (KnownNat n, KnownShS sh, TensorKind2 r, AstSpan s) => AstTensor AstMethodLet s (TKS2 (n ': sh) r) -> AstTensor AstMethodLet s (TKS2 (n ': sh) r) -astReverseS (AstConcrete ftk t) = AstConcrete ftk $ RepN $ treverseS (unRepN t) +astReverseS (AstConcrete ftk t) = AstConcrete ftk $ sreverse t astReverseS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astReverseS v astReverseS (Ast.AstReplicate k v) = astReplicate k v astReverseS (Ast.AstFromVectorS l) = astFromVectorS $ V.reverse l @@ -1918,7 +1919,7 @@ astTranspose perm = \case -- TODO: should the below be backpermute or permute? astGatherR (Nested.Internal.Shape.shrPermutePrefix perm sh) v (Nested.Internal.Shape.listrPermutePrefix perm vars, ix) - AstConcrete (FTKR sh FTKScalar) t -> AstConcrete (FTKR (Nested.Internal.Shape.shrPermutePrefix perm sh) FTKScalar) $ RepN $ ttransposeR perm (unRepN t) + AstConcrete (FTKR sh FTKScalar) t -> AstConcrete (FTKR (Nested.Internal.Shape.shrPermutePrefix perm sh) FTKScalar) $ rtranspose perm t Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astTranspose perm v u -> Ast.AstTranspose perm u -- we don't go inside AstSumOfList, because they are usually long @@ -2074,7 +2075,7 @@ astReshape shOut = \case Ast.AstFromVector l | [x] <- V.toList l -> astReshape shOut x Ast.AstReshape _ v -> astReshape shOut v AstConcrete (FTKR _ x) t -> AstConcrete (FTKR shOut x) - $ RepN $ Nested.rreshape shOut (unRepN t) + $ rreshape shOut t Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astReshape shOut v v -> let shIn = shapeAst v in case sameNat (Proxy @p) (Proxy @m) of @@ -2104,7 +2105,7 @@ astReshapeS = \case astReshapeS $ l V.! 0 Ast.AstReshapeS v -> astReshapeS @_ @sh2 v AstConcrete (FTKS _ x) t -> AstConcrete (FTKS knownShS x) - $ RepN $ treshapeS (unRepN t) + $ sreshape t Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astReshapeS v v -> case sameShape @sh @sh2 of Just Refl -> v @@ -2112,7 +2113,7 @@ astReshapeS = \case astCast :: (GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2) => AstTensor ms s (TKScalar r1) -> AstTensor ms s (TKScalar r2) -astCast (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ realToFrac (unRepN t) +astCast (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ kcast t astCast (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCast v astCast (Ast.AstCast v) = astCast v astCast (Ast.AstFromIntegral v) = astFromIntegral v @@ -2120,7 +2121,7 @@ astCast v = Ast.AstCast v astCastR :: (KnownNat n, GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2) => AstTensor AstMethodLet s (TKR n r1) -> AstTensor AstMethodLet s (TKR n r2) -astCastR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tcastR (unRepN t) +astCastR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ rcast t astCastR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCastR v astCastR (Ast.AstCastR v) = astCastR v astCastR (Ast.AstFromIntegralR v) = astFromIntegralR v @@ -2129,7 +2130,7 @@ astCastR v = Ast.AstCastR v astCastS :: ( KnownShS sh, GoodScalar r1, GoodScalar r2, RealFrac r1 , RealFrac r2 ) => AstTensor AstMethodLet s (TKS sh r1) -> AstTensor AstMethodLet s (TKS sh r2) -astCastS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ RepN $ tcastS (unRepN t) +astCastS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ scast t astCastS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCastS v astCastS (Ast.AstCastS v) = astCastS v astCastS (Ast.AstFromIntegralS v) = astFromIntegralS v @@ -2138,21 +2139,21 @@ astCastS v = Ast.AstCastS v astFromIntegral :: (GoodScalar r1, GoodScalar r2, Integral r1) => AstTensor ms PrimalSpan (TKScalar r1) -> AstTensor ms PrimalSpan (TKScalar r2) -astFromIntegral (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ fromIntegral (unRepN t) +astFromIntegral (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ kfromIntegral t astFromIntegral (Ast.AstFromIntegral v) = astFromIntegral v astFromIntegral v = Ast.AstFromIntegral v astFromIntegralR :: (KnownNat n, GoodScalar r1, GoodScalar r2, Integral r1) => AstTensor AstMethodLet PrimalSpan (TKR n r1) -> AstTensor AstMethodLet PrimalSpan (TKR n r2) -astFromIntegralR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tfromIntegralR (unRepN t) +astFromIntegralR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ rfromIntegral t astFromIntegralR (Ast.AstFromIntegralR v) = astFromIntegralR v astFromIntegralR v = Ast.AstFromIntegralR v astFromIntegralS :: (KnownShS sh, GoodScalar r1, GoodScalar r2, Integral r1) => AstTensor AstMethodLet PrimalSpan (TKS sh r1) -> AstTensor AstMethodLet PrimalSpan (TKS sh r2) -astFromIntegralS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ RepN $ tfromIntegralS (unRepN t) +astFromIntegralS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ sfromIntegral t astFromIntegralS (Ast.AstFromIntegralS v) = astFromIntegralS v astFromIntegralS v = Ast.AstFromIntegralS v @@ -2215,8 +2216,8 @@ astRFromS :: forall sh s r. (TensorKind1 r, KnownShS sh) astRFromS (AstConcrete ftk t) | Dict <- lemKnownNatRankS (knownShS @sh) = case ftk of FTKS _ x -> - let u = Nested.stoRanked (unRepN t) - in AstConcrete (FTKR (Nested.rshape u) x) (RepN u) + let u = rfromS t + in AstConcrete (FTKR (rshape u) x) u astRFromS (Ast.AstFromPrimal v) | Dict <- lemKnownNatRankS (knownShS @sh) = Ast.AstFromPrimal $ astRFromS v @@ -2229,8 +2230,8 @@ astRFromX :: forall sh s r. (TensorKind1 r, KnownShX sh) astRFromX (AstConcrete ftk t) | Dict <- lemKnownNatRankX (knownShX @sh) = case ftk of FTKX _ x -> - let u = Nested.mtoRanked (unRepN t) - in AstConcrete (FTKR (Nested.rshape u) x) (RepN u) + let u = rfromX t + in AstConcrete (FTKR (rshape u) x) u astRFromX (Ast.AstFromPrimal v) | Dict <- lemKnownNatRankX (knownShX @sh) = Ast.AstFromPrimal $ astRFromX v @@ -2242,8 +2243,8 @@ astSFromR :: forall sh s r. (TensorKind1 r, KnownShS sh, KnownNat (Rank sh)) -> AstTensor AstMethodLet s (TKS2 sh r) astSFromR (AstConcrete ftk t) = case ftk of FTKR _ x -> - AstConcrete (FTKS knownShS x) $ RepN - $ Nested.rcastToShaped (unRepN t) Nested.knownShS + let u = sfromR t + in AstConcrete (FTKS knownShS x) u astSFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSFromR v astSFromR (Ast.AstRFromS @sh1 v) = case sameShape @sh1 @sh of @@ -2257,8 +2258,8 @@ astSFromX :: forall sh sh' s r. -> AstTensor AstMethodLet s (TKS2 sh r) astSFromX (AstConcrete ftk t) = case ftk of FTKX _ x -> - AstConcrete (FTKS knownShS x) - $ RepN $ Nested.mcastToShaped (unRepN t) Nested.knownShS + let u = sfromX t + in AstConcrete (FTKS knownShS x) u astSFromX (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSFromX v astSFromX (Ast.AstXFromS @sh1 v) = case sameShape @sh1 @sh of @@ -2272,8 +2273,8 @@ astXFromR :: forall sh s r. -> AstTensor AstMethodLet s (TKX2 sh r) astXFromR (AstConcrete ftk t) = case ftk of FTKR _ x -> - let u = Nested.rcastToMixed (knownShX @sh) (unRepN t) - in AstConcrete (FTKX (Nested.mshape u) x) (RepN u) + let u = xfromR t + in AstConcrete (FTKX (xshape u) x) u astXFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromR v astXFromR v = Ast.AstXFromR v @@ -2283,8 +2284,8 @@ astXFromS :: forall sh sh' s r. -> AstTensor AstMethodLet s (TKX2 sh' r) astXFromS (AstConcrete ftk t) = case ftk of FTKS _ x -> - let u = Nested.scastToMixed (knownShX @sh') (unRepN t) - in AstConcrete (FTKX (Nested.mshape u) x) (RepN u) + let u = xfromS t + in AstConcrete (FTKX (xshape u) x) u astXFromS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromS v -- impossible, shapes may differ: astXFromS (Ast.AstSFromX v) = v astXFromS v = Ast.AstXFromS v @@ -2763,7 +2764,7 @@ astFromScalar t = case t of Ast.AstToScalar u -> u Ast.AstCond b a2 a3 -> Ast.AstCond b (astFromScalar a2) (astFromScalar a3) AstConcrete FTKScalar (RepN v) -> - AstConcrete (FTKS ZSS FTKScalar) $ RepN $ Nested.sscalar v + AstConcrete (FTKS ZSS FTKScalar) $ sscalar v AstN1 opCode u -> AstN1S opCode (astFromScalar u) AstN2 opCode u v -> AstN2S opCode (astFromScalar u) (astFromScalar v) -- TODO: Ast.AstR1 opCode u -> Ast.AstR1S opCode (astFromScalar u) diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index e48f5da66..22abe0e40 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -175,7 +175,7 @@ class ( Num (IntOf target) => BaseTensor (target :: Target) where -- Integer codomain - rshape :: TensorKind2 r => target (TKR2 n r) -> IShR n + rshape :: TensorKind1 r => target (TKR2 n r) -> IShR n rrank :: forall r n. (TensorKind2 r, KnownNat n) => target (TKR2 n r) -> Int rrank _ = valueOf @n rsize :: TensorKind2 r => target (TKR2 n r) -> Int