Skip to content

Commit

Permalink
Remove the HordeAd.Internal.BackendOX dependency from AstSimplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 14, 2024
1 parent b5fb2c2 commit a6cac7b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 53 deletions.
105 changes: 53 additions & 52 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ import HordeAd.Core.AstFreshId
import HordeAd.Core.AstTools
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.HVectorOps
import HordeAd.Core.OpsConcrete ()
import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Internal.BackendOX
import HordeAd.Util.ShapedList qualified as ShapedList
import HordeAd.Util.SizedList

Expand Down Expand Up @@ -598,14 +599,14 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) =
Ast.AstCastR t -> astCastR $ astIndexKnobsR knobs t ix
Ast.AstFromIntegralR v -> astFromIntegralR $ astIndexKnobsR knobs v ix
AstConcrete (FTKR _ x) t ->
let unConst :: AstInt AstMethodLet -> Maybe [Int64]
-> Maybe [Int64]
unConst (AstConcrete _ (RepN i)) (Just l) = Just $ i : l
let unConst :: AstInt AstMethodLet -> Maybe [IntOf RepN]
-> Maybe [IntOf RepN]
unConst (AstConcrete _ i) (Just l) = Just $ i : l
unConst _ _ = Nothing
in case foldr unConst (Just []) ix of
Just ixInt ->
let u = tindexZR (unRepN t) $ listToIndex ixInt
in AstConcrete (FTKR (Nested.rshape u) x) $ RepN u
let u = rindex t (fromList ixInt)
in AstConcrete (FTKR (rshape u) x) u
-- TODO: we'd need mapM for Index to keep this rank-typed
Nothing -> Ast.AstIndex v0 ix
Ast.AstProjectR{} -> Ast.AstIndex v0 ix
Expand Down Expand Up @@ -854,14 +855,13 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstCastS t -> astCastS $ astIndexKnobsS knobs t ix
Ast.AstFromIntegralS v -> astFromIntegralS $ astIndexKnobsS knobs v ix
AstConcrete (FTKS _ x) t ->
let unConst :: AstInt AstMethodLet -> Maybe [Int64]
-> Maybe [Int64]
unConst (AstConcrete _ (RepN i)) (Just l) = Just $ i : l
let unConst :: AstInt AstMethodLet -> Maybe [IntOf RepN]
-> Maybe [IntOf RepN]
unConst (AstConcrete _ i) (Just l) = Just $ i : l
unConst _ _ = Nothing
in case foldr unConst (Just []) ix of
Just ixInt -> AstConcrete (FTKS knownShS x)
$ RepN $ tindexZS (unRepN t)
$ ShapedList.listToIndex @shm ixInt
$ sindex @_ @_ @shm t (fromList ixInt)
-- TODO: we'd need mapM for Index to keep this rank-typed
Nothing -> Ast.AstIndexS v0 ix
Ast.AstProjectS{} -> Ast.AstIndexS v0 ix
Expand Down Expand Up @@ -1523,7 +1523,7 @@ astSum t0 = case shapeAst t0 of
Ast.AstSlice i 1 v -> astIndexR v (fromIntegral i :.: ZIR)
Ast.AstReverse v -> astSum v
AstConcrete (FTKR sh FTKScalar) t ->
AstConcrete (FTKR (Nested.Internal.Shape.shrTail sh) FTKScalar) $ RepN $ tsumR $ unRepN t
AstConcrete (FTKR (Nested.Internal.Shape.shrTail sh) FTKScalar) $ rsum t
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astSum v
_ -> Ast.AstSum t0

Expand Down Expand Up @@ -1557,7 +1557,7 @@ astSumS t0 = case sameNat (Proxy @n) (Proxy @0) of
Ast.AstSliceS @i @k v | Just Refl <- sameNat (Proxy @k) (Proxy @1) ->
astIndexS v (valueOf @i :.$ ZIS)
Ast.AstReverseS v -> astSumS v
AstConcrete _ t -> AstConcrete (FTKS knownShS FTKScalar) $ RepN $ Nested.ssumOuter1 $ unRepN t
AstConcrete _ t -> AstConcrete (FTKS knownShS FTKScalar) $ ssum t
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astSumS v
_ -> Ast.AstSumS t0

Expand Down Expand Up @@ -1609,13 +1609,13 @@ astFromVector :: forall s r n. (KnownNat n, TensorKind2 r, AstSpan s)
astFromVector v | V.length v == 1 = astReplicate (SNat @1) (v V.! 0)
astFromVector l | Just Refl <- sameAstSpan @s @PrimalSpan =
let unConst :: AstTensor AstMethodLet PrimalSpan (TKR2 n r)
-> Maybe (FullTensorKind r, Nested.Ranked n (RepORArray r))
unConst (AstConcrete (FTKR _ x) (RepN t)) = Just (x, t)
-> Maybe (FullTensorKind r, RepN (TKR2 n r))
unConst (AstConcrete (FTKR _ x) t) = Just (x, t)
unConst _ = Nothing
in case V.mapM unConst l of
Just l4 | Just ((x, _), _) <- V.uncons l4 ->
let l3 = V.map snd l4
in AstConcrete (FTKR (V.length l :$: Nested.rshape (l3 V.! 0)) x) $ RepN $ tfromVectorR l3
in AstConcrete (FTKR (V.length l :$: rshape (l3 V.! 0)) x) $ rfromVector l3
_ -> Ast.AstFromVector l
astFromVector l | Just Refl <- sameAstSpan @s @FullSpan =
let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKR2 n r)
Expand All @@ -1635,13 +1635,13 @@ astFromVectorS :: forall s r n sh.
astFromVectorS v | V.length v == 1 = astReplicate SNat (v V.! 0)
astFromVectorS l | Just Refl <- sameAstSpan @s @PrimalSpan =
let unConst :: AstTensor AstMethodLet PrimalSpan (TKS2 sh r)
-> Maybe (FullTensorKind r, Nested.Shaped sh (RepORArray r))
unConst (AstConcrete (FTKS _ x) (RepN t)) = Just (x, t)
-> Maybe (FullTensorKind r, RepN (TKS2 sh r))
unConst (AstConcrete (FTKS _ x) t) = Just (x, t)
unConst _ = Nothing
in case V.mapM unConst l of
Just l4 | Just ((x, _), _) <- V.uncons l4 ->
let l3 = V.map snd l4
in AstConcrete (FTKS knownShS x) $ RepN $ tfromVectorS l3
in AstConcrete (FTKS knownShS x) $ sfromVector l3
_ -> Ast.AstFromVectorS l
astFromVectorS l | Just Refl <- sameAstSpan @s @FullSpan =
let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKS2 sh r)
Expand All @@ -1660,12 +1660,13 @@ astFromVectorX :: forall s r n sh.
-> AstTensor AstMethodLet s (TKX (Just n ': sh) r)
astFromVectorX v | V.length v == 1 = astReplicate SNat (v V.! 0)
astFromVectorX l | Just Refl <- sameAstSpan @s @PrimalSpan =
let unConst :: AstTensor AstMethodLet PrimalSpan (TKX sh r) -> Maybe (Nested.Mixed sh r)
unConst (AstConcrete _ (RepN t)) = Just t
let unConst :: AstTensor AstMethodLet PrimalSpan (TKX sh r)
-> Maybe (RepN (TKX sh r))
unConst (AstConcrete _ t) = Just t
unConst _ = Nothing
in case V.mapM unConst l of
Just l3 | V.length l3 >= 1 ->
AstConcrete (FTKX (SKnown (SNat @n) :$% Nested.mshape (l3 V.! 0)) FTKScalar) $ RepN $ tfromVectorX l3
AstConcrete (FTKX (SKnown (SNat @n) :$% xshape (l3 V.! 0)) FTKScalar) $ xfromVector l3
_ -> Ast.AstFromVectorX l
astFromVectorX l | Just Refl <- sameAstSpan @s @FullSpan =
let unFromPrimal :: AstTensor AstMethodLet FullSpan (TKX sh r)
Expand Down Expand Up @@ -1734,15 +1735,15 @@ astReplicateNS v =

astReplicate0N :: forall n s r. (GoodScalar r, AstSpan s)
=> IShR n -> r -> AstTensor AstMethodLet s (TKR n r)
astReplicate0N sh = astReplicate0NT sh . fromPrimal . AstConcrete (FTKR ZSR FTKScalar) . RepN . Nested.rscalar
astReplicate0N sh = astReplicate0NT sh . fromPrimal . AstConcrete (FTKR ZSR FTKScalar) . rscalar

astReplicate0NS :: forall shn s r. (KnownShS shn, GoodScalar r, AstSpan s)
=> r -> AstTensor AstMethodLet s (TKS shn r)
astReplicate0NS =
let go :: ShS sh' -> AstTensor AstMethodLet s (TKS '[] r) -> AstTensor AstMethodLet s (TKS sh' r)
go ZSS v = v
go ((:$$) SNat sh') v | Dict <- sshapeKnown sh' = astReplicate SNat $ go sh' v
in go (knownShS @shn) . fromPrimal . AstConcrete (FTKS ZSS FTKScalar) . RepN . Nested.sscalar
in go (knownShS @shn) . fromPrimal . AstConcrete (FTKS ZSS FTKScalar) . sscalar

astReplicate0NT :: forall n s r. (GoodScalar r, AstSpan s)
=> IShR n -> AstTensor AstMethodLet s (TKR 0 r) -> AstTensor AstMethodLet s (TKR n r)
Expand All @@ -1758,7 +1759,7 @@ astAppend :: (KnownNat n, GoodScalar r, AstSpan s)
-> AstTensor AstMethodLet s (TKR (1 + n) r)
astAppend (AstConcrete (FTKR (ulen :$: sh) FTKScalar) u)
(AstConcrete (FTKR (vlen :$: _) FTKScalar) v) =
AstConcrete (FTKR (ulen + vlen :$: sh) FTKScalar) $ RepN $ tappendR (unRepN u) (unRepN v)
AstConcrete (FTKR (ulen + vlen :$: sh) FTKScalar) $ rappend u v
astAppend (Ast.AstFromPrimal u) (Ast.AstFromPrimal v) =
Ast.AstFromPrimal $ astAppend u v
astAppend (Ast.AstFromVector l1) (Ast.AstFromVector l2) =
Expand All @@ -1769,7 +1770,7 @@ astAppendS :: (KnownNat m, KnownNat n, KnownShS sh, GoodScalar r, AstSpan s)
=> AstTensor AstMethodLet s (TKS (m ': sh) r) -> AstTensor AstMethodLet s (TKS (n ': sh) r)
-> AstTensor AstMethodLet s (TKS ((m + n) ': sh) r)
astAppendS (AstConcrete _ u) (AstConcrete _ v) =
AstConcrete (FTKS knownShS FTKScalar) $ RepN $ tappendS (unRepN u) (unRepN v)
AstConcrete (FTKS knownShS FTKScalar) $ sappend u v
astAppendS (Ast.AstFromPrimal u) (Ast.AstFromPrimal v) =
Ast.AstFromPrimal $ astAppendS u v
astAppendS (Ast.AstFromVectorS l1) (Ast.AstFromVectorS l2) =
Expand All @@ -1780,7 +1781,7 @@ astSlice :: forall k s r. (KnownNat k, GoodScalar r, AstSpan s)
=> Int -> Int -> AstTensor AstMethodLet s (TKR (1 + k) r)
-> AstTensor AstMethodLet s (TKR (1 + k) r)
astSlice i n (AstConcrete (FTKR (_ :$: sh) FTKScalar) t) =
AstConcrete (FTKR (n :$: sh) FTKScalar) $ RepN $ tsliceR i n (unRepN t)
AstConcrete (FTKR (n :$: sh) FTKScalar) $ rslice i n t
astSlice i n (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSlice i n v
astSlice 0 n v | n == lengthAst v = v
astSlice _i n (Ast.AstReplicate @y2 _ v) = case stensorKind @y2 of
Expand Down Expand Up @@ -1817,7 +1818,7 @@ astSliceS :: forall i n k sh s r.
=> AstTensor AstMethodLet s (TKS (i + n + k ': sh) r)
-> AstTensor AstMethodLet s (TKS (n ': sh) r)
astSliceS (AstConcrete _ t) =
AstConcrete (FTKS knownShS FTKScalar) $ RepN $ tsliceS @i @n (unRepN t)
AstConcrete (FTKS knownShS FTKScalar) $ sslice (Proxy @i) (Proxy @n) t
astSliceS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSliceS @i @n v
astSliceS v | Just Refl <- sameNat (Proxy @i) (Proxy @0)
, Just Refl <- sameNat (Proxy @k) (Proxy @0) = v
Expand Down Expand Up @@ -1857,7 +1858,7 @@ astSliceS v = Ast.AstSliceS @i v
astReverse :: forall n s r. (KnownNat n, TensorKind2 r, AstSpan s)
=> AstTensor AstMethodLet s (TKR2 (1 + n) r)
-> AstTensor AstMethodLet s (TKR2 (1 + n) r)
astReverse (AstConcrete ftk t) = AstConcrete ftk $ RepN $ treverseR (unRepN t)
astReverse (AstConcrete ftk t) = AstConcrete ftk $ rreverse t
astReverse (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astReverse v
astReverse (Ast.AstReplicate k v) = astReplicate k v
astReverse (Ast.AstFromVector l) = astFromVector $ V.reverse l
Expand All @@ -1872,7 +1873,7 @@ astReverse v = Ast.AstReverse v
astReverseS :: forall n sh s r. (KnownNat n, KnownShS sh, TensorKind2 r, AstSpan s)
=> AstTensor AstMethodLet s (TKS2 (n ': sh) r)
-> AstTensor AstMethodLet s (TKS2 (n ': sh) r)
astReverseS (AstConcrete ftk t) = AstConcrete ftk $ RepN $ treverseS (unRepN t)
astReverseS (AstConcrete ftk t) = AstConcrete ftk $ sreverse t
astReverseS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astReverseS v
astReverseS (Ast.AstReplicate k v) = astReplicate k v
astReverseS (Ast.AstFromVectorS l) = astFromVectorS $ V.reverse l
Expand Down Expand Up @@ -1918,7 +1919,7 @@ astTranspose perm = \case
-- TODO: should the below be backpermute or permute?
astGatherR (Nested.Internal.Shape.shrPermutePrefix perm sh) v
(Nested.Internal.Shape.listrPermutePrefix perm vars, ix)
AstConcrete (FTKR sh FTKScalar) t -> AstConcrete (FTKR (Nested.Internal.Shape.shrPermutePrefix perm sh) FTKScalar) $ RepN $ ttransposeR perm (unRepN t)
AstConcrete (FTKR sh FTKScalar) t -> AstConcrete (FTKR (Nested.Internal.Shape.shrPermutePrefix perm sh) FTKScalar) $ rtranspose perm t
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astTranspose perm v
u -> Ast.AstTranspose perm u
-- we don't go inside AstSumOfList, because they are usually long
Expand Down Expand Up @@ -2074,7 +2075,7 @@ astReshape shOut = \case
Ast.AstFromVector l | [x] <- V.toList l -> astReshape shOut x
Ast.AstReshape _ v -> astReshape shOut v
AstConcrete (FTKR _ x) t -> AstConcrete (FTKR shOut x)
$ RepN $ Nested.rreshape shOut (unRepN t)
$ rreshape shOut t
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astReshape shOut v
v -> let shIn = shapeAst v
in case sameNat (Proxy @p) (Proxy @m) of
Expand Down Expand Up @@ -2104,23 +2105,23 @@ astReshapeS = \case
astReshapeS $ l V.! 0
Ast.AstReshapeS v -> astReshapeS @_ @sh2 v
AstConcrete (FTKS _ x) t -> AstConcrete (FTKS knownShS x)
$ RepN $ treshapeS (unRepN t)
$ sreshape t
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astReshapeS v
v -> case sameShape @sh @sh2 of
Just Refl -> v
_ -> Ast.AstReshapeS v

astCast :: (GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2)
=> AstTensor ms s (TKScalar r1) -> AstTensor ms s (TKScalar r2)
astCast (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ realToFrac (unRepN t)
astCast (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ kcast t
astCast (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCast v
astCast (Ast.AstCast v) = astCast v
astCast (Ast.AstFromIntegral v) = astFromIntegral v
astCast v = Ast.AstCast v

astCastR :: (KnownNat n, GoodScalar r1, GoodScalar r2, RealFrac r1, RealFrac r2)
=> AstTensor AstMethodLet s (TKR n r1) -> AstTensor AstMethodLet s (TKR n r2)
astCastR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tcastR (unRepN t)
astCastR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ rcast t
astCastR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCastR v
astCastR (Ast.AstCastR v) = astCastR v
astCastR (Ast.AstFromIntegralR v) = astFromIntegralR v
Expand All @@ -2129,7 +2130,7 @@ astCastR v = Ast.AstCastR v
astCastS :: ( KnownShS sh, GoodScalar r1, GoodScalar r2, RealFrac r1
, RealFrac r2 )
=> AstTensor AstMethodLet s (TKS sh r1) -> AstTensor AstMethodLet s (TKS sh r2)
astCastS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ RepN $ tcastS (unRepN t)
astCastS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ scast t
astCastS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astCastS v
astCastS (Ast.AstCastS v) = astCastS v
astCastS (Ast.AstFromIntegralS v) = astFromIntegralS v
Expand All @@ -2138,21 +2139,21 @@ astCastS v = Ast.AstCastS v
astFromIntegral :: (GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor ms PrimalSpan (TKScalar r1)
-> AstTensor ms PrimalSpan (TKScalar r2)
astFromIntegral (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ RepN $ fromIntegral (unRepN t)
astFromIntegral (AstConcrete FTKScalar t) = AstConcrete FTKScalar $ kfromIntegral t
astFromIntegral (Ast.AstFromIntegral v) = astFromIntegral v
astFromIntegral v = Ast.AstFromIntegral v

astFromIntegralR :: (KnownNat n, GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor AstMethodLet PrimalSpan (TKR n r1)
-> AstTensor AstMethodLet PrimalSpan (TKR n r2)
astFromIntegralR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ RepN $ tfromIntegralR (unRepN t)
astFromIntegralR (AstConcrete (FTKR sh FTKScalar) t) = AstConcrete (FTKR sh FTKScalar) $ rfromIntegral t
astFromIntegralR (Ast.AstFromIntegralR v) = astFromIntegralR v
astFromIntegralR v = Ast.AstFromIntegralR v

astFromIntegralS :: (KnownShS sh, GoodScalar r1, GoodScalar r2, Integral r1)
=> AstTensor AstMethodLet PrimalSpan (TKS sh r1)
-> AstTensor AstMethodLet PrimalSpan (TKS sh r2)
astFromIntegralS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ RepN $ tfromIntegralS (unRepN t)
astFromIntegralS (AstConcrete (FTKS sh FTKScalar) t) = AstConcrete (FTKS sh FTKScalar) $ sfromIntegral t
astFromIntegralS (Ast.AstFromIntegralS v) = astFromIntegralS v
astFromIntegralS v = Ast.AstFromIntegralS v

Expand Down Expand Up @@ -2215,8 +2216,8 @@ astRFromS :: forall sh s r. (TensorKind1 r, KnownShS sh)
astRFromS (AstConcrete ftk t)
| Dict <- lemKnownNatRankS (knownShS @sh) = case ftk of
FTKS _ x ->
let u = Nested.stoRanked (unRepN t)
in AstConcrete (FTKR (Nested.rshape u) x) (RepN u)
let u = rfromS t
in AstConcrete (FTKR (rshape u) x) u
astRFromS (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankS (knownShS @sh) =
Ast.AstFromPrimal $ astRFromS v
Expand All @@ -2229,8 +2230,8 @@ astRFromX :: forall sh s r. (TensorKind1 r, KnownShX sh)
astRFromX (AstConcrete ftk t)
| Dict <- lemKnownNatRankX (knownShX @sh) = case ftk of
FTKX _ x ->
let u = Nested.mtoRanked (unRepN t)
in AstConcrete (FTKR (Nested.rshape u) x) (RepN u)
let u = rfromX t
in AstConcrete (FTKR (rshape u) x) u
astRFromX (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankX (knownShX @sh) =
Ast.AstFromPrimal $ astRFromX v
Expand All @@ -2242,8 +2243,8 @@ astSFromR :: forall sh s r. (TensorKind1 r, KnownShS sh, KnownNat (Rank sh))
-> AstTensor AstMethodLet s (TKS2 sh r)
astSFromR (AstConcrete ftk t) = case ftk of
FTKR _ x ->
AstConcrete (FTKS knownShS x) $ RepN
$ Nested.rcastToShaped (unRepN t) Nested.knownShS
let u = sfromR t
in AstConcrete (FTKS knownShS x) u
astSFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSFromR v
astSFromR (Ast.AstRFromS @sh1 v) =
case sameShape @sh1 @sh of
Expand All @@ -2257,8 +2258,8 @@ astSFromX :: forall sh sh' s r.
-> AstTensor AstMethodLet s (TKS2 sh r)
astSFromX (AstConcrete ftk t) = case ftk of
FTKX _ x ->
AstConcrete (FTKS knownShS x)
$ RepN $ Nested.mcastToShaped (unRepN t) Nested.knownShS
let u = sfromX t
in AstConcrete (FTKS knownShS x) u
astSFromX (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astSFromX v
astSFromX (Ast.AstXFromS @sh1 v) =
case sameShape @sh1 @sh of
Expand All @@ -2272,8 +2273,8 @@ astXFromR :: forall sh s r.
-> AstTensor AstMethodLet s (TKX2 sh r)
astXFromR (AstConcrete ftk t) = case ftk of
FTKR _ x ->
let u = Nested.rcastToMixed (knownShX @sh) (unRepN t)
in AstConcrete (FTKX (Nested.mshape u) x) (RepN u)
let u = xfromR t
in AstConcrete (FTKX (xshape u) x) u
astXFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromR v
astXFromR v = Ast.AstXFromR v

Expand All @@ -2283,8 +2284,8 @@ astXFromS :: forall sh sh' s r.
-> AstTensor AstMethodLet s (TKX2 sh' r)
astXFromS (AstConcrete ftk t) = case ftk of
FTKS _ x ->
let u = Nested.scastToMixed (knownShX @sh') (unRepN t)
in AstConcrete (FTKX (Nested.mshape u) x) (RepN u)
let u = xfromS t
in AstConcrete (FTKX (xshape u) x) u
astXFromS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromS v
-- impossible, shapes may differ: astXFromS (Ast.AstSFromX v) = v
astXFromS v = Ast.AstXFromS v
Expand Down Expand Up @@ -2763,7 +2764,7 @@ astFromScalar t = case t of
Ast.AstToScalar u -> u
Ast.AstCond b a2 a3 -> Ast.AstCond b (astFromScalar a2) (astFromScalar a3)
AstConcrete FTKScalar (RepN v) ->
AstConcrete (FTKS ZSS FTKScalar) $ RepN $ Nested.sscalar v
AstConcrete (FTKS ZSS FTKScalar) $ sscalar v
AstN1 opCode u -> AstN1S opCode (astFromScalar u)
AstN2 opCode u v -> AstN2S opCode (astFromScalar u) (astFromScalar v)
-- TODO: Ast.AstR1 opCode u -> Ast.AstR1S opCode (astFromScalar u)
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class ( Num (IntOf target)
=> BaseTensor (target :: Target) where

-- Integer codomain
rshape :: TensorKind2 r => target (TKR2 n r) -> IShR n
rshape :: TensorKind1 r => target (TKR2 n r) -> IShR n
rrank :: forall r n. (TensorKind2 r, KnownNat n) => target (TKR2 n r) -> Int
rrank _ = valueOf @n
rsize :: TensorKind2 r => target (TKR2 n r) -> Int
Expand Down

0 comments on commit a6cac7b

Please sign in to comment.