diff --git a/src/HordeAd/Core/OpsConcrete.hs b/src/HordeAd/Core/OpsConcrete.hs index 3352db2a7..5f683ead1 100644 --- a/src/HordeAd/Core/OpsConcrete.hs +++ b/src/HordeAd/Core/OpsConcrete.hs @@ -629,17 +629,28 @@ tdot0R t u = OR.toVector t LA.<.> OR.toVector u -- TODO: try to weave a similar magic as in tindex0R -- TODO: for the non-singleton case see -- https://github.com/Mikolaj/horde-ad/pull/81#discussion_r1096532164 -updateNR :: forall n m a. GoodScalar a - => RepN (TKR (n + m) a) -> [(IxROf RepN n, RepN (TKR m a))] - -> RepN (TKR (n + m) a) -updateNR arr upd = - let values = Nested.rtoVector $ unRepN arr - sh = rshape arr - f !t (ix, u) = - let v = Nested.rtoVector $ unRepN u - i = fromIntegral $ unRepN $ toLinearIdx @n @m fromIntegral sh ix - in V.concat [V.take i t, v, V.drop (i + V.length v) t] - in RepN $ Nested.rfromVector sh (foldl' f values upd) +updateNR :: forall n m a. (KnownNat n, KnownNat m, TensorKind2 a) + => RepN (TKR2 (n + m) a) -> [(IxROf RepN n, RepN (TKR2 m a))] + -> RepN (TKR2 (n + m) a) +updateNR arr upd = case stensorKind @a of + STKScalar{} -> + let values = Nested.rtoVector $ unRepN arr + sh = rshape arr + f !t (ix, u) = + let v = Nested.rtoVector $ unRepN u + i = fromIntegral $ unRepN $ toLinearIdx @n @m fromIntegral sh ix + in V.concat [V.take i t, v, V.drop (i + V.length v) t] + in RepN $ Nested.rfromVector sh (foldl' f values upd) + _ -> + let arrNested = rnest (SNat @n) arr + shNested = rshape arrNested + f i v = case lookup (fromLinearIdx + @n (RepN . fromIntegral) + shNested ((RepN . fromIntegral) i)) upd of + Just u -> rnest (SNat @0) u + Nothing -> v + in runNest $ rfromList0N shNested + $ imap f $ runravelToList $ rflatten arrNested tminIndexR :: forall r r2 n. @@ -853,22 +864,37 @@ tgatherZ1R k t f = case stensorKind @r of -- TODO: try to weave a similar magic as in tindex0R -- TODO: for the non-singleton case see -- https://github.com/Mikolaj/horde-ad/pull/81#discussion_r1096532164 -updateNS :: forall n sh r. (GoodScalar r, KnownShS sh, KnownShS (Drop n sh)) - => RepN (TKS sh r) - -> [(IxSOf RepN (Take n sh), RepN (TKS (Drop n sh) r))] - -> RepN (TKS sh r) -updateNS arr upd = - let values = Nested.stoVector $ unRepN arr - sh = knownShS @sh - f !t (ix, u) = - let v = Nested.stoVector $ unRepN u - i = gcastWith (unsafeCoerce Refl - :: sh :~: Take n sh ++ Drop n sh) - $ fromIntegral $ unRepN - $ ShapedList.toLinearIdx @(Take n sh) @(Drop n sh) - fromIntegral sh ix - in V.concat [V.take i t, v, V.drop (i + V.length v) t] - in RepN $ Nested.sfromVector knownShS (foldl' f values upd) +updateNS :: forall n sh r. + ( TensorKind2 r, KnownShS sh, KnownShS (Drop n sh) + , KnownShS (Take n sh) ) + => RepN (TKS2 sh r) + -> [(IxSOf RepN (Take n sh), RepN (TKS2 (Drop n sh) r))] + -> RepN (TKS2 sh r) +updateNS arr upd = case stensorKind @r of + STKScalar{} -> + let values = Nested.stoVector $ unRepN arr + sh = knownShS @sh + f !t (ix, u) = + let v = Nested.stoVector $ unRepN u + i = gcastWith (unsafeCoerce Refl + :: sh :~: Take n sh ++ Drop n sh) + $ fromIntegral $ unRepN + $ ShapedList.toLinearIdx @(Take n sh) @(Drop n sh) + fromIntegral sh ix + in V.concat [V.take i t, v, V.drop (i + V.length v) t] + in RepN $ Nested.sfromVector knownShS (foldl' f values upd) + _ -> case shsProduct (knownShS @(Take n sh)) of + SNat -> + gcastWith (unsafeCoerce Refl :: sh :~: Take n sh ++ Drop n sh) $ + let arrNested = snest (knownShS @(Take n sh)) arr + shNested = sshape arrNested + f i v = case lookup (ShapedList.fromLinearIdx + @(Take n sh) (RepN . fromIntegral) + shNested ((RepN . fromIntegral) i)) upd of + Just u -> snest (knownShS @'[]) u + Nothing -> v + in sunNest @_ @(Take n sh) $ sfromList0N + $ imap f $ sunravelToList $ sflatten arrNested tminIndexS :: forall n sh r r2.