Skip to content

Commit

Permalink
Merge branch 'master' into lmad-no-perm
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Aug 3, 2023
2 parents 15c0eec + ee56f96 commit 6fa39b5
Show file tree
Hide file tree
Showing 142 changed files with 280 additions and 320 deletions.
2 changes: 1 addition & 1 deletion futhark.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ library
Language.Futhark.TypeChecker.Modules
Language.Futhark.TypeChecker.Monad
Language.Futhark.TypeChecker.Terms
Language.Futhark.TypeChecker.Terms.DoLoop
Language.Futhark.TypeChecker.Terms.Loop
Language.Futhark.TypeChecker.Terms.Monad
Language.Futhark.TypeChecker.Terms.Pat
Language.Futhark.TypeChecker.Types
Expand Down
8 changes: 4 additions & 4 deletions src/Futhark/AD/Fwd.hs
Original file line number Diff line number Diff line change
Expand Up @@ -402,21 +402,21 @@ fwdStm (Let pat aux (Match ses cases defbody (MatchDec ret ifsort))) = do
pat' <- bundleNewPat pat
ret' <- bundleTangents ret
addStm $ Let pat' aux $ Match ses cases' defbody' $ MatchDec ret' ifsort
fwdStm (Let pat aux (DoLoop val_pats loop@(WhileLoop v) body)) = do
fwdStm (Let pat aux (Loop val_pats loop@(WhileLoop v) body)) = do
val_pats' <- bundleNewList val_pats
pat' <- bundleNewPat pat
body' <-
localScope (scopeOfFParams (map fst val_pats) <> scopeOf loop) . slocal' $
fwdBody body
addStm $ Let pat' aux $ DoLoop val_pats' (WhileLoop v) body'
fwdStm (Let pat aux (DoLoop val_pats loop@(ForLoop i it bound loop_vars) body)) = do
addStm $ Let pat' aux $ Loop val_pats' (WhileLoop v) body'
fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound loop_vars) body)) = do
pat' <- bundleNewPat pat
val_pats' <- bundleNewList val_pats
loop_vars' <- bundleNewList loop_vars
body' <-
localScope (scopeOfFParams (map fst val_pats) <> scopeOf loop) . slocal' $
fwdBody body
addStm $ Let pat' aux $ DoLoop val_pats' (ForLoop i it bound loop_vars') body'
addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound loop_vars') body'
fwdStm (Let pat aux (WithAcc inputs lam)) = do
inputs' <- forM inputs $ \(shape, arrs, op) -> do
arrs_tan <- mapM tangent arrs
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/AD/Rev.hs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ diffStm stm@(Let pat _ (Match ses cases defbody _)) m = do
zipWithM_ insAdj branches_free branches_free_adj
diffStm (Let pat aux (Op soac)) m =
vjpSOAC vjpOps pat aux soac m
diffStm (Let pat aux loop@DoLoop {}) m =
diffStm (Let pat aux loop@Loop {}) m =
diffLoop diffStms pat aux loop m
-- See Note [Adjoints of accumulators]
diffStm stm@(Let pat _aux (WithAcc inputs lam)) m = do
Expand Down
4 changes: 2 additions & 2 deletions src/Futhark/AD/Rev/Hist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ radixSort xs n w = do
radixSortStep (map paramName params) types bit n w

letTupExp "sorted" $
DoLoop
Loop
(zip params $ map Var xs)
(ForLoop i Int64 iters [])
loopbody
Expand All @@ -664,7 +664,7 @@ radixSort xs n w = do

l <-
letTupExp' "log2res" $
DoLoop
Loop
(zip params [cond_init, m, Constant $ blankPrimValue int64])
(WhileLoop $ paramName cond)
body
Expand Down
30 changes: 15 additions & 15 deletions src/Futhark/AD/Rev/Loop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bindForLoop ::
a
) ->
a
bindForLoop (DoLoop val_pats form@(ForLoop i it bound loop_vars) body) f =
bindForLoop (Loop val_pats form@(ForLoop i it bound loop_vars) body) f =
f val_pats form i it bound loop_vars body
bindForLoop e _ = error $ "bindForLoop: not a for-loop:\n" <> prettyString e

Expand All @@ -57,7 +57,7 @@ renameForLoop loop f = renameExp loop >>= \loop' -> bindForLoop loop' (f loop')

-- | Is the loop a while-loop?
isWhileLoop :: Exp rep -> Bool
isWhileLoop (DoLoop _ WhileLoop {} _) = True
isWhileLoop (Loop _ WhileLoop {} _) = True
isWhileLoop _ = False

-- | Transforms a 'ForLoop' into a 'ForLoop' with an empty list of
Expand All @@ -73,11 +73,11 @@ removeLoopVars loop =
pure (paramName x_param, x')
(substs_list, subst_stms) <- collectStms $ mapM indexify loop_vars
let Body aux' stms' res' = substituteNames (M.fromList substs_list) body
pure $ DoLoop val_pats form $ Body aux' (subst_stms <> stms') res'
pure $ Loop val_pats form $ Body aux' (subst_stms <> stms') res'

-- | Augments a while-loop to also compute the number of iterations.
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters (DoLoop val_pats (WhileLoop b) body) = do
computeWhileIters (Loop val_pats (WhileLoop b) body) = do
bound_v <- newVName "bound"
let t = Prim $ IntType Int64
bound_param = Param mempty bound_v t
Expand All @@ -89,17 +89,17 @@ computeWhileIters (DoLoop val_pats (WhileLoop b) body) = do
in letSubExp "bound+1" $ BasicOp $ BinOp (Add Int64 OverflowUndef) (Var bound_v) one
addStms $ bodyStms body
pure (pure (subExpRes bound_plus_one) <> bodyResult body)
res <- letTupExp' "loop" $ DoLoop ((bound_param, bound_init) : val_pats) (WhileLoop b) body'
res <- letTupExp' "loop" $ Loop ((bound_param, bound_init) : val_pats) (WhileLoop b) body'
pure $ head res
computeWhileIters e = error $ "convertWhileIters: not a while-loop:\n" <> prettyString e

-- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the
-- surrounding 'DoLoop' is annotated with a @#[bound(n)]@ attribute,
-- surrounding 'Loop' is annotated with a @#[bound(n)]@ attribute,
-- where @n@ is an upper bound on the number of iterations of the
-- while-loop. The resulting for-loop will execute for @n@ iterations on
-- all inputs, so the tighter the bound the better.
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop bound_se (DoLoop val_pats (WhileLoop cond) body) =
convertWhileLoop bound_se (Loop val_pats (WhileLoop cond) body) =
localScope (scopeOfFParams $ map fst val_pats) $ do
i <- newVName "i"
body' <-
Expand All @@ -109,7 +109,7 @@ convertWhileLoop bound_se (DoLoop val_pats (WhileLoop cond) body) =
(pure body)
(resultBodyM $ map (Var . paramName . fst) val_pats)
]
pure $ DoLoop val_pats (ForLoop i Int64 bound_se mempty) body'
pure $ Loop val_pats (ForLoop i Int64 bound_se mempty) body'
convertWhileLoop _ e = error $ "convertWhileLoopBound: not a while-loop:\n" <> prettyString e

-- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest
Expand Down Expand Up @@ -149,12 +149,12 @@ nestifyLoop bound_se = nestifyLoop' bound_se
=<< nestifyLoop'
offset'
(n - 1)
(DoLoop val_pats'' (ForLoop i' it' bound_se loop_vars') inner_body)
(Loop val_pats'' (ForLoop i' it' bound_se loop_vars') inner_body)
pure $ varsRes inner_loop
pure $
DoLoop val_pats (ForLoop i it bound_se loop_vars) outer_body
Loop val_pats (ForLoop i it bound_se loop_vars) outer_body
| n == 1 =
pure $ DoLoop val_pats (ForLoop i it bound_se loop_vars) body
pure $ Loop val_pats (ForLoop i it bound_se loop_vars) body
| otherwise = pure loop

-- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest.
Expand Down Expand Up @@ -185,7 +185,7 @@ stripmine n pat loop = do
let loop_params_rem = map fst val_pats'
loop_inits_rem = map (Var . patElemName) $ patElems pat'
val_pats_rem = zip loop_params_rem loop_inits_rem
remain_loop = DoLoop val_pats_rem (ForLoop i' it' remain_iters loop_vars') remain_body
remain_loop = Loop val_pats_rem (ForLoop i' it' remain_iters loop_vars') remain_body
collectStms_ $ do
letBind pat' mined_loop
letBind pat remain_loop
Expand All @@ -194,7 +194,7 @@ stripmine n pat loop = do
-- expression is a for-loop with a @#[stripmine(n)]@ attribute, where
-- @n@ is the nesting depth.
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm stm@(Let pat aux loop@(DoLoop _ ForLoop {} _)) =
stripmineStm stm@(Let pat aux loop@(Loop _ ForLoop {} _)) =
case nums of
(n : _) -> stripmine n pat loop
_ -> pure $ oneStm stm
Expand Down Expand Up @@ -256,7 +256,7 @@ fwdLoop pat aux loop =

let pat' = pat <> Pat saved_pats
val_pats' = val_pats <> zip saved_params empty_saved_array
addStm $ Let pat' aux $ DoLoop val_pats' form body'
addStm $ Let pat' aux $ Loop val_pats' form body'

-- | Construct a loop value-pattern for the adjoint of the
-- given variable.
Expand Down Expand Up @@ -420,7 +420,7 @@ revLoop diffStms pat loop =
adjs' <-
letTupExp "loop_adj" $
substituteNames (restore_true_deps <> var_array_substs) $
DoLoop val_pat_adjs_list form' body_adj
Loop val_pat_adjs_list form' body_adj
let (loop_res_adjs, loop_free_var_val_adjs) =
splitAt (length $ loopRes loop_adjs) adjs'
(loop_free_adjs, loop_var_val_adjs) =
Expand Down
10 changes: 5 additions & 5 deletions src/Futhark/Analysis/Interference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ analyseStm lumap inuse0 stm =
(namesToList $ inuse_outside <> inuse <> lus <> last_use_mems)
)

-- We conservatively treat all memory arguments to a DoLoop to
-- We conservatively treat all memory arguments to a Loop to
-- interfere with each other, as well as anything used inside the
-- loop. This could potentially be improved by looking at the
-- interference computed by the loop body wrt. the loop arguments, but
Expand Down Expand Up @@ -115,7 +115,7 @@ analyseExp lumap inuse_outside expr =
fmap mconcat $
mapM (analyseBody lumap inuse_outside) $
defbody : map caseBody cases
DoLoop merge _ body ->
Loop merge _ body ->
analyseLoopParams merge <$> analyseBody lumap inuse_outside body
Op (Inner (SegOp segop)) -> do
analyseSegOp lumap inuse_outside segop
Expand Down Expand Up @@ -283,7 +283,7 @@ memSizes stms =
$ kernelBodyStms body
memSizesExp (Match _ cases defbody _) = do
mconcat <$> mapM (memSizes . bodyStms) (defbody : map caseBody cases)
memSizesExp (DoLoop _ _ body) =
memSizesExp (Loop _ _ body) =
memSizes $ bodyStms body
memSizesExp _ = pure mempty

Expand All @@ -300,7 +300,7 @@ memSpaces stms =
foldMap getSpacesStm $ kernelBodyStms $ segBody segop
getSpacesStm (Let _ _ (Match _ cases defbody _)) =
foldMap (foldMap getSpacesStm . bodyStms) $ defbody : map caseBody cases
getSpacesStm (Let _ _ (DoLoop _ _ body)) =
getSpacesStm (Let _ _ (Loop _ _ body)) =
foldMap getSpacesStm (bodyStms body)
getSpacesStm _ = mempty

Expand All @@ -322,7 +322,7 @@ analyseGPU' lumap stms =
inScopeOf stm $
mconcat
<$> mapM (analyseGPU' lumap . bodyStms) (defbody : map caseBody cases)
helper stm@Let {stmExp = DoLoop merge _ body} =
helper stm@Let {stmExp = Loop merge _ body} =
fmap (analyseLoopParams merge) . inScopeOf stm $
analyseGPU' lumap $
bodyStms body
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Analysis/LastUse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ lastUseExp (Match _ cases body _) used_nms = do
let used_nms' = used_cases <> body_used_nms
(_, last_used_arrs) <- lastUsedInNames used_nms $ free_in_body <> free_in_cases
pure (lutab_cases <> lutab', last_used_arrs, used_nms')
lastUseExp (DoLoop var_ses lf body) used_nms0 = inScopeOf lf $ do
lastUseExp (Loop var_ses lf body) used_nms0 = inScopeOf lf $ do
free_in_body <- aliasTransitiveClosure $ freeIn body
-- compute the aliasing transitive closure of initializers that are not last-uses
var_inis <- catMaybes <$> mapM (initHelper (free_in_body <> used_nms0)) var_ses
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Analysis/MemAlias.hs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ analyzeStm m (Let pat _ (Match _ cases defbody _)) = do
& mapMaybe (filterFun m')
& foldr (uncurry addAlias) m'
& pure
analyzeStm m (Let pat _ (DoLoop params _ body)) = do
analyzeStm m (Let pat _ (Loop params _ body)) = do
let m_init =
map snd params
& zip (patNames pat)
Expand Down
8 changes: 4 additions & 4 deletions src/Futhark/Analysis/Metrics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ stmMetrics = expMetrics . stmExp
expMetrics :: OpMetrics (Op rep) => Exp rep -> MetricsM ()
expMetrics (BasicOp op) =
seen "BasicOp" >> basicOpMetrics op
expMetrics (DoLoop _ ForLoop {} body) =
inside "DoLoop" $ seen "ForLoop" >> bodyMetrics body
expMetrics (DoLoop _ WhileLoop {} body) =
inside "DoLoop" $ seen "WhileLoop" >> bodyMetrics body
expMetrics (Loop _ ForLoop {} body) =
inside "Loop" $ seen "ForLoop" >> bodyMetrics body
expMetrics (Loop _ WhileLoop {} body) =
inside "Loop" $ seen "WhileLoop" >> bodyMetrics body
expMetrics (Match _ [Case [Just (BoolValue True)] tb] fb _) =
inside "If" $ do
inside "True" $ bodyMetrics tb
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Analysis/UsageTable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ usageInExp (Apply _ args _ _) =
| (arg, d) <- args,
d == Consume
]
usageInExp e@DoLoop {} =
usageInExp e@Loop {} =
foldMap consumedUsage $ namesToList $ consumedInExp e
usageInExp (Match _ cases defbody _) =
foldMap (usageInBody . caseBody) cases <> usageInBody defbody
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/CodeGen/ImpGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ defCompileExp pat (Apply fname args _ _) = do
(Var v, Mem {}) -> pure $ Just $ Imp.MemArg v
_ -> pure Nothing
defCompileExp pat (BasicOp op) = defCompileBasicOp pat op
defCompileExp pat (DoLoop merge form body) = do
defCompileExp pat (Loop merge form body) = do
attrs <- askAttrs
when ("unroll" `inAttrs` attrs) $
warn (noLoc :: SrcLoc) [] "#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/CodeGen/ImpGen/GPU/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ segOpSizes = onStms
S.singleton $ arrayDims $ patElemType pe
onStm (Let _ _ (Match _ cases defbody _)) =
foldMap (onStms . bodyStms . caseBody) cases <> onStms (bodyStms defbody)
onStm (Let _ _ (DoLoop _ _ body)) =
onStm (Let _ _ (Loop _ _ body)) =
onStms (bodyStms body)
onStm _ = mempty

Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Aliases.hs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ instance
where
merge_dec =
case e of
DoLoop merge _ body ->
Loop merge _ body ->
let mergeParamAliases fparam als
| primType (paramType fparam) =
Nothing
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/GPU/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ simplifyKernelOp _ (GPUBody ts body) = do
keepOnGPU _ _ = keepExpOnGPU . stmExp
keepExpOnGPU (BasicOp Index {}) = True
keepExpOnGPU (BasicOp (ArrayLit _ t)) | primType t = True
keepExpOnGPU DoLoop {} = True
keepExpOnGPU Loop {} = True
keepExpOnGPU _ = False

instance TraverseOpStms (Wise GPU) where
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Mem.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ expReturns (BasicOp (FlatUpdate v _ _)) =
Just . pure <$> varReturns v
expReturns (BasicOp op) =
Just . extReturns . staticShapes <$> basicOpType op
expReturns e@(DoLoop merge _ _) = do
expReturns e@(Loop merge _ _) = do
t <- expExtType e
Just <$> zipWithM typeWithDec t (map fst merge)
where
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ pApply pr =
pLoop :: PR rep -> Parser (Exp rep)
pLoop pr =
keyword "loop"
$> DoLoop
$> Loop
<*> pLoopParams
<*> pLoopForm
<* keyword "do"
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ instance PrettyRep rep => Pretty (Exp rep) where
Unsafe -> "apply <unsafe>"
Safe -> "apply"
pretty (Op op) = pretty op
pretty (DoLoop merge form loopbody) =
pretty (Loop merge form loopbody) =
"loop"
<+> braces (commastack $ map pretty params)
<+> equals
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Prop.hs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ safeExp (BasicOp op) = safeBasicOp op
safeBasicOp Iota {} = True
safeBasicOp Replicate {} = True
safeBasicOp _ = False
safeExp (DoLoop _ _ body) = safeBody body
safeExp (Loop _ _ body) = safeBody body
safeExp (Apply fname _ _ _) =
isBuiltInFunction fname
safeExp (Match _ cases def_case _) =
Expand Down
4 changes: 2 additions & 2 deletions src/Futhark/IR/Prop/Aliases.hs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ expAliases pes (Match _ cases defbody _) =
onBody body = (bodyAliases body, consumedInBody body)
bound = foldMap boundInBody $ defbody : map caseBody cases
expAliases _ (BasicOp op) = basicOpAliases op
expAliases pes (DoLoop merge _ loopbody) =
expAliases pes (Loop merge _ loopbody) =
mutualAliases (bound <> param_names) pes $ do
(p, als) <-
transitive . zip params $ zipWith (<>) arg_aliases (bodyAliases loopbody)
Expand Down Expand Up @@ -166,7 +166,7 @@ consumedInExp (Apply _ args _ _) =
consumeArg _ = mempty
consumedInExp (Match _ cases defbody _) =
foldMap (consumedInBody . caseBody) cases <> consumedInBody defbody
consumedInExp (DoLoop merge form body) =
consumedInExp (Loop merge form body) =
mconcat
( map (subExpAliases . snd) $
filter (unique . paramDeclType . fst) merge
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Prop/Names.hs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ instance
) =>
FreeIn (Exp rep)
where
freeIn' (DoLoop merge form loopbody) =
freeIn' (Loop merge form loopbody) =
let (params, args) = unzip merge
bound_here =
namesFromList $ M.keys $ scopeOf form <> scopeOfFParams params
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/Prop/TypeOf.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ expExtType ::
m [ExtType]
expExtType (Apply _ _ rt _) = pure $ map (fromDecl . declExtTypeOf . fst) rt
expExtType (Match _ _ _ rt) = pure $ map extTypeOf $ matchReturns rt
expExtType (DoLoop merge _ _) =
expExtType (Loop merge _ _) =
pure $ loopExtType $ map fst merge
expExtType (BasicOp op) = staticShapes <$> basicOpType op
expExtType (WithAcc inputs lam) =
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/SOACS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ usesAD prog = any stmUsesAD (progConsts prog) || any funUsesAD (progFuns prog)
lamUsesAD lam
expUsesAD (Match _ cases def_case _) =
any (bodyUsesAD . caseBody) cases || bodyUsesAD def_case
expUsesAD (DoLoop _ _ body) = bodyUsesAD body
expUsesAD (Loop _ _ body) = bodyUsesAD body
expUsesAD (WithAcc _ lam) = lamUsesAD lam
expUsesAD BasicOp {} = False
expUsesAD Apply {} = False
2 changes: 1 addition & 1 deletion src/Futhark/IR/Syntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ data Exp rep
-- body/ is picked.
Match [SubExp] [Case (Body rep)] (Body rep) (MatchDec (BranchType rep))
| -- | @loop {a} = {v} (for i < n|while b) do b@.
DoLoop [(FParam rep, SubExp)] (LoopForm rep) (Body rep)
Loop [(FParam rep, SubExp)] (LoopForm rep) (Body rep)
| -- | Create accumulators backed by the given arrays (which are
-- consumed) and pass them to the lambda, which must return the
-- updated accumulators and possibly some extra values. The
Expand Down
Loading

0 comments on commit 6fa39b5

Please sign in to comment.