Skip to content

Commit

Permalink
Move the manual vectorization hack to where it's necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Mar 2, 2025
1 parent 8cdbe84 commit 24447e0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 63 deletions.
44 changes: 16 additions & 28 deletions src/HordeAd/Core/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ import Data.Array.Mixed.Lemmas
import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape
( IShX
, fromSMayNat
, fromSMayNat'
, shxAppend
, shxDropSSX
Expand Down Expand Up @@ -511,19 +510,14 @@ class ( Num (IntOf target)
=> target (TKR 2 r) -> target (TKR 1 r) -> target (TKR 1 r)
-- How to generalize (#69)? The few straightforward generalizations
-- differ in types but all are far from matmul2.
-- rmatvecmul m v = rbuild1 (rlength m) (\i -> rdot0 v (m ! [i]))
-- rmatvecmul m v = rflatten $ rmap1 (rreplicate 1 . rdot0 v) m
rmatvecmul m v = rsum (rtr (rreplicate (rlength m) v * m))
rmatvecmul m v = rbuild1 (rlength m) (\i -> rdot0 v (m ! [i]))
rmatmul2 :: (GoodScalar r, Numeric r)
=> target (TKR 2 r) -> target (TKR 2 r) -> target (TKR 2 r)
-- How to generalize to tmatmul (#69)?
-- Just rmatmul2 the two outermost dimensions?
-- rmatmul2 m1 m2 = rmap1 (rmatvecmul (rtr m2)) m1
-- rmatmul2 m1 m2 = rbuild1 (rlength m1) (\i -> rmatvecmul (rtr m2) (m1 ! [i]))
rmatmul2 m1 m2 = case rshape m2 of
_ :$: width2 :$: ZSR ->
rsum (rtranspose [2,1,0] (rreplicate width2 m1)
* rtranspose [1,0] (rreplicate (rlength m1) m2))
rmatmul2 m1 m2 = rbuild1 (rlength m1) (\i -> rmatvecmul (rtr m2) (m1 ! [i]))
rreplicate :: (KnownSTK r, KnownNat n)
=> Int -> target (TKR2 n r) -> target (TKR2 (1 + n) r)
rreplicate0N :: (KnownSTK r, KnownNat n)
Expand Down Expand Up @@ -788,14 +782,13 @@ class ( Num (IntOf target)
smatvecmul :: forall r m n. (GoodScalar r, KnownNat m, KnownNat n)
=> target (TKS '[m, n] r) -> target (TKS '[n] r)
-> target (TKS '[m] r)
smatvecmul m v = ssum (str (sreplicate @_ @m v * m))
smatvecmul m v = sbuild1 @_ @m (\i -> sdot0 v (m `sindex` (i :.$ ZIS)))
smatmul2 :: forall r n m p.
(GoodScalar r, Numeric r, KnownNat n, KnownNat m, KnownNat p)
=> target (TKS '[m, n] r) -> target (TKS '[n, p] r)
-> target (TKS '[m, p] r)
smatmul2 m1 m2 =
ssum (stranspose @_ @'[2, 1, 0] (sreplicate @target @p m1)
* stranspose @_ @'[1, 0] (sreplicate @target @m m2))
sbuild1 @_ @m (\i -> smatvecmul (str m2) (m1 `sindex` (i :.$ ZIS)))
sreplicate :: (KnownNat k, KnownShS sh, KnownSTK r)
=> target (TKS2 sh r) -> target (TKS2 (k ': sh) r)
sreplicate = tsreplicate knownShS
Expand Down Expand Up @@ -1157,30 +1150,25 @@ class ( Num (IntOf target)
-> target (TKX '[Just m] r) -- TODO: generalize
xdot1In t u = xsum $ xtr (t * u)
xmatvecmul :: forall r mm mn. (GoodScalar r, ConvertTensor target)
=> Nested.SMayNat Int SNat mm -> Nested.SMayNat () SNat mn
=> Nested.SMayNat Int SNat mm -> Nested.SMayNat Int SNat mn
-> target (TKX '[mm, mn] r) -> target (TKX '[mn] r)
-> target (TKX '[mm] r)
-- This variant is not vectorized, so will be slow without vectorization.
xmatvecmul mm mn u v =
let mu :: Nested.SMayNat () SNat mm
mu = fromSMayNat (const $ Nested.SUnknown ()) Nested.SKnown mm
in withKnownShX (mu :!% ZKX) $
withKnownShX (mu :!% mn :!% ZKX) $
withKnownShX (mn :!% ZKX) $
withSNat (fromSMayNat' mm) $ \(SNat @n) ->
xmcast (mu :!% ZKX)
$ xbuild1 @_ @n @'[] (\i -> xdot0 v (u `xindex` (i :.% ZIX)))
-- TODO: when we switch to singletons, generalize this to non-Just types
-- or split into ranked-style and shaped-style variants or provide
-- convenient ways to lift ranked and shaped operations into mixed.
xmatvecmul mm mn m v =
withKnownShX (ssxFromShape $ mm :$% ZSX) $
withKnownShX (ssxFromShape $ mn :$% ZSX) $
withSNat (fromSMayNat' mm) $ \(SNat @k) ->
xmcast (ssxFromShape $ mm :$% ZSX)
$ xbuild1 @_ @k (\i -> xdot0 v (m `xindex` (i :.% ZIX)))
xmatmul2 :: forall r n m p.
(GoodScalar r, Numeric r, KnownNat n, KnownNat m, KnownNat p)
( GoodScalar r, ConvertTensor target
, Numeric r, KnownNat n, KnownNat m, KnownNat p )
=> target (TKX '[Just m, Just n] r)
-> target (TKX '[Just n, Just p] r)
-> target (TKX '[Just m, Just p] r)
xmatmul2 m1 m2 =
xsum (xtranspose @_ @'[2, 1, 0] (xreplicate @target @p m1)
* xtranspose @_ @'[1, 0] (xreplicate @target @m m2))
xbuild1 @_ @m (\i ->
xmatvecmul (Nested.SKnown (SNat @p)) (Nested.SKnown (SNat @n))
(xtr m2) (m1 `xindex` (i :.% ZIX)))
xreplicate :: (KnownNat k, KnownShX sh, KnownSTK r)
=> target (TKX2 sh r) -> target (TKX2 (Just k ': sh) r)
xreplicate0N :: (KnownSTK r, KnownShX sh)
Expand Down
32 changes: 32 additions & 0 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Array.Nested
, IxS (..)
, IxX (..)
, StaticShX(..)
, ShR (..)
, ShX (..)
, ShS (..)
, KnownShS (..)
Expand All @@ -36,6 +37,7 @@ import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Internal.Shape (shsInit, withKnownShS)
import Data.Array.Mixed.Types (unsafeCoerceRefl)
import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape (withKnownShX, ssxFromShape, fromSMayNat')

import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersConcrete
Expand Down Expand Up @@ -163,6 +165,13 @@ instance ( ADReadyNoLet target, ShareTensor target
let !u = tshare ue in
let !v = tshare ve
in dD (rdot0 u v) (dAdd (DeltaDot0R v u') (DeltaDot0R u v'))
-- These two are manually vectorized to avoid delta blowup when run
-- via primitive pipelines.
rmatvecmul m v = rsum (rtr (rreplicate (rlength m) v * m))
rmatmul2 m1 m2 = case rshape m2 of
_ :$: width2 :$: ZSR ->
rsum (rtranspose [2,1,0] (rreplicate width2 m1)
* rtranspose [1,0] (rreplicate (rlength m1) m2))
rreplicate k (D u u') = withSNat k $ \snat ->
dD (rreplicate k u) (DeltaReplicate snat knownSTK u')
-- TODO: speed up by using tindex0R and dDeltaIndex0 if the codomain has rank 0
Expand Down Expand Up @@ -226,6 +235,12 @@ instance ( ADReadyNoLet target, ShareTensor target
let !u = tshare ue in
let !v = tshare ve
in dD (sdot0 u v) (dAdd (DeltaDot0S v u') (DeltaDot0S u v'))
-- These two are manually vectorized to avoid delta blowup when run
-- via primitive pipelines.
smatvecmul m v = ssum (str (sreplicate v * m))
smatmul2 m1 m2 =
ssum (stranspose @_ @'[2, 1, 0] (sreplicate m1)
* stranspose @_ @'[1, 0] (sreplicate m2))
sindex (D u u') i =
let ix = tprimalPart <$> i
in dD (sindex u ix) (DeltaIndexS knownShS u' ix)
Expand Down Expand Up @@ -280,6 +295,23 @@ instance ( ADReadyNoLet target, ShareTensor target
let !u = tshare ue in
let !v = tshare ve
in dD (xdot0 u v) (dAdd (DeltaDot0X v u') (DeltaDot0X u v'))
-- These two are manually vectorized to avoid delta blowup when run
-- via primitive pipelines.
xmatvecmul mm mn m v =
withKnownShX (ssxFromShape $ mn :$% ZSX) $
withKnownShX (ssxFromShape $ mm :$% mn :$% ZSX) $
withSNat (fromSMayNat' mm) $ \(SNat @m) ->
withSNat (fromSMayNat' mn) $ \(SNat @n) ->
xmcast (ssxFromShape (mm :$% ZSX))
$ xsum (xtr (xreplicate @_ @m
(xmcast (ssxFromShape (Nested.SKnown (SNat @n)
:$% ZSX)) v)
* xmcast (ssxFromShape (Nested.SKnown (SNat @m)
:$% Nested.SKnown (SNat @n)
:$% ZSX)) m))
xmatmul2 m1 m2 =
xsum (xtranspose @_ @'[2, 1, 0] (xreplicate m1)
* xtranspose @_ @'[1, 0] (xreplicate m2))
xreplicate (D u u') = dD (xreplicate u) (DeltaReplicate SNat knownSTK u')
xindex (D u u') i =
let ix = tprimalPart <$> i
Expand Down
16 changes: 8 additions & 8 deletions test/simplified/TestAdaptorSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1153,11 +1153,11 @@ testMatvecmulPP = do
True (uncurry rmatvecmul)
(FTKProduct (FTKR [2, 3] FTKScalar) (FTKR [3] FTKScalar))
printArtifactPretty @_ @(TKR 1 Double) renames artifactRev
@?= "\\v2 m1 -> tpair (rfromS (str (str (sfromR (rreplicate 2 (tproject2 m1))) * sreplicate @_ @3 (sfromR v2))), rsum (rfromS (str (str (sfromR (tproject1 m1)) * sreplicate @_ @3 (sfromR v2)))))"
@?= "\\v3 m1 -> tpair (rfromS (str (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * sreplicate @_ @3 (sfromR v3))), rfromS (ssum @_ @2 (str (str (sfromR (tproject1 m1)) * sreplicate @_ @3 (sfromR v3)))))"
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> rfromS (ssum @_ @3 (str (sfromR (rreplicate 2 (tproject2 m1))) * str (sfromR (tproject1 m1))))"
@?= "\\m1 -> rfromS (ssum @_ @3 (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * str (sfromR (tproject1 m1))))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\v2 m1 -> tfromS (tpair (sreplicate @_ @2 (sfromR (tproject2 m1)) * str (sreplicate @_ @3 (sfromR v2)), smatvecmul (str (sfromR (tproject1 m1))) (sfromR v2)))"
@?= "\\v3 m1 -> tfromS (tpair (sreplicate @_ @2 (sfromR (tproject2 m1)) * str (sreplicate @_ @3 (sfromR v3)), smatvecmul (str (sfromR (tproject1 m1))) (sfromR v3)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\m1 -> rfromS (smatvecmul (sfromR (tproject1 m1)) (sfromR (tproject2 m1)))"

Expand All @@ -1170,11 +1170,11 @@ testMatmul2PP = do
True (uncurry rmatmul2)
(FTKProduct (FTKR [2, 3] FTKScalar) (FTKR [3, 4] FTKScalar))
printArtifactPretty @_ @(TKR 2 Double) renames artifactRev
@?= "\\m2 m1 -> tpair (rfromS (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * sreplicate @_ @3 (sfromR m2)))), rfromS (ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * sreplicate @_ @3 (sfromR m2)))))"
@?= "\\m7 m1 -> tpair (rfromS (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * sreplicate @_ @3 (sfromR m7)))), rfromS (ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * sreplicate @_ @3 (sfromR m7)))))"
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> rfromS (ssum @_ @3 (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * str (sreplicate @_ @2 (sfromR (tproject2 m1)))))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m2 m1 -> tfromS (tpair (smatmul2 (sfromR m2) (str (sfromR (tproject2 m1))), smatmul2 (str (sfromR (tproject1 m1))) (sfromR m2)))"
@?= "\\m7 m1 -> tfromS (tpair (smatmul2 (sfromR m7) (str (sfromR (tproject2 m1))), smatmul2 (str (sfromR (tproject1 m1))) (sfromR m7)))"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\m1 -> rfromS (smatmul2 (sfromR (tproject1 m1)) (sfromR (tproject2 m1)))"

Expand All @@ -1191,7 +1191,7 @@ testMatmul2FromMatvecmulPP = do
True (uncurry rmatmul2F)
(FTKProduct (FTKR [2, 3] FTKScalar) (FTKR [3, 4] FTKScalar))
printArtifactPretty @_ @(TKR 2 Double) renames artifactRev
@?= "\\m3 m1 -> tpair (rfromS (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * sreplicate @_ @3 (sfromR m3)))), rfromS (ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * sreplicate @_ @3 (sfromR m3)))))"
@?= "\\m7 m1 -> tpair (rfromS (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (sfromR (tproject2 m1))) * sreplicate @_ @3 (sfromR m7)))), rfromS (ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * sreplicate @_ @3 (sfromR m7)))))"
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> rfromS (ssum @_ @3 (stranspose @_ @[2,1,0] (sreplicate @_ @4 (sfromR (tproject1 m1))) * str (sreplicate @_ @2 (sfromR (tproject2 m1)))))"

Expand Down Expand Up @@ -1225,11 +1225,11 @@ testMatmul2PPS = do
True (uncurry smatmul2)
(FTKProduct (FTKS (SNat @2 :$$ SNat @3 :$$ ZSS) (FTKScalar @Float)) (FTKS (SNat @3 :$$ SNat @4 :$$ ZSS) (FTKScalar @Float)))
printArtifactPretty renames artifactRev
@?= "\\m2 m1 -> tpair (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (tproject2 m1)) * sreplicate @_ @3 m2)), ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (tproject1 m1)) * sreplicate @_ @3 m2)))"
@?= "\\m7 m1 -> tpair (ssum @_ @4 (stranspose @_ @[2,1,0] (str (sreplicate @_ @2 (tproject2 m1)) * sreplicate @_ @3 m7)), ssum @_ @2 (str (stranspose @_ @[2,1,0] (sreplicate @_ @4 (tproject1 m1)) * sreplicate @_ @3 m7)))"
printArtifactPrimalPretty renames artifactRev
@?= "\\m1 -> ssum @_ @3 (stranspose @_ @[2,1,0] (sreplicate @_ @4 (tproject1 m1)) * str (sreplicate @_ @2 (tproject2 m1)))"
printArtifactPretty renames (simplifyArtifact artifactRev)
@?= "\\m2 m1 -> tpair (smatmul2 m2 (str (tproject2 m1)), smatmul2 (str (tproject1 m1)) m2)"
@?= "\\m7 m1 -> tpair (smatmul2 m7 (str (tproject2 m1)), smatmul2 (str (tproject1 m1)) m7)"
printArtifactPrimalPretty renames (simplifyArtifact artifactRev)
@?= "\\m1 -> smatmul2 (tproject1 m1) (tproject2 m1)"

Expand Down
8 changes: 4 additions & 4 deletions test/simplified/TestGatherSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ testGatherSimpPP33 = do
resetVarCounter
let !t1 = gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 992
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 791
length (show t1) @?= 1574
length (show (simplifyInlineContract @(TKR 2 Float) t1)) @?= 1334
resetVarCounter
let !t2 = (\t -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @_ @10 [8, 16] t))
Expand All @@ -469,8 +469,8 @@ testGatherSimpPP34 = do
let !t1 = (\t -> rbuild1 4 (\i ->
gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan) (t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (mkAstVarName (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) . intToAstVarId $ 100000000)
length (show t1) @?= 1546
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 1546
length (show t1) @?= 2163
length (show (simplifyInlineContract @(TKR 3 Float) t1)) @?= 2163
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
(\t' -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
Expand Down
Loading

0 comments on commit 24447e0

Please sign in to comment.