Skip to content

Commit

Permalink
Generalize rfold and friends to nested arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 19, 2024
1 parent 7cfcff8 commit 07b3e41
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 66 deletions.
106 changes: 58 additions & 48 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1196,56 +1196,58 @@ class ( Num (IntOf target)
-- | A strict left fold.
rfold
:: forall rn rm n m.
(GoodScalar rn, GoodScalar rm, KnownNat n, KnownNat m)
=> (forall f. ADReady f => f (TKR n rn) -> f (TKR m rm) -> f (TKR n rn))
-> target (TKR n rn) -- ^ initial value
-> target (TKR (1 + m) rm) -- ^ iteration is over the outermost dimension
-> target (TKR n rn)
(TensorKind rn, TensorKind rm, KnownNat n, KnownNat m)
=> (forall f. ADReady f => f (TKR2 n rn) -> f (TKR2 m rm) -> f (TKR2 n rn))
-> target (TKR2 n rn) -- ^ initial value
-> target (TKR2 (1 + m) rm) -- ^ iteration is over the outermost dimension
-> target (TKR2 n rn)
rfold f acc0 es =
let shm :: IShR m
(width, shm) = case rshape es of
width2 :$: shm2 -> (width2, shm2)
ZSR -> error "rfold: impossible pattern needlessly required"
sh = rshape acc0
(width, shm, xm) = case tftk stensorKind es of
FTKR (width2 :$: shm2) x2 -> (width2, shm2, x2)
FTKR ZSR _ -> error "rfold: impossible pattern needlessly required"
(sh, x) = case tftk stensorKind acc0 of
FTKR sh2 x2 -> (sh2, x2)
in withSNat width $ \snat ->
tproject1
(dmapAccumL (Proxy @target)
snat
(FTKR @_ sh (FTKScalar @rn))
(FTKR @_ sh x)
(FTKScalar @Z0)
(FTKR @_ shm (FTKScalar @rm))
(FTKR @_ shm xm)
(let g :: forall f. ADReady f
=> f (TKR n rn) -> f (TKR m rm)
-> f (TKProduct (TKR n rn) TKUnit)
=> f (TKR2 n rn) -> f (TKR2 m rm)
-> f (TKProduct (TKR2 n rn) TKUnit)
g !acc !e = tpair (f acc e) tunit
in g)
acc0
es)
-- | A strict left scan.
rscan
:: forall rn rm n m.
(GoodScalar rn, GoodScalar rm, KnownNat n, KnownNat m)
=> (forall f. ADReady f => f (TKR n rn) -> f (TKR m rm) -> f (TKR n rn))
-> target (TKR n rn)
-> target (TKR (1 + m) rm)
-> target (TKR (1 + n) rn)
(TensorKind rn, TensorKind rm, KnownNat n, KnownNat m)
=> (forall f. ADReady f => f (TKR2 n rn) -> f (TKR2 m rm) -> f (TKR2 n rn))
-> target (TKR2 n rn)
-> target (TKR2 (1 + m) rm)
-> target (TKR2 (1 + n) rn)
rscan f acc0 es =
let shm :: IShR m
(width, shm) = case rshape es of
width2 :$: shm2 -> (width2, shm2)
ZSR -> error "rscan: impossible pattern needlessly required"
sh = rshape acc0
(width, shm, xm) = case tftk stensorKind es of
FTKR (width2 :$: shm2) x2 -> (width2, shm2, x2)
FTKR ZSR _ -> error "rfold: impossible pattern needlessly required"
(sh, x) = case tftk stensorKind acc0 of
FTKR sh2 x2 -> (sh2, x2)
in withSNat width $ \snat ->
let bs =
tproject2
$ dmapAccumL (Proxy @target)
snat
(FTKR @_ sh (FTKScalar @rn))
(FTKR @_ sh (FTKScalar @rn))
(FTKR @_ shm (FTKScalar @rm))
(FTKR @_ sh x)
(FTKR @_ sh x)
(FTKR @_ shm xm)
(let g :: forall f. ADReady f
=> f (TKR n rn) -> f (TKR m rm)
-> f (TKProduct (TKR n rn) (TKR n rn))
=> f (TKR2 n rn) -> f (TKR2 m rm)
-> f (TKProduct (TKR2 n rn) (TKR2 n rn))
g !acc !e = tlet (f acc e) $ \ !res -> tpair res res
in g)
acc0
Expand All @@ -1254,43 +1256,51 @@ class ( Num (IntOf target)
-- | A strict left fold.
sfold
:: forall rn rm sh shm k.
(GoodScalar rn, GoodScalar rm, KnownShS sh, KnownShS shm, KnownNat k)
=> (forall f. ADReady f => f (TKS sh rn) -> f (TKS shm rm) -> f (TKS sh rn))
-> target (TKS sh rn)
-> target (TKS (k ': shm) rm)
-> target (TKS sh rn)
(TensorKind rn, TensorKind rm, KnownShS sh, KnownShS shm, KnownNat k)
=> (forall f. ADReady f => f (TKS2 sh rn) -> f (TKS2 shm rm) -> f (TKS2 sh rn))
-> target (TKS2 sh rn)
-> target (TKS2 (k ': shm) rm)
-> target (TKS2 sh rn)
sfold f acc0 es =
tproject1
let xm = case tftk stensorKind es of
FTKS _ x2 -> x2
x = case tftk stensorKind acc0 of
FTKS _ x2 -> x2
in tproject1
(dmapAccumL (Proxy @target)
(SNat @k)
(FTKS @sh knownShS (FTKScalar @rn))
(FTKS @sh knownShS x)
(FTKScalar @Z0)
(FTKS @shm knownShS (FTKScalar @rm))
(FTKS @shm knownShS xm)
(let g :: forall f. ADReady f
=> f (TKS sh rn) -> f (TKS shm rm)
-> f (TKProduct (TKS sh rn) TKUnit)
=> f (TKS2 sh rn) -> f (TKS2 shm rm)
-> f (TKProduct (TKS2 sh rn) TKUnit)
g !acc !e = tpair (f acc e) tunit
in g)
acc0
es)
sscan
:: forall rn rm sh shm k.
(GoodScalar rn, GoodScalar rm, KnownShS sh, KnownShS shm, KnownNat k)
=> (forall f. ADReady f => f (TKS sh rn) -> f (TKS shm rm) -> f (TKS sh rn))
-> target (TKS sh rn)
-> target (TKS (k ': shm) rm)
-> target (TKS (1 + k ': sh) rn)
(TensorKind rn, TensorKind rm, KnownShS sh, KnownShS shm, KnownNat k)
=> (forall f. ADReady f => f (TKS2 sh rn) -> f (TKS2 shm rm) -> f (TKS2 sh rn))
-> target (TKS2 sh rn)
-> target (TKS2 (k ': shm) rm)
-> target (TKS2 (1 + k ': sh) rn)
sscan f acc0 es =
let bs =
let xm = case tftk stensorKind es of
FTKS _ x2 -> x2
x = case tftk stensorKind acc0 of
FTKS _ x2 -> x2
bs =
tproject2
$ dmapAccumL (Proxy @target)
(SNat @k)
(FTKS @sh knownShS (FTKScalar @rn))
(FTKS @sh knownShS (FTKScalar @rn))
(FTKS @shm knownShS (FTKScalar @rm))
(FTKS @sh knownShS x)
(FTKS @sh knownShS x)
(FTKS @shm knownShS xm)
(let g :: forall f. ADReady f
=> f (TKS sh rn) -> f (TKS shm rm)
-> f (TKProduct (TKS sh rn) (TKS sh rn))
=> f (TKS2 sh rn) -> f (TKS2 shm rm)
-> f (TKProduct (TKS2 sh rn) (TKS2 sh rn))
g !acc !e = tlet (f acc e) $ \ !res -> tpair res res
in g)
acc0
Expand Down
4 changes: 2 additions & 2 deletions test/simplified/TestConvSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ testCNNOPP2 :: Assertion
testCNNOPP2 = do
resetVarCounter
printAstPretty IM.empty maxPool2dUnpadded2
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [let x26 = 0 + 1 in rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * 0 in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [i26, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. i26) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w38 = rtranspose [1,2,3,0] (rreplicate 1 (rgather [1,1,1,2,2] (rfromVector (fromList [let x26 = 1 + stoScalar (sscalar 0) in rtranspose [1,2,0] (rreplicate 1 (let x27 = 2 * stoScalar (sscalar 0) + stoScalar (sscalar 0) in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x28 = 2 * stoScalar (sscalar 0) + stoScalar (sscalar 0) in rreplicate 1 (rreplicate 2 (rfromVector (fromList [tconcrete (FTKR [1,1,2,2] FTKScalar) (rfromListLinear [1,1,2,2] [1.0,1.0,1.0,1.0]) ! [i26, 0, i27, i28], rscalar 0.0]) ! [ifF (1 >. i26) 0 1]))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 2 (rreplicate 2 (rscalar 0.0)))))])) (\\[i46, i40, i36] -> [ifF (1 >. 1 + i36) 0 1, i46, i40, i36]))) in rgather [1,1] w38 (\\[i45, i39] -> [i45, i39, 0, 0, 0, 0])))"

maxPool2dUnpadded2
:: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
Expand Down Expand Up @@ -717,7 +717,7 @@ testCNNOPP4 = do
afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
afcnn2T = maxPool2dUnpadded4 $ conv2dUnpadded4 blackGlyph
printAstPretty IM.empty afcnn2T
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [let x21 = 0 + 1 in rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * 0 in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * 0 in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. i21 &&* 1 >. i21) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
@?= "rreplicate 1 (rreplicate 1 (let w36 = rgather [1,1,1,1,2,2] (rfromVector (fromList [let x21 = 1 + stoScalar (sscalar 0) in rtranspose [2,3,0,1] (rreplicate 1 (rreplicate 1 (let x20 = 2 * stoScalar (sscalar 0) + stoScalar (sscalar 0) in rtranspose [0,2,1] (rreplicate 1 (rreplicate 2 (let x12 = 2 * stoScalar (sscalar 0) + stoScalar (sscalar 0) in rreplicate 1 (rreplicate 2 (tconcrete (FTKR [2] FTKScalar) (rfromListLinear [2] [7.0,0.0]) ! [ifF ((0 <=. i21 &&* 1 >. i21) &&* ((0 <=. i20 &&* 2 >. i20) &&* (0 <=. i12 &&* 2 >. i12))) 0 1])))))))), rreplicate 1 (rreplicate 1 (rreplicate 1 (rgather [1,2,2] (rreplicate 2 (rreplicate 2 (rscalar 0.0))) (\\[i29, i26, i22] -> [i26, i22]))))])) (\\[i44, i38, i33, i30, i31, i32] -> [ifF ((0 <=. 1 + i33 &&* 1 >. 1 + i33) &&* ((0 <=. 1 + i30 &&* 1 >. 1 + i30) &&* ((0 <=. 2 * i44 + i31 &&* 2 >. 2 * i44 + i31) &&* (0 <=. 2 * i38 + i32 &&* 2 >. 2 * i38 + i32)))) 0 1, i44, i38, i33, i30, i31, i32]) in rgather [1,1] w36 (\\[i43, i37] -> [i43, i37, 0, 0, 0, 0])))"
printAstPretty IM.empty (simplifyInline afcnn2T)
@?= "rreplicate 1 (rreplicate 1 (rreplicate 1 (rreplicate 1 (rscalar 0.0))))"

Expand Down
Loading

0 comments on commit 07b3e41

Please sign in to comment.