Skip to content

Commit

Permalink
Fix HFunOf(AST) to match the span used in AstHFun inside AstMapAccumR
Browse files Browse the repository at this point in the history
Mikolaj committed Feb 18, 2024
1 parent e7cd98f commit 2232cbb
Showing 5 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@ sameAstSpan = case eqTypeRep (typeRep @s1) (typeRep @s2) of
type instance RankedOf (AstRanked s) = AstRanked s
type instance ShapedOf (AstRanked s) = AstShaped s
type instance HVectorOf (AstRanked s) = AstHVector s
type instance HFunOf (AstRanked s) = AstHFun s
type instance HFunOf (AstRanked s) = AstHFun PrimalSpan
type instance PrimalOf (AstRanked s) = AstRanked PrimalSpan
type instance DualOf (AstRanked s) = AstRanked DualSpan

@@ -996,7 +996,7 @@ maxF u v = ifF (u >=. v) u v
type instance RankedOf (AstNoVectorize s) = AstNoVectorize s
type instance ShapedOf (AstNoVectorize s) = AstNoVectorizeS s
type instance HVectorOf (AstNoVectorize s) = AstHVector s
type instance HFunOf (AstNoVectorize s) = AstHFun s
type instance HFunOf (AstNoVectorize s) = AstHFun PrimalSpan
type instance PrimalOf (AstNoVectorize s) = AstRanked PrimalSpan
type instance DualOf (AstNoVectorize s) = AstRanked DualSpan
type instance RankedOf (AstNoVectorizeS s) = AstNoVectorize s
@@ -1006,7 +1006,7 @@ type instance DualOf (AstNoVectorizeS s) = AstShaped DualSpan
type instance RankedOf (AstNoSimplify s) = AstNoSimplify s
type instance ShapedOf (AstNoSimplify s) = AstNoSimplifyS s
type instance HVectorOf (AstNoSimplify s) = AstHVector s
type instance HFunOf (AstNoSimplify s) = AstHFun s
type instance HFunOf (AstNoSimplify s) = AstHFun PrimalSpan
type instance PrimalOf (AstNoSimplify s) = AstRanked PrimalSpan
type instance DualOf (AstNoSimplify s) = AstRanked DualSpan
type instance RankedOf (AstNoSimplifyS s) = AstNoSimplify s
8 changes: 3 additions & 5 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
@@ -3065,18 +3065,16 @@ substitute1AstHVector i var = \case
Ast.AstHApply t ll ->
case ( substitute1AstHFun i var t
, map (V.map (substitute1AstDynamic i var)) ll ) of
(Nothing, llm) | all (V.all isNothing) llm -> Nothing
(Nothing, mll) | all (V.all isNothing) mll -> Nothing
(mt, mll) ->
Just $ astHApply (fromMaybe t mt) (zipWith (V.zipWith fromMaybe) ll mll)
Ast.AstLetHVectorInHVector vars2 u v ->
case ( substitute1AstHVector i var u
, substitute1AstHVector i var v ) of
case (substitute1AstHVector i var u, substitute1AstHVector i var v) of
(Nothing, Nothing) -> Nothing
(mu, mv) ->
Just $ astLetHVectorInHVector vars2 (fromMaybe u mu) (fromMaybe v mv)
Ast.AstLetHFunInHVector var2 f v ->
case ( substitute1AstHFun i var f
, substitute1AstHVector i var v ) of
case (substitute1AstHFun i var f, substitute1AstHVector i var v) of
(Nothing, Nothing) -> Nothing
(mf, mv) ->
Just $ astLetHFunInHVector var2 (fromMaybe f mf) (fromMaybe v mv)
3 changes: 3 additions & 0 deletions src/HordeAd/Core/AstVectorize.hs
Original file line number Diff line number Diff line change
@@ -972,6 +972,9 @@ build1VHFun k (var, v0) = case v0 of
Ast.AstHFun vvars l -> withSNat k $ \(SNat @k) ->
-- This handles the case of l having free variable beyond vvars,
-- which is not possible for lambdas used in folds, etc.
-- But note that due to substProjVarsHVector l2 has var occurences,
-- so build1VOccurenceUnknownHVectorRefresh is neccessary to handle
-- them and to eliminate them so that the function is closed again.
let f acc vars = substProjVarsHVector @k var vars acc
(l2, vvars2) = mapAccumR f l vvars
in Ast.AstHFun vvars2 (build1VOccurenceUnknownHVectorRefresh k (var, l2))
16 changes: 7 additions & 9 deletions src/HordeAd/Core/TensorAst.hs
Original file line number Diff line number Diff line change
@@ -468,8 +468,7 @@ astLetHVectorInFun a0 a f =
fun1DToAst a0 $ \ !vars !asts -> astLetHVectorIn vars a (f asts)

astLetHFunInFun
:: forall n s r. AstSpan s
=> AstHFun s -> (AstHFun s -> AstRanked s r n)
:: AstHFun PrimalSpan -> (AstHFun PrimalSpan -> AstRanked s r n)
-> AstRanked s r n
{-# INLINE astLetHFunInFun #-}
astLetHFunInFun a f =
@@ -607,8 +606,7 @@ astLetHVectorInFunS a0 a f =
fun1DToAst a0 $ \ !vars !asts -> astLetHVectorInS vars a (f asts)

astLetHFunInFunS
:: forall sh s r. AstSpan s
=> AstHFun s -> (AstHFun s -> AstShaped s r sh)
:: AstHFun PrimalSpan -> (AstHFun PrimalSpan -> AstShaped s r sh)
-> AstShaped s r sh
{-# INLINE astLetHFunInFunS #-}
astLetHFunInFunS a f =
@@ -660,11 +658,12 @@ astBuild1VectorizeS f =

-- * HVectorTensor instance

instance AstSpan s => HVectorTensor (AstRanked s) (AstShaped s) where
instance forall s. AstSpan s => HVectorTensor (AstRanked s) (AstShaped s) where
dshape = shapeAstHVector
dmkHVector = AstHVector
dlambda shss f = fun1LToAst shss $ \ !vvars !ll -> AstHFun vvars (unHFun f ll)
dHApply = AstHApply
dHApply f ll | Just Refl <- sameAstSpan @s @PrimalSpan = AstHApply f ll
dHApply _ _ = error "dHApply: wrong span"
dunHVector shs hVectorOf =
let f :: Int -> DynamicTensor VoidTensor -> AstDynamic s
f i = \case
@@ -701,7 +700,7 @@ instance AstSpan s => HVectorTensor (AstRanked s) (AstShaped s) where
dsharePrimal _ _ _ = error "dsharePrimal: wrong span"
dregister !domsOD !r !l =
fun1DToAst domsOD $ \ !vars !asts -> case vars of
[] -> error "dregister: empty hVector"
[] -> (l, V.empty)
!var : _ -> -- vars are fresh, so var uniquely represent vars
((dynamicVarNameToAstVarId var, AstBindingsHVector vars r) : l, asts)
dbuild1 = astBuildHVector1Vectorize
@@ -1089,8 +1088,7 @@ astLetHVectorInHVectorFun a0 a f =
fun1DToAst a0 $ \ !vars !asts -> astLetHVectorInHVector vars a (f asts)

astLetHFunInHVectorFun
:: forall s. AstSpan s
=> AstHFun s -> (AstHFun s -> AstHVector s)
:: AstHFun PrimalSpan -> (AstHFun PrimalSpan -> AstHVector s)
-> AstHVector s
{-# INLINE astLetHFunInHVectorFun #-}
astLetHFunInHVectorFun a f =
1 change: 1 addition & 0 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
@@ -736,6 +736,7 @@ class HVectorTensor (ranked :: RankedTensorType)
dmkHVector :: HVector ranked -> HVectorOf ranked
dlambda :: [VoidHVector] -> HFun -> HFunOf ranked
dHApply :: HFunOf ranked -> [HVector ranked] -> HVectorOf ranked
-- TODO: remove if still unused after a longer time
dunHVector :: VoidHVector -> HVectorOf ranked -> HVector ranked
-- ^ Warning: this operation easily breaks sharing.
dletHVectorInHVector

0 comments on commit 2232cbb

Please sign in to comment.