Skip to content

Commit

Permalink
Generalize scatters and gathers to nested arrays
Browse files Browse the repository at this point in the history
Mikolaj committed Dec 17, 2024
1 parent 4f70ec2 commit 5d8c838
Showing 7 changed files with 153 additions and 136 deletions.
24 changes: 12 additions & 12 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
@@ -358,10 +358,10 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSum :: (KnownNat n, GoodScalar r)
=> AstTensor ms s (TKR (1 + n) r) -> AstTensor ms s (TKR n r)
AstScatter :: forall m n p r s ms.
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r)
(KnownNat m, KnownNat n, KnownNat p, TensorKind2 r)
=> IShR (p + n)
-> AstTensor ms s (TKR (m + n) r) -> (AstVarList m, AstIxR ms p)
-> AstTensor ms s (TKR (p + n) r)
-> AstTensor ms s (TKR2 (m + n) r) -> (AstVarList m, AstIxR ms p)
-> AstTensor ms s (TKR2 (p + n) r)
AstFromVector :: (KnownNat n, TensorKind2 r)
=> Data.Vector.Vector (AstTensor ms s (TKR2 n r))
-> AstTensor ms s (TKR2 (1 + n) r)
@@ -379,10 +379,10 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstReshape :: (KnownNat n, KnownNat m, TensorKind2 r)
=> IShR m -> AstTensor ms s (TKR2 n r) -> AstTensor ms s (TKR2 m r)
AstGather :: forall m n p r s ms.
(KnownNat m, KnownNat n, KnownNat p, GoodScalar r)
(KnownNat m, KnownNat n, KnownNat p, TensorKind2 r)
=> IShR (m + n)
-> AstTensor ms s (TKR (p + n) r) -> (AstVarList m, AstIxR ms p)
-> AstTensor ms s (TKR (m + n) r)
-> AstTensor ms s (TKR2 (p + n) r) -> (AstVarList m, AstIxR ms p)
-> AstTensor ms s (TKR2 (m + n) r)
-- out of bounds indexing is permitted
AstProjectR :: (GoodScalar r, KnownNat n)
=> AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKR n r)
@@ -452,10 +452,10 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstScatterS :: forall sh2 p sh r s ms.
( KnownShS sh2, KnownShS sh, KnownNat p
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh), GoodScalar r )
=> AstTensor ms s (TKS (sh2 ++ Drop p sh) r)
, KnownShS (sh2 ++ Drop p sh), TensorKind2 r )
=> AstTensor ms s (TKS2 (sh2 ++ Drop p sh) r)
-> (AstVarListS sh2, AstIxS ms (Take p sh))
-> AstTensor ms s (TKS sh r)
-> AstTensor ms s (TKS2 sh r)

AstFromVectorS :: (KnownNat n, KnownShS sh, TensorKind2 r)
=> Data.Vector.Vector (AstTensor ms s (TKS2 sh r))
@@ -480,12 +480,12 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
-- beware that the order of type arguments is different than in orthotope
-- and than the order of value arguments in the ranked version
AstGatherS :: forall sh2 p sh r s ms.
( GoodScalar r, KnownShS sh, KnownShS sh2, KnownNat p
( TensorKind2 r, KnownShS sh, KnownShS sh2, KnownNat p
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh) )
=> AstTensor ms s (TKS sh r)
=> AstTensor ms s (TKS2 sh r)
-> (AstVarListS sh2, AstIxS ms (Take p sh))
-> AstTensor ms s (TKS (sh2 ++ Drop p sh) r)
-> AstTensor ms s (TKS2 (sh2 ++ Drop p sh) r)
-- out of bounds indexing is permitted
AstProjectS :: (GoodScalar r, KnownShS sh)
=> AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKS sh r)
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
@@ -790,7 +790,7 @@ interpretAst !env = \case
$ gcastWith (unsafeCoerce Refl :: sh2 :~: sh2 ++ Drop p sh)
-- transitivity of type equality doesn't work, by design,
-- so this direct cast is needed instead of more basic laws
$ sbuild @target @(TKScalar r) @(Rank sh2)
$ sbuild @target @r @(Rank sh2)
(interpretLambdaIndexS
interpretAst env
(vars, fromPrimal @s $ AstFromIntegralS $ AstFromScalar i))
127 changes: 68 additions & 59 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
@@ -462,10 +462,10 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) =
else astIndexKnobsR knobs v2 ix2
astGather
:: forall m' n' p' r'.
(GoodScalar r', KnownNat m', KnownNat p', KnownNat n')
=> IShR (m' + n') -> AstTensor AstMethodLet s (TKR (p' + n') r')
(TensorKind2 r', KnownNat m', KnownNat p', KnownNat n')
=> IShR (m' + n') -> AstTensor AstMethodLet s (TKR2 (p' + n') r')
-> (AstVarList m', AstIxR AstMethodLet p')
-> AstTensor AstMethodLet s (TKR (m' + n') r')
-> AstTensor AstMethodLet s (TKR2 (m' + n') r')
astGather sh2 v2 (vars2, ix2) =
if knobStepOnly knobs
then astGatherKnobsR knobs
@@ -541,7 +541,8 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) =
astIndex (astScatter sh v (vars, ix2)) rest1
Ast.AstScatter @_ @n7 (_ :$: sh)
v (vars, AstConcrete _ i5 :.: (ix2 :: AstIxR AstMethodLet p71))
| AstConcrete _ i6 <- i1 ->
| AstConcrete _ i6 <- i1
, STKScalar{} <- stensorKind @r ->
gcastWith (unsafeCoerce Refl :: m1 + n :~: p71 + n7) $
if i6 == i5
then astIndex (astScatter sh v (vars, ix2)) rest1
@@ -661,12 +662,12 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
else astIndexKnobsS knobs v2 ix2
astGather
:: forall shm' shn' p' r'.
( GoodScalar r', KnownShS shm', KnownShS shn', KnownNat p'
( TensorKind2 r', KnownShS shm', KnownShS shn', KnownNat p'
, KnownShS (Take p' shm'), KnownShS (Drop p' shm')
, KnownShS (shn' ++ Drop p' shm') )
=> AstTensor AstMethodLet s (TKS shm' r')
=> AstTensor AstMethodLet s (TKS2 shm' r')
-> (AstVarListS shn', AstIxS AstMethodLet (Take p' shm'))
-> AstTensor AstMethodLet s (TKS (shn' ++ Drop p' shm') r')
-> AstTensor AstMethodLet s (TKS2 (shn' ++ Drop p' shm') r')
astGather v2 (vars2, ix2) =
if knobStepOnly knobs
then astGatherKnobsS knobs
@@ -900,39 +901,39 @@ shareIx ix f = unsafePerformIO $ do

astGatherR
:: forall m n p s r.
(KnownNat m, KnownNat p, KnownNat n, GoodScalar r, AstSpan s)
=> IShR (m + n) -> AstTensor AstMethodLet s (TKR (p + n) r) -> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR (m + n) r)
(KnownNat m, KnownNat p, KnownNat n, TensorKind2 r, AstSpan s)
=> IShR (m + n) -> AstTensor AstMethodLet s (TKR2 (p + n) r) -> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR2 (m + n) r)
astGatherR = astGatherKnobsR defaultKnobs

astGatherS
:: forall sh2 p sh s r.
( GoodScalar r, KnownShS sh, KnownShS sh2, KnownNat p
( TensorKind2 r, KnownShS sh, KnownShS sh2, KnownNat p
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh) )
=> AstTensor AstMethodLet s (TKS sh r)
=> AstTensor AstMethodLet s (TKS2 sh r)
-> (AstVarListS sh2, AstIxS AstMethodLet (Take p sh))
-> AstTensor AstMethodLet s (TKS (sh2 ++ Drop p sh) r)
-> AstTensor AstMethodLet s (TKS2 (sh2 ++ Drop p sh) r)
astGatherS = astGatherKnobsS defaultKnobs

astGatherStep
:: forall m n p s r.
(KnownNat m, KnownNat p, KnownNat n, GoodScalar r, AstSpan s)
=> IShR (m + n) -> AstTensor AstMethodLet s (TKR (p + n) r) -> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR (m + n) r)
(KnownNat m, KnownNat p, KnownNat n, TensorKind2 r, AstSpan s)
=> IShR (m + n) -> AstTensor AstMethodLet s (TKR2 (p + n) r) -> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR2 (m + n) r)
astGatherStep sh v (vars, ix) =
astGatherKnobsR (defaultKnobs {knobStepOnly = True})
sh (astNonIndexStep v)
(vars, simplifyAstIxR ix)

astGatherStepS
:: forall sh2 p sh s r.
( KnownShS sh, KnownShS sh2, KnownNat p, GoodScalar r, AstSpan s
( KnownShS sh, KnownShS sh2, KnownNat p, TensorKind2 r, AstSpan s
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh) )
=> AstTensor AstMethodLet s (TKS sh r)
=> AstTensor AstMethodLet s (TKS2 sh r)
-> (AstVarListS sh2, AstIxS AstMethodLet (Take p sh))
-> AstTensor AstMethodLet s (TKS (sh2 ++ Drop p sh) r)
-> AstTensor AstMethodLet s (TKS2 (sh2 ++ Drop p sh) r)
-- TODO: this probably needs an extra condition similar to kN == vkN below
--astGatherStepS v (AstVarName varId ::$ ZSS, AstIntVarS varId2 :.$ ZIS)
-- | varId == varId2 = ...
@@ -950,10 +951,10 @@ astGatherStepS v (vars, ix) =
-- either from full recursive simplification or from astGatherStep.
astGatherKnobsR
:: forall m n p s r.
(KnownNat m, KnownNat p, KnownNat n, GoodScalar r, AstSpan s)
=> SimplifyKnobs -> IShR (m + n) -> AstTensor AstMethodLet s (TKR (p + n) r)
(KnownNat m, KnownNat p, KnownNat n, TensorKind2 r, AstSpan s)
=> SimplifyKnobs -> IShR (m + n) -> AstTensor AstMethodLet s (TKR2 (p + n) r)
-> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR (m + n) r)
-> AstTensor AstMethodLet s (TKR2 (m + n) r)
astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
case (sh0, (vars0, ix0)) of
_ | any (`varNameInAst` v0) vars0 ->
@@ -992,19 +993,19 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
error "astGather: impossible pattern needlessly required"
where
astIndex :: forall m' n' s'. (KnownNat m', KnownNat n', AstSpan s')
=> AstTensor AstMethodLet s' (TKR (m' + n') r) -> AstIxR AstMethodLet m'
-> AstTensor AstMethodLet s' (TKR n' r)
=> AstTensor AstMethodLet s' (TKR2 (m' + n') r) -> AstIxR AstMethodLet m'
-> AstTensor AstMethodLet s' (TKR2 n' r)
astIndex v2 ix2 = if knobStepOnly knobs
then astIndexKnobsR knobs
(astNonIndexStep v2)
(simplifyAstIxR ix2)
else astIndexKnobsR knobs v2 ix2
astGatherRec, astGather
:: forall m' n' p' s' r'.
(KnownNat m', KnownNat p', KnownNat n', AstSpan s', GoodScalar r')
=> IShR (m' + n') -> AstTensor AstMethodLet s' (TKR (p' + n') r')
(KnownNat m', KnownNat p', KnownNat n', AstSpan s', TensorKind2 r')
=> IShR (m' + n') -> AstTensor AstMethodLet s' (TKR2 (p' + n') r')
-> (AstVarList m', AstIxR AstMethodLet p')
-> AstTensor AstMethodLet s' (TKR (m' + n') r')
-> AstTensor AstMethodLet s' (TKR2 (m' + n') r')
astGatherRec sh2 v2 (vars2, ix2) =
if knobStepOnly knobs
then Ast.AstGather sh2 v2 (vars2, ix2)
@@ -1019,10 +1020,10 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
-- and so we don't have to reduce it to expose any top redexes.
astGatherCase
:: forall m' n' p' r'.
(KnownNat m', KnownNat p', KnownNat n', GoodScalar r')
=> IShR (m' + n') -> AstTensor AstMethodLet s (TKR (p' + n') r')
(KnownNat m', KnownNat p', KnownNat n', TensorKind2 r')
=> IShR (m' + n') -> AstTensor AstMethodLet s (TKR2 (p' + n') r')
-> (AstVarList m', AstIxR AstMethodLet p')
-> AstTensor AstMethodLet s (TKR (m' + n') r')
-> AstTensor AstMethodLet s (TKR2 (m' + n') r')
astGatherCase sh4 v4 (_, ZIR) = astReplicateN sh4 v4 -- not really possible
astGatherCase sh4 v4 ( vars4
, ix4@(i4 :.: (rest4 :: AstIxR AstMethodLet p1')) ) = case v4 of
@@ -1049,7 +1050,8 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
(astGatherRec sh4 u' (varsFresh, ix5))
Ast.AstCond b v w -> astCond b (astGather sh4 v (vars4, ix4))
(astGather sh4 w (vars4, ix4))
Ast.AstReplicate @y2 snat v | AstConcrete _ (RepN it) <- i4 -> case stensorKind @y2 of
Ast.AstReplicate @y2 snat v | AstConcrete _ (RepN it) <- i4
, STKScalar{} <- stensorKind @r' -> case stensorKind @y2 of
STKR{} ->
let i = fromIntegral it
in if 0 <= i && i < sNatValue snat
@@ -1117,14 +1119,16 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
(vars4, rest4)
Ast.AstScatter @_ @n7 (_ :$: sh)
v (vars, AstConcrete _ i5 :.: (ix2 :: AstIxR AstMethodLet p71))
| AstConcrete _ i6 <- i4 ->
| AstConcrete _ i6 <- i4
, STKScalar{} <- stensorKind @r' ->
gcastWith (unsafeCoerce Refl :: p1' + n' :~: p71 + n7) $
if i6 == i5
then astGather sh4 (astScatter sh v (vars, ix2)) (vars4, rest4)
else astReplicate0N sh4 0
Ast.AstScatter{} -> -- normal form
Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstFromVector l | AstConcrete _ (RepN it) <- i4 ->
Ast.AstFromVector l | AstConcrete _ (RepN it) <- i4
, STKScalar{} <- stensorKind @r' ->
let i = fromIntegral it
in if 0 <= i && i < length l
then astGather sh4 (l V.! i) (vars4, rest4)
@@ -1172,11 +1176,11 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
in astGather sh4 v (vars4, iRev :.: rest4)
Ast.AstTranspose perm v | valueOf @p' >= length perm ->
astGather sh4 v (vars4, Nested.Internal.Shape.ixrPermutePrefix (permInverse perm) ix4)
Ast.AstTranspose perm v ->
if knobExpand knobs
then astGather sh4 (astTransposeAsGather knobs perm v) (vars4, ix4)
else Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstReshape sh v -> case STKScalar (typeRep @r) of -- TODO
Ast.AstTranspose perm v ->case stensorKind @r' of
STKScalar{} | knobExpand knobs ->
astGather sh4 (astTransposeAsGather knobs perm v) (vars4, ix4)
_ -> Ast.AstGather sh4 v4 (vars4, ix4)
Ast.AstReshape sh v -> case stensorKind @r' of
STKScalar{} | knobExpand knobs ->
astGather sh4 (astReshapeAsGather knobs sh v) (vars4, ix4)
_ -> Ast.AstGather sh4 v4 (vars4, ix4)
@@ -1193,13 +1197,13 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
simplifyAstInt -- we generate the index, so we simplify on the spot
$ foldr (uncurry astLetInt) i
(zipSized vars (indexToSized ix))
composedGather :: p' <= m2 => AstTensor AstMethodLet s (TKR (m' + n') r')
composedGather :: p' <= m2 => AstTensor AstMethodLet s (TKR2 (m' + n') r')
composedGather =
let (vars2p, vars22) = splitAt_Sized @p' @(m2 - p') vars2
ix22 = fmap (substLet ix4 vars2p) ix2
in gcastWith (unsafeCoerce Refl :: m2 + n2 - p' :~: n')
$ astGather sh4 v2 (appendSized vars4 vars22, ix22)
assimilatedGather :: m2 <= p' => AstTensor AstMethodLet s (TKR (m' + n') r')
assimilatedGather :: m2 <= p' => AstTensor AstMethodLet s (TKR2 (m' + n') r')
assimilatedGather =
let (ix42, ix44) = splitAt_Index @m2 @(p' - m2) ix4
ix22 = fmap (substLet ix42 vars2) ix2
@@ -1231,6 +1235,7 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
( ShapedList.listToSized $ sizedToList vars4
, ShapedList.listToSized $ indexToList ix4 ) -}
Ast.AstRFromX{} -> error "TODO"
Ast.AstZipR _v -> error "TODO"

Ast.AstApply{} -> Ast.AstGather sh4 v4 (vars4, ix4)

@@ -1262,12 +1267,12 @@ isVar _ = False

astGatherKnobsS
:: forall sh2 p sh s r.
( GoodScalar r, KnownShS sh, KnownShS sh2, KnownNat p
( TensorKind2 r, KnownShS sh, KnownShS sh2, KnownNat p
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh) )
=> SimplifyKnobs -> AstTensor AstMethodLet s (TKS sh r)
=> SimplifyKnobs -> AstTensor AstMethodLet s (TKS2 sh r)
-> (AstVarListS sh2, AstIxS AstMethodLet (Take p sh))
-> AstTensor AstMethodLet s (TKS (sh2 ++ Drop p sh) r)
-> AstTensor AstMethodLet s (TKS2 (sh2 ++ Drop p sh) r)
astGatherKnobsS _ v (vars, ix) = Ast.AstGatherS v (vars, ix) -- TODO
{- TODO: is this beneficial?
AstGatherS @sh2 @p @sh @r AstIotaS (vars, i :.$ ZIS) ->
@@ -1563,17 +1568,19 @@ astSumS t0 = case sameNat (Proxy @n) (Proxy @0) of

-- TODO: fuse scatters, scatter and sum, perhaps more (fromList?)
astScatter :: forall m n p s r.
(GoodScalar r, KnownNat m, KnownNat n, KnownNat p, AstSpan s)
=> IShR (p + n) -> AstTensor AstMethodLet s (TKR (m + n) r)
(TensorKind2 r, KnownNat m, KnownNat n, KnownNat p, AstSpan s)
=> IShR (p + n) -> AstTensor AstMethodLet s (TKR2 (m + n) r)
-> (AstVarList m, AstIxR AstMethodLet p)
-> AstTensor AstMethodLet s (TKR (p + n) r)
-> AstTensor AstMethodLet s (TKR2 (p + n) r)
astScatter _sh v (ZR, ZIR) = v
astScatter sh@(k :$: _) _v (_vars, AstConcrete _ (RepN it) :.: _ix)
| let i = fromIntegral it
, not (0 <= i && i < k) =
, not (0 <= i && i < k)
, STKScalar{} <- stensorKind @r =
astReplicate0N sh 0
-- else update (rzero sh 0) [AstConcreteS it] (astScatter ...)
astScatter sh v (var ::: vars, ix) | not $ varNameToAstVarId var `varInIndex` ix =
astScatter sh v (var ::: vars, ix) | not $ varNameToAstVarId var `varInIndex` ix
, STKScalar{} <- stensorKind @r =
astScatter sh (astSum v) (vars, ix)
-- astScatter sh v (ZR, ix) = update (rzero sh 0) ix v
astScatter sh (Ast.AstFromPrimal v) (vars, ix) =
@@ -1583,17 +1590,18 @@ astScatter sh v (vars, ix) = Ast.AstScatter sh v (vars, ix)
astScatterS :: forall sh2 p sh s r.
( KnownShS sh2, KnownShS sh, KnownNat p
, KnownShS (Take p sh), KnownShS (Drop p sh)
, KnownShS (sh2 ++ Drop p sh), GoodScalar r, AstSpan s )
=> AstTensor AstMethodLet s (TKS (sh2 ++ Drop p sh) r)
, KnownShS (sh2 ++ Drop p sh), TensorKind2 r, AstSpan s )
=> AstTensor AstMethodLet s (TKS2 (sh2 ++ Drop p sh) r)
-> (AstVarListS sh2, AstIxS AstMethodLet (Take p sh))
-> AstTensor AstMethodLet s (TKS sh r)
-> AstTensor AstMethodLet s (TKS2 sh r)
astScatterS v (ZS, ZIS) =
gcastWith (unsafeCoerce Refl
:: Take p sh ++ Drop p sh :~: sh)
v
astScatterS v (Const var ::$ (vars :: AstVarListS sh3), ix)
| not $ varNameToAstVarId var `varInIndexS` ix
, Dict <- slistKnown vars =
, Dict <- slistKnown vars
, STKScalar{} <- stensorKind @r =
withShapeP (shapeT @sh3
++ (shapeT @(Drop p sh))) $ \(Proxy @sh4) ->
gcastWith (unsafeCoerce Refl :: sh3 ++ Drop p sh :~: sh4) $
@@ -1711,21 +1719,22 @@ astReplicate snat@SNat
v -> Ast.AstReplicate snat v

astReplicateN :: forall n p s r.
(KnownNat n, KnownNat p, GoodScalar r, AstSpan s)
=> IShR (n + p) -> AstTensor AstMethodLet s (TKR p r)
-> AstTensor AstMethodLet s (TKR (n + p) r)
(KnownNat n, KnownNat p, TensorKind2 r, AstSpan s)
=> IShR (n + p) -> AstTensor AstMethodLet s (TKR2 p r)
-> AstTensor AstMethodLet s (TKR2 (n + p) r)
astReplicateN sh v =
let go :: IShR n' -> AstTensor AstMethodLet s (TKR (n' + p) r)
let go :: IShR n' -> AstTensor AstMethodLet s (TKR2 (n' + p) r)
go ZSR = v
go (k :$: sh2) | Dict <- knownShR sh2 = withSNat k $ \snat ->
astReplicate snat $ go sh2
in go (takeShape sh)

astReplicateNS :: forall shn shp s r.
(KnownShS shn, KnownShS shp, GoodScalar r, AstSpan s)
=> AstTensor AstMethodLet s (TKS shp r) -> AstTensor AstMethodLet s (TKS (shn ++ shp) r)
(KnownShS shn, KnownShS shp, TensorKind2 r, AstSpan s)
=> AstTensor AstMethodLet s (TKS2 shp r)
-> AstTensor AstMethodLet s (TKS2 (shn ++ shp) r)
astReplicateNS v =
let go :: ShS shn' -> AstTensor AstMethodLet s (TKS (shn' ++ shp) r)
let go :: ShS shn' -> AstTensor AstMethodLet s (TKS2 (shn' ++ shp) r)
go ZSS = v
go ((:$$) @k @shn2 SNat shn2) | Dict <- sshapeKnown shn2 =
withShapeP (shapeT @shn2 ++ shapeT @shp) $ \(Proxy @sh) ->
Loading

0 comments on commit 5d8c838

Please sign in to comment.