Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify and fix vectorization of AstReplicate
Browse files Browse the repository at this point in the history
Mikolaj committed Dec 15, 2024
1 parent 1a376fb commit 928a5be
Showing 2 changed files with 4 additions and 37 deletions.
40 changes: 4 additions & 36 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
@@ -237,42 +237,10 @@ build1V snat@SNat (var, v0) =
STKProduct{} -> error "TODO"
STKUntyped -> error "TODO"
_ -> error "TODO"
Ast.AstReplicate @y2 snat2@(SNat @k2) v -> traceRule $
let repl2Stk :: forall z.
STensorKindType z
-> AstTensor AstMethodLet s (BuildTensorKind k z)
-> AstTensor AstMethodLet s (BuildTensorKind k
(BuildTensorKind k2 z))
repl2Stk stk u = case stk of
STKScalar{} -> u
STKR SNat STKScalar{} -> astTr $ astReplicate snat2 u
STKS sh STKScalar{} -> withKnownShS sh $ astTrS $ astReplicate snat2 u
STKX sh STKScalar{} -> withKnownShX sh $ astTrX $ astReplicate snat2 u
STKProduct @z1 @z2 stk1 stk2
| (Dict, Dict) <- lemTensorKind1OfBuild snat stk1
, Dict <- lemTensorKindOfBuild snat2 stk1
, (Dict, Dict) <- lemTensorKind1OfBuild
snat (stensorKind @(BuildTensorKind k2 z1))
, (Dict, Dict) <- lemTensorKind1OfBuild snat stk2
, Dict <- lemTensorKindOfBuild snat2 stk2
, (Dict, Dict) <- lemTensorKind1OfBuild
snat (stensorKind @(BuildTensorKind k2 z2)) ->
astLetFun u $ \ !uShared ->
let (u1, u2) = (astProject1 uShared, astProject2 uShared)
in astPair (repl2Stk stk1 u1) (repl2Stk stk2 u2)
STKUntyped ->
astTrAstHVector
$ fun1DToAst (shapeAstHVector u) $ \ !vars !asts ->
astLetHVectorIn
vars
u
(Ast.AstMkHVector
$ replicate1HVectorF
(\k3 -> withSNat k3 $ \snat3 -> astReplicate snat3)
(astReplicate SNat)
snat2 asts)
_ -> error "TODO"
in repl2Stk (stensorKind @y2) (build1V snat (var, v))
Ast.AstReplicate @y2 snat2@(SNat @k2) v
| Dict <- lemTensorKindOfBuild snat (stensorKind @y2) -> traceRule $
astTrGeneral @k2 (stensorKind @y2) (astReplicate snat2
$ build1V snat (var, v))
Ast.AstBuild1 snat2 (var2, v2) ->
build1VOccurenceUnknown snat (var, build1VOccurenceUnknown snat2 (var2, v2))
-- happens only when testing and mixing different pipelines
1 change: 0 additions & 1 deletion src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
@@ -22,7 +22,6 @@ module HordeAd.Core.TensorClass
import Prelude

import Data.Kind (Constraint, Type)
import Data.List (foldl')
import Data.List.NonEmpty (NonEmpty)
import Data.List.NonEmpty qualified as NonEmpty
import Data.Proxy (Proxy (Proxy))

0 comments on commit 928a5be

Please sign in to comment.