Skip to content

Commit

Permalink
Generalize updateNR and updateNS to nested arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 16, 2024
1 parent 0ea03f6 commit e1115a9
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -629,17 +629,28 @@ tdot0R t u = OR.toVector t LA.<.> OR.toVector u
-- TODO: try to weave a similar magic as in tindex0R
-- TODO: for the non-singleton case see
-- https://github.com/Mikolaj/horde-ad/pull/81#discussion_r1096532164
updateNR :: forall n m a. GoodScalar a
=> RepN (TKR (n + m) a) -> [(IxROf RepN n, RepN (TKR m a))]
-> RepN (TKR (n + m) a)
updateNR arr upd =
let values = Nested.rtoVector $ unRepN arr
sh = rshape arr
f !t (ix, u) =
let v = Nested.rtoVector $ unRepN u
i = fromIntegral $ unRepN $ toLinearIdx @n @m fromIntegral sh ix
in V.concat [V.take i t, v, V.drop (i + V.length v) t]
in RepN $ Nested.rfromVector sh (foldl' f values upd)
updateNR :: forall n m a. (KnownNat n, KnownNat m, TensorKind2 a)
=> RepN (TKR2 (n + m) a) -> [(IxROf RepN n, RepN (TKR2 m a))]
-> RepN (TKR2 (n + m) a)
updateNR arr upd = case stensorKind @a of
STKScalar{} ->
let values = Nested.rtoVector $ unRepN arr
sh = rshape arr
f !t (ix, u) =
let v = Nested.rtoVector $ unRepN u
i = fromIntegral $ unRepN $ toLinearIdx @n @m fromIntegral sh ix
in V.concat [V.take i t, v, V.drop (i + V.length v) t]
in RepN $ Nested.rfromVector sh (foldl' f values upd)
_ ->
let arrNested = rnest (SNat @n) arr
shNested = rshape arrNested
f i v = case lookup (fromLinearIdx
@n (RepN . fromIntegral)
shNested ((RepN . fromIntegral) i)) upd of
Just u -> rnest (SNat @0) u
Nothing -> v
in runNest $ rfromList0N shNested
$ imap f $ runravelToList $ rflatten arrNested

tminIndexR
:: forall r r2 n.
Expand Down Expand Up @@ -853,22 +864,37 @@ tgatherZ1R k t f = case stensorKind @r of
-- TODO: try to weave a similar magic as in tindex0R
-- TODO: for the non-singleton case see
-- https://github.com/Mikolaj/horde-ad/pull/81#discussion_r1096532164
updateNS :: forall n sh r. (GoodScalar r, KnownShS sh, KnownShS (Drop n sh))
=> RepN (TKS sh r)
-> [(IxSOf RepN (Take n sh), RepN (TKS (Drop n sh) r))]
-> RepN (TKS sh r)
updateNS arr upd =
let values = Nested.stoVector $ unRepN arr
sh = knownShS @sh
f !t (ix, u) =
let v = Nested.stoVector $ unRepN u
i = gcastWith (unsafeCoerce Refl
:: sh :~: Take n sh ++ Drop n sh)
$ fromIntegral $ unRepN
$ ShapedList.toLinearIdx @(Take n sh) @(Drop n sh)
fromIntegral sh ix
in V.concat [V.take i t, v, V.drop (i + V.length v) t]
in RepN $ Nested.sfromVector knownShS (foldl' f values upd)
updateNS :: forall n sh r.
( TensorKind2 r, KnownShS sh, KnownShS (Drop n sh)
, KnownShS (Take n sh) )
=> RepN (TKS2 sh r)
-> [(IxSOf RepN (Take n sh), RepN (TKS2 (Drop n sh) r))]
-> RepN (TKS2 sh r)
updateNS arr upd = case stensorKind @r of
STKScalar{} ->
let values = Nested.stoVector $ unRepN arr
sh = knownShS @sh
f !t (ix, u) =
let v = Nested.stoVector $ unRepN u
i = gcastWith (unsafeCoerce Refl
:: sh :~: Take n sh ++ Drop n sh)
$ fromIntegral $ unRepN
$ ShapedList.toLinearIdx @(Take n sh) @(Drop n sh)
fromIntegral sh ix
in V.concat [V.take i t, v, V.drop (i + V.length v) t]
in RepN $ Nested.sfromVector knownShS (foldl' f values upd)
_ -> case shsProduct (knownShS @(Take n sh)) of
SNat ->
gcastWith (unsafeCoerce Refl :: sh :~: Take n sh ++ Drop n sh) $
let arrNested = snest (knownShS @(Take n sh)) arr
shNested = sshape arrNested
f i v = case lookup (ShapedList.fromLinearIdx
@(Take n sh) (RepN . fromIntegral)
shNested ((RepN . fromIntegral) i)) upd of
Just u -> snest (knownShS @'[]) u
Nothing -> v
in sunNest @_ @(Take n sh) $ sfromList0N
$ imap f $ sunravelToList $ sflatten arrNested

tminIndexS
:: forall n sh r r2.
Expand Down

0 comments on commit e1115a9

Please sign in to comment.