diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index 73cd63115f..bdeae1b2d8 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -521,37 +521,49 @@ isMapWithOp pat e -- the data dependencies to see that the "dead" result is not -- actually used for computing one of the live ones. removeDeadReduction :: BottomUpRuleOp (Wise SOACS) -removeDeadReduction (_, used) pat aux (Screma w arrs form) - | Just ([Reduce comm redlam nes], maplam) <- isRedomapSOAC form, - not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check - let (red_pes, map_pes) = splitAt (length nes) $ patElems pat, - let redlam_deps = dataDependencies $ lambdaBody redlam, - let redlam_res = bodyResult $ lambdaBody redlam, - let redlam_params = lambdaParams redlam, - let used_after = - map snd . filter ((`UT.used` used) . patElemName . fst) $ - zip red_pes redlam_params, - let necessary = - findNecessaryForReturned - (`elem` used_after) - (zip redlam_params $ map resSubExp $ redlam_res <> redlam_res) - redlam_deps, - let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params, - not $ all (== True) (take (length nes) alive_mask) = Simplify $ do - let fixDeadToNeutral lives ne = if lives then Nothing else Just ne - dead_fix = zipWith fixDeadToNeutral alive_mask nes - (used_red_pes, _, used_nes) = - unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $ - zip3 red_pes redlam_params nes - - let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam - redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix) - - auxing aux $ - letBind (Pat $ used_red_pes ++ map_pes) $ - Op $ - Screma w arrs $ - redomapSOAC [Reduce comm redlam' used_nes] maplam' +removeDeadReduction (_, used) pat aux (Screma w arrs form) = + case isRedomapSOAC form of + Just ([Reduce comm redlam rednes], maplam) -> + let mkOp lam nes' maplam' = redomapSOAC [Reduce comm lam nes'] maplam' + in removeDeadReduction' redlam rednes maplam mkOp + _ -> + case isScanomapSOAC form of + Just ([Scan scanlam nes], maplam) -> + let mkOp lam nes' maplam' = scanomapSOAC [Scan lam nes'] maplam' + in removeDeadReduction' scanlam nes maplam mkOp + _ -> Skip + where + removeDeadReduction' redlam nes maplam mkOp + | not $ all (`UT.used` used) $ patNames pat, -- Quick/cheap check + let (red_pes, map_pes) = splitAt (length nes) $ patElems pat, + let redlam_deps = dataDependencies $ lambdaBody redlam, + let redlam_res = bodyResult $ lambdaBody redlam, + let redlam_params = lambdaParams redlam, + let used_after = + map snd . filter ((`UT.used` used) . patElemName . fst) $ + zip red_pes redlam_params, + let necessary = + findNecessaryForReturned + (`elem` used_after) + (zip redlam_params $ map resSubExp $ redlam_res <> redlam_res) + redlam_deps, + let alive_mask = map ((`nameIn` necessary) . paramName) redlam_params, + not $ all (== True) (take (length nes) alive_mask) = Simplify $ do + let fixDeadToNeutral lives ne = if lives then Nothing else Just ne + dead_fix = zipWith fixDeadToNeutral alive_mask nes + (used_red_pes, _, used_nes) = + unzip3 . filter (\(_, x, _) -> paramName x `nameIn` necessary) $ + zip3 red_pes redlam_params nes + + let maplam' = removeLambdaResults (take (length nes) alive_mask) maplam + redlam' <- removeLambdaResults (take (length nes) alive_mask) <$> fixLambdaParams redlam (dead_fix ++ dead_fix) + + auxing aux $ + letBind (Pat $ used_red_pes ++ map_pes) $ + Op $ + Screma w arrs $ + mkOp redlam' used_nes maplam' + removeDeadReduction' _ _ _ _ = Skip removeDeadReduction _ _ _ _ = Skip -- | If we are writing to an array that is never used, get rid of it.