Skip to content

Commit

Permalink
Don't recursively unlet (again) lambdas in folds
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Feb 15, 2024
1 parent cbbebc2 commit dbf1aa4
Showing 1 changed file with 24 additions and 84 deletions.
108 changes: 24 additions & 84 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
@@ -645,32 +645,14 @@ unletAst env t = case t of
Ast.AstFwd (vars, unletAst (emptyUnletEnv emptyADShare) v)
(V.map (unletAstDynamic env) l)
(V.map (unletAstDynamic env) ds)
Ast.AstFold (nvar, mvar, v) x0 as ->
Ast.AstFold (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
(unletAst env x0)
(unletAst env as)
Ast.AstFoldDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldDer (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
( varDx, varDa, varn1, varm1
, unletAst (emptyUnletEnv emptyADShare) ast1 )
( varDt2, nvar2, mvar2
, unletAstHVector (emptyUnletEnv emptyADShare) doms )
(unletAst env x0)
(unletAst env as)
Ast.AstScan (nvar, mvar, v) x0 as ->
Ast.AstScan (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
(unletAst env x0)
(unletAst env as)
Ast.AstScanDer (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstScanDer (nvar, mvar, unletAst (emptyUnletEnv emptyADShare) v)
( varDx, varDa, varn1, varm1
, unletAst (emptyUnletEnv emptyADShare) ast1 )
( varDt2, nvar2, mvar2
, unletAstHVector (emptyUnletEnv emptyADShare) doms )
(unletAst env x0)
(unletAst env as)
Ast.AstFold f x0 as ->
Ast.AstFold f (unletAst env x0) (unletAst env as)
Ast.AstFoldDer f df rf x0 as ->
Ast.AstFoldDer f df rf (unletAst env x0) (unletAst env as)
Ast.AstScan f x0 as ->
Ast.AstScan f (unletAst env x0) (unletAst env as)
Ast.AstScanDer f df rf x0 as ->
Ast.AstScanDer f df rf (unletAst env x0) (unletAst env as)

unletAstS
:: (GoodScalar r, Sh.Shape sh, AstSpan s)
@@ -748,32 +730,14 @@ unletAstS env t = case t of
Ast.AstFwdS (vars, unletAstS (emptyUnletEnv emptyADShare) v)
(V.map (unletAstDynamic env) l)
(V.map (unletAstDynamic env) ds)
Ast.AstFoldS (nvar, mvar, v) x0 as ->
Ast.AstFoldS (nvar, mvar, unletAstS (emptyUnletEnv emptyADShare) v)
(unletAstS env x0)
(unletAstS env as)
Ast.AstFoldDerS (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstFoldDerS (nvar, mvar, unletAstS (emptyUnletEnv emptyADShare) v)
( varDx, varDa, varn1, varm1
, unletAstS (emptyUnletEnv emptyADShare) ast1 )
( varDt2, nvar2, mvar2
, unletAstHVector (emptyUnletEnv emptyADShare) doms )
(unletAstS env x0)
(unletAstS env as)
Ast.AstScanS (nvar, mvar, v) x0 as ->
Ast.AstScanS (nvar, mvar, unletAstS (emptyUnletEnv emptyADShare) v)
(unletAstS env x0)
(unletAstS env as)
Ast.AstScanDerS (nvar, mvar, v) (varDx, varDa, varn1, varm1, ast1)
(varDt2, nvar2, mvar2, doms) x0 as ->
Ast.AstScanDerS (nvar, mvar, unletAstS (emptyUnletEnv emptyADShare) v)
( varDx, varDa, varn1, varm1
, unletAstS (emptyUnletEnv emptyADShare) ast1 )
( varDt2, nvar2, mvar2
, unletAstHVector (emptyUnletEnv emptyADShare) doms )
(unletAstS env x0)
(unletAstS env as)
Ast.AstFoldS f x0 as ->
Ast.AstFoldS f (unletAstS env x0) (unletAstS env as)
Ast.AstFoldDerS f df rf x0 as ->
Ast.AstFoldDerS f df rf (unletAstS env x0) (unletAstS env as)
Ast.AstScanS f x0 as ->
Ast.AstScanS f (unletAstS env x0) (unletAstS env as)
Ast.AstScanDerS f df rf x0 as ->
Ast.AstScanDerS f df rf (unletAstS env x0) (unletAstS env as)

unletAstDynamic
:: AstSpan s
@@ -829,44 +793,20 @@ unletAstHVector env = \case
Ast.AstRevDtS (vars, unletAstS (emptyUnletEnv emptyADShare) v)
(V.map (unletAstDynamic env) l)
(unletAstS env dt)
Ast.AstMapAccumR k accShs bShs eShs (accvars, evars, v) acc0 es ->
Ast.AstMapAccumR k accShs bShs eShs
( accvars, evars
, unletAstHVector (emptyUnletEnv emptyADShare) v )
Ast.AstMapAccumR k accShs bShs eShs f acc0 es ->
Ast.AstMapAccumR k accShs bShs eShs f
(V.map (unletAstDynamic env) acc0)
(V.map (unletAstDynamic env) es)
Ast.AstMapAccumRDer k accShs bShs eShs
(accvars, evars, v)
(vs1, vs2, vs3, vs4, ast)
(ws1, ws2, ws3, ws4, bst)
acc0 es ->
Ast.AstMapAccumRDer k accShs bShs eShs
( accvars, evars
, unletAstHVector (emptyUnletEnv emptyADShare) v )
( vs1, vs2, vs3, vs4
, unletAstHVector (emptyUnletEnv emptyADShare) ast )
( ws1, ws2, ws3, ws4
, unletAstHVector (emptyUnletEnv emptyADShare) bst )
Ast.AstMapAccumRDer k accShs bShs eShs f df rf acc0 es ->
Ast.AstMapAccumRDer k accShs bShs eShs f df rf
(V.map (unletAstDynamic env) acc0)
(V.map (unletAstDynamic env) es)
Ast.AstMapAccumL k accShs bShs eShs (accvars, evars, v) acc0 es ->
Ast.AstMapAccumL k accShs bShs eShs
( accvars, evars
, unletAstHVector (emptyUnletEnv emptyADShare) v )
Ast.AstMapAccumL k accShs bShs eShs f acc0 es ->
Ast.AstMapAccumL k accShs bShs eShs f
(V.map (unletAstDynamic env) acc0)
(V.map (unletAstDynamic env) es)
Ast.AstMapAccumLDer k accShs bShs eShs
(accvars, evars, v)
(vs1, vs2, vs3, vs4, ast)
(ws1, ws2, ws3, ws4, bst)
acc0 es ->
Ast.AstMapAccumLDer k accShs bShs eShs
( accvars, evars
, unletAstHVector (emptyUnletEnv emptyADShare) v )
( vs1, vs2, vs3, vs4
, unletAstHVector (emptyUnletEnv emptyADShare) ast )
( ws1, ws2, ws3, ws4
, unletAstHVector (emptyUnletEnv emptyADShare) bst )
Ast.AstMapAccumLDer k accShs bShs eShs f df rf acc0 es ->
Ast.AstMapAccumLDer k accShs bShs eShs f df rf
(V.map (unletAstDynamic env) acc0)
(V.map (unletAstDynamic env) es)

0 comments on commit dbf1aa4

Please sign in to comment.