From 2576b2e5fb6fb29853ff266afbfb66fced95046b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 14 Oct 2022 11:42:34 +0200 Subject: [PATCH 01/64] wip --- futhark.cabal | 3 + src/Futhark/CLI/Dev.hs | 2 + src/Futhark/Pass/Flatten.hs | 154 ++++++++++++++++++ src/Futhark/Pass/Flatten/Builtins.hs | 206 +++++++++++++++++++++++++ src/Futhark/Pass/Flatten/Distribute.hs | 149 ++++++++++++++++++ 5 files changed, 514 insertions(+) create mode 100644 src/Futhark/Pass/Flatten.hs create mode 100644 src/Futhark/Pass/Flatten/Builtins.hs create mode 100644 src/Futhark/Pass/Flatten/Distribute.hs diff --git a/futhark.cabal b/futhark.cabal index c8e3a05f73..2e95191112 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -380,6 +380,9 @@ library Futhark.Pass.ExtractKernels.ToGPU Futhark.Pass.ExtractMulticore Futhark.Pass.FirstOrderTransform + Futhark.Pass.Flatten + Futhark.Pass.Flatten.Builtins + Futhark.Pass.Flatten.Distribute Futhark.Pass.KernelBabysitting Futhark.Pass.Simplify Futhark.Passes diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs index 0eab6ba73a..29820665f5 100644 --- a/src/Futhark/CLI/Dev.hs +++ b/src/Futhark/CLI/Dev.hs @@ -47,6 +47,7 @@ import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten (flattenSOACs) import Futhark.Pass.KernelBabysitting import Futhark.Pass.Simplify import Futhark.Passes @@ -571,6 +572,7 @@ commandLineOptions = kernelsPassOption sinkGPU [], kernelsPassOption reduceDeviceSyncs [], typedPassOption soacsProg GPU extractKernels [], + typedPassOption soacsProg GPU flattenSOACs [], typedPassOption soacsProg MC extractMulticore [], iplOption [], allocateOption "a", diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs new file mode 100644 index 0000000000..79d1d81efd --- /dev/null +++ b/src/Futhark/Pass/Flatten.hs @@ -0,0 +1,154 @@ +{-# LANGUAGE TypeFamilies #-} + +-- The idea is to perform distribution on one level at a time, and +-- produce "irregular Maps" that can accept and produce irregular +-- arrays. These irregular maps will then be transformed into flat +-- parallelism based on their contents. This is a sensitive detail, +-- but if irregular maps contain only a single Stm, then it is fairly +-- straightforward, as we simply implement flattening rules for every +-- single kind of expression. Of course that is also somewhat +-- inefficient, so we want to support multiple Stms for things like +-- scalar code. +module Futhark.Pass.Flatten (flattenSOACs) where + +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor (bimap, first, second) +import Data.Foldable +import Data.Map qualified as M +import Data.Maybe (fromMaybe) +import Debug.Trace +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass +import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace, segScan) +import Futhark.Pass.ExtractKernels.ToGPU (scopeForGPU, soacsLambdaToGPU, soacsStmToGPU) +import Futhark.Pass.Flatten.Builtins +import Futhark.Pass.Flatten.Distribute +import Futhark.Tools + +data FlattenEnv = FlattenEnv + +newtype FlattenM a = FlattenM (StateT VNameSource (Reader FlattenEnv) a) + deriving + ( MonadState VNameSource, + MonadFreshNames, + MonadReader FlattenEnv, + Monad, + Functor, + Applicative + ) + +data IrregularRep = IrregularRep + { irregularSegments :: VName, + irregularFlags :: VName, + irregularOffsets :: VName, + iregularElems :: VName + } + +data ResRep + = -- | This variable is represented + -- completely straightforwardly- if it is + -- an array, it is a regular array. + Regular VName + | -- | The representation of an + -- irregular array. + Irregular IrregularRep + +data DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} + +insertRep :: ResTag -> ResRep -> DistEnv -> DistEnv +insertRep rt rep env = env {distResMap = M.insert rt rep $ distResMap env} + +insertReps :: [(ResTag, ResRep)] -> DistEnv -> DistEnv +insertReps = flip $ foldl (flip $ uncurry insertRep) + +instance Monoid DistEnv where + mempty = DistEnv mempty + +instance Semigroup DistEnv where + DistEnv x <> DistEnv y = DistEnv (x <> y) + +flagsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName), [VName]) +flagsAndElems env [] = (Nothing, []) +flagsAndElems env (DistInputFree v _ : vs) = + second (v :) $ flagsAndElems env vs +flagsAndElems env (DistInput rt _ : vs) = + case M.lookup rt $ distResMap env of + Just (Regular v') -> + second (v' :) $ flagsAndElems env vs + Just (Irregular (IrregularRep _ flags offsets elems)) -> + bimap (mplus $ Just (flags, offsets)) (elems :) $ flagsAndElems env vs + _ -> + error "flagsAndElems: nope" + +transformDistStm :: DistEnv -> DistStm -> Builder GPU DistEnv +transformDistStm env (DistStm inps res stm) = + case stm of + Let _ _ (BasicOp (Iota (Var n) x s Int64)) + | Just (DistInputFree ns _) <- lookup n inps -> do + let ~[DistResult rt _] = res + (flags, offsets, elems) <- doSegIota ns + let rep = Irregular $ IrregularRep ns flags offsets elems + pure $ insertRep rt rep env + Let _ _ (Op (Screma w arrs form)) + | Just reds <- isReduceSOAC form, + Just arrs' <- mapM (`lookup` inps) arrs, + (Just (flags, offsets), elems) <- flagsAndElems env arrs' -> do + elems' <- genSegRed flags offsets elems $ singleReduce reds + pure $ insertReps (zip (map distResTag res) (map Regular elems')) env + _ -> error $ "Unhandled:\n" ++ prettyString stm + +transformDistributed :: Distributed -> Builder GPU () +transformDistributed (Distributed dstms resmap) = do + env <- foldM transformDistStm mempty dstms + forM_ (M.toList resmap) $ \(rt, v) -> do + case M.lookup rt $ distResMap env of + Just (Regular v') -> letBindNames [v] $ BasicOp $ SubExp $ Var v' + Just Irregular {} -> error $ "Result is irregular: " ++ prettyString v + Nothing -> error $ "Missing result binding: " ++ prettyString v + +transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) +transformStm scope (Let pat _ (Op (Screma w arrs form))) + | Just lam <- isMapSOAC form = do + let distributed = distributeMap scope pat w arrs lam + m = transformDistributed distributed + traceM $ prettyString distributed + runReaderT (runBuilder_ m) scope +transformStm scope stm = pure $ oneStm $ soacsStmToGPU stm + +transformStms :: Scope SOACS -> Stms SOACS -> PassM (Stms GPU) +transformStms scope stms = + fold <$> traverse (transformStm (scope <> scopeOf stms)) stms + +transformFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +transformFunDef consts_scope fd = do + let FunDef + { funDefBody = Body () stms res, + funDefParams = fparams, + funDefRetType = rettype + } = fd + stms' <- transformStms (consts_scope <> scopeOfFParams fparams) stms + pure $ + fd + { funDefBody = Body () stms' res, + funDefRetType = rettype, + funDefParams = fparams + } + +transformProg :: Prog SOACS -> PassM (Prog GPU) +transformProg prog = do + consts' <- transformStms mempty $ progConsts prog + funs' <- mapM (transformFunDef $ scopeOf (progConsts prog)) $ progFuns prog + pure $ prog {progConsts = consts', progFuns = flatteningBuiltins <> funs'} + +-- | Transform a SOACS program to a GPU program, using flattening. +flattenSOACs :: Pass SOACS GPU +flattenSOACs = + Pass + { passName = "flatten", + passDescription = "Perform full flattening", + passFunction = transformProg + } +{-# NOINLINE flattenSOACs #-} diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs new file mode 100644 index 0000000000..95f82a20e0 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -0,0 +1,206 @@ +{-# LANGUAGE TypeFamilies #-} + +module Futhark.Pass.Flatten.Builtins + ( flatteningBuiltins, + genSegScan, + genSegRed, + doSegIota, + ) +where + +import Control.Monad.Reader +import Control.Monad.State.Strict +import Data.Foldable +import Data.Map qualified as M +import Data.Maybe (fromMaybe) +import Debug.Trace +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Pass +import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace, segScan) +import Futhark.Pass.ExtractKernels.ToGPU + ( scopeForGPU, + soacsLambdaToGPU, + soacsStmToGPU, + ) +import Futhark.Pass.Flatten.Distribute +import Futhark.Tools + +segIotaName, segRepName :: Name +segIotaName = "builtin#segiota" +segRepName = "builtin#segrep" + +genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genScan desc w lam nes arrs = do + pat <- fmap Pat $ forM (lambdaReturnType lam) $ \t -> + PatElem <$> newVName desc <*> pure (arrayOf t (Shape [w]) NoUniqueness) + let op = SegBinOp Commutative lam nes mempty + map_lam <- mkIdentityLambda $ lambdaReturnType lam + addStms =<< segScan lvl pat mempty w [op] map_lam arrs [] [] + pure $ patNames pat + where + lvl = SegThread SegNoVirt Nothing + +segScanLambda :: + (MonadBuilder m, BranchType (Rep m) ~ ExtType, LParamInfo (Rep m) ~ Type) => + Lambda (Rep m) -> + m (Lambda (Rep m)) +segScanLambda lam = do + x_flag_p <- newParam "x_flag" $ Prim Bool + y_flag_p <- newParam "y_flag" $ Prim Bool + let ts = lambdaReturnType lam + (xps, yps) = splitAt (length ts) $ lambdaParams lam + mkLambda ([x_flag_p] ++ xps ++ [y_flag_p] ++ yps) $ + bodyBind + =<< eBody + [ eBinOp LogOr (eParam x_flag_p) (eParam y_flag_p), + eIf + (eParam y_flag_p) + (eBody (map eParam yps)) + (pure $ lambdaBody lam) + ] + +genSegScan :: String -> Lambda GPU -> [SubExp] -> VName -> [VName] -> Builder GPU [VName] +genSegScan desc lam nes flags arrs = do + w <- arraySize 0 <$> lookupType flags + lam' <- segScanLambda lam + drop 1 <$> genScan desc w lam' (constant False : nes) (flags : arrs) + +genPrefixSum :: String -> VName -> Builder GPU VName +genPrefixSum desc ns = do + w <- arraySize 0 <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genScan desc w add_lam [intConst Int64 0] [ns] + +genSegPrefixSum :: String -> VName -> VName -> Builder GPU VName +genSegPrefixSum desc flags ns = do + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genSegScan desc add_lam [intConst Int64 0] flags [ns] + +genScatter :: VName -> VName -> SubExp -> Builder GPU (Exp GPU) +genScatter dest is v = do + n <- arraySize 0 <$> lookupType is + m <- arraySize 0 <$> lookupType dest + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, n)] + v_t <- subExpType v + (res, stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + i <- letSubExp "i" =<< eIndex is (eSubExp $ Var gtid) + pure $ WriteReturns mempty (Shape [m]) dest [(Slice [DimFix i], v)] + let kbody = KernelBody () stms [res] + pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space [v_t] kbody + +genTabulate :: SubExp -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU (Exp GPU) +genTabulate w m = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, w)] + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + ses <- m $ Var gtid + ts <- mapM subExpType ses + pure (map (Returns ResultMaySimplify mempty) ses, ts) + let kbody = KernelBody () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space ts kbody + +genFlags :: SubExp -> VName -> Builder GPU VName +genFlags m offsets = do + flags_allfalse <- + letExp "flags_allfalse" . BasicOp $ + Replicate (Shape [m]) (constant False) + letExp "flags" =<< genScatter flags_allfalse offsets (constant True) + +genSegRed :: VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] +genSegRed flags offsets elems red = do + scanned <- + genSegScan + "red" + (soacsLambdaToGPU $ redLambda red) + (redNeutral red) + flags + elems + num_segments <- arraySize 0 <$> lookupType offsets + letTupExp "segred" <=< genTabulate num_segments $ \i -> do + next_start <- + letSubExp "next_start" =<< eIndex offsets (toExp (pe64 i)) + this_end <- + letSubExp "this_end" =<< toExp (pe64 next_start - 1) + mapM (letSubExp "res" <=< (`eIndex` eSubExp this_end)) scanned + +genSegIota :: VName -> Builder GPU (VName, VName, VName) +genSegIota ns = do + n <- arraySize 0 <$> lookupType ns + is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) + offsets <- genPrefixSum "offsets" ns + m <- + letSubExp "m" + =<< eIf + (eSubExp is_empty) + (eBody [eSubExp $ intConst Int64 0]) + (eBody [eLast offsets]) + flags <- genFlags m offsets + ones <- letExp "ones" $ BasicOp $ Replicate (Shape [m]) one + iotas <- genSegPrefixSum "iotas" flags ones + res <- letExp "res" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex iotas (eSubExp i) + letTupExp' "xm1" $ BasicOp $ BinOp (Sub Int64 OverflowUndef) x one + pure (flags, offsets, res) + where + one = intConst Int64 1 + +buildingBuiltin :: Builder GPU (FunDef GPU) -> FunDef GPU +buildingBuiltin m = fst $ evalState (runBuilderT m mempty) blankNameSource + +segIotaBuiltin :: FunDef GPU +segIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genSegIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = segIotaName, + funDefRetType = + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +-- | Builtin functions used in flattening. Must be prepended to a +-- program that is transformed by flattening. The intention is to +-- avoid the code explosion that would result if we inserted +-- primitives everywhere. +flatteningBuiltins :: [FunDef GPU] +flatteningBuiltins = [segIotaBuiltin] + +-- | Perform a segmented iota. Returns flags,offsets,data. +doSegIota :: VName -> Builder GPU (VName, VName, VName) +doSegIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "segiota_flags" + offsets <- newVName "segiota_offsets" + elems <- newVName "segiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doSegIota: bad application") $ + applyRetType + (funDefRetType segIotaBuiltin) + (funDefParams segIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName segIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + restype + (Safe, mempty, mempty) + pure (flags, offsets, elems) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs new file mode 100644 index 0000000000..f357181cb3 --- /dev/null +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -0,0 +1,149 @@ +module Futhark.Pass.Flatten.Distribute + ( distributeMap, + ResMap, + Distributed (..), + DistStm (..), + DistInput (..), + DistType (..), + DistResult (..), + ResTag, + ) +where + +import Control.Monad.Reader +import Control.Monad.State +import Data.Bifunctor (second) +import Data.List qualified as L +import Data.Map qualified as M +import Data.Maybe (mapMaybe) +import Futhark.IR.GPU +import Futhark.IR.SOACS +import Futhark.MonadFreshNames +import Futhark.Util.Pretty + +newtype ResTag = ResTag Int + deriving (Eq, Ord, Show) + +-- | Something that is mapped. +data DistInput = DistInputFree VName Type | DistInput ResTag Type + deriving (Eq, Ord, Show) + +distInputType :: DistInput -> Type +distInputType (DistInputFree _ t) = t +distInputType (DistInput _ t) = t + +data DistType + = DistType + SubExp + -- ^ Outer regular size. + Rank + -- ^ Irregular dimensions on top (but after the leading regular + -- size). + Type + -- ^ The regular "element type" - in the worst case, at least a + -- scalar. + deriving (Eq, Ord, Show) + +data DistResult = DistResult {distResTag :: ResTag, distResType :: DistType} + deriving (Eq, Ord, Show) + +data DistStm = DistStm + { distStmInputs :: [(VName, DistInput)], + distStmResult :: [DistResult], + distStm :: Stm SOACS + } + deriving (Eq, Ord, Show) + +type ResMap = M.Map ResTag VName + +data Distributed = Distributed [DistStm] ResMap + deriving (Eq, Ord, Show) + +instance Pretty ResTag where + pretty (ResTag x) = "r" <> pretty x + +instance Pretty DistInput where + pretty (DistInputFree v _) = pretty v + pretty (DistInput rt _) = pretty rt + +instance Pretty DistType where + pretty (DistType w r t) = + brackets (pretty w) <> pretty r <> pretty t + +instance Pretty DistResult where + pretty (DistResult rt t) = + pretty rt <> colon <+> pretty t + +instance Pretty DistStm where + pretty (DistStm inputs res stm) = + "let" <+> ppTuple' (map pretty res) <+> "=" indent 2 stm' + where + res' = "return" <+> ppTuple' (map pretty res) + stm' = + "map" + <+> nestedBlock + "{" + "}" + (stack (map onInput inputs ++ [pretty stm, res'])) + onInput (v, inp) = + "for" + <+> parens (pretty v <> colon <+> pretty (distInputType inp)) + <+> "<-" + <+> pretty inp + +instance Pretty Distributed where + pretty (Distributed stms res) = + stms' res' + where + res' = stack $ map onRes $ M.toList res + stms' = stack $ map pretty stms + onRes (rt, v) = "let" <+> pretty v <+> "=" <+> pretty rt + +resultMap :: [DistStm] -> Pat Type -> Result -> M.Map ResTag VName +resultMap stms pat res = mconcat $ map f stms + where + res_map = zip (map resSubExp res) (patNames pat) + f stm = + foldMap g $ zip (distStmResult stm) (patNames (stmPat (distStm stm))) + g (DistResult rt _, v) = + maybe mempty (M.singleton rt) $ lookup (Var v) res_map + +splitIrregDims :: Names -> Type -> (Rank, Type) +splitIrregDims bound_outside (Array pt shape u) = + let (irreg, reg) = second reverse $ span regDim $ reverse $ shapeDims shape + in (Rank $ length irreg, Array pt (Shape reg) u) + where + regDim (Var v) = v `notNameIn` bound_outside + regDim Constant {} = True +splitIrregDims _ t = (mempty, t) + +distributeMap :: Scope SOACS -> Pat Type -> SubExp -> [VName] -> Lambda SOACS -> Distributed +distributeMap outer_scope map_pat w arrs lam = + let param_inputs = + zipWith paramInput (lambdaParams lam) arrs + (_, stms) = + L.mapAccumL distributeStm (ResTag 0, param_inputs) $ + stmsToList $ + bodyStms $ + lambdaBody lam + in Distributed stms $ resultMap stms map_pat (bodyResult (lambdaBody lam)) + where + bound_outside = namesFromList $ M.keys outer_scope + paramInput p arr = (paramName p, DistInputFree arr $ paramType p) + freeInput avail_inputs v = + (v,) <$> lookup v avail_inputs + patInput tag pe = + (patElemName pe, DistInput tag $ patElemType pe) + distType t = + uncurry (DistType w) $ splitIrregDims bound_outside t + distributeStm (ResTag tag, avail_inputs) stm = + let pat = stmPat stm + new_tags = map ResTag $ take (patSize pat) [tag ..] + avail_inputs' = + avail_inputs <> zipWith patInput new_tags (patElems pat) + stm' = + DistStm + (mapMaybe (freeInput avail_inputs) $ namesToList $ freeIn stm) + (zipWith DistResult new_tags $ map distType $ patTypes pat) + stm + in ((ResTag $ tag + length new_tags, avail_inputs'), stm') From f8712577279eb3cebfebaefd8a1b757a69f0ecba Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 3 Nov 2022 18:52:24 +0100 Subject: [PATCH 02/64] More half-baked work. --- src/Futhark/AD/Rev/Scan.hs | 2 +- src/Futhark/Construct.hs | 10 +- src/Futhark/IR/TypeCheck.hs | 6 +- src/Futhark/Pass/ExtractKernels/ToGPU.hs | 4 + src/Futhark/Pass/Flatten.hs | 193 +++++++++++++++--- src/Futhark/Pass/Flatten/Builtins.hs | 160 +++++++++++++-- src/Futhark/Pass/Flatten/Distribute.hs | 34 +-- tests/flattening/CosminArrayExample.fut | 17 -- tests/flattening/HighlyNestedMap.fut | 41 ---- tests/flattening/IntmRes1.fut | 23 --- tests/flattening/IntmRes2.fut | 30 --- tests/flattening/IntmRes3.fut | 36 ---- tests/flattening/LoopInv1.fut | 24 --- tests/flattening/LoopInv2.fut | 26 --- tests/flattening/LoopInv3.fut | 34 --- tests/flattening/LoopInvReshape.fut | 16 -- tests/flattening/Map-IotaMapReduce.fut | 14 -- tests/flattening/Map-Map-IotaMapReduce.fut | 18 -- tests/flattening/MapIotaReduce.fut | 12 -- tests/flattening/MatrixAddition.fut | 16 -- tests/flattening/SimpleReduce.fut | 11 - tests/flattening/VectorAddition.fut | 10 - tests/flattening/binop.fut | 5 + tests/flattening/flattening-pipeline | 2 - tests/flattening/flattening-test | 11 - tests/flattening/iota-index.fut | 10 + tests/flattening/iota-opaque-index.fut | 9 + tests/flattening/iota-opaque-slice-reduce.fut | 5 + tests/flattening/iota-reduce.fut | 7 + tests/flattening/redomap1.fut | 17 -- tests/flattening/redomap2.fut | 13 -- 31 files changed, 377 insertions(+), 439 deletions(-) delete mode 100644 tests/flattening/CosminArrayExample.fut delete mode 100644 tests/flattening/HighlyNestedMap.fut delete mode 100644 tests/flattening/IntmRes1.fut delete mode 100644 tests/flattening/IntmRes2.fut delete mode 100644 tests/flattening/IntmRes3.fut delete mode 100644 tests/flattening/LoopInv1.fut delete mode 100644 tests/flattening/LoopInv2.fut delete mode 100644 tests/flattening/LoopInv3.fut delete mode 100644 tests/flattening/LoopInvReshape.fut delete mode 100644 tests/flattening/Map-IotaMapReduce.fut delete mode 100644 tests/flattening/Map-Map-IotaMapReduce.fut delete mode 100644 tests/flattening/MapIotaReduce.fut delete mode 100644 tests/flattening/MatrixAddition.fut delete mode 100644 tests/flattening/SimpleReduce.fut delete mode 100644 tests/flattening/VectorAddition.fut create mode 100644 tests/flattening/binop.fut delete mode 100755 tests/flattening/flattening-pipeline delete mode 100755 tests/flattening/flattening-test create mode 100644 tests/flattening/iota-index.fut create mode 100644 tests/flattening/iota-opaque-index.fut create mode 100644 tests/flattening/iota-opaque-slice-reduce.fut create mode 100644 tests/flattening/iota-reduce.fut delete mode 100644 tests/flattening/redomap1.fut delete mode 100644 tests/flattening/redomap2.fut diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index 197957d1ea..86c858cabe 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -117,7 +117,7 @@ mkScanFinalMap ops w scan_lam xs ys rs = do ( buildBody_ $ do im1 <- letSubExp "im1" =<< toExp (le64 i - 1) ys_im1 <- forM ys $ \y -> - letSubExp (baseString y <> "_im1") =<< eIndex y (eSubExp im1) + letSubExp (baseString y <> "_im1") =<< eIndex y [eSubExp im1] lam_res <- mapM (letExp "const" . BasicOp . SubExp . resSubExp) diff --git a/src/Futhark/Construct.hs b/src/Futhark/Construct.hs index 2dd5ada11c..335d0b880f 100644 --- a/src/Futhark/Construct.hs +++ b/src/Futhark/Construct.hs @@ -417,11 +417,11 @@ eOutOfBounds arr is = do foldBinOp LogOr (constant False) =<< zipWithM checkDim ws is' -- | The array element at this index. -eIndex :: MonadBuilder m => VName -> m (Exp (Rep m)) -> m (Exp (Rep m)) -eIndex arr i = do - i' <- letSubExp "i" =<< i +eIndex :: MonadBuilder m => VName -> [m (Exp (Rep m))] -> m (Exp (Rep m)) +eIndex arr is = do + is' <- mapM (letSubExp "i" =<<) is arr_t <- lookupType arr - pure $ BasicOp $ Index arr $ fullSlice arr_t [DimFix i'] + pure $ BasicOp $ Index arr $ fullSlice arr_t $ map DimFix is' -- | The last element of the given array. eLast :: MonadBuilder m => VName -> m (Exp (Rep m)) @@ -430,7 +430,7 @@ eLast arr = do nm1 <- letSubExp "nm1" . BasicOp $ BinOp (Sub Int64 OverflowUndef) n (intConst Int64 1) - eIndex arr (eSubExp nm1) + eIndex arr [eSubExp nm1] -- | Construct an unspecified value of the given type. eBlank :: MonadBuilder m => Type -> m (Exp (Rep m)) diff --git a/src/Futhark/IR/TypeCheck.hs b/src/Futhark/IR/TypeCheck.hs index a763f2bacf..60a9eea492 100644 --- a/src/Futhark/IR/TypeCheck.hs +++ b/src/Futhark/IR/TypeCheck.hs @@ -723,7 +723,7 @@ checkSubExp (Var ident) = context ("In subexp " <> prettyText ident) $ do lookupType ident checkCerts :: Checkable rep => Certs -> TypeM rep () -checkCerts (Certs cs) = mapM_ (requireI [Prim Unit]) cs +checkCerts = mapM_ lookupType . unCerts checkSubExpRes :: Checkable rep => SubExpRes -> TypeM rep Type checkSubExpRes (SubExpRes cs se) = do @@ -1276,8 +1276,8 @@ checkStm :: Stm (Aliases rep) -> TypeM rep a -> TypeM rep a -checkStm stm@(Let pat (StmAux (Certs cs) _ (_, dec)) e) m = do - context "When checking certificates" $ mapM_ (requireI [Prim Unit]) cs +checkStm stm@(Let pat (StmAux cs _ (_, dec)) e) m = do + context "When checking certificates" $ checkCerts cs context "When checking expression annotation" $ checkExpDec dec context ("When matching\n" <> message " " pat <> "\nwith\n" <> message " " e) $ matchPat pat e diff --git a/src/Futhark/Pass/ExtractKernels/ToGPU.hs b/src/Futhark/Pass/ExtractKernels/ToGPU.hs index 83ff454643..4dd9b507f8 100644 --- a/src/Futhark/Pass/ExtractKernels/ToGPU.hs +++ b/src/Futhark/Pass/ExtractKernels/ToGPU.hs @@ -5,6 +5,7 @@ module Futhark.Pass.ExtractKernels.ToGPU segThread, soacsLambdaToGPU, soacsStmToGPU, + soacsExpToGPU, scopeForGPU, scopeForSOACs, injectSOACS, @@ -75,6 +76,9 @@ injectSOACS f = soacsStmToGPU :: Stm SOACS -> Stm GPU soacsStmToGPU = runIdentity . rephraseStm (injectSOACS OtherOp) +soacsExpToGPU :: Exp SOACS -> Exp GPU +soacsExpToGPU = runIdentity . rephraseExp (injectSOACS OtherOp) + soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU soacsLambdaToGPU = runIdentity . rephraseLambda (injectSOACS OtherOp) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 79d1d81efd..9ce2627029 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -15,6 +15,7 @@ import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (bimap, first, second) import Data.Foldable +import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe (fromMaybe) import Debug.Trace @@ -23,10 +24,12 @@ import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Pass import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace, segScan) -import Futhark.Pass.ExtractKernels.ToGPU (scopeForGPU, soacsLambdaToGPU, soacsStmToGPU) +import Futhark.Pass.ExtractKernels.ToGPU (scopeForGPU, soacsExpToGPU, soacsLambdaToGPU, soacsStmToGPU) import Futhark.Pass.Flatten.Builtins import Futhark.Pass.Flatten.Distribute import Futhark.Tools +import Futhark.Transform.Rename +import Futhark.Transform.Substitute data FlattenEnv = FlattenEnv @@ -41,7 +44,8 @@ newtype FlattenM a = FlattenM (StateT VNameSource (Reader FlattenEnv) a) ) data IrregularRep = IrregularRep - { irregularSegments :: VName, + { -- | Array of size of each segment, type @[]i64@. + irregularSegments :: VName, irregularFlags :: VName, irregularOffsets :: VName, iregularElems :: VName @@ -56,7 +60,7 @@ data ResRep -- irregular array. Irregular IrregularRep -data DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} +newtype DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} insertRep :: ResTag -> ResRep -> DistEnv -> DistEnv insertRep rt rep env = env {distResMap = M.insert rt rep $ distResMap env} @@ -70,50 +74,189 @@ instance Monoid DistEnv where instance Semigroup DistEnv where DistEnv x <> DistEnv y = DistEnv (x <> y) +resVar :: ResTag -> DistEnv -> ResRep +resVar rt env = fromMaybe bad $ M.lookup rt $ distResMap env + where + bad = error $ "resVar: unknown tag: " ++ show rt + flagsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName), [VName]) flagsAndElems env [] = (Nothing, []) flagsAndElems env (DistInputFree v _ : vs) = second (v :) $ flagsAndElems env vs flagsAndElems env (DistInput rt _ : vs) = - case M.lookup rt $ distResMap env of - Just (Regular v') -> + case resVar rt env of + Regular v' -> second (v' :) $ flagsAndElems env vs - Just (Irregular (IrregularRep _ flags offsets elems)) -> + Irregular (IrregularRep _ flags offsets elems) -> bimap (mplus $ Just (flags, offsets)) (elems :) $ flagsAndElems env vs - _ -> - error "flagsAndElems: nope" -transformDistStm :: DistEnv -> DistStm -> Builder GPU DistEnv -transformDistStm env (DistStm inps res stm) = - case stm of - Let _ _ (BasicOp (Iota (Var n) x s Int64)) +type Segments = NE.NonEmpty SubExp + +segMap1 :: Segments -> ([SubExp] -> Builder GPU Result) -> Builder GPU (Exp GPU) +segMap1 segments f = do + gtids <- replicateM (length segments) (newVName "gtid") + space <- mkSegSpace $ zip gtids $ toList segments + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + res <- f $ map Var gtids + ts <- mapM (subExpType . resSubExp) res + let resToRes (SubExpRes cs se) = Returns ResultMaySimplify cs se + pure (map resToRes res, ts) + let kbody = KernelBody () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space ts kbody + where + mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se + +readInput :: Segments -> DistEnv -> [SubExp] -> [(VName, DistInput)] -> SubExp -> Builder GPU SubExp +readInput _ _ _ _ (Constant x) = pure $ Constant x +readInput segments env is inputs (Var v) = + case lookup v inputs of + Nothing -> pure $ Var v + Just (DistInputFree arr _) -> + letSubExp (baseString v) =<< eIndex arr (map eSubExp is) + Just (DistInput rt _) -> do + case resVar rt env of + Regular arr -> + letSubExp (baseString v) =<< eIndex arr (map eSubExp is) + Irregular (IrregularRep _ flags offsets elems) -> + undefined + +readInputs :: Segments -> DistEnv -> [SubExp] -> [(VName, DistInput)] -> Builder GPU () +readInputs segments env is = mapM_ onInput + where + onInput (v, DistInputFree arr _) = + letBindNames [v] =<< eIndex arr (map eSubExp is) + onInput (v, DistInput rt t) = + case M.lookup rt $ distResMap env of + Just (Regular arr) -> + letBindNames [v] =<< eIndex arr (map eSubExp is) + Just (Irregular (IrregularRep _ flags offsets elems)) -> do + offset <- letSubExp "offset" =<< eIndex offsets (map eSubExp is) + num_elems <- letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + letBindNames [v] $ BasicOp $ Index elems slice + Nothing -> error $ "readInputs: " <> show rt + +transformScalarStms :: + Segments -> + DistEnv -> + [(VName, DistInput)] -> + [DistResult] -> + Stms SOACS -> + [VName] -> + Builder GPU DistEnv +transformScalarStms segments env inps distres stms res = do + vs <- letTupExp "scalar_dist" <=< renameExp <=< segMap1 segments $ \is -> do + readInputs segments env is inps + addStms $ fmap soacsStmToGPU stms + pure $ subExpsRes $ map Var res + pure $ insertReps (zip (map distResTag distres) $ map Regular vs) env + +transformScalarStm :: + Segments -> + DistEnv -> + [(VName, DistInput)] -> + [DistResult] -> + Stm SOACS -> + Builder GPU DistEnv +transformScalarStm segments env inps res stm = + transformScalarStms segments env inps res (oneStm stm) (patNames (stmPat stm)) + +transformDistBasicOp :: + Segments -> + DistEnv -> + ( [(VName, DistInput)], + DistResult, + PatElem Type, + StmAux (), + BasicOp + ) -> + Builder GPU DistEnv +transformDistBasicOp segments env (inps, res, pe, aux, e) = + case e of + BinOp {} -> + scalarCase + CmpOp {} -> + scalarCase + ConvOp {} -> + scalarCase + UnOp {} -> + scalarCase + Assert {} -> + scalarCase + Opaque op se + | Var v <- se, + Just (DistInput rt_in _) <- lookup v inps -> do + -- TODO: actually insert opaques + pure $ insertRep (distResTag res) (resVar rt_in env) env + | otherwise -> + scalarCase + Index arr slice + | null $ sliceDims slice -> + scalarCase + | Just rep <- lookup arr inps -> + case rep of + DistInput arr_rt _ -> + case resVar arr_rt env of + Irregular (IrregularRep arr_ns arr_flags arr_offsets arr_elems) -> do + ns <- letExp "slice_sizes" <=< segMap1 segments $ \is -> do + slice_ns <- mapM (readInput segments env is inps) $ sliceDims slice + fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns + offsets <- doPrefixSum ns + m <- letSubExp "total_elems" =<< eLast offsets + flags <- genFlags m offsets + elems <- letExp "elems" <=< segMap1 (NE.singleton m) $ \is -> + fmap (subExpsRes . pure) . letSubExp "v" + =<< eIndex arr_elems (map eSubExp is) + let rep = Irregular $ IrregularRep ns flags offsets elems + pure $ insertRep (distResTag res) rep env + Iota (Var n) x s Int64 | Just (DistInputFree ns _) <- lookup n inps -> do - let ~[DistResult rt _] = res (flags, offsets, elems) <- doSegIota ns let rep = Irregular $ IrregularRep ns flags offsets elems - pure $ insertRep rt rep env - Let _ _ (Op (Screma w arrs form)) + pure $ insertRep (distResTag res) rep env + _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e + where + scalarCase = + transformScalarStm segments env inps [res] $ + Let (Pat [pe]) aux (BasicOp e) + +transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv +transformDistStm segments env (DistStm inps res stm) = do + case stm of + Let pat aux (BasicOp e) -> do + let ~[res'] = res + ~[pe] = patElems pat + transformDistBasicOp segments env (inps, res', pe, stmAux stm, e) + Let _ _ (Op (Screma _ arrs form)) | Just reds <- isReduceSOAC form, Just arrs' <- mapM (`lookup` inps) arrs, (Just (flags, offsets), elems) <- flagsAndElems env arrs' -> do elems' <- genSegRed flags offsets elems $ singleReduce reds pure $ insertReps (zip (map distResTag res) (map Regular elems')) env - _ -> error $ "Unhandled:\n" ++ prettyString stm + _ -> error $ "Unhandled Stm:\n" ++ prettyString stm + +distResCerts :: DistEnv -> [DistInput] -> Certs +distResCerts env = Certs . map f + where + f (DistInputFree v _) = v + f (DistInput rt _) = case resVar rt env of + Regular v -> v + Irregular {} -> error "resCerts: irregular" -transformDistributed :: Distributed -> Builder GPU () -transformDistributed (Distributed dstms resmap) = do - env <- foldM transformDistStm mempty dstms - forM_ (M.toList resmap) $ \(rt, v) -> do - case M.lookup rt $ distResMap env of - Just (Regular v') -> letBindNames [v] $ BasicOp $ SubExp $ Var v' - Just Irregular {} -> error $ "Result is irregular: " ++ prettyString v - Nothing -> error $ "Missing result binding: " ++ prettyString v +transformDistributed :: Segments -> Distributed -> Builder GPU () +transformDistributed segments (Distributed dstms resmap) = do + env <- foldM (transformDistStm segments) mempty dstms + forM_ (M.toList resmap) $ \(rt, (cs_inps, v)) -> + certifying (distResCerts env cs_inps) $ + case resVar rt env of + Regular v' -> letBindNames [v] $ BasicOp $ SubExp $ Var v' + Irregular {} -> error $ "Result is irregular: " ++ prettyString v transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) transformStm scope (Let pat _ (Op (Screma w arrs form))) | Just lam <- isMapSOAC form = do let distributed = distributeMap scope pat w arrs lam - m = transformDistributed distributed + m = transformDistributed (NE.singleton w) distributed traceM $ prettyString distributed runReaderT (runBuilder_ m) scope transformStm scope stm = pure $ oneStm $ soacsStmToGPU stm diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 95f82a20e0..1b8f756a36 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -2,9 +2,12 @@ module Futhark.Pass.Flatten.Builtins ( flatteningBuiltins, + genFlags, genSegScan, genSegRed, doSegIota, + doPrefixSum, + doRepIota, ) where @@ -13,23 +16,25 @@ import Control.Monad.State.Strict import Data.Foldable import Data.Map qualified as M import Data.Maybe (fromMaybe) -import Debug.Trace +import Data.Text qualified as T import Futhark.IR.GPU import Futhark.IR.SOACS import Futhark.MonadFreshNames -import Futhark.Pass import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace, segScan) import Futhark.Pass.ExtractKernels.ToGPU ( scopeForGPU, soacsLambdaToGPU, soacsStmToGPU, ) -import Futhark.Pass.Flatten.Distribute import Futhark.Tools -segIotaName, segRepName :: Name -segIotaName = "builtin#segiota" -segRepName = "builtin#segrep" +builtinName :: T.Text -> Name +builtinName = nameFromText . ("builtin#" <>) + +segIotaName, repIotaName, prefixSumName :: Name +segIotaName = builtinName "segiota" +repIotaName = builtinName "repiota" +prefixSumName = builtinName "prefixsum" genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] genScan desc w lam nes arrs = do @@ -78,16 +83,15 @@ genSegPrefixSum desc flags ns = do add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 head <$> genSegScan desc add_lam [intConst Int64 0] flags [ns] -genScatter :: VName -> VName -> SubExp -> Builder GPU (Exp GPU) -genScatter dest is v = do - n <- arraySize 0 <$> lookupType is +genScatter :: VName -> SubExp -> (SubExp -> Builder GPU (VName, SubExp)) -> Builder GPU (Exp GPU) +genScatter dest n f = do m <- arraySize 0 <$> lookupType dest gtid <- newVName "gtid" space <- mkSegSpace [(gtid, n)] - v_t <- subExpType v - (res, stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do - i <- letSubExp "i" =<< eIndex is (eSubExp $ Var gtid) - pure $ WriteReturns mempty (Shape [m]) dest [(Slice [DimFix i], v)] + ((res, v_t), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + (i, v) <- f $ Var gtid + v_t <- subExpType v + pure (WriteReturns mempty (Shape [m]) dest [(Slice [DimFix (Var i)], v)], v_t) let kbody = KernelBody () stms [res] pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space [v_t] kbody @@ -107,7 +111,10 @@ genFlags m offsets = do flags_allfalse <- letExp "flags_allfalse" . BasicOp $ Replicate (Shape [m]) (constant False) - letExp "flags" =<< genScatter flags_allfalse offsets (constant True) + n <- arraySize 0 <$> lookupType offsets + letExp "flags" <=< genScatter flags_allfalse n $ \gtid -> do + i <- letExp "i" =<< eIndex offsets [eSubExp gtid] + pure (i, constant True) genSegRed :: VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] genSegRed flags offsets elems red = do @@ -121,16 +128,16 @@ genSegRed flags offsets elems red = do num_segments <- arraySize 0 <$> lookupType offsets letTupExp "segred" <=< genTabulate num_segments $ \i -> do next_start <- - letSubExp "next_start" =<< eIndex offsets (toExp (pe64 i)) + letSubExp "next_start" =<< eIndex offsets [toExp (pe64 i)] this_end <- letSubExp "this_end" =<< toExp (pe64 next_start - 1) - mapM (letSubExp "res" <=< (`eIndex` eSubExp this_end)) scanned + mapM (letSubExp "res" <=< (`eIndex` [eSubExp this_end])) scanned genSegIota :: VName -> Builder GPU (VName, VName, VName) -genSegIota ns = do - n <- arraySize 0 <$> lookupType ns +genSegIota ks = do + n <- arraySize 0 <$> lookupType ks is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) - offsets <- genPrefixSum "offsets" ns + offsets <- genPrefixSum "offsets" ks m <- letSubExp "m" =<< eIf @@ -141,12 +148,45 @@ genSegIota ns = do ones <- letExp "ones" $ BasicOp $ Replicate (Shape [m]) one iotas <- genSegPrefixSum "iotas" flags ones res <- letExp "res" <=< genTabulate m $ \i -> do - x <- letSubExp "x" =<< eIndex iotas (eSubExp i) + x <- letSubExp "x" =<< eIndex iotas [eSubExp i] letTupExp' "xm1" $ BasicOp $ BinOp (Sub Int64 OverflowUndef) x one pure (flags, offsets, res) where one = intConst Int64 1 +genRepIota :: VName -> Builder GPU (VName, VName, VName) +genRepIota ks = do + n <- arraySize 0 <$> lookupType ks + is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) + offsets <- genPrefixSum "offsets" ks + m <- + letSubExp "m" + =<< eIf + (eSubExp is_empty) + (eBody [eSubExp zero]) + (eBody [eLast offsets]) + is <- letExp "is" <=< genTabulate n $ \i -> do + o <- letSubExp "o" =<< eIndex offsets [eSubExp i] + k <- letSubExp "n" =<< eIndex ks [eSubExp i] + letTupExp' "i" + =<< eIf + (toExp (pe64 k .==. 0)) + (eBody [eSubExp negone]) + (eBody [toExp $ pe64 o - pe64 k]) + zeroes <- letExp "zeroes" $ BasicOp $ Replicate (Shape [m]) zero + starts <- + letExp "starts" <=< genScatter zeroes n $ \gtid -> do + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, gtid) + flags <- letExp "flags" <=< genTabulate m $ \i -> do + x <- letSubExp "x" =<< eIndex starts [eSubExp i] + letTupExp' "nonzero" =<< toExp (pe64 x .>. 0) + res <- genSegPrefixSum "res" flags starts + pure (flags, offsets, res) + where + zero = intConst Int64 0 + negone = intConst Int64 (-1) + buildingBuiltin :: Builder GPU (FunDef GPU) -> FunDef GPU buildingBuiltin m = fst $ evalState (runBuilderT m mempty) blankNameSource @@ -174,12 +214,56 @@ segIotaBuiltin = buildingBuiltin $ do funDefBody = body } +repIotaBuiltin :: FunDef GPU +repIotaBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ do + (flags, offsets, res) <- genRepIota (paramName nsp) + m <- arraySize 0 <$> lookupType res + pure $ subExpsRes [m, Var flags, Var offsets, Var res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = repIotaName, + funDefRetType = + [ Prim int64, + Array Bool (Shape [Ext 0]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique, + Array int64 (Shape [Ext 0]) Unique + ], + funDefParams = [np, nsp], + funDefBody = body + } + +prefixSumBuiltin :: FunDef GPU +prefixSumBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + nsp <- newParam "ns" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, nsp]) . buildBody_ $ + varsRes . pure <$> genPrefixSum "res" (paramName nsp) + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = prefixSumName, + funDefRetType = + [ Prim int64, + Array int64 (Shape [Free $ Var $ paramName np]) Nonunique + ], + funDefParams = [np, nsp], + funDefBody = body + } + -- | Builtin functions used in flattening. Must be prepended to a -- program that is transformed by flattening. The intention is to -- avoid the code explosion that would result if we inserted -- primitives everywhere. flatteningBuiltins :: [FunDef GPU] -flatteningBuiltins = [segIotaBuiltin] +flatteningBuiltins = [segIotaBuiltin, segIotaBuiltin, prefixSumBuiltin] -- | Perform a segmented iota. Returns flags,offsets,data. doSegIota :: VName -> Builder GPU (VName, VName, VName) @@ -204,3 +288,37 @@ doSegIota ns = do restype (Safe, mempty, mempty) pure (flags, offsets, elems) + +doRepIota :: VName -> Builder GPU (VName, VName, VName) +doRepIota ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + m <- newVName "m" + flags <- newVName "repiota_flags" + offsets <- newVName "repiota_offsets" + elems <- newVName "repiota_elems" + let args = [(n, Prim int64), (Var ns, ns_t)] + restype = + fromMaybe (error "doRepIota: bad application") $ + applyRetType + (funDefRetType repIotaBuiltin) + (funDefParams repIotaBuiltin) + args + letBindNames [m, flags, offsets, elems] $ + Apply + (funDefName repIotaBuiltin) + [(n, Observe), (Var ns, Observe)] + restype + (Safe, mempty, mempty) + pure (flags, offsets, elems) + +doPrefixSum :: VName -> Builder GPU VName +doPrefixSum ns = do + ns_t <- lookupType ns + let n = arraySize 0 ns_t + letExp "prefix_sum" $ + Apply + (funDefName prefixSumBuiltin) + [(n, Observe), (Var ns, Observe)] + [toDecl (staticShapes1 ns_t) Unique] + (Safe, mempty, mempty) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index f357181cb3..2b7b292263 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -10,15 +10,11 @@ module Futhark.Pass.Flatten.Distribute ) where -import Control.Monad.Reader -import Control.Monad.State import Data.Bifunctor (second) import Data.List qualified as L import Data.Map qualified as M -import Data.Maybe (mapMaybe) -import Futhark.IR.GPU +import Data.Maybe (fromMaybe, mapMaybe) import Futhark.IR.SOACS -import Futhark.MonadFreshNames import Futhark.Util.Pretty newtype ResTag = ResTag Int @@ -54,7 +50,8 @@ data DistStm = DistStm } deriving (Eq, Ord, Show) -type ResMap = M.Map ResTag VName +-- | First element of tuple are certificates for this result. +type ResMap = M.Map ResTag ([DistInput], VName) data Distributed = Distributed [DistStm] ResMap deriving (Eq, Ord, Show) @@ -99,14 +96,18 @@ instance Pretty Distributed where stms' = stack $ map pretty stms onRes (rt, v) = "let" <+> pretty v <+> "=" <+> pretty rt -resultMap :: [DistStm] -> Pat Type -> Result -> M.Map ResTag VName -resultMap stms pat res = mconcat $ map f stms +resultMap :: [(VName, DistInput)] -> [DistStm] -> Pat Type -> Result -> ResMap +resultMap avail_inputs stms pat res = mconcat $ map f stms where - res_map = zip (map resSubExp res) (patNames pat) f stm = foldMap g $ zip (distStmResult stm) (patNames (stmPat (distStm stm))) g (DistResult rt _, v) = - maybe mempty (M.singleton rt) $ lookup (Var v) res_map + maybe mempty (M.singleton rt) $ findRes v + findRes v = do + (SubExpRes cs _, pv) <- + L.find ((Var v ==) . resSubExp . fst) $ zip res $ patNames pat + Just (map findCert $ unCerts cs, pv) + findCert v = fromMaybe (DistInputFree v (Prim Unit)) $ lookup v avail_inputs splitIrregDims :: Names -> Type -> (Rank, Type) splitIrregDims bound_outside (Array pt shape u) = @@ -121,12 +122,12 @@ distributeMap :: Scope SOACS -> Pat Type -> SubExp -> [VName] -> Lambda SOACS -> distributeMap outer_scope map_pat w arrs lam = let param_inputs = zipWith paramInput (lambdaParams lam) arrs - (_, stms) = + ((_, avail_inputs), stms) = L.mapAccumL distributeStm (ResTag 0, param_inputs) $ stmsToList $ bodyStms $ lambdaBody lam - in Distributed stms $ resultMap stms map_pat (bodyResult (lambdaBody lam)) + in Distributed stms $ resultMap avail_inputs stms map_pat $ bodyResult $ lambdaBody lam where bound_outside = namesFromList $ M.keys outer_scope paramInput p arr = (paramName p, DistInputFree arr $ paramType p) @@ -141,9 +142,16 @@ distributeMap outer_scope map_pat w arrs lam = new_tags = map ResTag $ take (patSize pat) [tag ..] avail_inputs' = avail_inputs <> zipWith patInput new_tags (patElems pat) + free_in_stm = freeIn stm + used_free = mapMaybe (freeInput avail_inputs) $ namesToList free_in_stm + used_free_types = + mapMaybe (freeInput avail_inputs) + . namesToList + . foldMap (freeIn . distInputType . snd) + $ used_free stm' = DistStm - (mapMaybe (freeInput avail_inputs) $ namesToList $ freeIn stm) + (used_free_types <> used_free) (zipWith DistResult new_tags $ map distType $ patTypes pat) stm in ((ResTag $ tag + length new_tags, avail_inputs'), stm') diff --git a/tests/flattening/CosminArrayExample.fut b/tests/flattening/CosminArrayExample.fut deleted file mode 100644 index 2d9af63d6d..0000000000 --- a/tests/flattening/CosminArrayExample.fut +++ /dev/null @@ -1,17 +0,0 @@ --- Problem here is that we need will distribute the map --- let arrs = map (\x -> iota(2*x)) xs --- let arr's = map (\x arr -> reshape( (x,2), arr) $ zip xs arrs --- let res = map(\arr' -> reduce(op(+), 0, arr')) arr's --- == --- input { --- [ 1i64, 2i64, 3i64, 4i64] --- } --- output { --- [1i64, 6i64, 15i64, 28i64] --- } -def main (xs: []i64): []i64 = - map (\(x: i64) -> - let arr = #[unsafe] 0..<(2 * x) - let arr' = #[unsafe] unflatten 2 x arr in - reduce (+) 0 (arr'[0]) + reduce (+) 0 (arr'[1]) - ) xs diff --git a/tests/flattening/HighlyNestedMap.fut b/tests/flattening/HighlyNestedMap.fut deleted file mode 100644 index 42ea0087a5..0000000000 --- a/tests/flattening/HighlyNestedMap.fut +++ /dev/null @@ -1,41 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- ] --- [ [ [ [4,5,6] , [1,2,3] ] --- , [ [9,10,11], [6,7,8] ] --- ] --- , [ [ [4,5,6] , [3,2,1] ] --- , [ [11,10,9], [8,7,6] ] --- ] --- ] --- } --- output { --- [[[[5, 7, 9], --- [5, 7, 9]], --- [[15, 17, 19], --- [15, 17, 19]]], --- [[[7, 7, 7], --- [7, 7, 7]], --- [[19, 17, 15], --- [19, 17, 15]]]] --- } -def add1 [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def add2 [n][m] (xs: [n][m]i32, ys: [n][m]i32): [n][m]i32 = - map add1 (zip xs ys) - -def add3 [n][m][l] (xs: [n][m][l]i32, ys: [n][m][l]i32): [n][m][l]i32 = - map add2 (zip xs ys) - -def add4 (xs: [][][][]i32, ys: [][][][]i32): [][][][]i32 = - map add3 (zip xs ys) - -def main (a: [][][][]i32) (b: [][][][]i32): [][][][]i32 = - add4(a,b) diff --git a/tests/flattening/IntmRes1.fut b/tests/flattening/IntmRes1.fut deleted file mode 100644 index 5c90367bc9..0000000000 --- a/tests/flattening/IntmRes1.fut +++ /dev/null @@ -1,23 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[7, 8, 9], --- [16, 17, 18], --- [24, 25, 26], --- [33, 34, 35]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xss: [][]i32) (cs: []i32) (y: i32): [][]i32 = - map (\(xs: []i32, c: i32) -> - let y' = y * c + c - let zs = addToRow(xs,y') in - zs - ) (zip xss cs) diff --git a/tests/flattening/IntmRes2.fut b/tests/flattening/IntmRes2.fut deleted file mode 100644 index 8f4f1bd5cd..0000000000 --- a/tests/flattening/IntmRes2.fut +++ /dev/null @@ -1,30 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[7, 8, 9], --- [10, 11, 12]], --- [[18, 19, 20], --- [21, 22, 23]], --- [[21, 20, 19], --- [22, 23, 24]], --- [[32, 31, 30], --- [35, 34, 33]]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xsss: [][][]i32) (cs: []i32) (y: i32): [][][]i32 = - map (\(xss: [][]i32, c: i32) -> - let y' = y * c + c in - map (\(xs: []i32) -> - addToRow(xs,y') - ) xss - ) (zip xsss cs) diff --git a/tests/flattening/IntmRes3.fut b/tests/flattening/IntmRes3.fut deleted file mode 100644 index 230dbf6405..0000000000 --- a/tests/flattening/IntmRes3.fut +++ /dev/null @@ -1,36 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3,4] --- 5 --- } --- output { --- [[[[7, 8, 9], --- [10, 11, 12]]], --- [[[18, 19, 20], --- [21, 22, 23]]], --- [[[21, 20, 19], --- [22, 23, 24]]], --- [[[32, 31, 30], --- [35, 34, 33]]]] --- } -def addToRow [n] (xs: [n]i32, y: i32): [n]i32 = - map (\(x: i32): i32 -> x+y) xs - -def main (xssss: [][][][]i32) (cs: []i32) (y: i32): [][][][]i32 = - map (\(xsss: [][][]i32, c: i32) -> - let y' = y * c + c in - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addToRow(xs,y') - ) xss - ) xsss - ) (zip xssss cs) diff --git a/tests/flattening/LoopInv1.fut b/tests/flattening/LoopInv1.fut deleted file mode 100644 index c616c16562..0000000000 --- a/tests/flattening/LoopInv1.fut +++ /dev/null @@ -1,24 +0,0 @@ --- == --- input { --- [ [1,2,3], [4,5,6] --- , [6,7,8], [9,10,11] --- , [3,2,1], [4,5,6] --- , [8,7,6], [11,10,9] --- ] --- [1,2,3] --- } --- output { --- [[2, 4, 6], --- [5, 7, 9], --- [7, 9, 11], --- [10, 12, 14], --- [4, 4, 4], --- [5, 7, 9], --- [9, 9, 9], --- [12, 12, 12]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xss: [][]i32) (ys: []i32): [][]i32 = - map (\(xs: []i32) -> addRows(xs,ys)) xss diff --git a/tests/flattening/LoopInv2.fut b/tests/flattening/LoopInv2.fut deleted file mode 100644 index 7af5c15362..0000000000 --- a/tests/flattening/LoopInv2.fut +++ /dev/null @@ -1,26 +0,0 @@ --- == --- input { --- [ [ [1,2,3], [4,5,6] ] --- , [ [6,7,8], [9,10,11] ] --- , [ [3,2,1], [4,5,6] ] --- , [ [8,7,6], [11,10,9] ] --- ] --- [1,2,3] --- } --- output { --- [[[2, 4, 6], --- [5, 7, 9]], --- [[7, 9, 11], --- [10, 12, 14]], --- [[4, 4, 4], --- [5, 7, 9]], --- [[9, 9, 9], --- [12, 12, 12]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xsss: [][][]i32) (ys: []i32): [][][]i32 = - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> addRows(xs,ys)) xss - ) xsss diff --git a/tests/flattening/LoopInv3.fut b/tests/flattening/LoopInv3.fut deleted file mode 100644 index 3cffc4dfe9..0000000000 --- a/tests/flattening/LoopInv3.fut +++ /dev/null @@ -1,34 +0,0 @@ --- == --- input { --- [ [ [ [1,2,3], [4,5,6] ] --- ] --- , [ [ [6,7,8], [9,10,11] ] --- ] --- , [ [ [3,2,1], [4,5,6] ] --- ] --- , [ [ [8,7,6], [11,10,9] ] --- ] --- ] --- [1,2,3] --- } --- output { --- [[[[2, 4, 6], --- [5, 7, 9]]], --- [[[7, 9, 11], --- [10, 12, 14]]], --- [[[4, 4, 4], --- [5, 7, 9]]], --- [[[9, 9, 9], --- [12, 12, 12]]]] --- } -def addRows [n] (xs: [n]i32, ys: [n]i32): [n]i32 = - map2 (+) xs ys - -def main (xssss: [][][][]i32) (ys: []i32): [][][][]i32 = - map (\(xsss: [][][]i32) -> - map (\(xss: [][]i32) -> - map (\(xs: []i32) -> - addRows(xs,ys) - ) xss - ) xsss - ) xssss diff --git a/tests/flattening/LoopInvReshape.fut b/tests/flattening/LoopInvReshape.fut deleted file mode 100644 index 4f365889df..0000000000 --- a/tests/flattening/LoopInvReshape.fut +++ /dev/null @@ -1,16 +0,0 @@ --- This example presents difficulty for me right now, but also has a --- large potential for improvement later on. --- --- we could turn it into: --- --- let []i32 bettermain ([]i32 xs, [#n]i32 ys, [#n]i32 zs, [#n]i32 is, [#n]i32 js) = --- map (\i32 (i32 y, i32 z, i32 i, i32 j) -> --- xs[i*z + j] --- , zip(ys,zs,is,js)) - -def main [n][m] (xs: [m]i32, ys: [n]i64, zs: [n]i64, is: [n]i32, js: [n]i32): []i32 = - map (\(y: i64, z: i64, i: i32, j: i32): i32 -> - #[unsafe] - let tmp = unflatten y z xs - in tmp[i,j] - ) (zip4 ys zs is js) diff --git a/tests/flattening/Map-IotaMapReduce.fut b/tests/flattening/Map-IotaMapReduce.fut deleted file mode 100644 index ea6377f1f8..0000000000 --- a/tests/flattening/Map-IotaMapReduce.fut +++ /dev/null @@ -1,14 +0,0 @@ --- == --- input { --- [2,3,4] --- [8,3,2] --- } --- output { --- [8,9,12] --- } -def main [n] (xs: [n]i32) (ys: [n]i32): []i32 = - map (\(x: i32, y: i32): i32 -> - let tmp1 = 0.. - map (\(x: i32): i32 -> - let tmp1 = map i32.i64(iota(i64.i32 x)) - let tmp2 = map (*y) tmp1 in - reduce (+) 0 tmp2 - ) xs - ) (zip xss ys ) diff --git a/tests/flattening/MapIotaReduce.fut b/tests/flattening/MapIotaReduce.fut deleted file mode 100644 index e6840d5bea..0000000000 --- a/tests/flattening/MapIotaReduce.fut +++ /dev/null @@ -1,12 +0,0 @@ --- == --- input { --- [1,2,3,4] --- } --- output { --- [0, 1, 3, 6] --- } -def main (xs: []i32): []i32 = - map (\(x: i32): i32 -> - let tmp = 0.. - reduce (+) 0 xs - ) xss diff --git a/tests/flattening/VectorAddition.fut b/tests/flattening/VectorAddition.fut deleted file mode 100644 index 0b9445729b..0000000000 --- a/tests/flattening/VectorAddition.fut +++ /dev/null @@ -1,10 +0,0 @@ --- == --- input { --- [1,2,3,4] --- [5,6,7,8] --- } --- output { --- [6,8,10,12] --- } -def main (xs: []i32) (ys: []i32): []i32 = - map2 (+) xs ys diff --git a/tests/flattening/binop.fut b/tests/flattening/binop.fut new file mode 100644 index 0000000000..6496c9f5c3 --- /dev/null +++ b/tests/flattening/binop.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1,2,3] [4,5,6] } +-- output { [5,7,9] } + +def main = map2 (i32.+) diff --git a/tests/flattening/flattening-pipeline b/tests/flattening/flattening-pipeline deleted file mode 100755 index ed91df97eb..0000000000 --- a/tests/flattening/flattening-pipeline +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -futhark -s --flattening -i "$1" diff --git a/tests/flattening/flattening-test b/tests/flattening/flattening-test deleted file mode 100755 index 92bc4de552..0000000000 --- a/tests/flattening/flattening-test +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/sh - -HERE=$(dirname "$0") - -if [ $# -lt 1 ]; then - FILES="$HERE/"*.fut -else - FILES=$* -fi - -futhark-test --only-interpret --interpreter="$HERE/flattening-pipeline" $FILES diff --git a/tests/flattening/iota-index.fut b/tests/flattening/iota-index.fut new file mode 100644 index 0000000000..a21c5f4096 --- /dev/null +++ b/tests/flattening/iota-index.fut @@ -0,0 +1,10 @@ +-- iota is probably simplified away, but certs must be kept. +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (iota n)[i]) diff --git a/tests/flattening/iota-opaque-index.fut b/tests/flattening/iota-opaque-index.fut new file mode 100644 index 0000000000..065c55d294 --- /dev/null +++ b/tests/flattening/iota-opaque-index.fut @@ -0,0 +1,9 @@ +-- == +-- input { [1i64,2i64] [0,1] } +-- output { [0i64,1i64] } +-- input { [1i64,2i64] [0,2] } +-- error: out of bounds +-- input { [1i64,-2i64] [0,1] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i32) -> (opaque (iota n))[i]) diff --git a/tests/flattening/iota-opaque-slice-reduce.fut b/tests/flattening/iota-opaque-slice-reduce.fut new file mode 100644 index 0000000000..f18ebbc06f --- /dev/null +++ b/tests/flattening/iota-opaque-slice-reduce.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1i64,3i64] [0,1] [1,2] } +-- output { [0i64,3i64] } + +def main = map3 (\n i m -> i64.sum (opaque (iota n))[i:i+m]) diff --git a/tests/flattening/iota-reduce.fut b/tests/flattening/iota-reduce.fut new file mode 100644 index 0000000000..ba2d5ea6fa --- /dev/null +++ b/tests/flattening/iota-reduce.fut @@ -0,0 +1,7 @@ +-- == +-- input { [0i64,1i64,2i64] } +-- output { [0i64, 0i64, 1i64] } +-- input { [0i64,1i64,-2i64] } +-- error: Range 0..1..<-2 is invalid + +def main = map (\n -> i64.sum (iota n)) diff --git a/tests/flattening/redomap1.fut b/tests/flattening/redomap1.fut deleted file mode 100644 index 0621ac3b84..0000000000 --- a/tests/flattening/redomap1.fut +++ /dev/null @@ -1,17 +0,0 @@ --- == --- input { --- [[1,2,3],[1,2,3]] --- [[3,2,1],[6,7,8]] --- } --- output { --- [12, 27] --- } -def main [m][n] (xss: [m][n]i32) (yss: [m][n]i32): [m]i32 = - let final_res = - map (\(xs: [n]i32, ys: [n]i32): i32 -> - let tmp = - map (\(x: i32, y: i32): i32 -> x+y - ) (zip xs ys) in - reduce (+) 0 tmp - ) (zip xss yss) - in final_res diff --git a/tests/flattening/redomap2.fut b/tests/flattening/redomap2.fut deleted file mode 100644 index fa96cdb488..0000000000 --- a/tests/flattening/redomap2.fut +++ /dev/null @@ -1,13 +0,0 @@ --- == --- input { --- [1,2,3] --- [6,7,8] --- } --- output { --- 27 --- } -def main [n] (xs: [n]i32) (ys: [n]i32): i32 = - let tmp = - map (\(x: i32, y: i32): i32 -> x+y - ) (zip xs ys) in - reduce (+) 0 tmp From 8c87f3e80e7d4520e8d61795ec1576cf8f1a84e8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 6 Nov 2022 10:24:00 +0100 Subject: [PATCH 03/64] Fix offset calculation. --- src/Futhark/Pass/Flatten/Builtins.hs | 76 ++++++++++++++++++---------- 1 file changed, 49 insertions(+), 27 deletions(-) diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 1b8f756a36..3e42cdc05b 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -20,7 +20,7 @@ import Data.Text qualified as T import Futhark.IR.GPU import Futhark.IR.SOACS import Futhark.MonadFreshNames -import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace, segScan) +import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace) import Futhark.Pass.ExtractKernels.ToGPU ( scopeForGPU, soacsLambdaToGPU, @@ -36,17 +36,35 @@ segIotaName = builtinName "segiota" repIotaName = builtinName "repiota" prefixSumName = builtinName "prefixsum" -genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] -genScan desc w lam nes arrs = do - pat <- fmap Pat $ forM (lambdaReturnType lam) $ \t -> - PatElem <$> newVName desc <*> pure (arrayOf t (Shape [w]) NoUniqueness) - let op = SegBinOp Commutative lam nes mempty - map_lam <- mkIdentityLambda $ lambdaReturnType lam - addStms =<< segScan lvl pat mempty w [op] map_lam arrs [] [] - pure $ patNames pat +genScanomap :: String -> SubExp -> Lambda GPU -> [SubExp] -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU [VName] +genScanomap desc w lam nes m = do + gtid <- newVName "gtid" + space <- mkSegSpace [(gtid, w)] + ((res, res_t), stms) <- runBuilder . localScope (scopeOfSegSpace space) $ do + res <- m $ Var gtid + res_t <- mapM subExpType res + pure (map (Returns ResultMaySimplify mempty) res, res_t) + let kbody = KernelBody () stms res + op = SegBinOp Commutative lam nes mempty + letTupExp desc $ Op $ SegOp $ SegScan lvl space [op] res_t kbody where lvl = SegThread SegNoVirt Nothing +genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genScan desc w lam nes arrs = + genScanomap desc w lam nes $ \gtid -> forM arrs $ \arr -> + letSubExp (baseString arr <> "_elem") =<< eIndex arr [eSubExp gtid] + +-- Also known as a prescan. +genExScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genExScan desc w lam nes arrs = + genScanomap desc w lam nes $ \gtid -> + letTupExp' "to_prescan" + =<< eIf + (toExp $ pe64 gtid .==. 0) + (eBody (map eSubExp nes)) + (eBody (map (`eIndex` [toExp $ pe64 gtid - 1]) arrs)) + segScanLambda :: (MonadBuilder m, BranchType (Rep m) ~ ExtType, LParamInfo (Rep m) ~ Type) => Lambda (Rep m) -> @@ -78,6 +96,12 @@ genPrefixSum desc ns = do add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 head <$> genScan desc w add_lam [intConst Int64 0] [ns] +genExPrefixSum :: String -> VName -> Builder GPU VName +genExPrefixSum desc ns = do + w <- arraySize 0 <$> lookupType ns + add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 + head <$> genExScan desc w add_lam [intConst Int64 0] [ns] + genSegPrefixSum :: String -> VName -> VName -> Builder GPU VName genSegPrefixSum desc flags ns = do add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 @@ -133,17 +157,25 @@ genSegRed flags offsets elems red = do letSubExp "this_end" =<< toExp (pe64 next_start - 1) mapM (letSubExp "res" <=< (`eIndex` [eSubExp this_end])) scanned -genSegIota :: VName -> Builder GPU (VName, VName, VName) -genSegIota ks = do +-- Returns (#segments, segment start offsets, sum of segment sizes) +exScanAndSum :: VName -> Builder GPU (SubExp, VName, SubExp) +exScanAndSum ks = do n <- arraySize 0 <$> lookupType ks is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) - offsets <- genPrefixSum "offsets" ks + offsets <- genExPrefixSum "offsets" ks m <- letSubExp "m" =<< eIf (eSubExp is_empty) (eBody [eSubExp $ intConst Int64 0]) - (eBody [eLast offsets]) + -- Add last size because 'offsets' is an *exclusive* prefix + -- sum. + (eBody [eBinOp (Add Int64 OverflowUndef) (eLast offsets) (eLast ks)]) + pure (n, offsets, m) + +genSegIota :: VName -> Builder GPU (VName, VName, VName) +genSegIota ks = do + (_n, offsets, m) <- exScanAndSum ks flags <- genFlags m offsets ones <- letExp "ones" $ BasicOp $ Replicate (Shape [m]) one iotas <- genSegPrefixSum "iotas" flags ones @@ -156,15 +188,7 @@ genSegIota ks = do genRepIota :: VName -> Builder GPU (VName, VName, VName) genRepIota ks = do - n <- arraySize 0 <$> lookupType ks - is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) - offsets <- genPrefixSum "offsets" ks - m <- - letSubExp "m" - =<< eIf - (eSubExp is_empty) - (eBody [eSubExp zero]) - (eBody [eLast offsets]) + (n, offsets, m) <- exScanAndSum ks is <- letExp "is" <=< genTabulate n $ \i -> do o <- letSubExp "o" =<< eIndex offsets [eSubExp i] k <- letSubExp "n" =<< eIndex ks [eSubExp i] @@ -172,7 +196,7 @@ genRepIota ks = do =<< eIf (toExp (pe64 k .==. 0)) (eBody [eSubExp negone]) - (eBody [toExp $ pe64 o - pe64 k]) + (eBody [toExp $ pe64 o]) zeroes <- letExp "zeroes" $ BasicOp $ Replicate (Shape [m]) zero starts <- letExp "starts" <=< genScatter zeroes n $ \gtid -> do @@ -251,9 +275,7 @@ prefixSumBuiltin = buildingBuiltin $ do funDefAttrs = mempty, funDefName = prefixSumName, funDefRetType = - [ Prim int64, - Array int64 (Shape [Free $ Var $ paramName np]) Nonunique - ], + [Array int64 (Shape [Free $ Var $ paramName np]) Nonunique], funDefParams = [np, nsp], funDefBody = body } @@ -263,7 +285,7 @@ prefixSumBuiltin = buildingBuiltin $ do -- avoid the code explosion that would result if we inserted -- primitives everywhere. flatteningBuiltins :: [FunDef GPU] -flatteningBuiltins = [segIotaBuiltin, segIotaBuiltin, prefixSumBuiltin] +flatteningBuiltins = [segIotaBuiltin, repIotaBuiltin, prefixSumBuiltin] -- | Perform a segmented iota. Returns flags,offsets,data. doSegIota :: VName -> Builder GPU (VName, VName, VName) From eb10d91abc6f5ef0c5b49f81bd83b648b30d09cd Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 14 Nov 2022 00:50:52 +0100 Subject: [PATCH 04/64] Fix some things. --- src/Futhark/Pass/Flatten.hs | 29 ++++++++++--------- src/Futhark/Pass/Flatten/Builtins.hs | 17 ++++++----- tests/flattening/iota-opaque-slice-red.fut | 11 +++++++ tests/flattening/iota-opaque-slice-reduce.fut | 5 ---- .../{iota-reduce.fut => iota-red.fut} | 0 tests/flattening/slice-red.fut | 5 ++++ 6 files changed, 42 insertions(+), 25 deletions(-) create mode 100644 tests/flattening/iota-opaque-slice-red.fut delete mode 100644 tests/flattening/iota-opaque-slice-reduce.fut rename tests/flattening/{iota-reduce.fut => iota-red.fut} (100%) create mode 100644 tests/flattening/slice-red.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 9ce2627029..8039095f1f 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -79,16 +79,16 @@ resVar rt env = fromMaybe bad $ M.lookup rt $ distResMap env where bad = error $ "resVar: unknown tag: " ++ show rt -flagsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName), [VName]) -flagsAndElems env [] = (Nothing, []) -flagsAndElems env (DistInputFree v _ : vs) = - second (v :) $ flagsAndElems env vs -flagsAndElems env (DistInput rt _ : vs) = +segsAndElems :: DistEnv -> [DistInput] -> (Maybe (VName, VName, VName), [VName]) +segsAndElems env [] = (Nothing, []) +segsAndElems env (DistInputFree v _ : vs) = + second (v :) $ segsAndElems env vs +segsAndElems env (DistInput rt _ : vs) = case resVar rt env of Regular v' -> - second (v' :) $ flagsAndElems env vs - Irregular (IrregularRep _ flags offsets elems) -> - bimap (mplus $ Just (flags, offsets)) (elems :) $ flagsAndElems env vs + second (v' :) $ segsAndElems env vs + Irregular (IrregularRep segments flags offsets elems) -> + bimap (mplus $ Just (segments, flags, offsets)) (elems :) $ segsAndElems env vs type Segments = NE.NonEmpty SubExp @@ -201,10 +201,13 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ns <- letExp "slice_sizes" <=< segMap1 segments $ \is -> do slice_ns <- mapM (readInput segments env is inps) $ sliceDims slice fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns - offsets <- doPrefixSum ns - m <- letSubExp "total_elems" =<< eLast offsets + (_n, offsets, m) <- exScanAndSum ns + (_, _, repiota_elems) <- doRepIota ns flags <- genFlags m offsets - elems <- letExp "elems" <=< segMap1 (NE.singleton m) $ \is -> + elems <- letExp "elems" <=< renameExp <=< segMap1 (NE.singleton m) $ \is -> do + readInputs segments env is inps + segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) + is <- traverse (toSubExp "x") $ fixSlice (fmap pe64 slice) [pe64 segment] -- TODO - only works for scalars. fmap (subExpsRes . pure) . letSubExp "v" =<< eIndex arr_elems (map eSubExp is) let rep = Irregular $ IrregularRep ns flags offsets elems @@ -230,8 +233,8 @@ transformDistStm segments env (DistStm inps res stm) = do Let _ _ (Op (Screma _ arrs form)) | Just reds <- isReduceSOAC form, Just arrs' <- mapM (`lookup` inps) arrs, - (Just (flags, offsets), elems) <- flagsAndElems env arrs' -> do - elems' <- genSegRed flags offsets elems $ singleReduce reds + (Just (arr_segments, flags, offsets), elems) <- segsAndElems env arrs' -> do + elems' <- genSegRed arr_segments flags offsets elems $ singleReduce reds pure $ insertReps (zip (map distResTag res) (map Regular elems')) env _ -> error $ "Unhandled Stm:\n" ++ prettyString stm diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 3e42cdc05b..3c457431ff 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -5,6 +5,7 @@ module Futhark.Pass.Flatten.Builtins genFlags, genSegScan, genSegRed, + exScanAndSum, doSegIota, doPrefixSum, doRepIota, @@ -140,8 +141,8 @@ genFlags m offsets = do i <- letExp "i" =<< eIndex offsets [eSubExp gtid] pure (i, constant True) -genSegRed :: VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] -genSegRed flags offsets elems red = do +genSegRed :: VName -> VName -> VName -> [VName] -> Reduce SOACS -> Builder GPU [VName] +genSegRed segments flags offsets elems red = do scanned <- genSegScan "red" @@ -151,11 +152,13 @@ genSegRed flags offsets elems red = do elems num_segments <- arraySize 0 <$> lookupType offsets letTupExp "segred" <=< genTabulate num_segments $ \i -> do - next_start <- - letSubExp "next_start" =<< eIndex offsets [toExp (pe64 i)] - this_end <- - letSubExp "this_end" =<< toExp (pe64 next_start - 1) - mapM (letSubExp "res" <=< (`eIndex` [eSubExp this_end])) scanned + n <- letSubExp "n" =<< eIndex segments [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [toExp (pe64 i)] + letTupExp' "segment_res" <=< eIf (toExp $ pe64 n .==. 0) (eBody $ map eSubExp nes) $ + eBody $ + map (`eIndex` [toExp $ pe64 offset + pe64 n - 1]) scanned + where + nes = redNeutral red -- Returns (#segments, segment start offsets, sum of segment sizes) exScanAndSum :: VName -> Builder GPU (SubExp, VName, SubExp) diff --git a/tests/flattening/iota-opaque-slice-red.fut b/tests/flattening/iota-opaque-slice-red.fut new file mode 100644 index 0000000000..56b1218c9c --- /dev/null +++ b/tests/flattening/iota-opaque-slice-red.fut @@ -0,0 +1,11 @@ +-- == +-- input { [1i64,2i64] [0i64,1i64] } +-- output { [0i64,1i64] } +-- input { [1i64,5i64] [0i64,3i64] } +-- output { [0i64,7i64] } +-- input { [1i64,2i64] [0i64,3i64] } +-- error: out of bounds +-- input { [1i64,-2i64] [0i64,1i64] } +-- error: Range 0..1..<-2 is invalid + +def main = map2 (\n (i:i64) -> i64.sum (opaque (iota n))[i:]) diff --git a/tests/flattening/iota-opaque-slice-reduce.fut b/tests/flattening/iota-opaque-slice-reduce.fut deleted file mode 100644 index f18ebbc06f..0000000000 --- a/tests/flattening/iota-opaque-slice-reduce.fut +++ /dev/null @@ -1,5 +0,0 @@ --- == --- input { [1i64,3i64] [0,1] [1,2] } --- output { [0i64,3i64] } - -def main = map3 (\n i m -> i64.sum (opaque (iota n))[i:i+m]) diff --git a/tests/flattening/iota-reduce.fut b/tests/flattening/iota-red.fut similarity index 100% rename from tests/flattening/iota-reduce.fut rename to tests/flattening/iota-red.fut diff --git a/tests/flattening/slice-red.fut b/tests/flattening/slice-red.fut new file mode 100644 index 0000000000..4362860300 --- /dev/null +++ b/tests/flattening/slice-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[0i64,1i64,5i64],[-2i64,9i64,1i64]] [0i64,1i64] } +-- output { [6i64,10i64] } + +def main = map2 (\A (i:i64) -> i64.sum A[i:]) From d58ba996300c98d6d933f782c084ef6ffe7ee37d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 15 Nov 2022 22:04:17 -0600 Subject: [PATCH 05/64] Use full flattening in GPU pipelines. --- src/Futhark/Passes.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Passes.hs b/src/Futhark/Passes.hs index 1ed5ffb88e..56ae2db936 100644 --- a/src/Futhark/Passes.hs +++ b/src/Futhark/Passes.hs @@ -38,9 +38,9 @@ import Futhark.Pass.ExpandAllocations import Futhark.Pass.ExplicitAllocations.GPU qualified as GPU import Futhark.Pass.ExplicitAllocations.MC qualified as MC import Futhark.Pass.ExplicitAllocations.Seq qualified as Seq -import Futhark.Pass.ExtractKernels import Futhark.Pass.ExtractMulticore import Futhark.Pass.FirstOrderTransform +import Futhark.Pass.Flatten import Futhark.Pass.KernelBabysitting import Futhark.Pass.LiftAllocations as LiftAllocations import Futhark.Pass.LowerAllocations as LowerAllocations @@ -85,7 +85,7 @@ adPipeline = kernelsPipeline :: Pipeline SOACS GPU kernelsPipeline = standardPipeline - >>> onePass extractKernels + >>> onePass flattenSOACs >>> passes [ simplifyGPU, optimiseGenRed, From 138cecae13709766fd375c0d8242f2f4baff3236 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 15 Nov 2022 22:26:42 -0600 Subject: [PATCH 06/64] Irregular slice now works. --- src/Futhark/Pass/Flatten.hs | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 8039095f1f..056014f61a 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -48,7 +48,7 @@ data IrregularRep = IrregularRep irregularSegments :: VName, irregularFlags :: VName, irregularOffsets :: VName, - iregularElems :: VName + irregularElems :: VName } data ResRep @@ -161,6 +161,17 @@ transformScalarStm :: transformScalarStm segments env inps res stm = transformScalarStms segments env inps res (oneStm stm) (patNames (stmPat stm)) +distCerts :: [(VName, DistInput)] -> StmAux a -> DistEnv -> Certs +distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux + where + f v = case lookup v inps of + Nothing -> v + Just (DistInputFree v _) -> v + Just (DistInput rt _) -> + case resVar rt env of + Regular v' -> v' + Irregular r -> irregularElems r + transformDistBasicOp :: Segments -> DistEnv -> @@ -193,8 +204,8 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = Index arr slice | null $ sliceDims slice -> scalarCase - | Just rep <- lookup arr inps -> - case rep of + | Just arr_rep <- lookup arr inps -> + case arr_rep of DistInput arr_rt _ -> case resVar arr_rt env of Irregular (IrregularRep arr_ns arr_flags arr_offsets arr_elems) -> do @@ -205,16 +216,21 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = (_, _, repiota_elems) <- doRepIota ns flags <- genFlags m offsets elems <- letExp "elems" <=< renameExp <=< segMap1 (NE.singleton m) $ \is -> do - readInputs segments env is inps segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) - is <- traverse (toSubExp "x") $ fixSlice (fmap pe64 slice) [pe64 segment] -- TODO - only works for scalars. - fmap (subExpsRes . pure) . letSubExp "v" - =<< eIndex arr_elems (map eSubExp is) + segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] + readInputs segments env [segment] inps + let slice' = + fixSlice (fmap pe64 slice) $ + map (subtract (pe64 segment_start) . pe64) is + -- TODO: multidimensional segments and non-primitive type. + auxing aux $ + fmap (subExpsRes . pure) . letSubExp "v" + =<< eIndex arr (map toExp slice') let rep = Irregular $ IrregularRep ns flags offsets elems pure $ insertRep (distResTag res) rep env Iota (Var n) x s Int64 | Just (DistInputFree ns _) <- lookup n inps -> do - (flags, offsets, elems) <- doSegIota ns + (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns let rep = Irregular $ IrregularRep ns flags offsets elems pure $ insertRep (distResTag res) rep env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e @@ -229,7 +245,7 @@ transformDistStm segments env (DistStm inps res stm) = do Let pat aux (BasicOp e) -> do let ~[res'] = res ~[pe] = patElems pat - transformDistBasicOp segments env (inps, res', pe, stmAux stm, e) + transformDistBasicOp segments env (inps, res', pe, aux, e) Let _ _ (Op (Screma _ arrs form)) | Just reds <- isReduceSOAC form, Just arrs' <- mapM (`lookup` inps) arrs, From 8e946a28a79afadc345ee5d60e413c15263690bd Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 16 Nov 2022 09:03:15 -0600 Subject: [PATCH 07/64] This always works. --- src/Futhark/Pass/Flatten.hs | 45 +++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 056014f61a..6de52fc1cf 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -204,30 +204,27 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = Index arr slice | null $ sliceDims slice -> scalarCase - | Just arr_rep <- lookup arr inps -> - case arr_rep of - DistInput arr_rt _ -> - case resVar arr_rt env of - Irregular (IrregularRep arr_ns arr_flags arr_offsets arr_elems) -> do - ns <- letExp "slice_sizes" <=< segMap1 segments $ \is -> do - slice_ns <- mapM (readInput segments env is inps) $ sliceDims slice - fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns - (_n, offsets, m) <- exScanAndSum ns - (_, _, repiota_elems) <- doRepIota ns - flags <- genFlags m offsets - elems <- letExp "elems" <=< renameExp <=< segMap1 (NE.singleton m) $ \is -> do - segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) - segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] - readInputs segments env [segment] inps - let slice' = - fixSlice (fmap pe64 slice) $ - map (subtract (pe64 segment_start) . pe64) is - -- TODO: multidimensional segments and non-primitive type. - auxing aux $ - fmap (subExpsRes . pure) . letSubExp "v" - =<< eIndex arr (map toExp slice') - let rep = Irregular $ IrregularRep ns flags offsets elems - pure $ insertRep (distResTag res) rep env + | otherwise -> do + -- Maximally irregular case. + ns <- letExp "slice_sizes" <=< segMap1 segments $ \is -> do + slice_ns <- mapM (readInput segments env is inps) $ sliceDims slice + fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns + (_n, offsets, m) <- exScanAndSum ns + (_, _, repiota_elems) <- doRepIota ns + flags <- genFlags m offsets + elems <- letExp "elems" <=< renameExp <=< segMap1 (NE.singleton m) $ \is -> do + segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) + segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] + readInputs segments env [segment] inps + let slice' = + fixSlice (fmap pe64 slice) $ + map (subtract (pe64 segment_start) . pe64) is + -- TODO: multidimensional segments and non-primitive type. + auxing aux $ + fmap (subExpsRes . pure) . letSubExp "v" + =<< eIndex arr (map toExp slice') + let rep = Irregular $ IrregularRep ns flags offsets elems + pure $ insertRep (distResTag res) rep env Iota (Var n) x s Int64 | Just (DistInputFree ns _) <- lookup n inps -> do (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns From fb21c050e70bd2feba8cf88b3eac895aad7d1962 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 16 Nov 2022 10:21:55 -0600 Subject: [PATCH 08/64] Reshape and further Index fixes. --- src/Futhark/Pass/Flatten.hs | 11 ++++++++--- src/Futhark/Pass/Flatten/Distribute.hs | 8 +++++++- tests/flattening/slice2d-red.fut | 5 +++++ 3 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 tests/flattening/slice2d-red.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 6de52fc1cf..dde982b601 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -196,11 +196,14 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = scalarCase Opaque op se | Var v <- se, - Just (DistInput rt_in _) <- lookup v inps -> do + Just (DistInput rt_in _) <- lookup v inps -> -- TODO: actually insert opaques pure $ insertRep (distResTag res) (resVar rt_in env) env | otherwise -> scalarCase + Reshape kind newshape arr + | Just (DistInput rt_in _) <- lookup arr inps -> + pure $ insertRep (distResTag res) (resVar rt_in env) env Index arr slice | null $ sliceDims slice -> scalarCase @@ -216,10 +219,12 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] readInputs segments env [segment] inps + -- TODO: multidimensional segments let slice' = fixSlice (fmap pe64 slice) $ - map (subtract (pe64 segment_start) . pe64) is - -- TODO: multidimensional segments and non-primitive type. + unflattenIndex (map pe64 (sliceDims slice)) $ + subtract (pe64 segment_start) . pe64 $ + head is auxing aux $ fmap (subExpsRes . pure) . letSubExp "v" =<< eIndex arr (map toExp slice') diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 2b7b292263..1d2483180c 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -21,7 +21,13 @@ newtype ResTag = ResTag Int deriving (Eq, Ord, Show) -- | Something that is mapped. -data DistInput = DistInputFree VName Type | DistInput ResTag Type +data DistInput + = -- | A value bound outside the original map nest. By necessity + -- regular. + DistInputFree VName Type + | -- | A value constructed inside the original map nest. May be + -- irregular. + DistInput ResTag Type deriving (Eq, Ord, Show) distInputType :: DistInput -> Type diff --git a/tests/flattening/slice2d-red.fut b/tests/flattening/slice2d-red.fut new file mode 100644 index 0000000000..ad2bc650e5 --- /dev/null +++ b/tests/flattening/slice2d-red.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[[0i64,1i64],[4i64,5i64]],[[-2i64,9i64],[9i64,2i64]]] [0i64,1i64] [1i64,0i64] } +-- output { [6i64,11i64] } + +def main = map3 (\A (i:i64) (j: i64) -> i64.sum (flatten A[i:,j:])) From 35669f999a95987345f14a9fed2235d36d5b12f5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 16 Nov 2022 11:29:45 -0600 Subject: [PATCH 09/64] Handle the most general case of Iota. --- src/Futhark/Pass/Flatten.hs | 36 +++++++++++++++++++++++++-- tests/flattening/range-opaque-red.fut | 7 ++++++ 2 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 tests/flattening/range-opaque-red.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index dde982b601..28923f7ae5 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -92,6 +92,9 @@ segsAndElems env (DistInput rt _ : vs) = type Segments = NE.NonEmpty SubExp +segmentsShape :: Segments -> Shape +segmentsShape = Shape . toList + segMap1 :: Segments -> ([SubExp] -> Builder GPU Result) -> Builder GPU (Exp GPU) segMap1 segments f = do gtids <- replicateM (length segments) (newVName "gtid") @@ -172,6 +175,18 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux Regular v' -> v' Irregular r -> irregularElems r +-- | Only sensible for variables of segment-invariant type. +elemArr :: Segments -> DistEnv -> [(VName, DistInput)] -> SubExp -> Builder GPU VName +elemArr segments env inps (Var v) + | Just v_inp <- lookup v inps = + pure $ case v_inp of + DistInputFree ns _ -> ns + DistInput rt _ -> case resVar rt env of + Irregular r -> irregularElems r + Regular vs -> vs +elemArr segments env inps se = + letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se + transformDistBasicOp :: Segments -> DistEnv -> @@ -230,11 +245,28 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = =<< eIndex arr (map toExp slice') let rep = Irregular $ IrregularRep ns flags offsets elems pure $ insertRep (distResTag res) rep env - Iota (Var n) x s Int64 - | Just (DistInputFree ns _) <- lookup n inps -> do + Iota n (Constant x) (Constant s) Int64 + | zeroIsh x, + oneIsh s -> do + ns <- elemArr segments env inps n (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns let rep = Irregular $ IrregularRep ns flags offsets elems pure $ insertRep (distResTag res) rep env + Iota n x s Int64 -> do + ns <- elemArr segments env inps n + xs <- elemArr segments env inps x + ss <- elemArr segments env inps s + (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns + (_, _, repiota_elems) <- doRepIota ns + m <- arraySize 0 <$> lookupType elems + elems' <- letExp "elems_fixed" <=< segMap1 (NE.singleton m) $ \is -> do + segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) + v' <- letSubExp "v" =<< eIndex elems (map eSubExp is) + x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] + s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] + fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') + let rep = Irregular $ IrregularRep ns flags offsets elems' + pure $ insertRep (distResTag res) rep env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = diff --git a/tests/flattening/range-opaque-red.fut b/tests/flattening/range-opaque-red.fut new file mode 100644 index 0000000000..a2ca853795 --- /dev/null +++ b/tests/flattening/range-opaque-red.fut @@ -0,0 +1,7 @@ +-- == +-- input { [1i64,2i64] [3i64,3i64] [10i64,8i64] } +-- output { [25i64, 27i64] } +-- input { [1i64,2i64] [3i64,2i64] [10i64,-8i64] } +-- error: Range 2..2..<-8 is invalid + +def main = map3 (\a b c -> i64.sum (opaque (a..b.. Date: Wed, 16 Nov 2022 11:32:32 -0600 Subject: [PATCH 10/64] Remove some warnings. --- src/Futhark/Pass/Flatten.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 28923f7ae5..e90efeb0a5 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -132,7 +132,7 @@ readInputs segments env is = mapM_ onInput case M.lookup rt $ distResMap env of Just (Regular arr) -> letBindNames [v] =<< eIndex arr (map eSubExp is) - Just (Irregular (IrregularRep _ flags offsets elems)) -> do + Just (Irregular (IrregularRep _ _ offsets elems)) -> do offset <- letSubExp "offset" =<< eIndex offsets (map eSubExp is) num_elems <- letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] @@ -169,22 +169,22 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux where f v = case lookup v inps of Nothing -> v - Just (DistInputFree v _) -> v + Just (DistInputFree vs _) -> vs Just (DistInput rt _) -> case resVar rt env of - Regular v' -> v' + Regular vs -> vs Irregular r -> irregularElems r -- | Only sensible for variables of segment-invariant type. elemArr :: Segments -> DistEnv -> [(VName, DistInput)] -> SubExp -> Builder GPU VName -elemArr segments env inps (Var v) +elemArr _ env inps (Var v) | Just v_inp <- lookup v inps = pure $ case v_inp of DistInputFree ns _ -> ns DistInput rt _ -> case resVar rt env of Irregular r -> irregularElems r Regular vs -> vs -elemArr segments env inps se = +elemArr segments _ _ se = letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se transformDistBasicOp :: @@ -216,7 +216,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = pure $ insertRep (distResTag res) (resVar rt_in env) env | otherwise -> scalarCase - Reshape kind newshape arr + Reshape _ _ arr | Just (DistInput rt_in _) <- lookup arr inps -> pure $ insertRep (distResTag res) (resVar rt_in env) env Index arr slice From 1682d88d8980886461059d37355d7e7308181afc Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 16 Nov 2022 11:46:52 -0600 Subject: [PATCH 11/64] More foldable. --- futhark.cabal | 1 + src/Futhark/Pass/Flatten.hs | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/futhark.cabal b/futhark.cabal index 607b60160c..75c4acacdb 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -492,6 +492,7 @@ library , lsp >= 1.5.0 , mainland-pretty >=0.7.1 , cmark-gfm >=0.2.1 + , OneTuple , megaparsec >=9.0.0 , mtl >=2.2.1 , neat-interpolation >=0.3 diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index e90efeb0a5..2b68574a47 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -18,6 +18,7 @@ import Data.Foldable import Data.List.NonEmpty qualified as NE import Data.Map qualified as M import Data.Maybe (fromMaybe) +import Data.Tuple.Solo import Debug.Trace import Futhark.IR.GPU import Futhark.IR.SOACS @@ -95,15 +96,14 @@ type Segments = NE.NonEmpty SubExp segmentsShape :: Segments -> Shape segmentsShape = Shape . toList -segMap1 :: Segments -> ([SubExp] -> Builder GPU Result) -> Builder GPU (Exp GPU) -segMap1 segments f = do - gtids <- replicateM (length segments) (newVName "gtid") - space <- mkSegSpace $ zip gtids $ toList segments +segMap :: Traversable f => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) +segMap segments f = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do - res <- f $ map Var gtids + res <- f $ fmap Var gtids ts <- mapM (subExpType . resSubExp) res - let resToRes (SubExpRes cs se) = Returns ResultMaySimplify cs se - pure (map resToRes res, ts) + pure (map mkResult res, ts) let kbody = KernelBody () stms res pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space ts kbody where @@ -148,8 +148,8 @@ transformScalarStms :: [VName] -> Builder GPU DistEnv transformScalarStms segments env inps distres stms res = do - vs <- letTupExp "scalar_dist" <=< renameExp <=< segMap1 segments $ \is -> do - readInputs segments env is inps + vs <- letTupExp "scalar_dist" <=< renameExp <=< segMap segments $ \is -> do + readInputs segments env (toList is) inps addStms $ fmap soacsStmToGPU stms pure $ subExpsRes $ map Var res pure $ insertReps (zip (map distResTag distres) $ map Regular vs) env @@ -224,14 +224,14 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = scalarCase | otherwise -> do -- Maximally irregular case. - ns <- letExp "slice_sizes" <=< segMap1 segments $ \is -> do - slice_ns <- mapM (readInput segments env is inps) $ sliceDims slice + ns <- letExp "slice_sizes" <=< segMap segments $ \is -> do + slice_ns <- mapM (readInput segments env (toList is) inps) $ sliceDims slice fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns (_n, offsets, m) <- exScanAndSum ns (_, _, repiota_elems) <- doRepIota ns flags <- genFlags m offsets - elems <- letExp "elems" <=< renameExp <=< segMap1 (NE.singleton m) $ \is -> do - segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) + elems <- letExp "elems" <=< renameExp <=< segMap (NE.singleton m) $ \is -> do + segment <- letSubExp "segment" =<< eIndex repiota_elems (toList $ fmap eSubExp is) segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] readInputs segments env [segment] inps -- TODO: multidimensional segments @@ -239,7 +239,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = fixSlice (fmap pe64 slice) $ unflattenIndex (map pe64 (sliceDims slice)) $ subtract (pe64 segment_start) . pe64 $ - head is + NE.head is auxing aux $ fmap (subExpsRes . pure) . letSubExp "v" =<< eIndex arr (map toExp slice') @@ -259,9 +259,9 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns (_, _, repiota_elems) <- doRepIota ns m <- arraySize 0 <$> lookupType elems - elems' <- letExp "elems_fixed" <=< segMap1 (NE.singleton m) $ \is -> do - segment <- letSubExp "segment" =<< eIndex repiota_elems (map eSubExp is) - v' <- letSubExp "v" =<< eIndex elems (map eSubExp is) + elems' <- letExp "elems_fixed" <=< segMap (Solo m) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex repiota_elems [eSubExp i] + v' <- letSubExp "v" =<< eIndex elems [eSubExp i] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') From ef35b71c79c2d6cb3e881d2974decec595766a0b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 17 Nov 2022 16:33:28 -0600 Subject: [PATCH 12/64] Hacky initial support for flattening nested maps. --- src/Futhark/Pass/Flatten.hs | 46 +++++++++++++++++++++----- src/Futhark/Pass/Flatten/Builtins.hs | 1 + src/Futhark/Pass/Flatten/Distribute.hs | 22 +++++++----- tests/flattening/map-nested.fut | 5 +++ tests/flattening/map-slice-nested.fut | 5 +++ 5 files changed, 62 insertions(+), 17 deletions(-) create mode 100644 tests/flattening/map-nested.fut create mode 100644 tests/flattening/map-slice-nested.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 2b68574a47..1504e3e680 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -109,7 +109,7 @@ segMap segments f = do where mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se -readInput :: Segments -> DistEnv -> [SubExp] -> [(VName, DistInput)] -> SubExp -> Builder GPU SubExp +readInput :: Segments -> DistEnv -> [SubExp] -> DistInputs -> SubExp -> Builder GPU SubExp readInput _ _ _ _ (Constant x) = pure $ Constant x readInput segments env is inputs (Var v) = case lookup v inputs of @@ -123,7 +123,7 @@ readInput segments env is inputs (Var v) = Irregular (IrregularRep _ flags offsets elems) -> undefined -readInputs :: Segments -> DistEnv -> [SubExp] -> [(VName, DistInput)] -> Builder GPU () +readInputs :: Segments -> DistEnv -> [SubExp] -> DistInputs -> Builder GPU () readInputs segments env is = mapM_ onInput where onInput (v, DistInputFree arr _) = @@ -142,7 +142,7 @@ readInputs segments env is = mapM_ onInput transformScalarStms :: Segments -> DistEnv -> - [(VName, DistInput)] -> + DistInputs -> [DistResult] -> Stms SOACS -> [VName] -> @@ -157,14 +157,14 @@ transformScalarStms segments env inps distres stms res = do transformScalarStm :: Segments -> DistEnv -> - [(VName, DistInput)] -> + DistInputs -> [DistResult] -> Stm SOACS -> Builder GPU DistEnv transformScalarStm segments env inps res stm = transformScalarStms segments env inps res (oneStm stm) (patNames (stmPat stm)) -distCerts :: [(VName, DistInput)] -> StmAux a -> DistEnv -> Certs +distCerts :: DistInputs -> StmAux a -> DistEnv -> Certs distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux where f v = case lookup v inps of @@ -176,7 +176,7 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux Irregular r -> irregularElems r -- | Only sensible for variables of segment-invariant type. -elemArr :: Segments -> DistEnv -> [(VName, DistInput)] -> SubExp -> Builder GPU VName +elemArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName elemArr _ env inps (Var v) | Just v_inp <- lookup v inps = pure $ case v_inp of @@ -190,7 +190,7 @@ elemArr segments _ _ se = transformDistBasicOp :: Segments -> DistEnv -> - ( [(VName, DistInput)], + ( DistInputs, DistResult, PatElem Type, StmAux (), @@ -273,6 +273,12 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = transformScalarStm segments env inps [res] $ Let (Pat [pe]) aux (BasicOp e) +repPerSegment :: SubExp -> VName -> [VName] -> Builder GPU [VName] +repPerSegment w segments_per_elem vs = + letTupExp "replicated" <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes <$> mapM (letSubExp "v" <=< flip eIndex [eSubExp segment]) vs + transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv transformDistStm segments env (DistStm inps res stm) = do case stm of @@ -280,12 +286,34 @@ transformDistStm segments env (DistStm inps res stm) = do let ~[res'] = res ~[pe] = patElems pat transformDistBasicOp segments env (inps, res', pe, aux, e) - Let _ _ (Op (Screma _ arrs form)) + Let pat _ (Op (Screma w arrs form)) | Just reds <- isReduceSOAC form, Just arrs' <- mapM (`lookup` inps) arrs, (Just (arr_segments, flags, offsets), elems) <- segsAndElems env arrs' -> do elems' <- genSegRed arr_segments flags offsets elems $ singleReduce reds pure $ insertReps (zip (map distResTag res) (map Regular elems')) env + | Just map_lam <- isMapSOAC form -> do + arrs' <- mapM (elemArr segments env inps . Var) arrs + ws <- elemArr segments env inps w + (ws_flags, ws_offsets, ws_elems) <- doRepIota ws + new_segment <- arraySize 0 <$> lookupType ws_elems + let free_in_map = namesToList $ freeIn map_lam + replicated <- + repPerSegment new_segment ws_elems + =<< mapM (elemArr segments env inps . Var) free_in_map + free_ps <- + mapM (newParam "free_p" . rowType <=< lookupType) replicated + scope <- askScope + let substs = M.fromList $ zip free_in_map $ map paramName free_ps + map_lam' = + (substituteNames substs map_lam) + { lambdaParams = lambdaParams map_lam <> free_ps + } + distributed = distributeMap scope pat new_segment (arrs' <> replicated) map_lam' + m = transformDistributed (NE.singleton new_segment) distributed + addStms =<< runReaderT (runBuilder_ m) scope + let mkRep = Irregular . IrregularRep ws ws_flags ws_offsets + pure $ insertReps (zip (map distResTag res) (map mkRep (patNames pat))) env _ -> error $ "Unhandled Stm:\n" ++ prettyString stm distResCerts :: DistEnv -> [DistInput] -> Certs @@ -312,7 +340,7 @@ transformStm scope (Let pat _ (Op (Screma w arrs form))) m = transformDistributed (NE.singleton w) distributed traceM $ prettyString distributed runReaderT (runBuilder_ m) scope -transformStm scope stm = pure $ oneStm $ soacsStmToGPU stm +transformStm _ stm = pure $ oneStm $ soacsStmToGPU stm transformStms :: Scope SOACS -> Stms SOACS -> PassM (Stms GPU) transformStms scope stms = diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 3c457431ff..9cdd5aa59a 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -314,6 +314,7 @@ doSegIota ns = do (Safe, mempty, mempty) pure (flags, offsets, elems) +-- | Returns @(flags, offsets, elems)@. doRepIota :: VName -> Builder GPU (VName, VName, VName) doRepIota ns = do ns_t <- lookupType ns diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 1d2483180c..63fb3987b7 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -4,6 +4,7 @@ module Futhark.Pass.Flatten.Distribute Distributed (..), DistStm (..), DistInput (..), + DistInputs, DistType (..), DistResult (..), ResTag, @@ -30,6 +31,8 @@ data DistInput DistInput ResTag Type deriving (Eq, Ord, Show) +type DistInputs = [(VName, DistInput)] + distInputType :: DistInput -> Type distInputType (DistInputFree _ t) = t distInputType (DistInput _ t) = t @@ -50,7 +53,7 @@ data DistResult = DistResult {distResTag :: ResTag, distResType :: DistType} deriving (Eq, Ord, Show) data DistStm = DistStm - { distStmInputs :: [(VName, DistInput)], + { distStmInputs :: DistInputs, distStmResult :: [DistResult], distStm :: Stm SOACS } @@ -124,7 +127,15 @@ splitIrregDims bound_outside (Array pt shape u) = regDim Constant {} = True splitIrregDims _ t = (mempty, t) -distributeMap :: Scope SOACS -> Pat Type -> SubExp -> [VName] -> Lambda SOACS -> Distributed +freeInput :: [(VName, DistInput)] -> VName -> Maybe (VName, DistInput) +freeInput avail_inputs v = + (v,) <$> lookup v avail_inputs + +patInput :: ResTag -> PatElem Type -> (VName, DistInput) +patInput tag pe = + (patElemName pe, DistInput tag $ patElemType pe) + +distributeMap :: Scope rep -> Pat Type -> SubExp -> [VName] -> Lambda SOACS -> Distributed distributeMap outer_scope map_pat w arrs lam = let param_inputs = zipWith paramInput (lambdaParams lam) arrs @@ -137,12 +148,7 @@ distributeMap outer_scope map_pat w arrs lam = where bound_outside = namesFromList $ M.keys outer_scope paramInput p arr = (paramName p, DistInputFree arr $ paramType p) - freeInput avail_inputs v = - (v,) <$> lookup v avail_inputs - patInput tag pe = - (patElemName pe, DistInput tag $ patElemType pe) - distType t = - uncurry (DistType w) $ splitIrregDims bound_outside t + distType t = uncurry (DistType w) $ splitIrregDims bound_outside t distributeStm (ResTag tag, avail_inputs) stm = let pat = stmPat stm new_tags = map ResTag $ take (patSize pat) [tag ..] diff --git a/tests/flattening/map-nested.fut b/tests/flattening/map-nested.fut new file mode 100644 index 0000000000..b8e871cd93 --- /dev/null +++ b/tests/flattening/map-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [5i64,7i64] } +-- output { [20i64, 35i64] } + +def main = map (\n -> i64.sum (opaque (map (+2) (iota n)))) diff --git a/tests/flattening/map-slice-nested.fut b/tests/flattening/map-slice-nested.fut new file mode 100644 index 0000000000..6ec4f9ccfc --- /dev/null +++ b/tests/flattening/map-slice-nested.fut @@ -0,0 +1,5 @@ +-- == +-- input { [1i64,2i64,3i64,4i64,5i64] [-5i64,7i64] [2i64,3i64] [3i64,4i64] } +-- output { [-2i64, 11i64] } + +def main A = map3 (\x i j -> i64.sum (opaque (map (+x) A[i:j]))) From d2d7c7f1efcea2d2e175883dc18f7f757eb23372 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 18 Nov 2022 07:01:17 -0600 Subject: [PATCH 13/64] Flatten redomaps. --- src/Futhark/Pass/Flatten.hs | 92 +++++++++++++++++++-------- tests/flattening/map-nested.fut | 2 +- tests/flattening/map-slice-nested.fut | 2 +- tests/flattening/mapout.fut | 11 ++++ 4 files changed, 78 insertions(+), 29 deletions(-) create mode 100644 tests/flattening/mapout.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 1504e3e680..e162dd7fa0 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -69,6 +69,21 @@ insertRep rt rep env = env {distResMap = M.insert rt rep $ distResMap env} insertReps :: [(ResTag, ResRep)] -> DistEnv -> DistEnv insertReps = flip $ foldl (flip $ uncurry insertRep) +insertIrregular :: VName -> VName -> VName -> ResTag -> VName -> DistEnv -> DistEnv +insertIrregular ns flags offsets rt elems env = + let rep = Irregular $ IrregularRep ns flags offsets elems + in insertRep rt rep env + +insertIrregulars :: VName -> VName -> VName -> [(ResTag, VName)] -> DistEnv -> DistEnv +insertIrregulars ns flags offsets bnds env = + let (tags, elems) = unzip bnds + mkRep = Irregular . IrregularRep ns flags offsets + in insertReps (zip tags $ map mkRep elems) env + +insertRegulars :: [ResTag] -> [VName] -> DistEnv -> DistEnv +insertRegulars rts xs = + insertReps (zip rts $ map Regular xs) + instance Monoid DistEnv where mempty = DistEnv mempty @@ -243,15 +258,13 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = auxing aux $ fmap (subExpsRes . pure) . letSubExp "v" =<< eIndex arr (map toExp slice') - let rep = Irregular $ IrregularRep ns flags offsets elems - pure $ insertRep (distResTag res) rep env + pure $ insertIrregular ns flags offsets (distResTag res) elems env Iota n (Constant x) (Constant s) Int64 | zeroIsh x, oneIsh s -> do ns <- elemArr segments env inps n (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns - let rep = Irregular $ IrregularRep ns flags offsets elems - pure $ insertRep (distResTag res) rep env + pure $ insertIrregular ns flags offsets (distResTag res) elems env Iota n x s Int64 -> do ns <- elemArr segments env inps n xs <- elemArr segments env inps x @@ -265,8 +278,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') - let rep = Irregular $ IrregularRep ns flags offsets elems' - pure $ insertRep (distResTag res) rep env + pure $ insertIrregular ns flags offsets (distResTag res) elems' env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = @@ -279,6 +291,37 @@ repPerSegment w segments_per_elem vs = segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] subExpsRes <$> mapM (letSubExp "v" <=< flip eIndex [eSubExp segment]) vs +transformMap :: + Segments -> + DistEnv -> + DistInputs -> + Pat Type -> + SubExp -> + [VName] -> + Lambda SOACS -> + Builder GPU (VName, VName, VName) +transformMap segments env inps pat w arrs map_lam = do + arrs' <- mapM (elemArr segments env inps . Var) arrs + ws <- elemArr segments env inps w + (ws_flags, ws_offsets, ws_elems) <- doRepIota ws + new_segment <- arraySize 0 <$> lookupType ws_elems + let free_in_map = namesToList $ freeIn map_lam + replicated <- + repPerSegment new_segment ws_elems + =<< mapM (elemArr segments env inps . Var) free_in_map + free_ps <- + mapM (newParam "free_p" . rowType <=< lookupType) replicated + scope <- askScope + let substs = M.fromList $ zip free_in_map $ map paramName free_ps + map_lam' = + (substituteNames substs map_lam) + { lambdaParams = lambdaParams map_lam <> free_ps + } + distributed = distributeMap scope pat new_segment (arrs' <> replicated) map_lam' + m = transformDistributed (NE.singleton new_segment) distributed + addStms =<< runReaderT (runBuilder_ m) scope + pure (ws_flags, ws_offsets, ws) + transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv transformDistStm segments env (DistStm inps res stm) = do case stm of @@ -292,28 +335,23 @@ transformDistStm segments env (DistStm inps res stm) = do (Just (arr_segments, flags, offsets), elems) <- segsAndElems env arrs' -> do elems' <- genSegRed arr_segments flags offsets elems $ singleReduce reds pure $ insertReps (zip (map distResTag res) (map Regular elems')) env + | Just (reds, map_lam) <- isRedomapSOAC form -> do + map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> + PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) + (ws_flags, ws_offsets, ws) <- + transformMap segments env inps map_pat w arrs map_lam + let (redout_names, mapout_names) = + splitAt (redResults reds) (patNames map_pat) + elems' <- + genSegRed ws ws_flags ws_offsets redout_names $ + singleReduce reds + let (red_tags, map_tags) = splitAt (redResults reds) $ map distResTag res + pure $ + insertRegulars red_tags elems' $ + insertIrregulars ws ws_flags ws_offsets (zip map_tags mapout_names) env | Just map_lam <- isMapSOAC form -> do - arrs' <- mapM (elemArr segments env inps . Var) arrs - ws <- elemArr segments env inps w - (ws_flags, ws_offsets, ws_elems) <- doRepIota ws - new_segment <- arraySize 0 <$> lookupType ws_elems - let free_in_map = namesToList $ freeIn map_lam - replicated <- - repPerSegment new_segment ws_elems - =<< mapM (elemArr segments env inps . Var) free_in_map - free_ps <- - mapM (newParam "free_p" . rowType <=< lookupType) replicated - scope <- askScope - let substs = M.fromList $ zip free_in_map $ map paramName free_ps - map_lam' = - (substituteNames substs map_lam) - { lambdaParams = lambdaParams map_lam <> free_ps - } - distributed = distributeMap scope pat new_segment (arrs' <> replicated) map_lam' - m = transformDistributed (NE.singleton new_segment) distributed - addStms =<< runReaderT (runBuilder_ m) scope - let mkRep = Irregular . IrregularRep ws ws_flags ws_offsets - pure $ insertReps (zip (map distResTag res) (map mkRep (patNames pat))) env + (ws_flags, ws_offsets, ws) <- transformMap segments env inps pat w arrs map_lam + pure $ insertIrregulars ws ws_flags ws_offsets (zip (map distResTag res) $ patNames pat) env _ -> error $ "Unhandled Stm:\n" ++ prettyString stm distResCerts :: DistEnv -> [DistInput] -> Certs diff --git a/tests/flattening/map-nested.fut b/tests/flattening/map-nested.fut index b8e871cd93..3942a7868d 100644 --- a/tests/flattening/map-nested.fut +++ b/tests/flattening/map-nested.fut @@ -2,4 +2,4 @@ -- input { [5i64,7i64] } -- output { [20i64, 35i64] } -def main = map (\n -> i64.sum (opaque (map (+2) (iota n)))) +def main = map (\n -> i64.sum (map (+2) (iota n))) diff --git a/tests/flattening/map-slice-nested.fut b/tests/flattening/map-slice-nested.fut index 6ec4f9ccfc..0b01ac7880 100644 --- a/tests/flattening/map-slice-nested.fut +++ b/tests/flattening/map-slice-nested.fut @@ -2,4 +2,4 @@ -- input { [1i64,2i64,3i64,4i64,5i64] [-5i64,7i64] [2i64,3i64] [3i64,4i64] } -- output { [-2i64, 11i64] } -def main A = map3 (\x i j -> i64.sum (opaque (map (+x) A[i:j]))) +def main A = map3 (\x i j -> i64.sum (map (+x) A[i:j])) diff --git a/tests/flattening/mapout.fut b/tests/flattening/mapout.fut new file mode 100644 index 0000000000..67ee76a39c --- /dev/null +++ b/tests/flattening/mapout.fut @@ -0,0 +1,11 @@ +-- A redomap where part of the result is not reduced. +-- == +-- input { [5i64,7i64] [0i64,1i64] } +-- output { [20i64, 35i64] [0i64, 1i64] } + +def main ns is = map2 (\n (i:i64) -> let is = iota n + let xs = map (+2) is + let ys = map (*i) is + in (i64.sum xs, (opaque ys)[i])) + ns is + |> unzip From 4945cd17a4984171b782844ea63a2fca0e20da2b Mon Sep 17 00:00:00 2001 From: CKuke Date: Tue, 10 Jan 2023 11:02:34 +0100 Subject: [PATCH 14/64] starting out --- src/Futhark/Pass/Flatten.hs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index e162dd7fa0..11a203c233 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -279,6 +279,23 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env + Update _ vname slice exp2 -> do + let [DimSlice b n s] = unSlice slice + bs <- elemArr segments env inps b + ns <- elemArr segments env inps n + ss <- elemArr segments env inps s + + + + + -- Access the slice dimensions + -- dims <- unslice slice + + + + pure mempty + -- slice_ns <- mapM (readInput segments env (toList )) + -- (_, _, repiota_elems) <- doRepIota -- II1 _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = From 1e801d9cd309920a77d2e140b268cc28cb680e3b Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 11 Jan 2023 14:56:01 +0100 Subject: [PATCH 15/64] [WIP] `Update` flattening case Compiles but does not work. --- src/Futhark/Pass/Flatten.hs | 62 ++++++++++++++++++++-------- src/Futhark/Pass/Flatten/Builtins.hs | 1 + 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 11a203c233..c85644f2ef 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -202,6 +202,28 @@ elemArr _ env inps (Var v) elemArr segments _ _ se = letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se +-- [WIP] Get the array size of irregular segments. +-- This could propbably be a pure function +segmentArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName +segmentArr _ env inps (Var v) + | Just v_inp <- lookup v inps = + pure $ case v_inp of + DistInputFree _ _ -> error "segmentArr: Free variables not handled (yet)" + DistInput rt _ -> case resVar rt env of + Irregular r -> irregularSegments r + Regular _ -> error "segmentArr: Regulat arrays not handled (yet)" + +-- Get the irregular representation of a var. +-- Var *must* be irregular +getIrregRep :: DistEnv -> DistInputs -> VName -> IrregularRep +getIrregRep env inps v + | Just v_inp <- lookup v inps = + case v_inp of + DistInputFree _ _ -> error "getIrregRep: Free variables not handled (yet)" + DistInput rt _ -> case resVar rt env of + Irregular r -> r + Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" + transformDistBasicOp :: Segments -> DistEnv -> @@ -279,23 +301,29 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env - Update _ vname slice exp2 -> do - let [DimSlice b n s] = unSlice slice - bs <- elemArr segments env inps b - ns <- elemArr segments env inps n - ss <- elemArr segments env inps s - - - - - -- Access the slice dimensions - -- dims <- unslice slice - - - - pure mempty - -- slice_ns <- mapM (readInput segments env (toList )) - -- (_, _, repiota_elems) <- doRepIota -- II1 + Update safety as slice vs -> + case unSlice slice of + [DimSlice x n s] -> do + vs' <- letExp "vs" =<< toExp vs -- VName of `vs` + ns <- elemArr segments env inps n + xs <- elemArr segments env inps x + ss <- elemArr segments env inps s + let (IrregularRep shape flags offsets elems) = getIrregRep env inps as -- Irregular representation of `as` + (_, _, ii1_ns) <- doRepIota ns + (_, _, ii2_ns) <- certifying (distCerts inps aux env) $ doSegIota ns + m <- arraySize 0 <$> lookupType ii2_ns + genScatter as m $ \gid -> do + segment <- letSubExp "segment" =<< eIndex ii1_ns [eSubExp gid] + v' <- letExp "v" =<< eIndex vs' [eSubExp gid] -- Value to write + n' <- letSubExp "n" =<< eIndex ii2_ns [eSubExp gid] + x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] + s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] + o' <- letSubExp "o" =<< eIndex offsets [eSubExp segment] + i' <- letSubExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') -- Index to write at + pure (v', i') + pure $ insertIrregular shape flags offsets (distResTag res) as env -- ?? + [DimFix n] -> error $ "Update: Single dimension fixed index unhandled.\n" ++ prettyString e + _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 9cdd5aa59a..b93596105a 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -5,6 +5,7 @@ module Futhark.Pass.Flatten.Builtins genFlags, genSegScan, genSegRed, + genScatter, exScanAndSum, doSegIota, doPrefixSum, From 3fc3e0e435df167642a6e8fbf733252c9e10e9f0 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Fri, 13 Jan 2023 20:09:00 +0100 Subject: [PATCH 16/64] [WIP] Further work on `Update` flattening Fixing parts I think was wrong before. --- src/Futhark/Pass/Flatten.hs | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index c85644f2ef..1977d12be2 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -202,27 +202,17 @@ elemArr _ env inps (Var v) elemArr segments _ _ se = letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se --- [WIP] Get the array size of irregular segments. --- This could propbably be a pure function -segmentArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName -segmentArr _ env inps (Var v) - | Just v_inp <- lookup v inps = - pure $ case v_inp of - DistInputFree _ _ -> error "segmentArr: Free variables not handled (yet)" - DistInput rt _ -> case resVar rt env of - Irregular r -> irregularSegments r - Regular _ -> error "segmentArr: Regulat arrays not handled (yet)" - -- Get the irregular representation of a var. -- Var *must* be irregular -getIrregRep :: DistEnv -> DistInputs -> VName -> IrregularRep -getIrregRep env inps v - | Just v_inp <- lookup v inps = - case v_inp of +getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU IrregularRep +getIrregRep _ env inps v = + case lookup v inps of + Just v_inp -> case v_inp of DistInputFree _ _ -> error "getIrregRep: Free variables not handled (yet)" DistInput rt _ -> case resVar rt env of - Irregular r -> r + Irregular r -> pure r Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" + Nothing -> error $ "getIrregRep: variable '" ++ prettyString v ++ "' not found" transformDistBasicOp :: Segments -> @@ -301,28 +291,28 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env - Update safety as slice vs -> + Update _ as slice v -> case unSlice slice of [DimSlice x n s] -> do - vs' <- letExp "vs" =<< toExp vs -- VName of `vs` + vs <- elemArr segments env inps v ns <- elemArr segments env inps n xs <- elemArr segments env inps x ss <- elemArr segments env inps s - let (IrregularRep shape flags offsets elems) = getIrregRep env inps as -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as -- Irregular representation of `as` (_, _, ii1_ns) <- doRepIota ns (_, _, ii2_ns) <- certifying (distCerts inps aux env) $ doSegIota ns m <- arraySize 0 <$> lookupType ii2_ns - genScatter as m $ \gid -> do + elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do segment <- letSubExp "segment" =<< eIndex ii1_ns [eSubExp gid] - v' <- letExp "v" =<< eIndex vs' [eSubExp gid] -- Value to write + v' <- letExp "v" =<< eIndex vs [eSubExp gid] -- Value to write n' <- letSubExp "n" =<< eIndex ii2_ns [eSubExp gid] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] o' <- letSubExp "o" =<< eIndex offsets [eSubExp segment] i' <- letSubExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') -- Index to write at pure (v', i') - pure $ insertIrregular shape flags offsets (distResTag res) as env -- ?? - [DimFix n] -> error $ "Update: Single dimension fixed index unhandled.\n" ++ prettyString e + pure $ insertIrregular shape flags offsets (distResTag res) elems' env -- ?? + [DimFix _] -> error $ "Update: Single dimension fixed index unhandled.\n" ++ prettyString e _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where From ff6ceb37b945be2d9db12af03b9a61bf7f4e6ec2 Mon Sep 17 00:00:00 2001 From: Kaffe Date: Mon, 16 Jan 2023 10:44:48 +0100 Subject: [PATCH 17/64] test case --- tests/update1.fut | 1 + 1 file changed, 1 insertion(+) create mode 100644 tests/update1.fut diff --git a/tests/update1.fut b/tests/update1.fut new file mode 100644 index 0000000000..fb5ef968d7 --- /dev/null +++ b/tests/update1.fut @@ -0,0 +1 @@ +let main is js = map2(\i j -> iota 5 with [i:j] = iota (j-i)) is js \ No newline at end of file From c05cc91e31281f6678241b0ef6607d9ae9833f17 Mon Sep 17 00:00:00 2001 From: CKuke Date: Mon, 16 Jan 2023 13:38:29 +0100 Subject: [PATCH 18/64] update1 test update --- tests/update1.fut | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/update1.fut b/tests/update1.fut index fb5ef968d7..c3ab6ebab9 100644 --- a/tests/update1.fut +++ b/tests/update1.fut @@ -1 +1 @@ -let main is js = map2(\i j -> iota 5 with [i:j] = iota (j-i)) is js \ No newline at end of file +let main is js = map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js \ No newline at end of file From 1768e8f59c3e50c24763a05da8f318fc9a53f122 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Mon, 16 Jan 2023 15:16:44 +0100 Subject: [PATCH 19/64] Fix `Update` flattening for single dimension slice Switch the index and value in the result of `genScatter`. --- src/Futhark/Pass/Flatten.hs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 1977d12be2..5444ba23f6 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -298,20 +298,25 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ns <- elemArr segments env inps n xs <- elemArr segments env inps x ss <- elemArr segments env inps s - IrregularRep shape flags offsets elems <- getIrregRep segments env inps as -- Irregular representation of `as` + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + -- Inner indices (1 and 2) of `ns` (_, _, ii1_ns) <- doRepIota ns (_, _, ii2_ns) <- certifying (distCerts inps aux env) $ doSegIota ns + -- Number of updates to preform m <- arraySize 0 <$> lookupType ii2_ns elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do segment <- letSubExp "segment" =<< eIndex ii1_ns [eSubExp gid] - v' <- letExp "v" =<< eIndex vs [eSubExp gid] -- Value to write + -- Value to write + v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] n' <- letSubExp "n" =<< eIndex ii2_ns [eSubExp gid] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] o' <- letSubExp "o" =<< eIndex offsets [eSubExp segment] - i' <- letSubExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') -- Index to write at - pure (v', i') - pure $ insertIrregular shape flags offsets (distResTag res) elems' env -- ?? + -- Index to write `v'` at + i' <- letExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') + pure (i', v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env [DimFix _] -> error $ "Update: Single dimension fixed index unhandled.\n" ++ prettyString e _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e From 31e86ddac856c63441dcb9586c1ab38995fa0a54 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Mon, 16 Jan 2023 15:44:54 +0100 Subject: [PATCH 20/64] Irregular flattening `Update` fixed dimension --- src/Futhark/Pass/Flatten.hs | 17 ++++++++++++++++- tests/update_dimfix.fut | 7 +++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 tests/update_dimfix.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 5444ba23f6..12e295335c 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -317,7 +317,22 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = i' <- letExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') pure (i', v') pure $ insertIrregular shape flags offsets (distResTag res) elems' env - [DimFix _] -> error $ "Update: Single dimension fixed index unhandled.\n" ++ prettyString e + [DimFix n] -> do + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + vs <- elemArr segments env inps v + ns <- elemArr segments env inps n + -- Number of updates to preform + m <- arraySize 0 <$> lookupType vs + elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do + -- Value to write + v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] + n' <- letSubExp "n" =<< eIndex ns [eSubExp gid] + o' <- letSubExp "o" =<< eIndex offsets [eSubExp gid] + -- Index to write `v'` at + i' <- letExp "i" =<< toExp (pe64 o' + pe64 n') + pure (i', v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where diff --git a/tests/update_dimfix.fut b/tests/update_dimfix.fut new file mode 100644 index 0000000000..1892e58e1a --- /dev/null +++ b/tests/update_dimfix.fut @@ -0,0 +1,7 @@ +let main is js = + [let n = 0 in map2(\i j -> (iota 5 with [i] = j)[n]) is js + ,let n = 1 in map2(\i j -> (iota 5 with [i] = j)[n]) is js + ,let n = 2 in map2(\i j -> (iota 5 with [i] = j)[n]) is js + ,let n = 3 in map2(\i j -> (iota 5 with [i] = j)[n]) is js + ,let n = 4 in map2(\i j -> (iota 5 with [i] = j)[n]) is js + ] From effa35397ca2cce9a2201e130c17e094c8dc26e3 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Mon, 16 Jan 2023 16:35:08 +0100 Subject: [PATCH 21/64] [WIP] Unhandled multi-dim `Update` test case --- tests/update_multdim.fut | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 tests/update_multdim.fut diff --git a/tests/update_multdim.fut b/tests/update_multdim.fut new file mode 100644 index 0000000000..3823b1e4fe --- /dev/null +++ b/tests/update_multdim.fut @@ -0,0 +1,2 @@ +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = + map4(\i j as vs -> (copy as with [i:j,i:j] = vs)) is js ass vss From 6e35c60a850a68b8102f8b21472555a9fb60cf0f Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Mon, 16 Jan 2023 16:40:32 +0100 Subject: [PATCH 22/64] Format code --- src/Futhark/Pass/Flatten.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 12e295335c..70202d2f52 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -208,10 +208,10 @@ getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU Irregul getIrregRep _ env inps v = case lookup v inps of Just v_inp -> case v_inp of - DistInputFree _ _ -> error "getIrregRep: Free variables not handled (yet)" - DistInput rt _ -> case resVar rt env of - Irregular r -> pure r - Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" + DistInputFree _ _ -> error "getIrregRep: Free variables not handled (yet)" + DistInput rt _ -> case resVar rt env of + Irregular r -> pure r + Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" Nothing -> error $ "getIrregRep: variable '" ++ prettyString v ++ "' not found" transformDistBasicOp :: @@ -317,7 +317,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = i' <- letExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') pure (i', v') pure $ insertIrregular shape flags offsets (distResTag res) elems' env - [DimFix n] -> do + [DimFix n] -> do -- Irregular representation of `as` IrregularRep shape flags offsets elems <- getIrregRep segments env inps as vs <- elemArr segments env inps v @@ -333,7 +333,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = i' <- letExp "i" =<< toExp (pe64 o' + pe64 n') pure (i', v') pure $ insertIrregular shape flags offsets (distResTag res) elems' env - _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e + _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = From 1d303c070a5e58702b19a30752cc3da8820930f6 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 18 Jan 2023 11:30:50 +0100 Subject: [PATCH 23/64] Move `Update` tests to `tests/flattening/` --- tests/{ => flattening}/update1.fut | 0 tests/{ => flattening}/update_dimfix.fut | 0 tests/{ => flattening}/update_multdim.fut | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => flattening}/update1.fut (100%) rename tests/{ => flattening}/update_dimfix.fut (100%) rename tests/{ => flattening}/update_multdim.fut (100%) diff --git a/tests/update1.fut b/tests/flattening/update1.fut similarity index 100% rename from tests/update1.fut rename to tests/flattening/update1.fut diff --git a/tests/update_dimfix.fut b/tests/flattening/update_dimfix.fut similarity index 100% rename from tests/update_dimfix.fut rename to tests/flattening/update_dimfix.fut diff --git a/tests/update_multdim.fut b/tests/flattening/update_multdim.fut similarity index 100% rename from tests/update_multdim.fut rename to tests/flattening/update_multdim.fut From 84e3e85e49727d30e38e151f7ada67cbaac110ef Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 18 Jan 2023 11:38:54 +0100 Subject: [PATCH 24/64] Fully flatten replicated `elemArr` result Dubious fix, but makes `Update` case with invariant `vs` work. --- src/Futhark/Pass/Flatten.hs | 7 +++++-- tests/flattening/update_variant_is.fut | 7 +++++++ 2 files changed, 12 insertions(+), 2 deletions(-) create mode 100644 tests/flattening/update_variant_is.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 70202d2f52..bedf6c358d 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -199,8 +199,11 @@ elemArr _ env inps (Var v) DistInput rt _ -> case resVar rt env of Irregular r -> irregularElems r Regular vs -> vs -elemArr segments _ _ se = - letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se +elemArr segments _ _ se = do + rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se + dims <- arrayDims <$> lookupType rep + n <- toSubExp "n" $ product $ map pe64 dims + letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep -- Get the irregular representation of a var. -- Var *must* be irregular diff --git a/tests/flattening/update_variant_is.fut b/tests/flattening/update_variant_is.fut new file mode 100644 index 0000000000..8c62f3ba46 --- /dev/null +++ b/tests/flattening/update_variant_is.fut @@ -0,0 +1,7 @@ +-- Test with only variant indices. +-- == +-- input { [0i64,3i64,1i64] [5i64,8i64,6i64] } +-- output { [28i64,13i64,23i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2 (\i j -> reduce (+) 0 (iota 8 with [i:j] = iota 5)) is js From d1caa9425b3cd23638ff3c70c1afdfe30af81737 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 18 Jan 2023 12:26:18 +0100 Subject: [PATCH 25/64] More `Update` flattening tests --- tests/flattening/update1.fut | 1 - tests/flattening/update_dimfix.fut | 9 +++++++++ tests/flattening/update_multdim.fut | 2 ++ tests/flattening/update_variant_vs.fut | 7 +++++++ tests/flattening/update_variant_xs.fut | 6 ++++++ 5 files changed, 24 insertions(+), 1 deletion(-) delete mode 100644 tests/flattening/update1.fut create mode 100644 tests/flattening/update_variant_vs.fut create mode 100644 tests/flattening/update_variant_xs.fut diff --git a/tests/flattening/update1.fut b/tests/flattening/update1.fut deleted file mode 100644 index c3ab6ebab9..0000000000 --- a/tests/flattening/update1.fut +++ /dev/null @@ -1 +0,0 @@ -let main is js = map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js \ No newline at end of file diff --git a/tests/flattening/update_dimfix.fut b/tests/flattening/update_dimfix.fut index 1892e58e1a..41e31099fe 100644 --- a/tests/flattening/update_dimfix.fut +++ b/tests/flattening/update_dimfix.fut @@ -1,3 +1,12 @@ +-- Test with fixed dimension +-- == +-- input { [0,1,2,3,4] [5i64,6i64,7i64,8i64,9i64] } +-- output { [[5i64,0i64,0i64,0i64,0i64] +-- ,[1i64,6i64,1i64,1i64,1i64] +-- ,[2i64,2i64,7i64,2i64,2i64] +-- ,[3i64,3i64,3i64,8i64,3i64] +-- ,[4i64,4i64,4i64,4i64,9i64]] } + let main is js = [let n = 0 in map2(\i j -> (iota 5 with [i] = j)[n]) is js ,let n = 1 in map2(\i j -> (iota 5 with [i] = j)[n]) is js diff --git a/tests/flattening/update_multdim.fut b/tests/flattening/update_multdim.fut index 3823b1e4fe..a56fb0c807 100644 --- a/tests/flattening/update_multdim.fut +++ b/tests/flattening/update_multdim.fut @@ -1,2 +1,4 @@ +-- TODO: add test case(s) once this compiles. + let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = map4(\i j as vs -> (copy as with [i:j,i:j] = vs)) is js ass vss diff --git a/tests/flattening/update_variant_vs.fut b/tests/flattening/update_variant_vs.fut new file mode 100644 index 0000000000..0452da5693 --- /dev/null +++ b/tests/flattening/update_variant_vs.fut @@ -0,0 +1,7 @@ +-- Test with only variant vs. +-- == +-- input { [3i64,3i64,3i64] } +-- output { [7i64,7i64,7i64] } + +entry main (vs : []i64) = + map (\v -> reduce (+) 0 (iota 5 with [1:4] = iota v)) vs diff --git a/tests/flattening/update_variant_xs.fut b/tests/flattening/update_variant_xs.fut new file mode 100644 index 0000000000..c923829c82 --- /dev/null +++ b/tests/flattening/update_variant_xs.fut @@ -0,0 +1,6 @@ +-- Test with only variant 'xs'. +-- == +-- input { [1i64,2i64,3i64] [3i64,3i64,3i64] } +-- output { [8i64,8i64,10i64] } + +let main is js = map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js From 450a5d56c03b766421cfc5cf843e25d1d5f81758 Mon Sep 17 00:00:00 2001 From: Kaffe Date: Wed, 18 Jan 2023 12:42:48 +0100 Subject: [PATCH 26/64] futhark implementation of rule --- rewritefut/segupdate.fut | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 rewritefut/segupdate.fut diff --git a/rewritefut/segupdate.fut b/rewritefut/segupdate.fut new file mode 100644 index 0000000000..e3296652dc --- /dev/null +++ b/rewritefut/segupdate.fut @@ -0,0 +1,36 @@ +-- Flat-Parallel Segmented Update +-- == +-- compiled input { [1i64,2i64,3i64,1i64,2i64,1i64,2i64,3i64,4i64] [3i64,2i64,4i64] [0i64,0i64,0i64,0i64,0i64] [2i64,1i64,2i64] [0i64, 1i64, 0i64] [1i64, 1i64, 2i64] } output { [0i64,0i64,3i64,1i64,0i64,0i64,2i64,0i64,4i64] } + +let sgmSumI64 [n] (flg : [n]i64) (arr : [n]i64) : [n]i64 = + let flgs_vals = + scan ( \ (f1, x1) (f2,x2) -> + let f = f1 | f2 in + if f2 != 0 then (f, x2) + else (f, x1 + x2) ) + (0, 0i64) (zip flg arr) + let (_, vals) = unzip flgs_vals + in vals + +let mkFlagArray [m] (aoa_shp: [m]i64) (zero: i64) + (aoa_val: [m]i64) : []i64 = + let shp_rot = map(\i -> if i==0i64 then 0i64 else aoa_shp[i-1]) (iota m) + let shp_scn = scan (+) 0i64 shp_rot + let aoa_len = shp_scn[m-1]+aoa_shp[m-1] + let shp_ind = map2 (\shp ind -> if shp==0 then -1i64 else ind) aoa_shp shp_scn + in scatter (replicate aoa_len zero) shp_ind aoa_val + +let segUpdate [n][m][t] (xss_val : [n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + let fvss = (mkFlagArray shp_vss 0 (1...t :> [t]i64)) :> [m]i64 + let II1 = sgmSumI64 fvss fvss |> map (\x -> x - 1) + let shp_xss_rot = map(\i -> if i==0i64 then 0i64 else shp_xss[i-1]) (iota t) + let bxss = scan (+) 0 shp_xss_rot + let II2 = sgmSumI64 fvss (replicate m 1) |> map (\x -> x - 1) + let iss = map (\i -> bxss[II1[i]] + bs[II1[i]] + (II2[i] * ss[II1[i]])) (iota m) + in scatter (copy xss_val) iss vss_val + + +let main [n][m][t] (xss_val : [n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) + (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = + segUpdate xss_val shp_xss vss_val shp_vss bs ss \ No newline at end of file From ca0ddbb82726d36fd60d9fad9d0e0ca37658ebce Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 18 Jan 2023 12:50:05 +0100 Subject: [PATCH 27/64] Flatten `Update` fully irregular test case --- tests/flattening/update_fully_irregular.fut | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 tests/flattening/update_fully_irregular.fut diff --git a/tests/flattening/update_fully_irregular.fut b/tests/flattening/update_fully_irregular.fut new file mode 100644 index 0000000000..2577332b05 --- /dev/null +++ b/tests/flattening/update_fully_irregular.fut @@ -0,0 +1,7 @@ +-- Fully irregular test-case +-- == +-- input { [5i64,4i64,3i64] [1i64,2i64,3i64] [3i64,3i64,3i64] } +-- output { [8i64,4i64,3i64] } + +entry main [n] (xs : [n]i64) (is : [n]i64) (js : [n]i64) = + map3 (\x i j -> reduce (+) 0 (iota x with [i:j] = iota (j-i))) xs is js From 1fe9ee8bb073ce612235312856191316445139c6 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Wed, 18 Jan 2023 13:33:51 +0100 Subject: [PATCH 28/64] More `Update` flattening test cases Now checks all cases where `xs`, `is` and `vs` are variant/invariant except for the case where all are invariant (as that would not be irregular). --- tests/flattening/update_invariant_is.fut | 7 +++++++ tests/flattening/update_invariant_vs.fut | 8 ++++++++ tests/flattening/update_invariant_xs.fut | 7 +++++++ tests/flattening/update_variant_vs.fut | 2 +- tests/flattening/update_variant_xs.fut | 7 ++++--- 5 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 tests/flattening/update_invariant_is.fut create mode 100644 tests/flattening/update_invariant_vs.fut create mode 100644 tests/flattening/update_invariant_xs.fut diff --git a/tests/flattening/update_invariant_is.fut b/tests/flattening/update_invariant_is.fut new file mode 100644 index 0000000000..31c8672602 --- /dev/null +++ b/tests/flattening/update_invariant_is.fut @@ -0,0 +1,7 @@ +-- Test with only invariant indices. +-- == +-- input { [4i64,5i64,6i64] [3i64,3i64,3i64] } +-- output { [3i64,7i64,12i64] } + +entry main [n] (xs : [n]i64) (vs : [n]i64) = + map2(\x v -> reduce (+) 0 (iota x with [1:4] = iota v)) xs vs diff --git a/tests/flattening/update_invariant_vs.fut b/tests/flattening/update_invariant_vs.fut new file mode 100644 index 0000000000..ab2a01225a --- /dev/null +++ b/tests/flattening/update_invariant_vs.fut @@ -0,0 +1,8 @@ +-- Test with only invariant 'vs'. +-- == +-- input { [6i64,7i64,8i64] [0i64,1i64,2i64] [5i64,6i64,7i64] } +-- output { [15i64,16i64,18i64] } + +entry main [n] (xs : [n]i64) (is : [n]i64) (js : [n]i64) = + map3(\x i j -> reduce (+) 0 (iota x with [i:j] = iota 5)) xs is js + diff --git a/tests/flattening/update_invariant_xs.fut b/tests/flattening/update_invariant_xs.fut new file mode 100644 index 0000000000..fe9dc215dc --- /dev/null +++ b/tests/flattening/update_invariant_xs.fut @@ -0,0 +1,7 @@ +-- Test with only invariant 'xs'. +-- == +-- input { [1i64,2i64,3i64] [3i64,3i64,3i64] } +-- output { [8i64,8i64,10i64] } + +entry main [n] (is : [n]i64) (js : [n]i64) = + map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js diff --git a/tests/flattening/update_variant_vs.fut b/tests/flattening/update_variant_vs.fut index 0452da5693..3fdec27aae 100644 --- a/tests/flattening/update_variant_vs.fut +++ b/tests/flattening/update_variant_vs.fut @@ -1,4 +1,4 @@ --- Test with only variant vs. +-- Test with only variant 'vs'. -- == -- input { [3i64,3i64,3i64] } -- output { [7i64,7i64,7i64] } diff --git a/tests/flattening/update_variant_xs.fut b/tests/flattening/update_variant_xs.fut index c923829c82..89730113b3 100644 --- a/tests/flattening/update_variant_xs.fut +++ b/tests/flattening/update_variant_xs.fut @@ -1,6 +1,7 @@ -- Test with only variant 'xs'. -- == --- input { [1i64,2i64,3i64] [3i64,3i64,3i64] } --- output { [8i64,8i64,10i64] } +-- input { [4i64,5i64,6i64] } +-- output { [3i64,7i64,12i64] } -let main is js = map2(\i j -> reduce (+) 0 (iota 5 with [i:j] = iota (j-i))) is js +entry main [n] (xs : [n]i64) = + map (\x -> reduce (+) 0 (iota x with [1:4] = iota 3)) xs From e9bd637feaa1f04cecf20597f1efdd4bde77a17c Mon Sep 17 00:00:00 2001 From: Kaffe Date: Thu, 19 Jan 2023 13:29:09 +0100 Subject: [PATCH 29/64] updated futhark seg update --- rewritefut/segupdate.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rewritefut/segupdate.fut b/rewritefut/segupdate.fut index e3296652dc..980d3702db 100644 --- a/rewritefut/segupdate.fut +++ b/rewritefut/segupdate.fut @@ -20,7 +20,7 @@ let mkFlagArray [m] (aoa_shp: [m]i64) (zero: i64) let shp_ind = map2 (\shp ind -> if shp==0 then -1i64 else ind) aoa_shp shp_scn in scatter (replicate aoa_len zero) shp_ind aoa_val -let segUpdate [n][m][t] (xss_val : [n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) +let segUpdate [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = let fvss = (mkFlagArray shp_vss 0 (1...t :> [t]i64)) :> [m]i64 let II1 = sgmSumI64 fvss fvss |> map (\x -> x - 1) @@ -28,9 +28,9 @@ let segUpdate [n][m][t] (xss_val : [n]i64) (shp_xss : [t]i64) (vss_val : [m]i64 let bxss = scan (+) 0 shp_xss_rot let II2 = sgmSumI64 fvss (replicate m 1) |> map (\x -> x - 1) let iss = map (\i -> bxss[II1[i]] + bs[II1[i]] + (II2[i] * ss[II1[i]])) (iota m) - in scatter (copy xss_val) iss vss_val + in scatter xss_val iss vss_val -let main [n][m][t] (xss_val : [n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) +let main [n][m][t] (xss_val : *[n]i64) (shp_xss : [t]i64) (vss_val : [m]i64) (shp_vss : [t]i64) (bs : [t]i64) (ss : [t]i64): [n]i64 = segUpdate xss_val shp_xss vss_val shp_vss bs ss \ No newline at end of file From 5c7e03979995cb84d75d01b3d6a1d155513ae57a Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Thu, 19 Jan 2023 13:54:39 +0100 Subject: [PATCH 30/64] Variant stride in `update_fully_irregular.fut` test --- tests/flattening/update_fully_irregular.fut | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/flattening/update_fully_irregular.fut b/tests/flattening/update_fully_irregular.fut index 2577332b05..cbd5a98f93 100644 --- a/tests/flattening/update_fully_irregular.fut +++ b/tests/flattening/update_fully_irregular.fut @@ -1,7 +1,7 @@ -- Fully irregular test-case -- == --- input { [5i64,4i64,3i64] [1i64,2i64,3i64] [3i64,3i64,3i64] } --- output { [8i64,4i64,3i64] } +-- input { [5i64,6i64,7i64] [2i64,3i64,1i64] [3i64,1i64,2i64] [5i64,6i64,3i64] [1i64,2i64,3i64] } +-- output { [4i64,9i64,19i64] } -entry main [n] (xs : [n]i64) (is : [n]i64) (js : [n]i64) = - map3 (\x i j -> reduce (+) 0 (iota x with [i:j] = iota (j-i))) xs is js +entry main [n] (xs : [n]i64) (vs : [n]i64) (is : [n]i64) (js : [n]i64) (ss: [n]i64) = + map5 (\x v i j s -> reduce (+) 0 (iota x with [i:j:s] = iota v)) xs vs is js ss From cf924d6d9777915dfd4032b2b34fd4be6dad7e92 Mon Sep 17 00:00:00 2001 From: Cornelius Sevald-Krause Date: Thu, 19 Jan 2023 14:44:40 +0100 Subject: [PATCH 31/64] Rename variables Rename `ii1_ns` and `ii2_ns` to `ii1_vss` and `ii2_vss`. --- src/Futhark/Pass/Flatten.hs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index bedf6c358d..2b88a407f6 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -304,15 +304,15 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = -- Irregular representation of `as` IrregularRep shape flags offsets elems <- getIrregRep segments env inps as -- Inner indices (1 and 2) of `ns` - (_, _, ii1_ns) <- doRepIota ns - (_, _, ii2_ns) <- certifying (distCerts inps aux env) $ doSegIota ns + (_, _, ii1_vss) <- doRepIota ns + (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns -- Number of updates to preform - m <- arraySize 0 <$> lookupType ii2_ns + m <- arraySize 0 <$> lookupType ii2_vss elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do - segment <- letSubExp "segment" =<< eIndex ii1_ns [eSubExp gid] + segment <- letSubExp "segment" =<< eIndex ii1_vss [eSubExp gid] -- Value to write v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] - n' <- letSubExp "n" =<< eIndex ii2_ns [eSubExp gid] + n' <- letSubExp "n" =<< eIndex ii2_vss [eSubExp gid] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] o' <- letSubExp "o" =<< eIndex offsets [eSubExp segment] From 63cae7275f9da8b1417ed675ad28b29b52f8f86d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 Feb 2023 11:28:29 +0100 Subject: [PATCH 32/64] Disable for now. --- tests/flattening/update_multdim.fut | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/flattening/update_multdim.fut b/tests/flattening/update_multdim.fut index a56fb0c807..6f1886097f 100644 --- a/tests/flattening/update_multdim.fut +++ b/tests/flattening/update_multdim.fut @@ -1,4 +1,6 @@ -- TODO: add test case(s) once this compiles. +-- == +-- tags { disable } let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = map4(\i j as vs -> (copy as with [i:j,i:j] = vs)) is js ass vss From 5b01fa15b33ec8709f552b7162a17baed18bd89c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 Feb 2023 13:47:38 +0100 Subject: [PATCH 33/64] Refactor to use readInputs. --- src/Futhark/Pass/Flatten.hs | 83 ++++++++++++-------------- src/Futhark/Pass/Flatten/Builtins.hs | 5 +- src/Futhark/Pass/Flatten/Distribute.hs | 3 +- 3 files changed, 43 insertions(+), 48 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 2b88a407f6..e570ed94a7 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -200,9 +200,9 @@ elemArr _ env inps (Var v) Irregular r -> irregularElems r Regular vs -> vs elemArr segments _ _ se = do - rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se + rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se dims <- arrayDims <$> lookupType rep - n <- toSubExp "n" $ product $ map pe64 dims + n <- toSubExp "n" $ product $ map pe64 dims letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep -- Get the irregular representation of a var. @@ -294,49 +294,42 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env - Update _ as slice v -> - case unSlice slice of - [DimSlice x n s] -> do - vs <- elemArr segments env inps v - ns <- elemArr segments env inps n - xs <- elemArr segments env inps x - ss <- elemArr segments env inps s - -- Irregular representation of `as` - IrregularRep shape flags offsets elems <- getIrregRep segments env inps as - -- Inner indices (1 and 2) of `ns` - (_, _, ii1_vss) <- doRepIota ns - (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns - -- Number of updates to preform - m <- arraySize 0 <$> lookupType ii2_vss - elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do - segment <- letSubExp "segment" =<< eIndex ii1_vss [eSubExp gid] - -- Value to write - v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] - n' <- letSubExp "n" =<< eIndex ii2_vss [eSubExp gid] - x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] - s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] - o' <- letSubExp "o" =<< eIndex offsets [eSubExp segment] - -- Index to write `v'` at - i' <- letExp "i" =<< toExp (pe64 o' + pe64 x' + pe64 n' * pe64 s') - pure (i', v') - pure $ insertIrregular shape flags offsets (distResTag res) elems' env - [DimFix n] -> do - -- Irregular representation of `as` - IrregularRep shape flags offsets elems <- getIrregRep segments env inps as - vs <- elemArr segments env inps v - ns <- elemArr segments env inps n - -- Number of updates to preform - m <- arraySize 0 <$> lookupType vs - elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do - -- Value to write - v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] - n' <- letSubExp "n" =<< eIndex ns [eSubExp gid] - o' <- letSubExp "o" =<< eIndex offsets [eSubExp gid] - -- Index to write `v'` at - i' <- letExp "i" =<< toExp (pe64 o' + pe64 n') - pure (i', v') - pure $ insertIrregular shape flags offsets (distResTag res) elems' env - _ -> error $ "Multi dimension update unhandled:\n" ++ prettyString e + Update _ as (Slice [DimSlice x n s]) (Var v) -> do + ns <- elemArr segments env inps n + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + -- Inner indices (1 and 2) of `ns` + (_, _, ii1_vss) <- doRepIota ns + (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns + -- Number of updates to preform + m <- arraySize 0 <$> lookupType ii2_vss + elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] + readInputs segments env [seg_i] $ filter ((/= as) . fst) inps + -- Value to write + v' <- letSubExp "v" =<< eIndex v [eSubExp in_seg_i] + o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] + -- Index to write `v'` at + i <- letExp "i" =<< toExp (pe64 o' + pe64 x + pe64 in_seg_i * pe64 s) + pure (i, v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env + Update _ as (Slice [DimFix n]) v -> do + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + vs <- elemArr segments env inps v + ns <- elemArr segments env inps n + -- Number of updates to preform + m <- arraySize 0 <$> lookupType vs + elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do + -- Value to write + v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] + n' <- letSubExp "n" =<< eIndex ns [eSubExp gid] + o' <- letSubExp "o" =<< eIndex offsets [eSubExp gid] + -- Index to write `v'` at + i' <- letExp "i" =<< toExp (pe64 o' + pe64 n') + pure (i', v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index b93596105a..2c462008b7 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -291,7 +291,7 @@ prefixSumBuiltin = buildingBuiltin $ do flatteningBuiltins :: [FunDef GPU] flatteningBuiltins = [segIotaBuiltin, repIotaBuiltin, prefixSumBuiltin] --- | Perform a segmented iota. Returns flags,offsets,data. +-- | @[0,1,2,0,1,0,1,2,3,4,...]@. Returns @(flags,offsets,elems)@. doSegIota :: VName -> Builder GPU (VName, VName, VName) doSegIota ns = do ns_t <- lookupType ns @@ -315,7 +315,8 @@ doSegIota ns = do (Safe, mempty, mempty) pure (flags, offsets, elems) --- | Returns @(flags, offsets, elems)@. +-- | Produces @[0,0,0,1,1,2,2,2,...]@. Returns @(flags, offsets, +-- elems)@. doRepIota :: VName -> Builder GPU (VName, VName, VName) doRepIota ns = do ns_t <- lookupType ns diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 63fb3987b7..ef44bf4956 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -16,6 +16,7 @@ import Data.List qualified as L import Data.Map qualified as M import Data.Maybe (fromMaybe, mapMaybe) import Futhark.IR.SOACS +import Futhark.Util (nubOrd) import Futhark.Util.Pretty newtype ResTag = ResTag Int @@ -163,7 +164,7 @@ distributeMap outer_scope map_pat w arrs lam = $ used_free stm' = DistStm - (used_free_types <> used_free) + (nubOrd $ used_free_types <> used_free) (zipWith DistResult new_tags $ map distType $ patTypes pat) stm in ((ResTag $ tag + length new_tags, avail_inputs'), stm') From 49227a2d12c58b540d15f66490a8f9cad996e23a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 Feb 2023 14:03:45 +0100 Subject: [PATCH 34/64] Refactor slightly. --- src/Futhark/Pass/Flatten.hs | 46 +++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index e570ed94a7..ab26a6be8d 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -294,26 +294,32 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env - Update _ as (Slice [DimSlice x n s]) (Var v) -> do - ns <- elemArr segments env inps n - -- Irregular representation of `as` - IrregularRep shape flags offsets elems <- getIrregRep segments env inps as - -- Inner indices (1 and 2) of `ns` - (_, _, ii1_vss) <- doRepIota ns - (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns - -- Number of updates to preform - m <- arraySize 0 <$> lookupType ii2_vss - elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do - seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] - in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] - readInputs segments env [seg_i] $ filter ((/= as) . fst) inps - -- Value to write - v' <- letSubExp "v" =<< eIndex v [eSubExp in_seg_i] - o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] - -- Index to write `v'` at - i <- letExp "i" =<< toExp (pe64 o' + pe64 x + pe64 in_seg_i * pe64 s) - pure (i, v') - pure $ insertIrregular shape flags offsets (distResTag res) elems' env + Update _ as slice@(Slice [DimSlice x n s]) (Var v) + | Just (DistInput _ as_t) <- lookup as inps -> do + ns <- elemArr segments env inps n + -- Irregular representation of `as` + IrregularRep shape flags offsets elems <- getIrregRep segments env inps as + -- Inner indices (1 and 2) of `ns` + (_, _, ii1_vss) <- doRepIota ns + (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns + -- Number of updates to perform + m <- arraySize 0 <$> lookupType ii2_vss + elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] + readInputs segments env [seg_i] $ filter ((/= as) . fst) inps + let slice' = fmap pe64 slice + flat_i = + flattenIndex + (map pe64 $ arrayDims as_t) + (fixSlice slice' [pe64 in_seg_i]) + -- Value to write + v' <- letSubExp "v" =<< eIndex v [eSubExp in_seg_i] + o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] + -- Index to write `v'` at + i <- letExp "i" =<< toExp (pe64 o' + flat_i) + pure (i, v') + pure $ insertIrregular shape flags offsets (distResTag res) elems' env Update _ as (Slice [DimFix n]) v -> do -- Irregular representation of `as` IrregularRep shape flags offsets elems <- getIrregRep segments env inps as From 8b916b33699cf60c08b5e237c174b416415f53e6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 Feb 2023 14:18:20 +0100 Subject: [PATCH 35/64] Perhaps this makes multiple dimensions work. --- src/Futhark/Pass/Flatten.hs | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index ab26a6be8d..6230512e7f 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -294,9 +294,16 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env - Update _ as slice@(Slice [DimSlice x n s]) (Var v) + Update _ as slice (Var v) | Just (DistInput _ as_t) <- lookup as inps -> do - ns <- elemArr segments env inps n + ns <- letExp "slice_sizes" + <=< renameExp + <=< segMap (shapeDims (segmentsShape segments)) + $ \is -> do + readInputs segments env is $ + filter ((`elem` sliceDims slice) . Var . fst) inps + n <- letSubExp "n" <=< toExp $ product $ map pe64 $ sliceDims slice + pure [subExpRes n] -- Irregular representation of `as` IrregularRep shape flags offsets elems <- getIrregRep segments env inps as -- Inner indices (1 and 2) of `ns` @@ -304,7 +311,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = (_, _, ii2_vss) <- certifying (distCerts inps aux env) $ doSegIota ns -- Number of updates to perform m <- arraySize 0 <$> lookupType ii2_vss - elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do + elems' <- letExp "elems_scatter" <=< renameExp <=< genScatter elems m $ \gid -> do seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] readInputs segments env [seg_i] $ filter ((/= as) . fst) inps @@ -314,28 +321,16 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = (map pe64 $ arrayDims as_t) (fixSlice slice' [pe64 in_seg_i]) -- Value to write - v' <- letSubExp "v" =<< eIndex v [eSubExp in_seg_i] + v_t <- lookupType v + v' <- + if primType v_t + then pure $ Var v + else letSubExp "v" =<< eIndex v [eSubExp in_seg_i] o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] -- Index to write `v'` at i <- letExp "i" =<< toExp (pe64 o' + flat_i) pure (i, v') pure $ insertIrregular shape flags offsets (distResTag res) elems' env - Update _ as (Slice [DimFix n]) v -> do - -- Irregular representation of `as` - IrregularRep shape flags offsets elems <- getIrregRep segments env inps as - vs <- elemArr segments env inps v - ns <- elemArr segments env inps n - -- Number of updates to preform - m <- arraySize 0 <$> lookupType vs - elems' <- letExp "elems_scatter" <=< genScatter elems m $ \gid -> do - -- Value to write - v' <- letSubExp "v" =<< eIndex vs [eSubExp gid] - n' <- letSubExp "n" =<< eIndex ns [eSubExp gid] - o' <- letSubExp "o" =<< eIndex offsets [eSubExp gid] - -- Index to write `v'` at - i' <- letExp "i" =<< toExp (pe64 o' + pe64 n') - pure (i', v') - pure $ insertIrregular shape flags offsets (distResTag res) elems' env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = From 211effa3dacf6d0ebe004c16e960378979beca61 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 1 Feb 2023 16:01:39 +0100 Subject: [PATCH 36/64] Handle general Updates. --- src/Futhark/Construct.hs | 4 ++- src/Futhark/Pass/Flatten.hs | 50 ++++++++++++++++++++------ src/Futhark/Pass/Flatten/Distribute.hs | 3 ++ tests/flattening/update_mixdim.fut | 12 +++++++ tests/flattening/update_multdim.fut | 11 ++++-- 5 files changed, 65 insertions(+), 15 deletions(-) create mode 100644 tests/flattening/update_mixdim.fut diff --git a/src/Futhark/Construct.hs b/src/Futhark/Construct.hs index 335d0b880f..abb8d9f203 100644 --- a/src/Futhark/Construct.hs +++ b/src/Futhark/Construct.hs @@ -416,8 +416,10 @@ eOutOfBounds arr is = do BinOp LogOr less_than_zero greater_than_size foldBinOp LogOr (constant False) =<< zipWithM checkDim ws is' --- | The array element at this index. +-- | The array element at this index. Returns array unmodified if +-- indexes are null (does not even need to be an array in that case). eIndex :: MonadBuilder m => VName -> [m (Exp (Rep m))] -> m (Exp (Rep m)) +eIndex arr [] = eSubExp $ Var arr eIndex arr is = do is' <- mapM (letSubExp "i" =<<) is arr_t <- lookupType arr diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 6230512e7f..31658db74a 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -31,6 +31,8 @@ import Futhark.Pass.Flatten.Distribute import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute +import Futhark.Util.IntegralExp +import Prelude hiding (rem) data FlattenEnv = FlattenEnv @@ -206,12 +208,39 @@ elemArr segments _ _ se = do letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep -- Get the irregular representation of a var. --- Var *must* be irregular getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU IrregularRep -getIrregRep _ env inps v = +getIrregRep segments env inps v = case lookup v inps of Just v_inp -> case v_inp of - DistInputFree _ _ -> error "getIrregRep: Free variables not handled (yet)" + DistInputFree arr t -> do + arr_t <- lookupType arr + segment_size <- + letSubExp "reg_seg_size" <=< toExp $ product $ map pe64 $ arrayDims t + segments_arr <- + letExp "reg_segments" . BasicOp $ + Replicate (segmentsShape segments) segment_size + num_elems <- + letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t + elems <- + letExp "reg_elems" . BasicOp $ + Reshape ReshapeArbitrary (Shape [num_elems]) arr + flags <- letExp "reg_flags" <=< segMap (Solo num_elems) $ \(Solo i) -> do + flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 + pure [subExpRes flag] + offsets <- letExp "reg_offsets" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do + let flat_seg_i = + flattenIndex + (map pe64 (shapeDims (segmentsShape segments))) + (map pe64 is) + offset <- letSubExp "offset" <=< toExp $ flat_seg_i * pe64 segment_size + pure [subExpRes offset] + pure $ + IrregularRep + { irregularSegments = segments_arr, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems + } DistInput rt _ -> case resVar rt env of Irregular r -> pure r Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" @@ -295,7 +324,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') pure $ insertIrregular ns flags offsets (distResTag res) elems' env Update _ as slice (Var v) - | Just (DistInput _ as_t) <- lookup as inps -> do + | Just as_t <- distInputType <$> lookup as inps -> do ns <- letExp "slice_sizes" <=< renameExp <=< segMap (shapeDims (segmentsShape segments)) @@ -315,17 +344,16 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp gid] in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp gid] readInputs segments env [seg_i] $ filter ((/= as) . fst) inps - let slice' = fmap pe64 slice + v_t <- lookupType v + let in_seg_is = + unflattenIndex (map pe64 (arrayDims v_t)) (pe64 in_seg_i) + slice' = fmap pe64 slice flat_i = flattenIndex (map pe64 $ arrayDims as_t) - (fixSlice slice' [pe64 in_seg_i]) + (fixSlice slice' in_seg_is) -- Value to write - v_t <- lookupType v - v' <- - if primType v_t - then pure $ Var v - else letSubExp "v" =<< eIndex v [eSubExp in_seg_i] + v' <- letSubExp "v" =<< eIndex v (map toExp in_seg_is) o' <- letSubExp "o" =<< eIndex offsets [eSubExp seg_i] -- Index to write `v'` at i <- letExp "i" =<< toExp (pe64 o' + flat_i) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index ef44bf4956..73662b9429 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -6,6 +6,7 @@ module Futhark.Pass.Flatten.Distribute DistInput (..), DistInputs, DistType (..), + distInputType, DistResult (..), ResTag, ) @@ -34,6 +35,8 @@ data DistInput type DistInputs = [(VName, DistInput)] +-- | The type of a 'DistInput'. This corresponds to the parameter +-- type of the original map nest. distInputType :: DistInput -> Type distInputType (DistInputFree _ t) = t distInputType (DistInput _ t) = t diff --git a/tests/flattening/update_mixdim.fut b/tests/flattening/update_mixdim.fut new file mode 100644 index 0000000000..77b6aa16bf --- /dev/null +++ b/tests/flattening/update_mixdim.fut @@ -0,0 +1,12 @@ +-- Mixing slices and indexes in complex ways. +-- == +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[0f64,1f64],[4f64,5f64]] +-- } +-- output { [91.0, 99.0] } + +let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vs : [n][]f64) = + map4(\i j as vs -> f64.sum(flatten(copy as with [i,i:j] = vs))) is js ass vs diff --git a/tests/flattening/update_multdim.fut b/tests/flattening/update_multdim.fut index 6f1886097f..f90aa3f2b3 100644 --- a/tests/flattening/update_multdim.fut +++ b/tests/flattening/update_multdim.fut @@ -1,6 +1,11 @@ --- TODO: add test case(s) once this compiles. -- == --- tags { disable } +-- input { [0i64,1i64] +-- [2i64,3i64] +-- [[[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]], +-- [[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64],[9f64,9f64,9f64,9f64]]] +-- [[[0f64,1f64],[2f64,3f64]],[[4f64,5f64],[6f64,7f64]]] +-- } +-- output { [78.0, 94.0] } let main [n] (is : [n]i64) (js : [n]i64) (ass : [n][][]f64) (vss : [n][][]f64) = - map4(\i j as vs -> (copy as with [i:j,i:j] = vs)) is js ass vss + map4(\i j as vs -> f64.sum(flatten(copy as with [i:j,i:j] = vs))) is js ass vss From 8d4ea7263f4f5c0ad2ba06749a04198c84e58689 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 2 Feb 2023 12:17:21 +0100 Subject: [PATCH 37/64] Catch this error. --- src/Futhark/Pass/Flatten.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 31658db74a..24e487d1d2 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -359,6 +359,8 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = i <- letExp "i" =<< toExp (pe64 o' + flat_i) pure (i, v') pure $ insertIrregular shape flags offsets (distResTag res) elems' env + | otherwise -> + error "Flattening update: destination is not input." _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = From 9658a97a4f9763796ea6918feb4c821d74d77bea Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 17 Feb 2023 16:42:25 +0100 Subject: [PATCH 38/64] Handle array results. --- src/Futhark/Pass/Flatten.hs | 18 +++++++++++++----- src/Futhark/Pass/Flatten/Distribute.hs | 16 ++++++++++------ tests/flattening/range-irreg-stride.fut | 7 +++++++ 3 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 tests/flattening/range-irreg-stride.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 24e487d1d2..1d8daa8859 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -309,19 +309,22 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ns <- elemArr segments env inps n (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns pure $ insertIrregular ns flags offsets (distResTag res) elems env - Iota n x s Int64 -> do + Iota n x s it -> do ns <- elemArr segments env inps n xs <- elemArr segments env inps x ss <- elemArr segments env inps s (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns (_, _, repiota_elems) <- doRepIota ns m <- arraySize 0 <$> lookupType elems - elems' <- letExp "elems_fixed" <=< segMap (Solo m) $ \(Solo i) -> do + elems' <- letExp "iota_elems_fixed" <=< segMap (Solo m) $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex repiota_elems [eSubExp i] v' <- letSubExp "v" =<< eIndex elems [eSubExp i] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] - fmap (subExpsRes . pure) . letSubExp "v" =<< toExp (pe64 x' + pe64 v' * pe64 s') + fmap (subExpsRes . pure) . letSubExp "v" <=< toExp $ + primExpFromSubExp (IntType it) x' + ~+~ sExt it (untyped (pe64 v')) + ~*~ primExpFromSubExp (IntType it) s' pure $ insertIrregular ns flags offsets (distResTag res) elems' env Update _ as slice (Var v) | Just as_t <- distInputType <$> lookup as inps -> do @@ -447,11 +450,16 @@ distResCerts env = Certs . map f transformDistributed :: Segments -> Distributed -> Builder GPU () transformDistributed segments (Distributed dstms resmap) = do env <- foldM (transformDistStm segments) mempty dstms - forM_ (M.toList resmap) $ \(rt, (cs_inps, v)) -> + forM_ (M.toList resmap) $ \(rt, (cs_inps, v, v_t)) -> certifying (distResCerts env cs_inps) $ case resVar rt env of Regular v' -> letBindNames [v] $ BasicOp $ SubExp $ Var v' - Irregular {} -> error $ "Result is irregular: " ++ prettyString v + Irregular irreg -> do + -- It might have an irregular representation, but we know + -- that it is actually regular because it is a result. + let shape = segmentsShape segments <> arrayShape v_t + letBindNames [v] $ + BasicOp (Reshape ReshapeArbitrary shape (irregularElems irreg)) transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) transformStm scope (Let pat _ (Op (Screma w arrs form))) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 73662b9429..0398b559c1 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -64,7 +64,11 @@ data DistStm = DistStm deriving (Eq, Ord, Show) -- | First element of tuple are certificates for this result. -type ResMap = M.Map ResTag ([DistInput], VName) +-- +-- Second is the name to which is should be bound. +-- +-- Third is the element type (i.e. excluding segments). +type ResMap = M.Map ResTag ([DistInput], VName, Type) data Distributed = Distributed [DistStm] ResMap deriving (Eq, Ord, Show) @@ -113,13 +117,13 @@ resultMap :: [(VName, DistInput)] -> [DistStm] -> Pat Type -> Result -> ResMap resultMap avail_inputs stms pat res = mconcat $ map f stms where f stm = - foldMap g $ zip (distStmResult stm) (patNames (stmPat (distStm stm))) - g (DistResult rt _, v) = - maybe mempty (M.singleton rt) $ findRes v - findRes v = do + foldMap g $ zip (distStmResult stm) (patElems (stmPat (distStm stm))) + g (DistResult rt _, pe) = + maybe mempty (M.singleton rt) $ findRes pe + findRes (PatElem v v_t) = do (SubExpRes cs _, pv) <- L.find ((Var v ==) . resSubExp . fst) $ zip res $ patNames pat - Just (map findCert $ unCerts cs, pv) + Just (map findCert $ unCerts cs, pv, v_t) findCert v = fromMaybe (DistInputFree v (Prim Unit)) $ lookup v avail_inputs splitIrregDims :: Names -> Type -> (Rank, Type) diff --git a/tests/flattening/range-irreg-stride.fut b/tests/flattening/range-irreg-stride.fut new file mode 100644 index 0000000000..464fd4acef --- /dev/null +++ b/tests/flattening/range-irreg-stride.fut @@ -0,0 +1,7 @@ +-- == +-- input { 10i64 [1,2] } +-- output { [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +-- [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]] +-- } + +def main k = map (\s -> (0..s.. [k]i32) From 7d264e4ce5a70ab85fccb7c19f4931b166aad1ab Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 17 Feb 2023 17:13:06 +0100 Subject: [PATCH 39/64] Flattening of Copy. --- src/Futhark/Pass/Flatten.hs | 93 ++++++++++++++++++++---------- tests/flattening/update_dimfix.fut | 46 +++++++++++---- 2 files changed, 96 insertions(+), 43 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 1d8daa8859..5f921554c8 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -207,44 +207,55 @@ elemArr segments _ _ se = do n <- toSubExp "n" $ product $ map pe64 dims letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep +mkIrregFromReg :: + Segments -> + VName -> + Builder GPU IrregularRep +mkIrregFromReg segments arr = do + arr_t <- lookupType arr + segment_size <- + letSubExp "reg_seg_size" <=< toExp . product . map pe64 $ + drop (shapeRank (segmentsShape segments)) (arrayDims arr_t) + segments_arr <- + letExp "reg_segments" . BasicOp $ + Replicate (segmentsShape segments) segment_size + num_elems <- + letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t + elems <- + letExp "reg_elems" . BasicOp $ + Reshape ReshapeArbitrary (Shape [num_elems]) arr + flags <- letExp "reg_flags" <=< segMap (Solo num_elems) $ \(Solo i) -> do + flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 + pure [subExpRes flag] + offsets <- letExp "reg_offsets" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do + let flat_seg_i = + flattenIndex + (map pe64 (shapeDims (segmentsShape segments))) + (map pe64 is) + offset <- letSubExp "offset" <=< toExp $ flat_seg_i * pe64 segment_size + pure [subExpRes offset] + pure $ + IrregularRep + { irregularSegments = segments_arr, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems + } + -- Get the irregular representation of a var. getIrregRep :: Segments -> DistEnv -> DistInputs -> VName -> Builder GPU IrregularRep getIrregRep segments env inps v = case lookup v inps of Just v_inp -> case v_inp of - DistInputFree arr t -> do - arr_t <- lookupType arr - segment_size <- - letSubExp "reg_seg_size" <=< toExp $ product $ map pe64 $ arrayDims t - segments_arr <- - letExp "reg_segments" . BasicOp $ - Replicate (segmentsShape segments) segment_size - num_elems <- - letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t - elems <- - letExp "reg_elems" . BasicOp $ - Reshape ReshapeArbitrary (Shape [num_elems]) arr - flags <- letExp "reg_flags" <=< segMap (Solo num_elems) $ \(Solo i) -> do - flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 - pure [subExpRes flag] - offsets <- letExp "reg_offsets" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do - let flat_seg_i = - flattenIndex - (map pe64 (shapeDims (segmentsShape segments))) - (map pe64 is) - offset <- letSubExp "offset" <=< toExp $ flat_seg_i * pe64 segment_size - pure [subExpRes offset] - pure $ - IrregularRep - { irregularSegments = segments_arr, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems - } + DistInputFree arr _ -> mkIrregFromReg segments arr DistInput rt _ -> case resVar rt env of Irregular r -> pure r - Regular _ -> error "getIrregRep: Regulat arrays not handled (yet)" - Nothing -> error $ "getIrregRep: variable '" ++ prettyString v ++ "' not found" + Regular arr -> mkIrregFromReg segments arr + Nothing -> do + v' <- + letExp (baseString v <> "_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + mkIrregFromReg segments v' transformDistBasicOp :: Segments -> @@ -326,6 +337,26 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ~+~ sExt it (untyped (pe64 v')) ~*~ primExpFromSubExp (IntType it) s' pure $ insertIrregular ns flags offsets (distResTag res) elems' env + Copy v -> + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- letExp (baseString v' <> "_copy") $ BasicOp $ Copy v' + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt _) -> + case resVar rt env of + Irregular r -> do + let name = baseString (irregularElems r) <> "_copy" + elems_copy <- letExp name $ BasicOp $ Copy $ irregularElems r + let rep = Irregular $ r {irregularElems = elems_copy} + pure $ insertRep (distResTag res) rep env + Regular v' -> do + v'' <- letExp (baseString v' <> "_copy") $ BasicOp $ Copy v' + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + v' <- + letExp (baseString v <> "_copy_free") . BasicOp $ + Replicate (segmentsShape segments) (Var v) + pure $ insertRegulars [distResTag res] [v'] env Update _ as slice (Var v) | Just as_t <- distInputType <$> lookup as inps -> do ns <- letExp "slice_sizes" diff --git a/tests/flattening/update_dimfix.fut b/tests/flattening/update_dimfix.fut index 41e31099fe..6ddad56a10 100644 --- a/tests/flattening/update_dimfix.fut +++ b/tests/flattening/update_dimfix.fut @@ -1,16 +1,38 @@ -- Test with fixed dimension -- == --- input { [0,1,2,3,4] [5i64,6i64,7i64,8i64,9i64] } --- output { [[5i64,0i64,0i64,0i64,0i64] --- ,[1i64,6i64,1i64,1i64,1i64] --- ,[2i64,2i64,7i64,2i64,2i64] --- ,[3i64,3i64,3i64,8i64,3i64] --- ,[4i64,4i64,4i64,4i64,9i64]] } +-- input { [0,1,2,3,4] [0,1,2,3,4] [5,6,7,8,9] } +-- output { +-- [[[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]], +-- [[5, 1, 2, 3, 4], +-- [0, 6, 2, 3, 4], +-- [0, 1, 7, 3, 4], +-- [0, 1, 2, 8, 4], +-- [0, 1, 2, 3, 9]]] +-- } -let main is js = - [let n = 0 in map2(\i j -> (iota 5 with [i] = j)[n]) is js - ,let n = 1 in map2(\i j -> (iota 5 with [i] = j)[n]) is js - ,let n = 2 in map2(\i j -> (iota 5 with [i] = j)[n]) is js - ,let n = 3 in map2(\i j -> (iota 5 with [i] = j)[n]) is js - ,let n = 4 in map2(\i j -> (iota 5 with [i] = j)[n]) is js +let main (arr: []i32) (is: []i32) (js: []i32) = + [map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js + ,map2(\i j -> (copy arr with [i] = j)) is js ] From f9e5bc66a93c6e63ba8c63f885cd647cdf1b8ee2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 15 Mar 2023 09:49:43 +0100 Subject: [PATCH 40/64] Incomplete work on nested maps. --- src/Futhark/Pass/Flatten.hs | 40 +++++++++++++++++--------- src/Futhark/Pass/Flatten/Distribute.hs | 2 +- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 5f921554c8..4a51eb38eb 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -32,7 +32,7 @@ import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute import Futhark.Util.IntegralExp -import Prelude hiding (rem) +import Prelude hiding (div, rem) data FlattenEnv = FlattenEnv @@ -113,6 +113,9 @@ type Segments = NE.NonEmpty SubExp segmentsShape :: Segments -> Shape segmentsShape = Shape . toList +segmentsDims :: Segments -> [SubExp] +segmentsDims = shapeDims . segmentsShape + segMap :: Traversable f => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) segMap segments f = do gtids <- traverse (const $ newVName "gtid") segments @@ -194,13 +197,13 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux -- | Only sensible for variables of segment-invariant type. elemArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName -elemArr _ env inps (Var v) +elemArr segments env inps (Var v) | Just v_inp <- lookup v inps = - pure $ case v_inp of - DistInputFree ns _ -> ns + case v_inp of + DistInputFree vs _ -> irregularElems <$> mkIrregFromReg segments vs DistInput rt _ -> case resVar rt env of - Irregular r -> irregularElems r - Regular vs -> vs + Irregular r -> pure $ irregularElems r + Regular vs -> irregularElems <$> mkIrregFromReg segments vs elemArr segments _ _ se = do rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se dims <- arrayDims <$> lookupType rep @@ -401,11 +404,22 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = transformScalarStm segments env inps [res] $ Let (Pat [pe]) aux (BasicOp e) -repPerSegment :: SubExp -> VName -> [VName] -> Builder GPU [VName] -repPerSegment w segments_per_elem vs = - letTupExp "replicated" <=< segMap (Solo w) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes <$> mapM (letSubExp "v" <=< flip eIndex [eSubExp segment]) vs +-- Replicates inner dimension. +repPerSegment :: Segments -> DistEnv -> DistInputs -> SubExp -> VName -> VName -> Builder GPU VName +repPerSegment segments env inps w segments_per_elem v + | Just v_inp <- lookup v inps = + case v_inp of + DistInputFree v' _ -> do + letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) + DistInput rt _ -> case resVar rt env of + Irregular r -> + letExp (baseString v <> "_rep_free_irreg_inp") <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) + | otherwise = + letExp (baseString v <> "_rep_free") $ BasicOp $ Replicate (Shape [w]) (Var v) transformMap :: Segments -> @@ -423,8 +437,7 @@ transformMap segments env inps pat w arrs map_lam = do new_segment <- arraySize 0 <$> lookupType ws_elems let free_in_map = namesToList $ freeIn map_lam replicated <- - repPerSegment new_segment ws_elems - =<< mapM (elemArr segments env inps . Var) free_in_map + mapM (repPerSegment segments env inps new_segment ws_elems) free_in_map free_ps <- mapM (newParam "free_p" . rowType <=< lookupType) replicated scope <- askScope @@ -435,6 +448,7 @@ transformMap segments env inps pat w arrs map_lam = do } distributed = distributeMap scope pat new_segment (arrs' <> replicated) map_lam' m = transformDistributed (NE.singleton new_segment) distributed + traceM $ unlines ["inner map", prettyString distributed] addStms =<< runReaderT (runBuilder_ m) scope pure (ws_flags, ws_offsets, ws) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 0398b559c1..034a5df08e 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -26,7 +26,7 @@ newtype ResTag = ResTag Int -- | Something that is mapped. data DistInput = -- | A value bound outside the original map nest. By necessity - -- regular. + -- regular. The type is the parameter type. DistInputFree VName Type | -- | A value constructed inside the original map nest. May be -- irregular. From 30583cd5b03a087fcba5d4b8623d878570ab1afa Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 15 Mar 2023 17:26:03 +0100 Subject: [PATCH 41/64] More work on nested maps. --- src/Futhark/Pass/Flatten.hs | 64 +++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 4a51eb38eb..52a087ff53 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -17,7 +17,7 @@ import Data.Bifunctor (bimap, first, second) import Data.Foldable import Data.List.NonEmpty qualified as NE import Data.Map qualified as M -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, mapMaybe) import Data.Tuple.Solo import Debug.Trace import Futhark.IR.GPU @@ -404,22 +404,43 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = transformScalarStm segments env inps [res] $ Let (Pat [pe]) aux (BasicOp e) --- Replicates inner dimension. -repPerSegment :: Segments -> DistEnv -> DistInputs -> SubExp -> VName -> VName -> Builder GPU VName -repPerSegment segments env inps w segments_per_elem v - | Just v_inp <- lookup v inps = +-- Replicates inner dimension for inputs. +onMapFreeVar :: Segments -> DistEnv -> DistInputs -> SubExp -> VName -> VName -> Maybe (Builder GPU (VName, VName)) +onMapFreeVar segments env inps w segments_per_elem v = do + v_inp <- lookup v inps + pure . fmap (v,) $ case v_inp of + DistInputFree v' _ -> do + letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) + DistInput rt _ -> case resVar rt env of + Irregular r -> + letExp (baseString v <> "_rep_free_irreg_inp") <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) + +onMapInputArr :: + Segments -> + DistEnv -> + DistInputs -> + SubExp -> + VName -> + Builder GPU VName +onMapInputArr segments env inps w arr = + case lookup arr inps of + Just v_inp -> case v_inp of - DistInputFree v' _ -> do - letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo w) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) - DistInput rt _ -> case resVar rt env of - Irregular r -> - letExp (baseString v <> "_rep_free_irreg_inp") <=< segMap (Solo w) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) - | otherwise = - letExp (baseString v <> "_rep_free") $ BasicOp $ Replicate (Shape [w]) (Var v) + DistInputFree vs _ -> + letExp (baseString vs <> "_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [w]) vs + DistInput rt _ -> undefined + Nothing -> do + arr_row_t <- rowType <$> lookupType arr + arr_rep <- + letExp (baseString arr <> "_inp_rep") . BasicOp $ + Replicate (segmentsShape segments) (Var arr) + letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [w] <> arrayShape arr_row_t) arr_rep transformMap :: Segments -> @@ -431,17 +452,18 @@ transformMap :: Lambda SOACS -> Builder GPU (VName, VName, VName) transformMap segments env inps pat w arrs map_lam = do - arrs' <- mapM (elemArr segments env inps . Var) arrs ws <- elemArr segments env inps w (ws_flags, ws_offsets, ws_elems) <- doRepIota ws new_segment <- arraySize 0 <$> lookupType ws_elems - let free_in_map = namesToList $ freeIn map_lam - replicated <- - mapM (repPerSegment segments env inps new_segment ws_elems) free_in_map + arrs' <- mapM (onMapInputArr segments env inps new_segment) arrs + (free_replicated, replicated) <- + fmap unzip . sequence $ + mapMaybe (onMapFreeVar segments env inps new_segment ws_elems) $ + namesToList (freeIn map_lam) free_ps <- mapM (newParam "free_p" . rowType <=< lookupType) replicated scope <- askScope - let substs = M.fromList $ zip free_in_map $ map paramName free_ps + let substs = M.fromList $ zip free_replicated $ map paramName free_ps map_lam' = (substituteNames substs map_lam) { lambdaParams = lambdaParams map_lam <> free_ps From fd9fc78b085b8d59f36bd544520b6edf6f2f6b8d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 20 Mar 2023 09:05:03 +0100 Subject: [PATCH 42/64] All these SegOps should be virtualised. --- src/Futhark/Pass/Flatten.hs | 2 +- src/Futhark/Pass/Flatten/Builtins.hs | 14 ++++---------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 52a087ff53..7ca1f105f7 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -125,7 +125,7 @@ segMap segments f = do ts <- mapM (subExpType . resSubExp) res pure (map mkResult res, ts) let kbody = KernelBody () stms res - pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space ts kbody + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody where mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 2c462008b7..c3a00c6549 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -15,19 +15,13 @@ where import Control.Monad.Reader import Control.Monad.State.Strict -import Data.Foldable -import Data.Map qualified as M import Data.Maybe (fromMaybe) import Data.Text qualified as T import Futhark.IR.GPU import Futhark.IR.SOACS import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace) -import Futhark.Pass.ExtractKernels.ToGPU - ( scopeForGPU, - soacsLambdaToGPU, - soacsStmToGPU, - ) +import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU) import Futhark.Tools builtinName :: T.Text -> Name @@ -50,7 +44,7 @@ genScanomap desc w lam nes m = do op = SegBinOp Commutative lam nes mempty letTupExp desc $ Op $ SegOp $ SegScan lvl space [op] res_t kbody where - lvl = SegThread SegNoVirt Nothing + lvl = SegThread SegVirt Nothing genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] genScan desc w lam nes arrs = @@ -119,7 +113,7 @@ genScatter dest n f = do v_t <- subExpType v pure (WriteReturns mempty (Shape [m]) dest [(Slice [DimFix (Var i)], v)], v_t) let kbody = KernelBody () stms [res] - pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space [v_t] kbody + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space [v_t] kbody genTabulate :: SubExp -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU (Exp GPU) genTabulate w m = do @@ -130,7 +124,7 @@ genTabulate w m = do ts <- mapM subExpType ses pure (map (Returns ResultMaySimplify mempty) ses, ts) let kbody = KernelBody () stms res - pure $ Op $ SegOp $ SegMap (SegThread SegNoVirt Nothing) space ts kbody + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody genFlags :: SubExp -> VName -> Builder GPU VName genFlags m offsets = do From 4674db5cce1fd9ac278b1b4828a07b8b2b7cb9da Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 13 Apr 2023 21:37:48 +0200 Subject: [PATCH 43/64] More work, including nonfunctional test. --- src/Futhark/Pass/Flatten.hs | 86 ++++++++++++++++++-------- src/Futhark/Pass/Flatten/Distribute.hs | 4 +- tests/flattening/map-nested-deeper.fut | 7 +++ 3 files changed, 70 insertions(+), 27 deletions(-) create mode 100644 tests/flattening/map-nested-deeper.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 7ca1f105f7..3b71cfc160 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -149,15 +149,14 @@ readInputs segments env is = mapM_ onInput onInput (v, DistInputFree arr _) = letBindNames [v] =<< eIndex arr (map eSubExp is) onInput (v, DistInput rt t) = - case M.lookup rt $ distResMap env of - Just (Regular arr) -> + case resVar rt env of + Regular arr -> letBindNames [v] =<< eIndex arr (map eSubExp is) - Just (Irregular (IrregularRep _ _ offsets elems)) -> do + Irregular (IrregularRep _ _ offsets elems) -> do offset <- letSubExp "offset" =<< eIndex offsets (map eSubExp is) num_elems <- letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] letBindNames [v] $ BasicOp $ Index elems slice - Nothing -> error $ "readInputs: " <> show rt transformScalarStms :: Segments -> @@ -405,42 +404,77 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = Let (Pat [pe]) aux (BasicOp e) -- Replicates inner dimension for inputs. -onMapFreeVar :: Segments -> DistEnv -> DistInputs -> SubExp -> VName -> VName -> Maybe (Builder GPU (VName, VName)) +onMapFreeVar :: + Segments -> + DistEnv -> + DistInputs -> + SubExp -> + VName -> + VName -> + Maybe (Builder GPU (VName, DistInput)) onMapFreeVar segments env inps w segments_per_elem v = do v_inp <- lookup v inps pure . fmap (v,) $ case v_inp of - DistInputFree v' _ -> do - letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo w) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) - DistInput rt _ -> case resVar rt env of - Irregular r -> - letExp (baseString v <> "_rep_free_irreg_inp") <=< segMap (Solo w) $ \(Solo i) -> do + DistInputFree v' t -> do + fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_free_inp") + <=< segMap (Solo w) + $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) + DistInput rt t -> case resVar rt env of + Irregular r -> + -- FIXME: completely bogus; we need to invent a new IrregRep here. + fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_irreg_inp") + <=< segMap (Solo w) + $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + readInputs segments env [segment] [(v, v_inp)] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) + Regular vs -> + fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_reg_inp") + <=< segMap (Solo w) + $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex vs [eSubExp segment]) onMapInputArr :: Segments -> DistEnv -> DistInputs -> SubExp -> + Param Type -> VName -> - Builder GPU VName -onMapInputArr segments env inps w arr = + Builder GPU DistInput +onMapInputArr segments env inps w p arr = case lookup arr inps of Just v_inp -> case v_inp of - DistInputFree vs _ -> - letExp (baseString vs <> "_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [w]) vs - DistInput rt _ -> undefined + DistInputFree vs t -> do + v <- + letExp (baseString vs <> "_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [w]) vs + pure $ DistInputFree v t + DistInput rt t -> + case resVar rt env of + Irregular r -> do + elems_t <- lookupType $ irregularElems r + -- If parameter type of the map corresponds to the + -- element type of the value array, we can map it + -- directly. + if stripArray (shapeRank (segmentsShape segments)) elems_t == paramType p + then pure $ DistInputFree (irregularElems r) elems_t + else undefined + Regular vs -> + undefined Nothing -> do arr_row_t <- rowType <$> lookupType arr arr_rep <- letExp (baseString arr <> "_inp_rep") . BasicOp $ Replicate (segmentsShape segments) (Var arr) - letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [w] <> arrayShape arr_row_t) arr_rep + v <- + letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [w] <> arrayShape arr_row_t) arr_rep + pure $ DistInputFree v arr_row_t transformMap :: Segments -> @@ -455,13 +489,12 @@ transformMap segments env inps pat w arrs map_lam = do ws <- elemArr segments env inps w (ws_flags, ws_offsets, ws_elems) <- doRepIota ws new_segment <- arraySize 0 <$> lookupType ws_elems - arrs' <- mapM (onMapInputArr segments env inps new_segment) arrs + arrs' <- zipWithM (onMapInputArr segments env inps new_segment) (lambdaParams map_lam) arrs (free_replicated, replicated) <- fmap unzip . sequence $ mapMaybe (onMapFreeVar segments env inps new_segment ws_elems) $ namesToList (freeIn map_lam) - free_ps <- - mapM (newParam "free_p" . rowType <=< lookupType) replicated + free_ps <- mapM (newParam "free_p" . distInputType) replicated scope <- askScope let substs = M.fromList $ zip free_replicated $ map paramName free_ps map_lam' = @@ -531,7 +564,10 @@ transformDistributed segments (Distributed dstms resmap) = do transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) transformStm scope (Let pat _ (Op (Screma w arrs form))) | Just lam <- isMapSOAC form = do - let distributed = distributeMap scope pat w arrs lam + let arrs' = + zipWith DistInputFree arrs $ + map paramType (lambdaParams (scremaLambda form)) + distributed = distributeMap scope pat w arrs' lam m = transformDistributed (NE.singleton w) distributed traceM $ prettyString distributed runReaderT (runBuilder_ m) scope diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 034a5df08e..e31e46cd66 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -143,7 +143,7 @@ patInput :: ResTag -> PatElem Type -> (VName, DistInput) patInput tag pe = (patElemName pe, DistInput tag $ patElemType pe) -distributeMap :: Scope rep -> Pat Type -> SubExp -> [VName] -> Lambda SOACS -> Distributed +distributeMap :: Scope rep -> Pat Type -> SubExp -> [DistInput] -> Lambda SOACS -> Distributed distributeMap outer_scope map_pat w arrs lam = let param_inputs = zipWith paramInput (lambdaParams lam) arrs @@ -155,7 +155,7 @@ distributeMap outer_scope map_pat w arrs lam = in Distributed stms $ resultMap avail_inputs stms map_pat $ bodyResult $ lambdaBody lam where bound_outside = namesFromList $ M.keys outer_scope - paramInput p arr = (paramName p, DistInputFree arr $ paramType p) + paramInput p arr = (paramName p, arr) distType t = uncurry (DistType w) $ splitIrregDims bound_outside t distributeStm (ResTag tag, avail_inputs) stm = let pat = stmPat stm diff --git a/tests/flattening/map-nested-deeper.fut b/tests/flattening/map-nested-deeper.fut new file mode 100644 index 0000000000..0731c53015 --- /dev/null +++ b/tests/flattening/map-nested-deeper.fut @@ -0,0 +1,7 @@ +-- == +-- input { [5i64,7i64] [5i64,7i64] } +-- output { [23i64, 28i64] } + +def main = map2 (\n x -> + let A = #[opaque] map (+2) (iota n) + in i64.sum (map (\i -> A[i%x]) (iota n))) From 5d6c27fcfcc6783bb9316b192fc67b0629787744 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 15 Apr 2023 11:37:26 +0200 Subject: [PATCH 44/64] Start of work on Replicate. --- src/Futhark/Pass/Flatten.hs | 47 +++++++++++++++++++++++++++++++++ tests/flattening/replicate0.fut | 6 +++++ 2 files changed, 53 insertions(+) create mode 100644 tests/flattening/replicate0.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 3b71cfc160..6179d80335 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -259,6 +259,48 @@ getIrregRep segments env inps v = Replicate (segmentsShape segments) (Var v) mkIrregFromReg segments v' +-- Do 'map2 replicate ns A', where 'A' is an irregular array (and so +-- is the result, obviously). +replicateIrreg :: + Segments -> + DistEnv -> + VName -> + String -> + IrregularRep -> + Builder GPU IrregularRep +replicateIrreg segments env ns desc rep = do + -- This does not change the number of segments - it simply makes + -- each of them larger. + num_segments <- arraySize 0 <$> lookupType ns + + (ns_flags, ns_offsets, ns_elems) <- doRepIota ns + + w <- arraySize 0 <$> lookupType ns_elems + + elems <- letExp (desc <> "_elems") <=< segMap (Solo w) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex ns_elems [eSubExp i] + v <- letSubExp "v" =<< eIndex (irregularElems rep) [eSubExp segment] + pure $ subExpsRes [v] + + rep_segments <- letExp (desc <> "_segments") <=< segMap (Solo num_segments) $ + \(Solo i) -> do + segment_old <- + letSubExp "segment_old" =<< eIndex (irregularSegments rep) [eSubExp i] + n <- + letSubExp "n" =<< eIndex ns [eSubExp i] + segment_new <- + letSubExp "segment_new" . BasicOp $ + BinOp (Mul Int64 OverflowUndef) segment_old n + pure $ subExpsRes [segment_new] + + pure $ + IrregularRep + { irregularSegments = rep_segments, + irregularFlags = ns_flags, + irregularOffsets = ns_offsets, + irregularElems = elems + } + transformDistBasicOp :: Segments -> DistEnv -> @@ -339,6 +381,11 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ~+~ sExt it (untyped (pe64 v')) ~*~ primExpFromSubExp (IntType it) s' pure $ insertIrregular ns flags offsets (distResTag res) elems' env + Replicate (Shape [n]) (Var v) -> do + ns <- elemArr segments env inps n + rep <- getIrregRep segments env inps v + rep' <- replicateIrreg segments env ns (baseString v) rep + pure $ insertRep (distResTag res) (Irregular rep') env Copy v -> case lookup v inps of Just (DistInputFree v' _) -> do diff --git a/tests/flattening/replicate0.fut b/tests/flattening/replicate0.fut new file mode 100644 index 0000000000..1e34ee3240 --- /dev/null +++ b/tests/flattening/replicate0.fut @@ -0,0 +1,6 @@ +-- == +-- input { [1i64,2i64] [0, 1] [4,5] } +-- output { [4,5] } + +def main = map3 (\n (i:i32) (x:i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) From 29af9460822d304f50a4fb23a8693a9b676087b2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 16 Apr 2023 22:30:19 +0200 Subject: [PATCH 45/64] Further work on replicate. --- src/Futhark/Pass/Flatten.hs | 84 ++++++++++++++++++---------- src/Futhark/Pass/Flatten/Builtins.hs | 2 +- tests/flattening/replicate1.fut | 7 +++ 3 files changed, 62 insertions(+), 31 deletions(-) create mode 100644 tests/flattening/replicate1.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 6179d80335..f608cd5948 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -154,9 +154,18 @@ readInputs segments env is = mapM_ onInput letBindNames [v] =<< eIndex arr (map eSubExp is) Irregular (IrregularRep _ _ offsets elems) -> do offset <- letSubExp "offset" =<< eIndex offsets (map eSubExp is) - num_elems <- letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) - let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] - letBindNames [v] $ BasicOp $ Index elems slice + case arrayDims t of + [num_elems] -> do + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + letBindNames [v] $ BasicOp $ Index elems slice + _ -> do + num_elems <- + letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] + v_flat <- + letExp (baseString v <> "_float") $ BasicOp $ Index elems slice + letBindNames [v] . BasicOp $ + Reshape ReshapeArbitrary (arrayShape t) v_flat transformScalarStms :: Segments -> @@ -269,35 +278,56 @@ replicateIrreg :: IrregularRep -> Builder GPU IrregularRep replicateIrreg segments env ns desc rep = do - -- This does not change the number of segments - it simply makes - -- each of them larger. + -- Replication does not change the number of segments - it simply + -- makes each of them larger. + num_segments <- arraySize 0 <$> lookupType ns - (ns_flags, ns_offsets, ns_elems) <- doRepIota ns + -- ns multipled with existing segment sizes. + ns_full <- letExp (baseString ns <> "_full") <=< segMap (Solo num_segments) $ + \(Solo i) -> do + n <- + letSubExp "n" =<< eIndex ns [eSubExp i] + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp i] + full_segment <- + letSubExp "new_segment" =<< toExp (pe64 n * pe64 old_segment) + pure $ subExpsRes [full_segment] + + (ns_full_flags, ns_full_offsets, ns_full_elems) <- doRepIota ns_full + (_, _, flat_to_segs) <- doSegIota ns_full - w <- arraySize 0 <$> lookupType ns_elems + w <- arraySize 0 <$> lookupType ns_full_elems elems <- letExp (desc <> "_elems") <=< segMap (Solo w) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex ns_elems [eSubExp i] - v <- letSubExp "v" =<< eIndex (irregularElems rep) [eSubExp segment] + -- Which segment we are in. + segment_i <- + letSubExp "segment_i" =<< eIndex ns_full_elems [eSubExp i] + -- How much this segment is replicated. + n <- + letSubExp "n_to_rep" =<< eIndex ns [eSubExp segment_i] + -- Size of original segment. + old_segment <- + letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp i] + -- Index of value inside *new* segment. + j_new <- + letSubExp "j_new" =<< eIndex flat_to_segs [eSubExp i] + -- Index of value inside *old* segment. + j_old <- + letSubExp "j_old" =<< toExp (pe64 j_new `rem` pe64 old_segment) + -- Offset of values in original segment. + offset <- + letSubExp "offset" =<< eIndex (irregularOffsets rep) [eSubExp segment_i] + v <- + letSubExp "v" + =<< eIndex (irregularElems rep) [toExp $ pe64 offset + pe64 j_old] pure $ subExpsRes [v] - rep_segments <- letExp (desc <> "_segments") <=< segMap (Solo num_segments) $ - \(Solo i) -> do - segment_old <- - letSubExp "segment_old" =<< eIndex (irregularSegments rep) [eSubExp i] - n <- - letSubExp "n" =<< eIndex ns [eSubExp i] - segment_new <- - letSubExp "segment_new" . BasicOp $ - BinOp (Mul Int64 OverflowUndef) segment_old n - pure $ subExpsRes [segment_new] - pure $ IrregularRep - { irregularSegments = rep_segments, - irregularFlags = ns_flags, - irregularOffsets = ns_offsets, + { irregularSegments = ns_full_elems, + irregularFlags = ns_full_flags, + irregularOffsets = ns_full_offsets, irregularElems = elems } @@ -470,13 +500,7 @@ onMapFreeVar segments env inps w segments_per_elem v = do subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) DistInput rt t -> case resVar rt env of Irregular r -> - -- FIXME: completely bogus; we need to invent a new IrregRep here. - fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_irreg_inp") - <=< segMap (Solo w) - $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - readInputs segments env [segment] [(v, v_inp)] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v [eSubExp segment]) + undefined Regular vs -> fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo w) diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index c3a00c6549..1fd4ef92d3 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -273,7 +273,7 @@ prefixSumBuiltin = buildingBuiltin $ do funDefAttrs = mempty, funDefName = prefixSumName, funDefRetType = - [Array int64 (Shape [Free $ Var $ paramName np]) Nonunique], + [Array int64 (Shape [Free $ Var $ paramName np]) Unique], funDefParams = [np, nsp], funDefBody = body } diff --git a/tests/flattening/replicate1.fut b/tests/flattening/replicate1.fut new file mode 100644 index 0000000000..9d61bf17d3 --- /dev/null +++ b/tests/flattening/replicate1.fut @@ -0,0 +1,7 @@ +-- Now we are replicating a regular array. +-- == +-- input { [1i64,2i64] [0, 1] [[4,5],[5,6]] } +-- output { [[4,5],[5,6]] } + +def main = map3 (\n (i:i32) (x:[2]i32) -> let A = opaque (replicate n x) + in #[unsafe] A[i]) From ffda3c81ee31a9fd1573b72f959a36cc816609fd Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 18 Apr 2023 17:00:55 +0200 Subject: [PATCH 46/64] More work; not really working. --- src/Futhark/Pass/Flatten.hs | 135 ++++++++++++++++--------- src/Futhark/Pass/Flatten/Distribute.hs | 49 +++++++-- tests/flattening/map-nested-deeper.fut | 12 ++- 3 files changed, 136 insertions(+), 60 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index f608cd5948..9788b26733 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -303,12 +303,9 @@ replicateIrreg segments env ns desc rep = do -- Which segment we are in. segment_i <- letSubExp "segment_i" =<< eIndex ns_full_elems [eSubExp i] - -- How much this segment is replicated. - n <- - letSubExp "n_to_rep" =<< eIndex ns [eSubExp segment_i] -- Size of original segment. old_segment <- - letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp i] + letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp segment_i] -- Index of value inside *new* segment. j_new <- letSubExp "j_new" =<< eIndex flat_to_segs [eSubExp i] @@ -325,7 +322,7 @@ replicateIrreg segments env ns desc rep = do pure $ IrregularRep - { irregularSegments = ns_full_elems, + { irregularSegments = ns_full, irregularFlags = ns_full_flags, irregularOffsets = ns_full_offsets, irregularElems = elems @@ -416,6 +413,13 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = rep <- getIrregRep segments env inps v rep' <- replicateIrreg segments env ns (baseString v) rep pure $ insertRep (distResTag res) (Irregular rep') env + Replicate (Shape [n]) (Constant v) -> do + ns <- elemArr segments env inps n + (flags, offsets, elems) <- + certifying (distCerts inps aux env) $ doSegIota ns + w <- arraySize 0 <$> lookupType elems + elems' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) + pure $ insertIrregular ns flags offsets (distResTag res) elems' env Copy v -> case lookup v inps of Just (DistInputFree v' _) -> do @@ -485,28 +489,34 @@ onMapFreeVar :: Segments -> DistEnv -> DistInputs -> - SubExp -> VName -> + (VName, VName, VName) -> VName -> - Maybe (Builder GPU (VName, DistInput)) -onMapFreeVar segments env inps w segments_per_elem v = do + Maybe (Builder GPU (VName, MapArray IrregularRep)) +onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do + let segments_per_elem = ws_elems v_inp <- lookup v inps - pure . fmap (v,) $ case v_inp of - DistInputFree v' t -> do - fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_free_inp") - <=< segMap (Solo w) - $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) - DistInput rt t -> case resVar rt env of - Irregular r -> - undefined - Regular vs -> - fmap (`DistInputFree` t) . letExp (baseString v <> "_rep_free_reg_inp") - <=< segMap (Solo w) + pure $ do + ws_prod <- arraySize 0 <$> lookupType ws_elems + fmap (v,) $ case v_inp of + DistInputFree v' t -> do + fmap (`MapArray` t) . letExp (baseString v <> "_rep_free_free_inp") + <=< segMap (Solo ws_prod) $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex vs [eSubExp segment]) + subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) + DistInput rt t -> case resVar rt env of + Irregular r -> do + rep <- getIrregRep segments env inps v + -- FIXME, wrong: should be 'map (`replicate` rep) ws + rep' <- replicateIrreg segments env ws (baseString v) rep + pure $ MapOther rep' t + Regular vs -> + fmap (`MapArray` t) . letExp (baseString v <> "_rep_free_reg_inp") + <=< segMap (Solo ws_prod) + $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex vs [eSubExp segment]) onMapInputArr :: Segments -> @@ -515,7 +525,7 @@ onMapInputArr :: SubExp -> Param Type -> VName -> - Builder GPU DistInput + Builder GPU (MapArray t) onMapInputArr segments env inps w p arr = case lookup arr inps of Just v_inp -> @@ -524,7 +534,7 @@ onMapInputArr segments env inps w p arr = v <- letExp (baseString vs <> "_flat") . BasicOp $ Reshape ReshapeArbitrary (Shape [w]) vs - pure $ DistInputFree v t + pure $ MapArray v t DistInput rt t -> case resVar rt env of Irregular r -> do @@ -533,8 +543,9 @@ onMapInputArr segments env inps w p arr = -- element type of the value array, we can map it -- directly. if stripArray (shapeRank (segmentsShape segments)) elems_t == paramType p - then pure $ DistInputFree (irregularElems r) elems_t - else undefined + then pure $ MapArray (irregularElems r) elems_t + else -- Otherwise we need to perform surgery on the metadata. + pure $ MapOther undefined elems_t Regular vs -> undefined Nothing -> do @@ -545,9 +556,14 @@ onMapInputArr segments env inps w p arr = v <- letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ Reshape ReshapeArbitrary (Shape [w] <> arrayShape arr_row_t) arr_rep - pure $ DistInputFree v arr_row_t + pure $ MapArray v arr_row_t -transformMap :: +scopeOfDistInputs :: DistInputs -> Scope GPU +scopeOfDistInputs = scopeOfLParams . map f + where + f (v, inp) = Param mempty v (distInputType inp) + +transformInnerMap :: Segments -> DistEnv -> DistInputs -> @@ -556,25 +572,44 @@ transformMap :: [VName] -> Lambda SOACS -> Builder GPU (VName, VName, VName) -transformMap segments env inps pat w arrs map_lam = do +transformInnerMap segments env inps pat w arrs map_lam = do ws <- elemArr segments env inps w (ws_flags, ws_offsets, ws_elems) <- doRepIota ws new_segment <- arraySize 0 <$> lookupType ws_elems - arrs' <- zipWithM (onMapInputArr segments env inps new_segment) (lambdaParams map_lam) arrs + arrs' <- + zipWithM + (onMapInputArr segments env inps new_segment) + (lambdaParams map_lam) + arrs + let free = freeIn map_lam + free_sizes <- + localScope (scopeOfDistInputs inps) $ + foldMap freeIn <$> mapM lookupType (namesToList free) + let free_and_sizes = namesToList $ free <> free_sizes (free_replicated, replicated) <- fmap unzip . sequence $ - mapMaybe (onMapFreeVar segments env inps new_segment ws_elems) $ - namesToList (freeIn map_lam) - free_ps <- mapM (newParam "free_p" . distInputType) replicated + mapMaybe + (onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems)) + free_and_sizes + free_ps <- + zipWithM + newParam + (map ((<> "_free") . baseString) free_and_sizes) + (map mapArrayRowType replicated) scope <- askScope let substs = M.fromList $ zip free_replicated $ map paramName free_ps map_lam' = - (substituteNames substs map_lam) - { lambdaParams = lambdaParams map_lam <> free_ps - } - distributed = distributeMap scope pat new_segment (arrs' <> replicated) map_lam' - m = transformDistributed (NE.singleton new_segment) distributed - traceM $ unlines ["inner map", prettyString distributed] + substituteNames + substs + ( map_lam + { lambdaParams = free_ps <> lambdaParams map_lam + } + ) + (distributed, arrmap) = + distributeMap scope pat new_segment (replicated <> arrs') map_lam' + m = + transformDistributed arrmap (NE.singleton new_segment) distributed + traceM $ unlines ["inner map distributed", prettyString distributed] addStms =<< runReaderT (runBuilder_ m) scope pure (ws_flags, ws_offsets, ws) @@ -595,7 +630,7 @@ transformDistStm segments env (DistStm inps res stm) = do map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) (ws_flags, ws_offsets, ws) <- - transformMap segments env inps map_pat w arrs map_lam + transformInnerMap segments env inps map_pat w arrs map_lam let (redout_names, mapout_names) = splitAt (redResults reds) (patNames map_pat) elems' <- @@ -606,7 +641,7 @@ transformDistStm segments env (DistStm inps res stm) = do insertRegulars red_tags elems' $ insertIrregulars ws ws_flags ws_offsets (zip map_tags mapout_names) env | Just map_lam <- isMapSOAC form -> do - (ws_flags, ws_offsets, ws) <- transformMap segments env inps pat w arrs map_lam + (ws_flags, ws_offsets, ws) <- transformInnerMap segments env inps pat w arrs map_lam pure $ insertIrregulars ws ws_flags ws_offsets (zip (map distResTag res) $ patNames pat) env _ -> error $ "Unhandled Stm:\n" ++ prettyString stm @@ -618,9 +653,13 @@ distResCerts env = Certs . map f Regular v -> v Irregular {} -> error "resCerts: irregular" -transformDistributed :: Segments -> Distributed -> Builder GPU () -transformDistributed segments (Distributed dstms resmap) = do - env <- foldM (transformDistStm segments) mempty dstms +transformDistributed :: + M.Map ResTag IrregularRep -> + Segments -> + Distributed -> + Builder GPU () +transformDistributed irregs segments (Distributed dstms resmap) = do + env <- foldM (transformDistStm segments) env_initial dstms forM_ (M.toList resmap) $ \(rt, (cs_inps, v, v_t)) -> certifying (distResCerts env cs_inps) $ case resVar rt env of @@ -631,15 +670,17 @@ transformDistributed segments (Distributed dstms resmap) = do let shape = segmentsShape segments <> arrayShape v_t letBindNames [v] $ BasicOp (Reshape ReshapeArbitrary shape (irregularElems irreg)) + where + env_initial = DistEnv {distResMap = M.map Irregular irregs} transformStm :: Scope SOACS -> Stm SOACS -> PassM (Stms GPU) transformStm scope (Let pat _ (Op (Screma w arrs form))) | Just lam <- isMapSOAC form = do let arrs' = - zipWith DistInputFree arrs $ + zipWith MapArray arrs $ map paramType (lambdaParams (scremaLambda form)) - distributed = distributeMap scope pat w arrs' lam - m = transformDistributed (NE.singleton w) distributed + (distributed, _) = distributeMap scope pat w arrs' lam + m = transformDistributed mempty (NE.singleton w) distributed traceM $ prettyString distributed runReaderT (runBuilder_ m) scope transformStm _ stm = pure $ oneStm $ soacsStmToGPU stm diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index e31e46cd66..bbf7e2d6be 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -1,5 +1,7 @@ module Futhark.Pass.Flatten.Distribute ( distributeMap, + MapArray (..), + mapArrayRowType, ResMap, Distributed (..), DistStm (..), @@ -128,10 +130,10 @@ resultMap avail_inputs stms pat res = mconcat $ map f stms splitIrregDims :: Names -> Type -> (Rank, Type) splitIrregDims bound_outside (Array pt shape u) = - let (irreg, reg) = second reverse $ span regDim $ reverse $ shapeDims shape + let (reg, irreg) = second reverse $ span regDim $ reverse $ shapeDims shape in (Rank $ length irreg, Array pt (Shape reg) u) where - regDim (Var v) = v `notNameIn` bound_outside + regDim (Var v) = v `nameIn` bound_outside regDim Constant {} = True splitIrregDims _ t = (mempty, t) @@ -143,19 +145,50 @@ patInput :: ResTag -> PatElem Type -> (VName, DistInput) patInput tag pe = (patElemName pe, DistInput tag $ patElemType pe) -distributeMap :: Scope rep -> Pat Type -> SubExp -> [DistInput] -> Lambda SOACS -> Distributed +-- | The input we are mapping over in 'distributeMap'. +data MapArray t + = -- | A straightforward array passed in to a + -- top-level map. + MapArray VName Type + | -- | Something more exotic - distribution will assign it a + -- 'ResTag', but not do anything else. This is used to + -- distributed nested maps whose inputs are produced in the outer + -- nests. + MapOther t Type + +mapArrayRowType :: MapArray t -> Type +mapArrayRowType (MapArray _ t) = t +mapArrayRowType (MapOther _ t) = t + +distributeMap :: + Scope rep -> + Pat Type -> + SubExp -> + [MapArray t] -> + Lambda SOACS -> + (Distributed, M.Map ResTag t) distributeMap outer_scope map_pat w arrs lam = - let param_inputs = - zipWith paramInput (lambdaParams lam) arrs + let ((tag, arrmap), param_inputs) = + L.mapAccumL paramInput (ResTag 0, mempty) $ + zip (lambdaParams lam) arrs ((_, avail_inputs), stms) = - L.mapAccumL distributeStm (ResTag 0, param_inputs) $ + L.mapAccumL distributeStm (tag, param_inputs) $ stmsToList $ bodyStms $ lambdaBody lam - in Distributed stms $ resultMap avail_inputs stms map_pat $ bodyResult $ lambdaBody lam + in ( Distributed stms $ resultMap avail_inputs stms map_pat $ bodyResult $ lambdaBody lam, + arrmap + ) where bound_outside = namesFromList $ M.keys outer_scope - paramInput p arr = (paramName p, arr) + paramInput (ResTag i, m) (p, MapArray arr t) = + ( (ResTag i, m), + (paramName p, DistInputFree arr $ paramType p) + ) + paramInput (ResTag i, m) (p, MapOther x t) = + ( (ResTag (i + 1), M.insert (ResTag i) x m), + (paramName p, DistInput (ResTag i) $ paramType p) + ) distType t = uncurry (DistType w) $ splitIrregDims bound_outside t distributeStm (ResTag tag, avail_inputs) stm = let pat = stmPat stm diff --git a/tests/flattening/map-nested-deeper.fut b/tests/flattening/map-nested-deeper.fut index 0731c53015..f072791603 100644 --- a/tests/flattening/map-nested-deeper.fut +++ b/tests/flattening/map-nested-deeper.fut @@ -1,7 +1,9 @@ -- == --- input { [5i64,7i64] [5i64,7i64] } --- output { [23i64, 28i64] } +-- input { [5i64,7i64] [[5],[7]] } +-- output { [7,9] } -def main = map2 (\n x -> - let A = #[opaque] map (+2) (iota n) - in i64.sum (map (\i -> A[i%x]) (iota n))) +def main = map2 (\n xs -> + #[unsafe] + let A = #[opaque] replicate n xs + let B = #[opaque] map (\x -> (opaque x)[0]+2) A + in B[0]) From 024e6c75c8e25f4c8dc3d14087d7c1eb16938852 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 21 May 2023 17:47:01 +0200 Subject: [PATCH 47/64] Handle free irregular arrays in nested map. --- src/Futhark/Pass/Flatten.hs | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 9788b26733..77e86968e0 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -506,10 +506,19 @@ onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) DistInput rt t -> case resVar rt env of - Irregular r -> do - rep <- getIrregRep segments env inps v - -- FIXME, wrong: should be 'map (`replicate` rep) ws - rep' <- replicateIrreg segments env ws (baseString v) rep + Irregular rep -> do + offsets <- letExp (baseString v <> "_rep_free_irreg_offsets") + <=< segMap (Solo ws_prod) + $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex ws_elems [eSubExp i] + subExpsRes . pure <$> (letSubExp "v" =<< eIndex (irregularOffsets rep) [eSubExp segment]) + let rep' = + IrregularRep + { irregularSegments = ws, + irregularFlags = irregularFlags rep, + irregularOffsets = offsets, + irregularElems = irregularElems rep + } pure $ MapOther rep' t Regular vs -> fmap (`MapArray` t) . letExp (baseString v <> "_rep_free_reg_inp") From 29dacc57a9271fdb87d2d455b4f7a03d88eb0552 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 22 May 2023 09:23:02 +0200 Subject: [PATCH 48/64] Another test. --- tests/flattening/map-nested-free2d.fut | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 tests/flattening/map-nested-free2d.fut diff --git a/tests/flattening/map-nested-free2d.fut b/tests/flattening/map-nested-free2d.fut new file mode 100644 index 0000000000..57af621dd1 --- /dev/null +++ b/tests/flattening/map-nested-free2d.fut @@ -0,0 +1,9 @@ +-- == +-- input { [5i64,7i64] [5i64,7i64] [3i64,2i64] } +-- output { [3i64, 2i64] } + +def main = map3 (\n m x -> + #[unsafe] + let A = #[opaque] replicate n (replicate m x) + let B = #[opaque] map (\i -> A[i%x,i%x]) (iota n) + in B[0]) From 2b5809bfe6aae33c19cd44ce300f731f1e41abbb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 22 May 2023 09:49:32 +0200 Subject: [PATCH 49/64] Handle free irregular arrays in nested maps. --- src/Futhark/Pass/Flatten.hs | 44 +++++++++++++++++++------- tests/flattening/map-nested-deeper.fut | 2 +- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 77e86968e0..548d69e31a 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -531,30 +531,50 @@ onMapInputArr :: Segments -> DistEnv -> DistInputs -> - SubExp -> + VName -> Param Type -> VName -> - Builder GPU (MapArray t) -onMapInputArr segments env inps w p arr = + Builder GPU (MapArray IrregularRep) +onMapInputArr segments env inps ii2 p arr = do + ws_prod <- arraySize 0 <$> lookupType ii2 case lookup arr inps of Just v_inp -> case v_inp of DistInputFree vs t -> do v <- letExp (baseString vs <> "_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [w]) vs + Reshape ReshapeArbitrary (Shape [ws_prod]) vs pure $ MapArray v t - DistInput rt t -> + DistInput rt _ -> case resVar rt env of - Irregular r -> do - elems_t <- lookupType $ irregularElems r + Irregular rep -> do + elems_t <- lookupType $ irregularElems rep -- If parameter type of the map corresponds to the -- element type of the value array, we can map it -- directly. if stripArray (shapeRank (segmentsShape segments)) elems_t == paramType p - then pure $ MapArray (irregularElems r) elems_t - else -- Otherwise we need to perform surgery on the metadata. - pure $ MapOther undefined elems_t + then pure $ MapArray (irregularElems rep) elems_t + else do + -- Otherwise we need to perform surgery on the metadata. + ~[p_segments, p_offsets] <- letTupExp + (baseString (paramName p) <> "_rep_inp_irreg") + <=< segMap (Solo ws_prod) + $ \(Solo i) -> do + segment_i <- + letSubExp "segment" =<< eIndex ii2 [eSubExp i] + segment <- + letSubExp "v" =<< eIndex (irregularSegments rep) [eSubExp segment_i] + offset <- + letSubExp "v" =<< eIndex (irregularOffsets rep) [eSubExp segment_i] + pure $ subExpsRes [segment, offset] + let rep' = + IrregularRep + { irregularElems = irregularElems rep, + irregularFlags = irregularFlags rep, + irregularSegments = p_segments, + irregularOffsets = p_offsets + } + pure $ MapOther rep' elems_t Regular vs -> undefined Nothing -> do @@ -564,7 +584,7 @@ onMapInputArr segments env inps w p arr = Replicate (segmentsShape segments) (Var arr) v <- letExp (baseString arr <> "_inp_rep_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [w] <> arrayShape arr_row_t) arr_rep + Reshape ReshapeArbitrary (Shape [ws_prod] <> arrayShape arr_row_t) arr_rep pure $ MapArray v arr_row_t scopeOfDistInputs :: DistInputs -> Scope GPU @@ -587,7 +607,7 @@ transformInnerMap segments env inps pat w arrs map_lam = do new_segment <- arraySize 0 <$> lookupType ws_elems arrs' <- zipWithM - (onMapInputArr segments env inps new_segment) + (onMapInputArr segments env inps ws_elems) (lambdaParams map_lam) arrs let free = freeIn map_lam diff --git a/tests/flattening/map-nested-deeper.fut b/tests/flattening/map-nested-deeper.fut index f072791603..f941f80e24 100644 --- a/tests/flattening/map-nested-deeper.fut +++ b/tests/flattening/map-nested-deeper.fut @@ -5,5 +5,5 @@ def main = map2 (\n xs -> #[unsafe] let A = #[opaque] replicate n xs - let B = #[opaque] map (\x -> (opaque x)[0]+2) A + let B = #[opaque] map (\x -> (opaque x)[0]+2i32) A in B[0]) From 9c8351b841e3d4f7043424351787c7f8e70db34b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 30 May 2023 19:31:40 +0200 Subject: [PATCH 50/64] Handle distribution of free and identity results. --- src/Futhark/Pass/Flatten.hs | 11 +++++- src/Futhark/Pass/Flatten/Distribute.hs | 55 ++++++++++++++++++++------ 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 548d69e31a..f831a2cac5 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -687,7 +687,8 @@ transformDistributed :: Segments -> Distributed -> Builder GPU () -transformDistributed irregs segments (Distributed dstms resmap) = do +transformDistributed irregs segments dist = do + let Distributed dstms (DistResults resmap reps) = dist env <- foldM (transformDistStm segments) env_initial dstms forM_ (M.toList resmap) $ \(rt, (cs_inps, v, v_t)) -> certifying (distResCerts env cs_inps) $ @@ -699,6 +700,14 @@ transformDistributed irregs segments (Distributed dstms resmap) = do let shape = segmentsShape segments <> arrayShape v_t letBindNames [v] $ BasicOp (Reshape ReshapeArbitrary shape (irregularElems irreg)) + forM_ reps $ \(v, r) -> + case r of + Left se -> + letBindNames [v] $ BasicOp $ Replicate (segmentsShape segments) se + Right (DistInputFree arr _) -> + letBindNames [v] $ BasicOp $ SubExp $ Var arr + Right DistInput {} -> + error "replication of irregular identity result" where env_initial = DistEnv {distResMap = M.map Irregular irregs} diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index bbf7e2d6be..5f734c5399 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -2,6 +2,8 @@ module Futhark.Pass.Flatten.Distribute ( distributeMap, MapArray (..), mapArrayRowType, + DistResults (..), + DistRep, ResMap, Distributed (..), DistStm (..), @@ -14,10 +16,10 @@ module Futhark.Pass.Flatten.Distribute ) where -import Data.Bifunctor (second) +import Data.Bifunctor import Data.List qualified as L import Data.Map qualified as M -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe import Futhark.IR.SOACS import Futhark.Util (nubOrd) import Futhark.Util.Pretty @@ -72,7 +74,15 @@ data DistStm = DistStm -- Third is the element type (i.e. excluding segments). type ResMap = M.Map ResTag ([DistInput], VName, Type) -data Distributed = Distributed [DistStm] ResMap +-- | The results of a map-distribution that were free or identity +-- mapped in the original map function. These correspond to plain +-- replicated arrays. +type DistRep = (VName, Either SubExp DistInput) + +data DistResults = DistResults ResMap [DistRep] + deriving (Eq, Ord, Show) + +data Distributed = Distributed [DistStm] DistResults deriving (Eq, Ord, Show) instance Pretty ResTag where @@ -108,12 +118,16 @@ instance Pretty DistStm where <+> pretty inp instance Pretty Distributed where - pretty (Distributed stms res) = + pretty (Distributed stms (DistResults resmap reps)) = stms' res' where - res' = stack $ map onRes $ M.toList res + res' = stack $ map onRes (M.toList resmap) <> map onRep reps stms' = stack $ map pretty stms onRes (rt, v) = "let" <+> pretty v <+> "=" <+> pretty rt + onRep (v, Left se) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty se) + onRep (v, Right tag) = + "let" <+> pretty v <+> "=" <+> "rep" <> parens (pretty tag) resultMap :: [(VName, DistInput)] -> [DistStm] -> Pat Type -> Result -> ResMap resultMap avail_inputs stms pat res = mconcat $ map f stms @@ -160,6 +174,22 @@ mapArrayRowType :: MapArray t -> Type mapArrayRowType (MapArray _ t) = t mapArrayRowType (MapOther _ t) = t +-- This is used to handle those results that are constants or lambda +-- parameters. +findReps :: [(VName, DistInput)] -> Pat Type -> Lambda SOACS -> [DistRep] +findReps avail_inputs map_pat lam = + mapMaybe f $ zip (patElems map_pat) (bodyResult (lambdaBody lam)) + where + f (pe, SubExpRes _ (Var v)) = + case lookup v avail_inputs of + Nothing -> Just (patElemName pe, Left $ Var v) + Just inp + | v `elem` map paramName (lambdaParams lam) -> + Just (patElemName pe, Right inp) + | otherwise -> Nothing + f (pe, SubExpRes _ (Constant v)) = do + Just (patElemName pe, Left $ Constant v) + distributeMap :: Scope rep -> Pat Type -> @@ -173,23 +203,26 @@ distributeMap outer_scope map_pat w arrs lam = zip (lambdaParams lam) arrs ((_, avail_inputs), stms) = L.mapAccumL distributeStm (tag, param_inputs) $ - stmsToList $ - bodyStms $ - lambdaBody lam - in ( Distributed stms $ resultMap avail_inputs stms map_pat $ bodyResult $ lambdaBody lam, + stmsToList (bodyStms (lambdaBody lam)) + resmap = + resultMap avail_inputs stms map_pat $ + bodyResult (lambdaBody lam) + reps = findReps avail_inputs map_pat lam + in ( Distributed stms $ DistResults resmap reps, arrmap ) where bound_outside = namesFromList $ M.keys outer_scope - paramInput (ResTag i, m) (p, MapArray arr t) = + paramInput (ResTag i, m) (p, MapArray arr _) = ( (ResTag i, m), (paramName p, DistInputFree arr $ paramType p) ) - paramInput (ResTag i, m) (p, MapOther x t) = + paramInput (ResTag i, m) (p, MapOther x _) = ( (ResTag (i + 1), M.insert (ResTag i) x m), (paramName p, DistInput (ResTag i) $ paramType p) ) distType t = uncurry (DistType w) $ splitIrregDims bound_outside t + distributeStm (ResTag tag, avail_inputs) stm = let pat = stmPat stm new_tags = map ResTag $ take (patSize pat) [tag ..] From d112f205e7af39780e969825d8fa6360f128fda7 Mon Sep 17 00:00:00 2001 From: cornelius-sevald <54140942+cornelius-sevald@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:49:08 +0200 Subject: [PATCH 51/64] Start function flattening (#1870) * Start function flattening * `cmp-bench-json.py` rewritten in Haskell (Issue #748) (#1860) * Note in CHANGELOG. * Use new tool. * Remove cmp-bench-json.py. * Fix #1863. (#1864) * This is 0.23.1. * Onwards! * Fix typo. * Remove copyCopyToCopy rule. (#1866) This is a very old (5+ years) rule that is much too naive in its handling of memory. We have better optimisations now, that aren't buggy. * Remove SrcLoc from ImportName. Syntactic information does not belong in semantic objects. * Use ImportName consistently. (#1869) Previously some parts of the compiler would use FilePaths directly, and it is ambiguous whether those refer to canonical import names. Now it should be clearer. * futhark-benchmarks: bump * Workaround for tiny /tmp on these servers. * futhark-benchmarks: bump * futhark-benchmarks: bump * futhark-benchmarks: bump * Workaround for temporary ghcup breakage. * Switch to GHC 9.4 in Cabal CI. (#1871) If this does not fix Windows, then I will remove it (again). * Plain values should never be Unique. * No need for this. * Also no setUniqueness here. * futhark-benchmarks: bump * Fix #1874. * Avoid spurious space. * Make consumption an effect on functions, rather than types. (#1873) This is a breaking change, because until now we allowed functions like def f (a: *[]i32, b: []i32) = ... where we could then pass in a tuple where in an application `f (x,y)` the value `x` would be consumed, but not `y`. However, this became increasingly difficult to support as the language grew (and frankly, it was always buggy). With this commit, the syntax above is still permitted, but it is interpreted as def f ((a,b): *([]i32, []i32)) = ... i.e. the single tuple argument is consumed *as a whole*. Long term we can also consider amending the syntax or warning about cases where it is misleading, but that is less urgent. I've wanted to make this simplification for a long time, but I always hit various snags. Today I managed to make it work, and the next step will be cleaning up the notion of "uniqueness" in return types as well (it should be the more general notion of "aliases"). * Forgot a test for #1874. * Avoid warnings about "potentially uninitialized" variables. C compilers are (understandably) not smart enough to see that these are never actually used uninitialised. * Make source language Apply AST node multi-argument. (#1875) This is a deviation from the concrete syntax, but humans tend to think of function calls having multiple arguments. Also, the AST had to keep a lot of useless metadata around to express the results of the intermediate applications. And again, it is related to making #1872 more feasible. * Better constant folding for CmpOp PrimExps. This mostly has the effect of making generated code a little neater. * futhark-benchmarks: bump * Add some comments. * More explicit. * Fix #1878. * Forbid access to interpreter. * Ensure no apply-of-apply. The symptom of this being wrong is that defunctionalisation would create duplicate functions. No more! * Handle array results. * Flattening of Copy. * Use Hendrix for CI. (#1862) * First experiment at using Hendrix for CI. * Maybe like this. * Import everything locally. * Try this. * More systems. * Also OpenCL. * Also depend on these. * More readable when split. * Import new CI actions. * Testing with slurm. * Forgot to specify hendrix and the partition flag might also be needed. * The wrong composite actions was included * Trying cuda and opencl on hendrix * Trying to use the composite test action for benchmarks. * Wrong amount of indentation * Forgot to add a |. * Some small changes that will most likely not change things. * trying to use sbatch * switching to titanrtx and used the p flag wrong. * Trailing whitespace purge. * Skip these on TITAN X. * Any GPU will work for these. * Trying to run benchmarks without slurmbench.py * Syntax errors * Accidentally used old keyword test. * found another syntax error i think * I think the equality sign broke it * maybe this will work * Used gres wrongly. * Do not use old futhark-benchmarks. * Trying to use srun and cleaned up composite actions. * Add some comments. * More explicit. * Fix #1878. * Forbid access to interpreter. * Ensure no apply-of-apply. The symptom of this being wrong is that defunctionalisation would create duplicate functions. No more! * Revert "Trying to use srun and cleaned up composite actions." This reverts commit 6c4111f5e497d286c87aa84b2aa283bcb24d4bff. * using srun and fixing commit history hopefully? * Adding an 8 hour time limit. * Missing -. * Newer version og futhark-benchmarks * Trying to use `${{ always() }}`. * Revert "Newer version og futhark-benchmarks" because of `${{ always() }}` This reverts commit 965e788207c0d4f3d177d7c7d7e170a38919cc5d. * Hopefully this is the correct version of the futhark-benchmarks * Remove always() --------- Co-authored-by: due * Do not use hendrix except where needed. * Cleanup whitespace. * Matplotlib is handy. * Add job names. * Avoid unnecessary deallocation. * These seem broken. * Style fixes. * Bump GHC. * Not needed anymore. * Seems to fix the nontermination. * Support rev AD of scanomaps and scatters with non-identity lambdas. (#1880) * Fix #1883. * Loop over all dimensions here. * Precompute more chunk counts. This is mostly to track the change in the parallelisation of Replicate in the preceding commit. * Allow arbitrary expressions in size expressions. We still only permit elaboration of expressions that correspond to variables or integer constants. This is a step on the path to realising #1659. * Always forget about the unit tests. * Avoid extra braces when printing. * Oops; fix copy/paste error. * These brackets are necessary. * Fix typo. * A few other wording fixes. * A few more text improvements. * Fix error in manifest schema discovered by @Erk-. * Newer action. * Fix invalid link Thanks to @lkuty for noticing. * Use explicit entry. * Fix #1885. * Better style. * Plotting tool. (#1877) Closes #1861. * Make executable. * Remove trailing whitespace. * Final status message. * Use GitHub machines for Python tests. * Generate tuning param definitions in GenericC. (#1890) This is a step towards #1884. Now that GenericC is responsible for all the work (and has all the information), it can generate new API functions. * Record which tuning params are relevant to which entry points. (#1891) This involves extending the manifest and server protocol, and modifying 'futhark autotune' to use this new information. The main advantage (apart from general cleanup) is that we can now tune threshold parameters used in non-inlined functions. * This is 0.24.1. * Onwards! * Fix #1895. * Do not use interpreter. * Incomplete work on nested maps. * More work on nested maps. * Fix #1896. * This goes in tests. * Use Hendrix for A100 jobs. (#1898) * Fail early. * All these SegOps should be virtualised. * Start function flattening * Incomplete work on function lifting * Very rudimentary lifted function results Currently only handles lifting of functions whose return types are scalar typed variables i.e. no constants or arrays. * Work on lifted function results * Further work on lifted function results * Change way return types are lifted * Correctly return constants from lifted functions * Existential size return for lifted functions Merge building of body statements and results for lifted functions. Will probably need to filter out existential size quantifiers before lifting results. * Filter existential sizes from lifted functions Remove existential quantifiers from the return type and result of a function before lifting as I believe their lifted version aren't needed. * Revert "Filter existential sizes from lifted functions" This reverts commit d04ecc5b3aa88923828bf3ecb8209bb121c8dbbe. It might be useful later but for now it complicates things. * Application of lifted functions * Do not lift entry points. * Work in progress match-expression flattening * Fix bug in lifting function parameters Lifting irregular parameters was (wrongly) in the order `[offsets, flags, segments, elements]`. When calling, the arguments were (rightly) given in the order `[segments, flags, offsets, elements]`. * Fix bug in lifting of if-then-else Wrote too many elements in the final scatters. * Make lifted if-then-else a little nicer * Handle irregular inputs to if-expressions * Handle irregular results of if-expressions * Handle general irregular match-expressions * Irregular match-expr: handle empty arrays * Better error messages * Handle free variables in `liftArg` `inputReps` now also gives type information, which is used by `liftArg` to determine if free variables are regular or irregular. * Flatten builtins scans over multi-dim arrays Let scan functions (genScanomap, genScan, genExScan, ...) in the flatten builtins module operate on multi-dimensional arrays. Of note is that `exScanAndSum`, when given a single-dimensional array, will return the # of segments and sum of segment sizes as scalar values and when given a multi-dimensional array will return them as arrays. Also move `segMap` from Flatten.hs to Flatten.Builtins.hs * Make sure flag and elems array have same size When passing flag and elems array to a function, or returning them from a function, resize them to please the type checker. * Replicate free vars in result of lifted functions * Handle free variables in match-expressions Move the common "if a subexp is a constant or free variable, replicate it, and otherwise do a lookup in dist inputs and dist env" code to a function `liftSubExp`. This is used in `liftArg`, `liftResult` and lifting match-expressions. * Add tests for lifting functions * Add tests for flattening match-expressions --------- Co-authored-by: Troels Henriksen --- src/Futhark/Pass/Flatten.hs | 434 +++++++++++++++++- src/Futhark/Pass/Flatten/Builtins.hs | 204 ++++++-- src/Futhark/Pass/Flatten/Distribute.hs | 64 +-- src/Futhark/Util.hs | 7 + .../function-lifting/func_const.fut | 22 + .../flattening/function-lifting/func_free.fut | 27 ++ .../function-lifting/func_fully_irreg.fut | 23 + .../function-lifting/func_irreg_input.fut | 17 + .../function-lifting/func_irreg_result.fut | 17 + .../function-lifting/func_irreg_update.fut | 23 + .../flattening/function-lifting/func_mix.fut | 25 + .../function-lifting/func_mix_nested.fut | 31 ++ .../function-lifting/func_simple.fut | 16 + tests/flattening/match-case/if.fut | 17 + .../flattening/match-case/if_fully_irreg.fut | 24 + .../flattening/match-case/if_irreg_input.fut | 17 + .../flattening/match-case/if_irreg_result.fut | 20 + .../match-case/match_fully_irreg.fut | 25 + 18 files changed, 936 insertions(+), 77 deletions(-) create mode 100644 tests/flattening/function-lifting/func_const.fut create mode 100644 tests/flattening/function-lifting/func_free.fut create mode 100644 tests/flattening/function-lifting/func_fully_irreg.fut create mode 100644 tests/flattening/function-lifting/func_irreg_input.fut create mode 100644 tests/flattening/function-lifting/func_irreg_result.fut create mode 100644 tests/flattening/function-lifting/func_irreg_update.fut create mode 100644 tests/flattening/function-lifting/func_mix.fut create mode 100644 tests/flattening/function-lifting/func_mix_nested.fut create mode 100644 tests/flattening/function-lifting/func_simple.fut create mode 100644 tests/flattening/match-case/if.fut create mode 100644 tests/flattening/match-case/if_fully_irreg.fut create mode 100644 tests/flattening/match-case/if_irreg_input.fut create mode 100644 tests/flattening/match-case/if_irreg_result.fut create mode 100644 tests/flattening/match-case/match_fully_irreg.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index f831a2cac5..56f030b67f 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -15,9 +15,10 @@ import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (bimap, first, second) import Data.Foldable +import Data.List qualified as L import Data.List.NonEmpty qualified as NE import Data.Map qualified as M -import Data.Maybe (fromMaybe, mapMaybe) +import Data.Maybe (fromMaybe, isNothing, mapMaybe) import Data.Tuple.Solo import Debug.Trace import Futhark.IR.GPU @@ -31,6 +32,7 @@ import Futhark.Pass.Flatten.Distribute import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute +import Futhark.Util (mapEither) import Futhark.Util.IntegralExp import Prelude hiding (div, rem) @@ -108,6 +110,14 @@ segsAndElems env (DistInput rt _ : vs) = Irregular (IrregularRep segments flags offsets elems) -> bimap (mplus $ Just (segments, flags, offsets)) (elems :) $ segsAndElems env vs +-- Mapping from original variable names to their distributed resreps +inputReps :: DistInputs -> DistEnv -> M.Map VName (Type, ResRep) +inputReps inputs env = M.fromList $ map (second getRep) inputs + where + getRep di = case di of + DistInput rt t -> (t, resVar rt env) + DistInputFree v' t -> (t, Regular v') + type Segments = NE.NonEmpty SubExp segmentsShape :: Segments -> Shape @@ -116,19 +126,6 @@ segmentsShape = Shape . toList segmentsDims :: Segments -> [SubExp] segmentsDims = shapeDims . segmentsShape -segMap :: Traversable f => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) -segMap segments f = do - gtids <- traverse (const $ newVName "gtid") segments - space <- mkSegSpace $ zip (toList gtids) (toList segments) - ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do - res <- f $ fmap Var gtids - ts <- mapM (subExpType . resSubExp) res - pure (map mkResult res, ts) - let kbody = KernelBody () stms res - pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody - where - mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se - readInput :: Segments -> DistEnv -> [SubExp] -> DistInputs -> SubExp -> Builder GPU SubExp readInput _ _ _ _ (Constant x) = pure $ Constant x readInput segments env is inputs (Var v) = @@ -672,8 +669,238 @@ transformDistStm segments env (DistStm inps res stm) = do | Just map_lam <- isMapSOAC form -> do (ws_flags, ws_offsets, ws) <- transformInnerMap segments env inps pat w arrs map_lam pure $ insertIrregulars ws ws_flags ws_offsets (zip (map distResTag res) $ patNames pat) env + Let _ _ (Match scrutinees cases defaultCase _) -> do + let [w] = NE.toList segments + + -- Lift the scrutinees. + -- If it's a variable, we know it's a scalar and the lifted version will therefore be a regular array. + lifted_scrutinees <- forM scrutinees $ \scrut -> do + (_, rep) <- liftSubExp segments inps env scrut + case rep of + Regular v' -> pure v' + Irregular {} -> + error $ + "transformDistStm: Non-scalar match scrutinee: " ++ prettyString scrut + -- Cases for tagging values that match the same branch. + -- The default case is the 0'th equvalence class. + let equiv_cases = + zipWith + ( \(Case pat _) n -> + Case pat $ eBody [toExp $ intConst Int64 n] + ) + cases + [1 ..] + let equiv_case_default = eBody [toExp $ intConst Int64 0] + -- Match the scrutinees againts the branch cases + equiv_classes <- letExp "equiv_classes" <=< segMap (Solo w) $ \(Solo i) -> do + scruts <- mapM (letSubExp "scruts" <=< flip eIndex [toExp i]) lifted_scrutinees + cls <- letSubExp "cls" =<< eMatch scruts equiv_cases equiv_case_default + pure [subExpRes cls] + let num_cases = fromIntegral $ length cases + 1 + n_cases <- letExp "n_cases" <=< toExp $ intConst Int64 num_cases + -- Parition the indices of the scrutinees by their equvalence class such + -- that (the indices) of the scrutinees belonging to class 0 come first, + -- then those belonging to class 1 and so on. + (partition_sizes, partition_offs, partition_inds) <- doPartition n_cases equiv_classes + inds_t <- lookupType partition_inds + -- Get the indices of each scrutinee by equivalence class + inds <- forM [0 .. num_cases - 1] $ \i -> do + num_elems <- + letSubExp ("size" ++ show i) + =<< eIndex partition_sizes [toExp $ intConst Int64 i] + begin <- + letSubExp ("idx_begin" ++ show i) + =<< eIndex partition_offs [toExp $ intConst Int64 i] + letExp ("inds_branch" ++ show i) $ + BasicOp $ + Index partition_inds $ + fullSlice inds_t [DimSlice begin num_elems (intConst Int64 1)] + + -- Take the elements at index `is` from an input `v`. + let splitInput is v = do + (t, rep) <- liftSubExp segments inps env (Var v) + (t,v,) <$> case rep of + Regular arr -> do + -- In the regular case we just take the elements + -- of the array given by `is` + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + arr' <- letExp "split_arr" <=< segMap (Solo n) $ \(Solo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "arr" =<< eIndex arr [toExp idx]) + pure $ Regular arr' + Irregular (IrregularRep segs flags offsets elems) -> do + -- In the irregular case we take the elements + -- of the `segs` array given by `is` like in the regular case + n <- letSubExp "n" =<< (toExp . arraySize 0 =<< lookupType is) + segs' <- letExp "split_segs" <=< segMap (Solo n) $ \(Solo i) -> do + idx <- letExp "idx" =<< eIndex is [eSubExp i] + subExpsRes . pure <$> (letSubExp "segs" =<< eIndex segs [toExp idx]) + -- From this we calculate the offsets and number of elements + (_, offsets', num_elems) <- exScanAndSum segs' + (_, _, ii1) <- doRepIota segs' + (_, _, ii2) <- doSegIota segs' + -- We then take the elements we need from `elems` and `flags` + -- For each index `i`, we roughly: + -- Get the offset of the segment we want to copy by indexing + -- `offsets` through `is` further through `ii1` i.e. + -- `offset = offsets[is[ii1[i]]]` + -- We then add `ii2[i]` to `offset` + -- and use that to index into `elems` and `flags`. + ~[flags', elems'] <- letTupExp "split_flags_elems" <=< segMap (Solo num_elems) $ \(Solo i) -> do + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp i]]] + idx <- letExp "idx" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp i]) + flags_split <- letSubExp "flags" =<< eIndex flags [toExp idx] + elems_split <- letSubExp "elems" =<< eIndex elems [toExp idx] + pure $ subExpsRes [flags_split, elems_split] + pure $ + Irregular $ + IrregularRep + { irregularSegments = segs', + irregularFlags = flags', + irregularOffsets = offsets', + irregularElems = elems' + } + -- Given the indices for which a branch is taken and its body, + -- distribute the statements of the body of that branch. + let distributeBranch is body = do + (ts, vs, reps) <- unzip3 <$> mapM (splitInput is) (namesToList $ freeIn body) + let inputs = do + (v, t, i) <- zip3 vs ts [0 ..] + pure (v, DistInput (ResTag i) t) + let env' = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + scope <- askScope + let (inputs', dstms) = distributeBody scope w inputs body + pure (inputs', env', dstms) + + -- Distribute and lift the branch bodies. + -- We put the default case at the start as it's the 0'th equivalence class + -- and is therefore the first segment after the partition. + let branch_bodies = defaultCase : map (\(Case _ body) -> body) cases + (branch_inputs, branch_envs, branch_dstms) <- unzip3 <$> zipWithM distributeBranch inds branch_bodies + let branch_results = map bodyResult branch_bodies + lifted_bodies <- forM [0 .. num_cases - 1] $ \i -> do + size <- letSubExp "size" =<< eIndex partition_sizes [toExp $ intConst Int64 i] + let inputs = branch_inputs !! fromIntegral i + let env' = branch_envs !! fromIntegral i + let dstms = branch_dstms !! fromIntegral i + let result = branch_results !! fromIntegral i + res' <- liftBody size inputs env' dstms result + subExpsRes <$> mapM (\(SubExpRes _ se) -> letSubExp ("result" ++ show i) =<< toExp se) res' + + let result_types = map ((\(DistType _ _ t) -> t) . distResType) res + branch_reps <- + mapM + ( fmap (resultToResReps result_types) + . mapM (letExp "branch_result" <=< toExp . resSubExp) + ) + lifted_bodies + + -- Write back the regular results of a branch to a (partially) blank space + let scatterRegular space (is, xs) = do + ~(Array _ (Shape [size]) _) <- lookupType xs + letExp "regular_scatter" + =<< genScatter + space + size + ( \gtid -> do + x <- letSubExp "x" =<< eIndex xs [eSubExp gtid] + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, x) + ) + -- Write back the irregular elements of a branch to a (partially) blank space + -- The `offsets` variable is the offsets of the final result, + -- whereas `irregRep` is the irregular representation of the result of a single branch. + let scatterIrregular offsets space (is, irregRep) = do + let IrregularRep {irregularSegments = segs, irregularElems = elems} = irregRep + (_, _, ii1) <- doRepIota segs + (_, _, ii2) <- doSegIota segs + ~(Array _ (Shape [size]) _) <- lookupType elems + letExp "irregular_scatter" + =<< genScatter + space + size + ( \gtid -> do + x <- letSubExp "x" =<< eIndex elems [eSubExp gtid] + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp gtid]]] + i <- letExp "i" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp gtid]) + pure (i, x) + ) + -- Given a single result from each branch as well the *unlifted* + -- result type, merge the results of all branches into a single result. + let mergeResult iss branchesRep resType = + case resType of + -- Regular case + Prim pt -> do + let xs = map (\(Regular v) -> v) branchesRep + let resultType = Array pt (Shape [w]) NoUniqueness + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + result <- foldM scatterRegular resultSpace $ zip iss xs + pure $ Regular result + -- Irregular case + Array pt _ _ -> do + let branchesIrregRep = map (\(Irregular irregRep) -> irregRep) branchesRep + let segsType = Array (IntType Int64) (Shape [w]) NoUniqueness + -- Create a blank space for the 'segs' + segsSpace <- letExp "blank_segs" =<< eBlank segsType + -- Write back the segs of each branch to the blank space + segs <- foldM scatterRegular segsSpace $ zip iss (irregularSegments <$> branchesIrregRep) + (_, offsets, num_elems) <- exScanAndSum segs + let resultType = Array pt (Shape [num_elems]) NoUniqueness + -- Create the blank space for the result + resultSpace <- letExp "blank_res" =<< eBlank resultType + -- Write back the values of each branch to the blank space + elems <- foldM (scatterIrregular offsets) resultSpace $ zip iss branchesIrregRep + flags <- genFlags num_elems offsets + pure $ + Irregular $ + IrregularRep + { irregularSegments = segs, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems + } + Acc {} -> error "transformDistStm: Acc" + Mem {} -> error "transformDistStm: Mem" + + -- Merge the results of the branches and insert the resulting res reps + reps <- zipWithM (mergeResult inds) (L.transpose branch_reps) result_types + pure $ insertReps (zip (map distResTag res) reps) env + Let _ _ (Apply name args rettype s) -> do + let [w] = NE.toList segments + let name' = liftFunName name + let rettype' = liftRetType w rettype + args' <- ((w, Observe) :) . concat <$> mapM (liftArg segments inps env) args + result <- letTupExp (nameToString name' <> "_res") $ Apply name' args' rettype' s + let reps = resultToResReps rettype result + pure $ insertReps (zip (map distResTag res) reps) env _ -> error $ "Unhandled Stm:\n" ++ prettyString stm +-- | This function walks through the *unlifted* result types +-- and uses the *lifted* results to construct the corresponding res reps. +-- +-- See the 'liftResult' function for the opposite process i.e. +-- turning 'ResRep's into results. +resultToResReps :: [TypeBase s u] -> [VName] -> [ResRep] +resultToResReps types results = + snd $ + L.mapAccumL + ( \rs t -> case t of + Prim {} -> + let (v : rs') = rs + rep = Regular v + in (rs', rep) + Array {} -> + let (_ : segs : flags : offsets : elems : rs') = rs + rep = Irregular $ IrregularRep segs flags offsets elems + in (rs', rep) + Acc {} -> error "resultToResReps: Illegal type 'Acc'" + Mem {} -> error "resultToResReps: Illegal type 'Mem'" + ) + results + types + distResCerts :: DistEnv -> [DistInput] -> Certs distResCerts env = Certs . map f where @@ -727,6 +954,179 @@ transformStms :: Scope SOACS -> Stms SOACS -> PassM (Stms GPU) transformStms scope stms = fold <$> traverse (transformStm (scope <> scopeOf stms)) stms +-- If the sub-expression is a constant, replicate it to match the shape of `segments` +-- If it's a variable, lookup the variable in the dist inputs and dist env, +-- and if it can't be found it is a free variable, so we replicate it to match the shape of `segments`. +liftSubExp :: Segments -> DistInputs -> DistEnv -> SubExp -> Builder GPU (Type, ResRep) +liftSubExp segments inps env se = case se of + c@(Constant prim) -> + let t = Prim $ primValueType prim + in ((t,) . Regular <$> letExp "lifted_const" (BasicOp $ Replicate (segmentsShape segments) c)) + Var v -> case M.lookup v $ inputReps inps env of + Just (t, Regular v') -> do + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> error "getRepSubExp: Acc" + Mem {} -> error "getRepSubExp: Mem" + Just (t, Irregular irreg) -> pure (t, Irregular irreg) + Nothing -> do + t <- lookupType v + v' <- letExp "free_replicated" $ BasicOp $ Replicate (segmentsShape segments) (Var v) + (t,) + <$> case t of + Prim {} -> pure $ Regular v' + Array {} -> Irregular <$> mkIrregFromReg segments v' + Acc {} -> error "getRepSubExp: Acc" + Mem {} -> error "getRepSubExp: Mem" + +liftParam :: SubExp -> FParam SOACS -> PassM ([FParam GPU], ResRep) +liftParam w fparam = + case declTypeOf fparam of + Prim pt -> do + p <- + newParam + (desc <> "_lifted") + (arrayOf (Prim pt) (Shape [w]) Nonunique) + pure ([p], Regular $ paramName p) + Array pt _ u -> do + num_elems <- + newParam (desc <> "_num_elems") $ Prim int64 + segments <- + newParam (desc <> "_segments") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + flags <- + newParam (desc <> "_flags") $ + arrayOf (Prim Bool) (Shape [Var (paramName num_elems)]) Nonunique + offsets <- + newParam (desc <> "_offsets") $ + arrayOf (Prim int64) (Shape [w]) Nonunique + elems <- + newParam (desc <> "_elems") $ + arrayOf (Prim pt) (Shape [Var (paramName num_elems)]) u + pure + ( [num_elems, segments, flags, offsets, elems], + Irregular $ + IrregularRep + { irregularSegments = paramName segments, + irregularFlags = paramName flags, + irregularOffsets = paramName offsets, + irregularElems = paramName elems + } + ) + Acc {} -> + error "liftParam: Acc" + Mem {} -> + error "liftParam: Mem" + where + desc = baseString (paramName fparam) + +liftArg :: Segments -> DistInputs -> DistEnv -> (SubExp, Diet) -> Builder GPU [(SubExp, Diet)] +liftArg segments inps env (se, d) = do + (_, rep) <- liftSubExp segments inps env se + case rep of + Regular v -> pure [(Var v, d)] + Irregular irreg -> mkIrrep irreg + where + mkIrrep + ( IrregularRep + { irregularSegments = segs, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems + } + ) = do + t <- lookupType elems + num_elems <- letExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) elems + -- Only apply the original diet to the 'elems' array + let diets = replicate 4 Observe ++ [d] + pure $ zipWith (curry (first Var)) [num_elems, segs, flags', offsets, elems'] diets + +-- Lifts a functions return type such that it matches the lifted functions return type. +liftRetType :: SubExp -> [RetType SOACS] -> [RetType GPU] +liftRetType w = concat . snd . L.mapAccumL liftType 0 + where + liftType i rettype = + let lifted = case rettype of + Prim pt -> pure $ arrayOf (Prim pt) (Shape [Free w]) Nonunique + Array pt _ u -> + let num_elems = Prim int64 + segs = arrayOf (Prim int64) (Shape [Free w]) Nonunique + flags = arrayOf (Prim Bool) (Shape [Ext i :: Ext SubExp]) Nonunique + offsets = arrayOf (Prim int64) (Shape [Free w]) Nonunique + elems = arrayOf (Prim pt) (Shape [Ext i :: Ext SubExp]) u + in [num_elems, segs, flags, offsets, elems] + Acc {} -> error "liftRetType: Acc" + Mem {} -> error "liftRetType: Mem" + in (i + length lifted, lifted) + +-- Lift a result of a function. +liftResult :: Segments -> DistInputs -> DistEnv -> SubExpRes -> Builder GPU Result +liftResult segments inps env res = map (SubExpRes mempty . Var) <$> vs + where + vs = do + (_, rep) <- liftSubExp segments inps env (resSubExp res) + case rep of + Regular v -> pure [v] + Irregular irreg -> mkIrrep irreg + mkIrrep + ( IrregularRep + { irregularSegments = segs, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems + } + ) = do + t <- lookupType elems + num_elems <- letExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) elems + pure [num_elems, segs, flags', offsets, elems'] + +liftBody :: SubExp -> DistInputs -> DistEnv -> [DistStm] -> Result -> Builder GPU Result +liftBody w inputs env dstms result = do + let segments = NE.singleton w + env' <- foldM (transformDistStm segments) env dstms + result' <- mapM (liftResult segments inputs env') result + pure $ concat result' + +liftFunName :: Name -> Name +liftFunName name = name <> "_lifted" + +liftFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) +liftFunDef const_scope fd = do + let FunDef + { funDefBody = body, + funDefParams = fparams, + funDefRetType = rettype + } = fd + wp <- newParam "w" $ Prim int64 + let w = Var $ paramName wp + (fparams', reps) <- mapAndUnzipM (liftParam w) fparams + let inputs = do + (p, i) <- zip fparams [0 ..] + pure (paramName p, DistInput (ResTag i) (paramType p)) + let rettype' = liftRetType w rettype + let (inputs', dstms) = + distributeBody const_scope (Var (paramName wp)) inputs body + env = DistEnv $ M.fromList $ zip (map ResTag [0 ..]) reps + -- Lift the body of the function and get the results + (result, stms) <- + runReaderT + (runBuilder $ liftBody w inputs' env dstms $ bodyResult body) + (const_scope <> scopeOfFParams (concat fparams')) + let name = liftFunName $ funDefName fd + pure $ + fd + { funDefName = name, + funDefBody = Body () stms result, + funDefParams = wp : concat fparams', + funDefRetType = rettype' + } + transformFunDef :: Scope SOACS -> FunDef SOACS -> PassM (FunDef GPU) transformFunDef consts_scope fd = do let FunDef @@ -746,7 +1146,11 @@ transformProg :: Prog SOACS -> PassM (Prog GPU) transformProg prog = do consts' <- transformStms mempty $ progConsts prog funs' <- mapM (transformFunDef $ scopeOf (progConsts prog)) $ progFuns prog - pure $ prog {progConsts = consts', progFuns = flatteningBuiltins <> funs'} + lifted_funs <- + mapM (liftFunDef $ scopeOf (progConsts prog)) $ + filter (isNothing . funDefEntryPoint) $ + progFuns prog + pure $ prog {progConsts = consts', progFuns = flatteningBuiltins <> lifted_funs <> funs'} -- | Transform a SOACS program to a GPU program, using flattening. flattenSOACs :: Pass SOACS GPU diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 1fd4ef92d3..fcdb95e0e8 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -2,6 +2,7 @@ module Futhark.Pass.Flatten.Builtins ( flatteningBuiltins, + segMap, genFlags, genSegScan, genSegRed, @@ -10,11 +11,13 @@ module Futhark.Pass.Flatten.Builtins doSegIota, doPrefixSum, doRepIota, + doPartition, ) where import Control.Monad.Reader import Control.Monad.State.Strict +import Data.Foldable (toList) import Data.Maybe (fromMaybe) import Data.Text qualified as T import Futhark.IR.GPU @@ -23,21 +26,36 @@ import Futhark.MonadFreshNames import Futhark.Pass.ExtractKernels.BlockedKernel (mkSegSpace) import Futhark.Pass.ExtractKernels.ToGPU (soacsLambdaToGPU) import Futhark.Tools +import Futhark.Util (unsnoc) builtinName :: T.Text -> Name builtinName = nameFromText . ("builtin#" <>) -segIotaName, repIotaName, prefixSumName :: Name +segIotaName, repIotaName, prefixSumName, partitionName :: Name segIotaName = builtinName "segiota" repIotaName = builtinName "repiota" prefixSumName = builtinName "prefixsum" +partitionName = builtinName "partition" -genScanomap :: String -> SubExp -> Lambda GPU -> [SubExp] -> (SubExp -> Builder GPU [SubExp]) -> Builder GPU [VName] -genScanomap desc w lam nes m = do - gtid <- newVName "gtid" - space <- mkSegSpace [(gtid, w)] +segMap :: Traversable f => f SubExp -> (f SubExp -> Builder GPU Result) -> Builder GPU (Exp GPU) +segMap segments f = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) + ((res, ts), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do + res <- f $ fmap Var gtids + ts <- mapM (subExpType . resSubExp) res + pure (map mkResult res, ts) + let kbody = KernelBody () stms res + pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space ts kbody + where + mkResult (SubExpRes cs se) = Returns ResultMaySimplify cs se + +genScanomap :: Traversable f => String -> f SubExp -> Lambda GPU -> [SubExp] -> (f SubExp -> Builder GPU [SubExp]) -> Builder GPU [VName] +genScanomap desc segments lam nes m = do + gtids <- traverse (const $ newVName "gtid") segments + space <- mkSegSpace $ zip (toList gtids) (toList segments) ((res, res_t), stms) <- runBuilder . localScope (scopeOfSegSpace space) $ do - res <- m $ Var gtid + res <- m $ fmap Var gtids res_t <- mapM subExpType res pure (map (Returns ResultMaySimplify mempty) res, res_t) let kbody = KernelBody () stms res @@ -46,20 +64,21 @@ genScanomap desc w lam nes m = do where lvl = SegThread SegVirt Nothing -genScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] -genScan desc w lam nes arrs = - genScanomap desc w lam nes $ \gtid -> forM arrs $ \arr -> - letSubExp (baseString arr <> "_elem") =<< eIndex arr [eSubExp gtid] +genScan :: Traversable f => String -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genScan desc segments lam nes arrs = + genScanomap desc segments lam nes $ \gtids -> forM arrs $ \arr -> + letSubExp (baseString arr <> "_elem") =<< eIndex arr (toList $ fmap eSubExp gtids) -- Also known as a prescan. -genExScan :: String -> SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] -genExScan desc w lam nes arrs = - genScanomap desc w lam nes $ \gtid -> - letTupExp' "to_prescan" - =<< eIf - (toExp $ pe64 gtid .==. 0) - (eBody (map eSubExp nes)) - (eBody (map (`eIndex` [toExp $ pe64 gtid - 1]) arrs)) +genExScan :: Traversable f => String -> f SubExp -> Lambda GPU -> [SubExp] -> [VName] -> Builder GPU [VName] +genExScan desc segments lam nes arrs = + genScanomap desc segments lam nes $ \gtids -> + let Just (outerDims, innerDim) = unsnoc $ toList gtids + in letTupExp' "to_prescan" + =<< eIf + (toExp $ pe64 innerDim .==. 0) + (eBody (map eSubExp nes)) + (eBody (map (`eIndex` (map toExp outerDims ++ [toExp $ pe64 innerDim - 1])) arrs)) segScanLambda :: (MonadBuilder m, BranchType (Rep m) ~ ExtType, LParamInfo (Rep m) ~ Type) => @@ -84,19 +103,19 @@ genSegScan :: String -> Lambda GPU -> [SubExp] -> VName -> [VName] -> Builder GP genSegScan desc lam nes flags arrs = do w <- arraySize 0 <$> lookupType flags lam' <- segScanLambda lam - drop 1 <$> genScan desc w lam' (constant False : nes) (flags : arrs) + drop 1 <$> genScan desc [w] lam' (constant False : nes) (flags : arrs) genPrefixSum :: String -> VName -> Builder GPU VName genPrefixSum desc ns = do - w <- arraySize 0 <$> lookupType ns + ws <- arrayDims <$> lookupType ns add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 - head <$> genScan desc w add_lam [intConst Int64 0] [ns] + head <$> genScan desc ws add_lam [intConst Int64 0] [ns] genExPrefixSum :: String -> VName -> Builder GPU VName genExPrefixSum desc ns = do - w <- arraySize 0 <$> lookupType ns + ws <- arrayDims <$> lookupType ns add_lam <- binOpLambda (Add Int64 OverflowUndef) int64 - head <$> genExScan desc w add_lam [intConst Int64 0] [ns] + head <$> genExScan desc ws add_lam [intConst Int64 0] [ns] genSegPrefixSum :: String -> VName -> VName -> Builder GPU VName genSegPrefixSum desc flags ns = do @@ -156,20 +175,40 @@ genSegRed segments flags offsets elems red = do nes = redNeutral red -- Returns (#segments, segment start offsets, sum of segment sizes) +-- Note: If given a multi-dimensional array, +-- `#segments` and `sum of segment sizes` will be arrays, not scalars. +-- `segment start offsets` will always have the same shape as `ks`. exScanAndSum :: VName -> Builder GPU (SubExp, VName, SubExp) exScanAndSum ks = do - n <- arraySize 0 <$> lookupType ks - is_empty <- letSubExp "is_empty" =<< toExp (pe64 n .==. 0) - offsets <- genExPrefixSum "offsets" ks - m <- - letSubExp "m" - =<< eIf - (eSubExp is_empty) - (eBody [eSubExp $ intConst Int64 0]) - -- Add last size because 'offsets' is an *exclusive* prefix - -- sum. - (eBody [eBinOp (Add Int64 OverflowUndef) (eLast offsets) (eLast ks)]) - pure (n, offsets, m) + ns <- arrayDims <$> lookupType ks + -- If `ks` only has a single dimension + -- the size will be a scalar, otherwise it's an array. + ns' <- letExp "ns" $ BasicOp $ case ns of + [] -> error $ "exScanAndSum: Given non-array argument: " ++ prettyString ks + [n] -> SubExp n + _ -> ArrayLit ns (Prim int64) + -- Check if the innermost dimension is empty. + is_empty <- + letExp "is_empty" + =<< ( case ns of + [n] -> toExp (pe64 n .==. 0) + _ -> eLast ns' >>= letSubExp "n" >>= (\n -> toExp $ pe64 n .==. 0) + ) + offsets <- letExp "offsets" =<< toExp =<< genExPrefixSum "offsets" ks + ms <- letExp "ms" <=< segMap (init ns) $ \gtids -> do + let idxs = map toExp gtids + offset <- letExp "offset" =<< eIndex offsets idxs + k <- letExp "k" =<< eIndex ks idxs + m <- + letSubExp "m" + =<< eIf + (toExp is_empty) + (eBody [eSubExp $ intConst Int64 0]) + -- Add last size because 'offsets' is an *exclusive* prefix + -- sum. + (eBody [eBinOp (Add Int64 OverflowUndef) (eLast offset) (eLast k)]) + pure [subExpRes m] + pure (Var ns', offsets, Var ms) genSegIota :: VName -> Builder GPU (VName, VName, VName) genSegIota ks = do @@ -209,6 +248,54 @@ genRepIota ks = do zero = intConst Int64 0 negone = intConst Int64 (-1) +genPartition :: VName -> VName -> VName -> Builder GPU (VName, VName, VName) +genPartition n k cls = do + let n' = Var n + let k' = Var k + let dims = [k', n'] + -- Create a `[k][n]` array of flags such that `cls_flags[i][j]` + -- is equal 1 if the j'th element is a member of equivalence class `i` i.e. + -- the `i`th row is a flag array for equivalence class `i`. + cls_flags <- + letExp "flags" + <=< segMap dims + $ \[i, j] -> do + c <- letSubExp "c" =<< eIndex cls [toExp j] + cls_flag <- + letSubExp "cls_flag" + =<< eIf + (toExp $ pe64 i .==. pe64 c) + (eBody [toExp $ intConst Int64 1]) + (eBody [toExp $ intConst Int64 0]) + pure [subExpRes cls_flag] + + -- Offsets of each of the individual equivalence classes. + (_, local_offs, _counts) <- exScanAndSum cls_flags + -- The number of elems in each class + counts <- letExp "counts" =<< toExp _counts + -- Offsets of the whole equivalence classes + global_offs <- genExPrefixSum "global_offs" counts + -- Offsets over all of the equivalence classes. + cls_offs <- + letExp "cls_offs" =<< do + segMap dims $ \[i, j] -> do + global_offset <- letExp "global_offset" =<< eIndex global_offs [toExp i] + offset <- + letSubExp "offset" + =<< eBinOp + (Add Int64 OverflowUndef) + (eIndex local_offs [toExp i, toExp j]) + (toExp global_offset) + pure [subExpRes offset] + + scratch <- letExp "scratch" $ BasicOp $ Scratch int64 [n'] + res <- letExp "scatter_res" <=< genScatter scratch n' $ \gtid -> do + c <- letExp "c" =<< eIndex cls [toExp gtid] + ind <- letExp "ind" =<< eIndex cls_offs [toExp c, toExp gtid] + i <- letSubExp "i" =<< toExp gtid + pure (ind, i) + pure (counts, global_offs, res) + buildingBuiltin :: Builder GPU (FunDef GPU) -> FunDef GPU buildingBuiltin m = fst $ evalState (runBuilderT m mempty) blankNameSource @@ -278,12 +365,35 @@ prefixSumBuiltin = buildingBuiltin $ do funDefBody = body } +partitionBuiltin :: FunDef GPU +partitionBuiltin = buildingBuiltin $ do + np <- newParam "n" $ Prim int64 + kp <- newParam "k" $ Prim int64 + csp <- newParam "cs" $ Array int64 (Shape [Var (paramName np)]) Nonunique + body <- + localScope (scopeOfFParams [np, kp, csp]) . buildBody_ $ do + (counts, offsets, res) <- genPartition (paramName np) (paramName kp) (paramName csp) + pure $ varsRes [counts, offsets, res] + pure + FunDef + { funDefEntryPoint = Nothing, + funDefAttrs = mempty, + funDefName = partitionName, + funDefRetType = + [ Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName kp]) Unique, + Array int64 (Shape [Free $ Var $ paramName np]) Unique + ], + funDefParams = [np, kp, csp], + funDefBody = body + } + -- | Builtin functions used in flattening. Must be prepended to a -- program that is transformed by flattening. The intention is to -- avoid the code explosion that would result if we inserted -- primitives everywhere. flatteningBuiltins :: [FunDef GPU] -flatteningBuiltins = [segIotaBuiltin, repIotaBuiltin, prefixSumBuiltin] +flatteningBuiltins = [segIotaBuiltin, repIotaBuiltin, prefixSumBuiltin, partitionBuiltin] -- | @[0,1,2,0,1,0,1,2,3,4,...]@. Returns @(flags,offsets,elems)@. doSegIota :: VName -> Builder GPU (VName, VName, VName) @@ -344,3 +454,25 @@ doPrefixSum ns = do [(n, Observe), (Var ns, Observe)] [toDecl (staticShapes1 ns_t) Unique] (Safe, mempty, mempty) + +doPartition :: VName -> VName -> Builder GPU (VName, VName, VName) +doPartition k cs = do + cs_t <- lookupType cs + let n = arraySize 0 cs_t + counts <- newVName "partition_counts" + offsets <- newVName "partition_offsets" + res <- newVName "partition_res" + let args = [(n, Prim int64), (Var k, Prim int64), (Var cs, cs_t)] + restype = + fromMaybe (error "doPartition: bad application") $ + applyRetType + (funDefRetType partitionBuiltin) + (funDefParams partitionBuiltin) + args + letBindNames [counts, offsets, res] $ + Apply + (funDefName partitionBuiltin) + [(n, Observe), (Var k, Observe), (Var cs, Observe)] + restype + (Safe, mempty, mempty) + pure (counts, offsets, res) diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 5f734c5399..3e64bd0599 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -1,5 +1,6 @@ module Futhark.Pass.Flatten.Distribute ( distributeMap, + distributeBody, MapArray (..), mapArrayRowType, DistResults (..), @@ -12,7 +13,7 @@ module Futhark.Pass.Flatten.Distribute DistType (..), distInputType, DistResult (..), - ResTag, + ResTag (..), ) where @@ -159,6 +160,40 @@ patInput :: ResTag -> PatElem Type -> (VName, DistInput) patInput tag pe = (patElemName pe, DistInput tag $ patElemType pe) +distributeBody :: + Scope rep -> + SubExp -> + DistInputs -> + Body SOACS -> + (DistInputs, [DistStm]) +distributeBody outer_scope w param_inputs body = + let ((_, avail_inputs), stms) = + L.mapAccumL distributeStm (ResTag (length param_inputs), param_inputs) $ + stmsToList $ + bodyStms body + in (avail_inputs, stms) + where + bound_outside = namesFromList $ M.keys outer_scope + distType t = uncurry (DistType w) $ splitIrregDims bound_outside t + distributeStm (ResTag tag, avail_inputs) stm = + let pat = stmPat stm + new_tags = map ResTag $ take (patSize pat) [tag ..] + avail_inputs' = + avail_inputs <> zipWith patInput new_tags (patElems pat) + free_in_stm = freeIn stm + used_free = mapMaybe (freeInput avail_inputs) $ namesToList free_in_stm + used_free_types = + mapMaybe (freeInput avail_inputs) + . namesToList + . foldMap (freeIn . distInputType . snd) + $ used_free + stm' = + DistStm + (nubOrd $ used_free_types <> used_free) + (zipWith DistResult new_tags $ map distType $ patTypes pat) + stm + in ((ResTag $ tag + length new_tags, avail_inputs'), stm') + -- | The input we are mapping over in 'distributeMap'. data MapArray t = -- | A straightforward array passed in to a @@ -198,12 +233,10 @@ distributeMap :: Lambda SOACS -> (Distributed, M.Map ResTag t) distributeMap outer_scope map_pat w arrs lam = - let ((tag, arrmap), param_inputs) = + let ((_, arrmap), param_inputs) = L.mapAccumL paramInput (ResTag 0, mempty) $ zip (lambdaParams lam) arrs - ((_, avail_inputs), stms) = - L.mapAccumL distributeStm (tag, param_inputs) $ - stmsToList (bodyStms (lambdaBody lam)) + (avail_inputs, stms) = distributeBody outer_scope w param_inputs $ lambdaBody lam resmap = resultMap avail_inputs stms map_pat $ bodyResult (lambdaBody lam) @@ -212,7 +245,6 @@ distributeMap outer_scope map_pat w arrs lam = arrmap ) where - bound_outside = namesFromList $ M.keys outer_scope paramInput (ResTag i, m) (p, MapArray arr _) = ( (ResTag i, m), (paramName p, DistInputFree arr $ paramType p) @@ -221,23 +253,3 @@ distributeMap outer_scope map_pat w arrs lam = ( (ResTag (i + 1), M.insert (ResTag i) x m), (paramName p, DistInput (ResTag i) $ paramType p) ) - distType t = uncurry (DistType w) $ splitIrregDims bound_outside t - - distributeStm (ResTag tag, avail_inputs) stm = - let pat = stmPat stm - new_tags = map ResTag $ take (patSize pat) [tag ..] - avail_inputs' = - avail_inputs <> zipWith patInput new_tags (patElems pat) - free_in_stm = freeIn stm - used_free = mapMaybe (freeInput avail_inputs) $ namesToList free_in_stm - used_free_types = - mapMaybe (freeInput avail_inputs) - . namesToList - . foldMap (freeIn . distInputType . snd) - $ used_free - stm' = - DistStm - (nubOrd $ used_free_types <> used_free) - (zipWith DistResult new_tags $ map distType $ patTypes pat) - stm - in ((ResTag $ tag + length new_tags, avail_inputs'), stm') diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 07d541e59d..a45c39442f 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -20,6 +20,7 @@ module Futhark.Util partitionMaybe, maybeNth, maybeHead, + unsnoc, splitFromEnd, splitAt3, focusNth, @@ -172,6 +173,12 @@ maybeHead :: [a] -> Maybe a maybeHead [] = Nothing maybeHead (x : _) = Just x +-- | Split the last element from the list, if it exists. +unsnoc :: [a] -> Maybe ([a], a) +unsnoc [] = Nothing +unsnoc [x] = Just ([], x) +unsnoc (x : xs) = unsnoc xs >>= \(ys, y) -> Just (x : ys, y) + -- | Like 'splitAt', but from the end. splitFromEnd :: Int -> [a] -> ([a], [a]) splitFromEnd i l = splitAt (length l - i) l diff --git a/tests/flattening/function-lifting/func_const.fut b/tests/flattening/function-lifting/func_const.fut new file mode 100644 index 0000000000..d6102298d5 --- /dev/null +++ b/tests/flattening/function-lifting/func_const.fut @@ -0,0 +1,22 @@ +-- Lifting a function with a constants as argument and result +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [7i64, 7i64,10i64,16i64,25i64,37i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) (xs : []i64) : ([]i64, i64) = + let ys = map (x*) xs + in (ys, 7) + +#[noinline] +let foo (x : i64) = + let xs = iota x + let (ys, z) = bar 3 xs + in z + reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_free.fut b/tests/flattening/function-lifting/func_free.fut new file mode 100644 index 0000000000..433ba3156d --- /dev/null +++ b/tests/flattening/function-lifting/func_free.fut @@ -0,0 +1,27 @@ +-- Lifting a function with a free variables as argument and result +-- == +-- entry: main +-- input { [ 0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [280i64,294i64,308i64,322i64,336i64,350i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let v1 : []i64 = [5,9,6] + +#[noinline] +let v2 : []i64 = [3,1,4,1,5] + +#[noinline] +let bar (xs : []i64) (y : i64) : (i64, []i64) = + let z = y + reduce (+) 0 xs + in (z, copy v2) + +#[noinline] +let foo (x : i64) = + let (y, zs) = bar v1 x + let z = reduce (+) 0 zs + in (y * z) + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_fully_irreg.fut b/tests/flattening/function-lifting/func_fully_irreg.fut new file mode 100644 index 0000000000..27382dfba8 --- /dev/null +++ b/tests/flattening/function-lifting/func_fully_irreg.fut @@ -0,0 +1,23 @@ +-- Lifting a function with an irregular +-- parameter and return type +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 3i64, 15i64,45i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : []i64 = + let y = reduce (+) 0 xs + in iota y + +#[noinline] +let foo (x : i64) = + let xs = iota x + let ys = bar xs + in reduce (+) 0 ys + +def main (xs : []i64) = map foo xs + diff --git a/tests/flattening/function-lifting/func_irreg_input.fut b/tests/flattening/function-lifting/func_irreg_input.fut new file mode 100644 index 0000000000..718895b430 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 1i64, 3i64, 6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (xs : []i64) : i64 = reduce (+) 0 xs + +#[noinline] +let foo (x : i64) = + let xs = iota x + in bar xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_result.fut b/tests/flattening/function-lifting/func_irreg_result.fut new file mode 100644 index 0000000000..0225943179 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_result.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- input { [0i64,1i64,2i64,3i64,4i64, 5i64] } +-- output { [0i64,0i64,1i64,3i64,6i64,10i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (x : i64) : []i64 = iota x + +#[noinline] +let foo (x : i64) = + let xs = bar x + in reduce (+) 0 xs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_irreg_update.fut b/tests/flattening/function-lifting/func_irreg_update.fut new file mode 100644 index 0000000000..b329759b91 --- /dev/null +++ b/tests/flattening/function-lifting/func_irreg_update.fut @@ -0,0 +1,23 @@ +-- Lifting a function which consumes its argument +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, 0i64, 1i64, 2i64, 4i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar [n] (xs : *[n]i64) (z : i64) (ys : [z]i64) : [n]i64 = + let m = n - z + in xs with [m:n] = ys + +#[noinline] +let foo (a : i64) = + let b = a / 2 + let xs = iota a + let ys = iota b :> [b]i64 + let zs = bar xs b ys + in reduce (+) 0 zs + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/function-lifting/func_mix.fut b/tests/flattening/function-lifting/func_mix.fut new file mode 100644 index 0000000000..cca3be9c17 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix.fut @@ -0,0 +1,25 @@ +-- Lifting a function with both regular and irregular +-- parameters and return types. +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 0i64, -1i64, 27i64, 252i64, 1175i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 9i64, 9i64, 0i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + in (iota z, z) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_mix_nested.fut b/tests/flattening/function-lifting/func_mix_nested.fut new file mode 100644 index 0000000000..0431af1e94 --- /dev/null +++ b/tests/flattening/function-lifting/func_mix_nested.fut @@ -0,0 +1,31 @@ +-- Lifting a function that calls another function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64] [0i64, 1i64, 2i64, 3i64, 4i64] } +-- output { [0i64, 0i64, 0i64, 52290i64, 21935100i64] } +-- input { [5i64, 4i64, 3i64, 2i64, 1i64, 0i64] [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [0i64, 3990i64, 3990i64, 33i64, 0i64, 0i64] } +-- input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +let baz (xs : []i64) (y : i64) : ([]i64, []i64) = + let z = y * reduce (+) 0 xs + in (iota y, iota z) + +#[noinline] +let bar (y : i64) (xs : []i64) : ([]i64, i64) = + let z = y * reduce (+) 0 xs + let (as, bs) = baz (iota z) z + let a = reduce (+) 0 as + in (bs, a) + +#[noinline] +let foo (a : i64) (b : i64) = + let xs = iota a + let (ys, z) = bar b xs + in reduce (+) 0 ys - z + +def main (as : []i64) (bs : []i64) = map2 foo as bs + diff --git a/tests/flattening/function-lifting/func_simple.fut b/tests/flattening/function-lifting/func_simple.fut new file mode 100644 index 0000000000..22002bbf3e --- /dev/null +++ b/tests/flattening/function-lifting/func_simple.fut @@ -0,0 +1,16 @@ +-- Lifting a simple function +-- == +-- entry: main +-- input { [0i64, 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [1i64, 2i64, 3i64, 4i64, 5i64, 6i64] } +-- input { empty([0]i64) } +-- output { empty([0]i64) } + + +#[noinline] +def bar (x : i64) = x + 1 + +#[noinline] +def foo (x : i64) = bar x + +def main (xs : []i64) = map foo xs diff --git a/tests/flattening/match-case/if.fut b/tests/flattening/match-case/if.fut new file mode 100644 index 0000000000..cc1594d32e --- /dev/null +++ b/tests/flattening/match-case/if.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-1i64,1i64,-2i64,2i64,-3i64,3i64] } +-- output { [ 1i64,2i64, 4i64,4i64, 9i64,6i64] } +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [25i64, 9i64,8i64,4i64,0i64, 1i64,6i64,2i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 2i64, 4i64, 6i64, 8i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [ 1i64, 4i64, 9i64,16i64,25i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = if x < 0 then x * x else x * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_fully_irreg.fut b/tests/flattening/match-case/if_fully_irreg.fut new file mode 100644 index 0000000000..4efbbe1771 --- /dev/null +++ b/tests/flattening/match-case/if_fully_irreg.fut @@ -0,0 +1,24 @@ +-- == +-- entry: main +-- nobench input { [ 2i64, 7i64, 1i64, 8i64, 7i64] } +-- output { [ 2i64, 23i64, 0i64, 31i64, 23i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 2i64, 6i64, 12i64, 20i64] } +-- nobench input { [ 6i64, 7i64, 8i64, 9i64, 10i64] } +-- output { [16i64, 23i64, 31i64, 40i64, 50i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar [n] (xs : [n]i64) = + if n <= 5 then (false, xs) + else (true, copy xs with [5] = n) + +#[noinline] +let foo (x : i64) = + let xs = iota x in + let (b, ys) = bar xs + let z = reduce (+) 0 ys + in if b then z else z * 2 + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_irreg_input.fut b/tests/flattening/match-case/if_irreg_input.fut new file mode 100644 index 0000000000..8b4a164b0b --- /dev/null +++ b/tests/flattening/match-case/if_irreg_input.fut @@ -0,0 +1,17 @@ +-- == +-- entry: main +-- nobench input { [-5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [-1i64,-1i64,6i64,1i64,0i64,-1i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64,10i64] } +-- nobench input { [-1i64,-2i64,-3i64,-4i64,-5i64] } +-- output { [-1i64,-1i64,-1i64,-1i64,-1i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) = + let ys = iota (i64.abs x) + in if x < 0 then -1 else reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/if_irreg_result.fut b/tests/flattening/match-case/if_irreg_result.fut new file mode 100644 index 0000000000..218780ae1e --- /dev/null +++ b/tests/flattening/match-case/if_irreg_result.fut @@ -0,0 +1,20 @@ +-- == +-- entry: main +-- nobench input { [ -5i64,-3i64,4i64,2i64,0i64,-1i64,3i64,1i64] } +-- output { [300i64,36i64,6i64,1i64,0i64, 0i64,3i64,0i64] } +-- nobench input { [ 1i64, 2i64, 3i64, 4i64, 5i64] } +-- output { [ 0i64, 1i64, 3i64, 6i64, 10i64] } +-- nobench input { [ 1i64,-2i64,-3i64, -4i64, -5i64] } +-- output { [ 0i64, 6i64,36i64,120i64,300i64] } +-- nobench input { empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let bar (x : i64) = if x < 0 then iota (x*x) else iota x + +#[noinline] +let foo (x : i64) = + let ys = bar x + in reduce (+) 0 ys + +def main [n] (xs : [n]i64) = map foo xs diff --git a/tests/flattening/match-case/match_fully_irreg.fut b/tests/flattening/match-case/match_fully_irreg.fut new file mode 100644 index 0000000000..72f4b6b2c5 --- /dev/null +++ b/tests/flattening/match-case/match_fully_irreg.fut @@ -0,0 +1,25 @@ +-- == +-- entry: main +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [0i64, 1i64, 2i64, 0i64, 1i64, 2i64, 0i64, 1i64, 2i64] } +-- output { [7i64, -5i64, -4i64, 2i64, -1i64, -1i64, 1i64, -1i64, 2i64] } +-- nobench input { [0i64, 0i64, 0i64, 1i64, 1i64, 1i64, 2i64, 2i64, 2i64] [2i64, 2i64, 2i64, 1i64, 1i64, 1i64, 0i64, 0i64, 0i64] } +-- output { [-4i64, -4i64, -4i64, -1i64, -1i64, -1i64, 1i64, 1i64, 1i64] } +-- nobench input { [1i64, 2i64, 3i64] [4i64, 5i64, 6i64] } +-- output { [2i64, 35i64, 135i64] } +-- nobench input { empty([0]i64) empty([0]i64) } +-- output { empty([0]i64) } + +#[noinline] +let foo (x : i64) (y : i64) (zs : []i64) = + let (a, as) = + match (x, y) + case (0,0) -> (3,iota 5) + case (0,b) -> (5,iota b) + case (a,0) -> (a,iota 3) + case (a,b) -> (a*b, zs) + in reduce (+) 0 as - a + +let bar (x : i64) (y : i64) = + let zs = iota (x * y) in foo x y zs + +def main [n] (xs : [n]i64) (ys : [n]i64) = map2 bar xs ys From acb7caf96871f147f92d931bfb0c4592a9b7e8c2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 16 Sep 2023 16:40:49 +0200 Subject: [PATCH 52/64] Flattening rearranges. --- src/Futhark/Pass/Flatten.hs | 68 +++++++++++++++++++++++++++++++++ tests/flattening/dup2d.fut | 7 ++++ tests/flattening/rearrange0.fut | 5 +++ tests/flattening/rearrange1.fut | 5 +++ 4 files changed, 85 insertions(+) create mode 100644 tests/flattening/dup2d.fut create mode 100644 tests/flattening/rearrange0.fut create mode 100644 tests/flattening/rearrange1.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index c372817118..315a99c341 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -55,6 +55,7 @@ data IrregularRep = IrregularRep irregularOffsets :: VName, irregularElems :: VName } + deriving (Show) data ResRep = -- | This variable is represented @@ -64,6 +65,7 @@ data ResRep | -- | The representation of an -- irregular array. Irregular IrregularRep + deriving (Show) newtype DistEnv = DistEnv {distResMap :: M.Map ResTag ResRep} @@ -123,6 +125,9 @@ type Segments = NE.NonEmpty SubExp segmentsShape :: Segments -> Shape segmentsShape = Shape . toList +segmentsRank :: Segments -> Int +segmentsRank = shapeRank . segmentsShape + segmentsDims :: Segments -> [SubExp] segmentsDims = shapeDims . segmentsShape @@ -325,6 +330,43 @@ replicateIrreg segments env ns desc rep = do irregularElems = elems } +rearrangeFlat :: (IntegralExp num) => [Int] -> [num] -> num -> num +rearrangeFlat perm dims i = + -- TODO? Maybe we need to invert one of these permutations. + flattenIndex dims $ + rearrangeShape perm $ + unflattenIndex (rearrangeShape perm dims) i + +rearrangeIrreg :: + Segments -> + DistEnv -> + TypeBase Shape u -> + [Int] -> + IrregularRep -> + Builder GPU IrregularRep +rearrangeIrreg segments env v_t perm (IrregularRep shape flags offsets elems) = do + m <- arraySize 0 <$> lookupType elems + (_, _, ii1_vss) <- doRepIota shape + (_, _, ii2_vss) <- doSegIota shape + elems' <- letExp "elems_rearrange" <=< renameExp <=< segMap (Solo m) $ + \(Solo i) -> do + seg_i <- letSubExp "seg_i" =<< eIndex ii1_vss [eSubExp i] + offset <- letSubExp "offset" =<< eIndex offsets [eSubExp seg_i] + in_seg_i <- letSubExp "in_seg_i" =<< eIndex ii2_vss [eSubExp i] + let v_dims = map pe64 $ arrayDims v_t + in_seg_is_tr = rearrangeFlat perm v_dims $ pe64 in_seg_i + v' <- + letSubExp "v" + =<< eIndex elems [toExp $ pe64 offset + in_seg_is_tr] + pure [subExpRes v'] + pure $ + IrregularRep + { irregularSegments = shape, + irregularFlags = flags, + irregularOffsets = offsets, + irregularElems = elems' + } + transformDistBasicOp :: Segments -> DistEnv -> @@ -481,6 +523,32 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = pure $ insertIrregular shape flags offsets (distResTag res) elems' env | otherwise -> error "Flattening update: destination is not input." + Rearrange perm v -> do + case lookup v inps of + Just (DistInputFree v' _) -> do + v'' <- + letExp (baseString v' <> "_tr") . BasicOp $ + Rearrange perm v' + pure $ insertRegulars [distResTag res] [v''] env + Just (DistInput rt v_t) -> do + case resVar rt env of + Irregular rep -> do + rep' <- + certifying (distCerts inps aux env) $ + rearrangeIrreg segments env v_t perm rep + pure $ insertRep (distResTag res) (Irregular rep') env + Regular v' -> do + let r = segmentsRank segments + v'' <- + letExp (baseString v' <> "_tr") . BasicOp $ + Rearrange ([0 .. r - 1] ++ map (+ r) perm) v' + pure $ insertRegulars [distResTag res] [v''] env + Nothing -> do + let r = segmentsRank segments + v' <- + letExp (baseString v <> "_tr") . BasicOp $ + Rearrange ([0 .. r - 1] ++ map (+ r) perm) v + pure $ insertRegulars [distResTag res] [v'] env _ -> error $ "Unhandled BasicOp:\n" ++ prettyString e where scalarCase = diff --git a/tests/flattening/dup2d.fut b/tests/flattening/dup2d.fut new file mode 100644 index 0000000000..e33a722966 --- /dev/null +++ b/tests/flattening/dup2d.fut @@ -0,0 +1,7 @@ +-- == +-- input { [[1,2,3],[4,5,6]] } +-- auto output + +def dup = replicate 2 >-> transpose >-> flatten + +entry main (z: [][]i32) = z |> map dup |> dup diff --git a/tests/flattening/rearrange0.fut b/tests/flattening/rearrange0.fut new file mode 100644 index 0000000000..4bbddbfdf5 --- /dev/null +++ b/tests/flattening/rearrange0.fut @@ -0,0 +1,5 @@ +-- == +-- input { [[[1,2],[4,5],[7,8]],[[8,4],[5,1],[7,2]]] } +-- output { [5,1] } + +def main (xsss: [][][]i32) = map (\xs -> (opaque (transpose (opaque xs)))[1,1]) xsss diff --git a/tests/flattening/rearrange1.fut b/tests/flattening/rearrange1.fut new file mode 100644 index 0000000000..371768b399 --- /dev/null +++ b/tests/flattening/rearrange1.fut @@ -0,0 +1,5 @@ +-- == +-- input { [3i64,4i64] } +-- output { [1i64,1i64] } + +def main = map (\n -> ((transpose (replicate (n+1) (iota n))))[1,1]) From d14cf71312697f30ce1bf3c349d36f0d55ed3a1a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 16 Sep 2023 17:35:13 +0200 Subject: [PATCH 53/64] Add failing test. --- src/Futhark/Pass/Flatten.hs | 13 ++++++++----- tests/flattening/dup3d.fut | 9 +++++++++ 2 files changed, 17 insertions(+), 5 deletions(-) create mode 100644 tests/flattening/dup3d.fut diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 315a99c341..05aa7ccaef 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -217,8 +217,11 @@ elemArr segments env inps (Var v) elemArr segments _ _ se = do rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se dims <- arrayDims <$> lookupType rep - n <- toSubExp "n" $ product $ map pe64 dims - letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep + if length dims == 1 + then pure rep + else do + n <- toSubExp "n" $ product $ map pe64 dims + letExp "reshape" $ BasicOp $ Reshape ReshapeArbitrary (Shape [n]) rep mkIrregFromReg :: Segments -> @@ -228,7 +231,7 @@ mkIrregFromReg segments arr = do arr_t <- lookupType arr segment_size <- letSubExp "reg_seg_size" <=< toExp . product . map pe64 $ - drop (shapeRank (segmentsShape segments)) (arrayDims arr_t) + drop (segmentsRank segments) (arrayDims arr_t) segments_arr <- letExp "reg_segments" . BasicOp $ Replicate (segmentsShape segments) segment_size @@ -301,7 +304,7 @@ replicateIrreg segments env ns desc rep = do w <- arraySize 0 <$> lookupType ns_full_elems - elems <- letExp (desc <> "_elems") <=< segMap (Solo w) $ \(Solo i) -> do + elems <- letExp (desc <> "_rep_elems") <=< segMap (Solo w) $ \(Solo i) -> do -- Which segment we are in. segment_i <- letSubExp "segment_i" =<< eIndex ns_full_elems [eSubExp i] @@ -623,7 +626,7 @@ onMapInputArr segments env inps ii2 p arr = do -- If parameter type of the map corresponds to the -- element type of the value array, we can map it -- directly. - if stripArray (shapeRank (segmentsShape segments)) elems_t == paramType p + if stripArray (segmentsRank segments) elems_t == paramType p then pure $ MapArray (irregularElems rep) elems_t else do -- Otherwise we need to perform surgery on the metadata. diff --git a/tests/flattening/dup3d.fut b/tests/flattening/dup3d.fut new file mode 100644 index 0000000000..a61343bbdd --- /dev/null +++ b/tests/flattening/dup3d.fut @@ -0,0 +1,9 @@ +-- Currently fails; an array that is too small is produced somehow. I +-- suspect replication. +-- == +-- input { [[[1,2],[3,4]],[[5,6],[7,8]]] } +-- auto output + +def dup = replicate 2 >-> transpose >-> flatten + +def main (z: [2][2][2]i32) = z |> map (map dup) |> map dup |> dup From 4bdf0e3d7117aa6f72cef884af9536267a38be76 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 15:55:26 +0100 Subject: [PATCH 54/64] Fix type annotation. --- src/Futhark/Pass/Flatten/Builtins.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Pass/Flatten/Builtins.hs b/src/Futhark/Pass/Flatten/Builtins.hs index 36c0b2754d..d60c486084 100644 --- a/src/Futhark/Pass/Flatten/Builtins.hs +++ b/src/Futhark/Pass/Flatten/Builtins.hs @@ -128,8 +128,8 @@ genScatter dest n f = do space <- mkSegSpace [(gtid, n)] ((res, v_t), stms) <- collectStms $ localScope (scopeOfSegSpace space) $ do (i, v) <- f $ Var gtid - v_t <- subExpType v - pure (WriteReturns mempty dest [(Slice [DimFix (Var i)], v)], v_t) + dest_t <- lookupType dest + pure (WriteReturns mempty dest [(Slice [DimFix (Var i)], v)], dest_t) let kbody = KernelBody () stms [res] pure $ Op $ SegOp $ SegMap (SegThread SegVirt Nothing) space [v_t] kbody From aae4bd1e14736351a40579052a7af396c0db7e13 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:05:07 +0100 Subject: [PATCH 55/64] Better test data. --- tests/flattening/dup3d.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/flattening/dup3d.fut b/tests/flattening/dup3d.fut index a61343bbdd..1b8e2a228e 100644 --- a/tests/flattening/dup3d.fut +++ b/tests/flattening/dup3d.fut @@ -1,9 +1,9 @@ -- Currently fails; an array that is too small is produced somehow. I -- suspect replication. -- == --- input { [[[1,2],[3,4]],[[5,6],[7,8]]] } +-- input { [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] } -- auto output -def dup = replicate 2 >-> transpose >-> flatten +def dup = replicate 5 >-> transpose >-> flatten -def main (z: [2][2][2]i32) = z |> map (map dup) |> map dup |> dup +def main (z: [2][3][4]i32) = z |> map (map dup) |> map dup |> dup From 068a16e84540d930a7a10777e2682aae42b5ddd1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:06:19 +0100 Subject: [PATCH 56/64] Fix typo. --- src/Futhark/Pass/Flatten.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 05aa7ccaef..b35aa36aed 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -165,7 +165,7 @@ readInputs segments env is = mapM_ onInput letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] v_flat <- - letExp (baseString v <> "_float") $ BasicOp $ Index elems slice + letExp (baseString v <> "_flat") $ BasicOp $ Index elems slice letBindNames [v] . BasicOp $ Reshape ReshapeArbitrary (arrayShape t) v_flat From 647e4fe1ef73a889c7d7da8ac84c8f6965225c1f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:06:36 +0100 Subject: [PATCH 57/64] This seems wrong. --- src/Futhark/Pass/Flatten.hs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index b35aa36aed..f7dd8b8f70 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -614,11 +614,8 @@ onMapInputArr segments env inps ii2 p arr = do case lookup arr inps of Just v_inp -> case v_inp of - DistInputFree vs t -> do - v <- - letExp (baseString vs <> "_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [ws_prod]) vs - pure $ MapArray v t + DistInputFree vs t -> + pure $ MapArray vs t DistInput rt _ -> case resVar rt env of Irregular rep -> do From 8aff0d5677a5335c68f882ce5340c3012e3d4a4b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:13:52 +0100 Subject: [PATCH 58/64] New formatting. --- src/Futhark/Pass/Flatten.hs | 8 +++++--- src/Futhark/Pass/Flatten/Distribute.hs | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index f7dd8b8f70..336e945d32 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -448,7 +448,7 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = fmap (subExpsRes . pure) . letSubExp "v" <=< toExp $ primExpFromSubExp (IntType it) x' ~+~ sExt it (untyped (pe64 v')) - ~*~ primExpFromSubExp (IntType it) s' + ~*~ primExpFromSubExp (IntType it) s' pure $ insertIrregular ns flags offsets (distResTag res) elems' env Replicate (Shape [n]) (Var v) -> do ns <- elemArr segments env inps n @@ -574,7 +574,8 @@ onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do ws_prod <- arraySize 0 <$> lookupType ws_elems fmap (v,) $ case v_inp of DistInputFree v' t -> do - fmap (`MapArray` t) . letExp (baseString v <> "_rep_free_free_inp") + fmap (`MapArray` t) + . letExp (baseString v <> "_rep_free_free_inp") <=< segMap (Solo ws_prod) $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] @@ -595,7 +596,8 @@ onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do } pure $ MapOther rep' t Regular vs -> - fmap (`MapArray` t) . letExp (baseString v <> "_rep_free_reg_inp") + fmap (`MapArray` t) + . letExp (baseString v <> "_rep_free_reg_inp") <=< segMap (Solo ws_prod) $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex segments_per_elem [eSubExp i] diff --git a/src/Futhark/Pass/Flatten/Distribute.hs b/src/Futhark/Pass/Flatten/Distribute.hs index 3e64bd0599..cc48fbc1fe 100644 --- a/src/Futhark/Pass/Flatten/Distribute.hs +++ b/src/Futhark/Pass/Flatten/Distribute.hs @@ -48,14 +48,14 @@ distInputType (DistInput _ t) = t data DistType = DistType + -- | Outer regular size. SubExp - -- ^ Outer regular size. - Rank - -- ^ Irregular dimensions on top (but after the leading regular + -- | Irregular dimensions on top (but after the leading regular -- size). - Type - -- ^ The regular "element type" - in the worst case, at least a + Rank + -- | The regular "element type" - in the worst case, at least a -- scalar. + Type deriving (Eq, Ord, Show) data DistResult = DistResult {distResTag :: ResTag, distResType :: DistType} From 9738c472bd4e5a7be0dd38216006453b79dc6a84 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:38:10 +0100 Subject: [PATCH 59/64] Document nomenclature. --- src/Futhark/Pass/Flatten.hs | 185 ++++++++++++++++++++++++------------ 1 file changed, 122 insertions(+), 63 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 336e945d32..eb5e2bafef 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -36,6 +36,65 @@ import Futhark.Util (mapEither) import Futhark.Util.IntegralExp import Prelude hiding (div, rem) +-- Note [Representation of Flat Arrays] +-- +-- This flattening implementation uses largely the nomenclature and +-- structure described by Cosmin Oancea. In particular, consider an +-- irregular array 'A' where +-- +-- - A has 'n' segments (outermost dimension). +-- +-- - A has element type 't'. +-- +-- - A has a total of 'm' elements (where 'm' is divisible by 'n', +-- and may indeed be 'm'). +-- +-- Then A is represented by the following arrays: +-- +-- - A_D : [m]t; the "data array". +-- +-- - A_S : [n]i64; the "shape array" giving the size of each segment. +-- +-- - A_F : [m]bool; the "flag array", indicating when an element begins a +-- new segment. +-- +-- - A_O : [n]i64; the offset array, indicating for each segment +-- where it starts in the data (and flag) array. +-- +-- - A_II : [m]t; the "inner indices"; indicating for each element +-- its index within its corresponding segment. +-- +-- Some of these structures can be computed from each other, but +-- conceptually they all coexist. +-- +-- Note that we only consider the *outer* dimension to be the +-- "segments". Also, 't' may actually be an array itself (although in +-- this case, the shape must be invariant to all parallel dimensions). +-- The inner structure is preserved through code, not data. (Or in +-- practice, ad-hoc auxiliary arrays produced by code.) In Cosmin's +-- notation, we maintain only the information for the outermost +-- dimension. +-- +-- As an example, consider an irregular array +-- +-- A = [ [], [ [1,2,3], [4], [], [5,6] ], [ [7], [], [8,9,10] ] ] +-- +-- then +-- +-- n = 3 +-- +-- m = 10 +-- +-- A_D = [1,2,3,4,5,6,7,8,9,10] +-- +-- A_S = [0, 4, 3] +-- +-- A_F = [T,F,F,F,F,F,T,F,F,F] +-- +-- A_O = [0, 0, 6] +-- +-- A_II = [1,1,1,1,1,1,2,2,2,2] + data FlattenEnv = FlattenEnv newtype FlattenM a = FlattenM (StateT VNameSource (Reader FlattenEnv) a) @@ -50,10 +109,10 @@ newtype FlattenM a = FlattenM (StateT VNameSource (Reader FlattenEnv) a) data IrregularRep = IrregularRep { -- | Array of size of each segment, type @[]i64@. - irregularSegments :: VName, - irregularFlags :: VName, - irregularOffsets :: VName, - irregularElems :: VName + irregularS :: VName, + irregularF :: VName, + irregularO :: VName, + irregularD :: VName } deriving (Show) @@ -203,17 +262,17 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux Just (DistInput rt _) -> case resVar rt env of Regular vs -> vs - Irregular r -> irregularElems r + Irregular r -> irregularD r -- | Only sensible for variables of segment-invariant type. elemArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName elemArr segments env inps (Var v) | Just v_inp <- lookup v inps = case v_inp of - DistInputFree vs _ -> irregularElems <$> mkIrregFromReg segments vs + DistInputFree vs _ -> irregularD <$> mkIrregFromReg segments vs DistInput rt _ -> case resVar rt env of - Irregular r -> pure $ irregularElems r - Regular vs -> irregularElems <$> mkIrregFromReg segments vs + Irregular r -> pure $ irregularD r + Regular vs -> irregularD <$> mkIrregFromReg segments vs elemArr segments _ _ se = do rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se dims <- arrayDims <$> lookupType rep @@ -252,10 +311,10 @@ mkIrregFromReg segments arr = do pure [subExpRes offset] pure $ IrregularRep - { irregularSegments = segments_arr, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems + { irregularS = segments_arr, + irregularF = flags, + irregularO = offsets, + irregularD = elems } -- Get the irregular representation of a var. @@ -294,7 +353,7 @@ replicateIrreg segments env ns desc rep = do n <- letSubExp "n" =<< eIndex ns [eSubExp i] old_segment <- - letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp i] + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp i] full_segment <- letSubExp "new_segment" =<< toExp (pe64 n * pe64 old_segment) pure $ subExpsRes [full_segment] @@ -310,7 +369,7 @@ replicateIrreg segments env ns desc rep = do letSubExp "segment_i" =<< eIndex ns_full_elems [eSubExp i] -- Size of original segment. old_segment <- - letSubExp "old_segment" =<< eIndex (irregularSegments rep) [eSubExp segment_i] + letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp segment_i] -- Index of value inside *new* segment. j_new <- letSubExp "j_new" =<< eIndex flat_to_segs [eSubExp i] @@ -319,18 +378,18 @@ replicateIrreg segments env ns desc rep = do letSubExp "j_old" =<< toExp (pe64 j_new `rem` pe64 old_segment) -- Offset of values in original segment. offset <- - letSubExp "offset" =<< eIndex (irregularOffsets rep) [eSubExp segment_i] + letSubExp "offset" =<< eIndex (irregularO rep) [eSubExp segment_i] v <- letSubExp "v" - =<< eIndex (irregularElems rep) [toExp $ pe64 offset + pe64 j_old] + =<< eIndex (irregularD rep) [toExp $ pe64 offset + pe64 j_old] pure $ subExpsRes [v] pure $ IrregularRep - { irregularSegments = ns_full, - irregularFlags = ns_full_flags, - irregularOffsets = ns_full_offsets, - irregularElems = elems + { irregularS = ns_full, + irregularF = ns_full_flags, + irregularO = ns_full_offsets, + irregularD = elems } rearrangeFlat :: (IntegralExp num) => [Int] -> [num] -> num -> num @@ -364,10 +423,10 @@ rearrangeIrreg segments env v_t perm (IrregularRep shape flags offsets elems) = pure [subExpRes v'] pure $ IrregularRep - { irregularSegments = shape, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems' + { irregularS = shape, + irregularF = flags, + irregularO = offsets, + irregularD = elems' } transformDistBasicOp :: @@ -472,11 +531,11 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = Just (DistInput rt _) -> case resVar rt env of Irregular r -> do - let name = baseString (irregularElems r) <> "_copy" + let name = baseString (irregularD r) <> "_copy" elems_copy <- letExp name . BasicOp $ - Replicate mempty (Var $ irregularElems r) - let rep = Irregular $ r {irregularElems = elems_copy} + Replicate mempty (Var $ irregularD r) + let rep = Irregular $ r {irregularD = elems_copy} pure $ insertRep (distResTag res) rep env Regular v' -> do v'' <- @@ -586,13 +645,13 @@ onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do <=< segMap (Solo ws_prod) $ \(Solo i) -> do segment <- letSubExp "segment" =<< eIndex ws_elems [eSubExp i] - subExpsRes . pure <$> (letSubExp "v" =<< eIndex (irregularOffsets rep) [eSubExp segment]) + subExpsRes . pure <$> (letSubExp "v" =<< eIndex (irregularO rep) [eSubExp segment]) let rep' = IrregularRep - { irregularSegments = ws, - irregularFlags = irregularFlags rep, - irregularOffsets = offsets, - irregularElems = irregularElems rep + { irregularS = ws, + irregularF = irregularF rep, + irregularO = offsets, + irregularD = irregularD rep } pure $ MapOther rep' t Regular vs -> @@ -621,12 +680,12 @@ onMapInputArr segments env inps ii2 p arr = do DistInput rt _ -> case resVar rt env of Irregular rep -> do - elems_t <- lookupType $ irregularElems rep + elems_t <- lookupType $ irregularD rep -- If parameter type of the map corresponds to the -- element type of the value array, we can map it -- directly. if stripArray (segmentsRank segments) elems_t == paramType p - then pure $ MapArray (irregularElems rep) elems_t + then pure $ MapArray (irregularD rep) elems_t else do -- Otherwise we need to perform surgery on the metadata. ~[p_segments, p_offsets] <- letTupExp @@ -636,16 +695,16 @@ onMapInputArr segments env inps ii2 p arr = do segment_i <- letSubExp "segment" =<< eIndex ii2 [eSubExp i] segment <- - letSubExp "v" =<< eIndex (irregularSegments rep) [eSubExp segment_i] + letSubExp "v" =<< eIndex (irregularS rep) [eSubExp segment_i] offset <- - letSubExp "v" =<< eIndex (irregularOffsets rep) [eSubExp segment_i] + letSubExp "v" =<< eIndex (irregularO rep) [eSubExp segment_i] pure $ subExpsRes [segment, offset] let rep' = IrregularRep - { irregularElems = irregularElems rep, - irregularFlags = irregularFlags rep, - irregularSegments = p_segments, - irregularOffsets = p_offsets + { irregularD = irregularD rep, + irregularF = irregularF rep, + irregularS = p_segments, + irregularO = p_offsets } pure $ MapOther rep' elems_t Regular vs -> @@ -831,10 +890,10 @@ transformDistStm segments env (DistStm inps res stm) = do pure $ Irregular $ IrregularRep - { irregularSegments = segs', - irregularFlags = flags', - irregularOffsets = offsets', - irregularElems = elems' + { irregularS = segs', + irregularF = flags', + irregularO = offsets', + irregularD = elems' } -- Given the indices for which a branch is taken and its body, -- distribute the statements of the body of that branch. @@ -887,7 +946,7 @@ transformDistStm segments env (DistStm inps res stm) = do -- The `offsets` variable is the offsets of the final result, -- whereas `irregRep` is the irregular representation of the result of a single branch. let scatterIrregular offsets space (is, irregRep) = do - let IrregularRep {irregularSegments = segs, irregularElems = elems} = irregRep + let IrregularRep {irregularS = segs, irregularD = elems} = irregRep (_, _, ii1) <- doRepIota segs (_, _, ii2) <- doSegIota segs ~(Array _ (Shape [size]) _) <- lookupType elems @@ -921,7 +980,7 @@ transformDistStm segments env (DistStm inps res stm) = do -- Create a blank space for the 'segs' segsSpace <- letExp "blank_segs" =<< eBlank segsType -- Write back the segs of each branch to the blank space - segs <- foldM scatterRegular segsSpace $ zip iss (irregularSegments <$> branchesIrregRep) + segs <- foldM scatterRegular segsSpace $ zip iss (irregularS <$> branchesIrregRep) (_, offsets, num_elems) <- exScanAndSum segs let resultType = Array pt (Shape [num_elems]) NoUniqueness -- Create the blank space for the result @@ -932,10 +991,10 @@ transformDistStm segments env (DistStm inps res stm) = do pure $ Irregular $ IrregularRep - { irregularSegments = segs, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems } Acc {} -> error "transformDistStm: Acc" Mem {} -> error "transformDistStm: Mem" @@ -1011,7 +1070,7 @@ transformDistributed irregs segments dist = do let shape = segmentsShape segments <> arrayShape v_t v_copy <- letExp (baseString v) . BasicOp $ - Replicate mempty (Var $ irregularElems irreg) + Replicate mempty (Var $ irregularD irreg) letBindNames [v] $ BasicOp (Reshape ReshapeArbitrary shape v_copy) forM_ reps $ \(v, r) -> @@ -1096,10 +1155,10 @@ liftParam w fparam = ( [num_elems, segments, flags, offsets, elems], Irregular $ IrregularRep - { irregularSegments = paramName segments, - irregularFlags = paramName flags, - irregularOffsets = paramName offsets, - irregularElems = paramName elems + { irregularS = paramName segments, + irregularF = paramName flags, + irregularO = paramName offsets, + irregularD = paramName elems } ) Acc {} -> @@ -1118,10 +1177,10 @@ liftArg segments inps env (se, d) = do where mkIrrep ( IrregularRep - { irregularSegments = segs, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems } ) = do t <- lookupType elems @@ -1161,10 +1220,10 @@ liftResult segments inps env res = map (SubExpRes mempty . Var) <$> vs Irregular irreg -> mkIrrep irreg mkIrrep ( IrregularRep - { irregularSegments = segs, - irregularFlags = flags, - irregularOffsets = offsets, - irregularElems = elems + { irregularS = segs, + irregularF = flags, + irregularO = offsets, + irregularD = elems } ) = do t <- lookupType elems From db810705deec64c370fdec680cc4ff7cf87e8a9c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:41:01 +0100 Subject: [PATCH 60/64] Use nomenclature. --- src/Futhark/Pass/Flatten.hs | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index eb5e2bafef..a139b52a6d 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -61,8 +61,11 @@ import Prelude hiding (div, rem) -- - A_O : [n]i64; the offset array, indicating for each segment -- where it starts in the data (and flag) array. -- --- - A_II : [m]t; the "inner indices"; indicating for each element --- its index within its corresponding segment. +-- - A_II1 : [m]t; the "segment indices"; a mapping from element +-- index to index of the segment it belongs to. +-- +-- - A_II2 : [m]t; the "inner indices"; a mapping from element index +-- to index within its corresponding segment. -- -- Some of these structures can be computed from each other, but -- conceptually they all coexist. @@ -93,7 +96,9 @@ import Prelude hiding (div, rem) -- -- A_O = [0, 0, 6] -- --- A_II = [1,1,1,1,1,1,2,2,2,2] +-- A_II1 = [0,0,0,1,3,3,4,6,6,6] +-- +-- A_II2 = [1,1,1,1,1,1,2,2,2,2] data FlattenEnv = FlattenEnv @@ -291,18 +296,18 @@ mkIrregFromReg segments arr = do segment_size <- letSubExp "reg_seg_size" <=< toExp . product . map pe64 $ drop (segmentsRank segments) (arrayDims arr_t) - segments_arr <- + arr_S <- letExp "reg_segments" . BasicOp $ Replicate (segmentsShape segments) segment_size num_elems <- letSubExp "reg_num_elems" <=< toExp $ product $ map pe64 $ arrayDims arr_t - elems <- - letExp "reg_elems" . BasicOp $ + arr_D <- + letExp "reg_D" . BasicOp $ Reshape ReshapeArbitrary (Shape [num_elems]) arr - flags <- letExp "reg_flags" <=< segMap (Solo num_elems) $ \(Solo i) -> do + arr_F <- letExp "reg_F" <=< segMap (Solo num_elems) $ \(Solo i) -> do flag <- letSubExp "flag" <=< toExp $ (pe64 i `rem` pe64 segment_size) .==. 0 pure [subExpRes flag] - offsets <- letExp "reg_offsets" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do + arr_O <- letExp "reg_O" <=< segMap (shapeDims (segmentsShape segments)) $ \is -> do let flat_seg_i = flattenIndex (map pe64 (shapeDims (segmentsShape segments))) @@ -311,10 +316,10 @@ mkIrregFromReg segments arr = do pure [subExpRes offset] pure $ IrregularRep - { irregularS = segments_arr, - irregularF = flags, - irregularO = offsets, - irregularD = elems + { irregularS = arr_S, + irregularF = arr_F, + irregularO = arr_O, + irregularD = arr_D } -- Get the irregular representation of a var. From eb443bbf8a0da3a3a1114f5c0225ac8a110e887c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 18:46:07 +0100 Subject: [PATCH 61/64] Style fixes. --- src/Futhark/Pass/Flatten.hs | 56 +++++++++++++++---------------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index a139b52a6d..c1b2253c5f 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -218,18 +218,18 @@ readInputs segments env is = mapM_ onInput case resVar rt env of Regular arr -> letBindNames [v] =<< eIndex arr (map eSubExp is) - Irregular (IrregularRep _ _ offsets elems) -> do - offset <- letSubExp "offset" =<< eIndex offsets (map eSubExp is) + Irregular (IrregularRep _ _ v_O v_D) -> do + offset <- letSubExp "offset" =<< eIndex v_O (map eSubExp is) case arrayDims t of [num_elems] -> do let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] - letBindNames [v] $ BasicOp $ Index elems slice + letBindNames [v] $ BasicOp $ Index v_D slice _ -> do num_elems <- letSubExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) let slice = Slice [DimSlice offset num_elems (intConst Int64 1)] v_flat <- - letExp (baseString v <> "_flat") $ BasicOp $ Index elems slice + letExp (baseString v <> "_flat") $ BasicOp $ Index v_D slice letBindNames [v] . BasicOp $ Reshape ReshapeArbitrary (arrayShape t) v_flat @@ -270,15 +270,15 @@ distCerts inps aux env = Certs $ map f $ unCerts $ stmAuxCerts aux Irregular r -> irregularD r -- | Only sensible for variables of segment-invariant type. -elemArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName -elemArr segments env inps (Var v) +dataArr :: Segments -> DistEnv -> DistInputs -> SubExp -> Builder GPU VName +dataArr segments env inps (Var v) | Just v_inp <- lookup v inps = case v_inp of DistInputFree vs _ -> irregularD <$> mkIrregFromReg segments vs DistInput rt _ -> case resVar rt env of Irregular r -> pure $ irregularD r Regular vs -> irregularD <$> mkIrregFromReg segments vs -elemArr segments _ _ se = do +dataArr segments _ _ se = do rep <- letExp "rep" $ BasicOp $ Replicate (segmentsShape segments) se dims <- arrayDims <$> lookupType rep if length dims == 1 @@ -494,13 +494,13 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = Iota n (Constant x) (Constant s) Int64 | zeroIsh x, oneIsh s -> do - ns <- elemArr segments env inps n + ns <- dataArr segments env inps n (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns pure $ insertIrregular ns flags offsets (distResTag res) elems env Iota n x s it -> do - ns <- elemArr segments env inps n - xs <- elemArr segments env inps x - ss <- elemArr segments env inps s + ns <- dataArr segments env inps n + xs <- dataArr segments env inps x + ss <- dataArr segments env inps s (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns (_, _, repiota_elems) <- doRepIota ns m <- arraySize 0 <$> lookupType elems @@ -515,12 +515,12 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ~*~ primExpFromSubExp (IntType it) s' pure $ insertIrregular ns flags offsets (distResTag res) elems' env Replicate (Shape [n]) (Var v) -> do - ns <- elemArr segments env inps n + ns <- dataArr segments env inps n rep <- getIrregRep segments env inps v rep' <- replicateIrreg segments env ns (baseString v) rep pure $ insertRep (distResTag res) (Irregular rep') env Replicate (Shape [n]) (Constant v) -> do - ns <- elemArr segments env inps n + ns <- dataArr segments env inps n (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns w <- arraySize 0 <$> lookupType elems @@ -739,7 +739,7 @@ transformInnerMap :: Lambda SOACS -> Builder GPU (VName, VName, VName) transformInnerMap segments env inps pat w arrs map_lam = do - ws <- elemArr segments env inps w + ws <- dataArr segments env inps w (ws_flags, ws_offsets, ws_elems) <- doRepIota ws new_segment <- arraySize 0 <$> lookupType ws_elems arrs' <- @@ -938,15 +938,10 @@ transformDistStm segments env (DistStm inps res stm) = do -- Write back the regular results of a branch to a (partially) blank space let scatterRegular space (is, xs) = do ~(Array _ (Shape [size]) _) <- lookupType xs - letExp "regular_scatter" - =<< genScatter - space - size - ( \gtid -> do - x <- letSubExp "x" =<< eIndex xs [eSubExp gtid] - i <- letExp "i" =<< eIndex is [eSubExp gtid] - pure (i, x) - ) + letExp "regular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex xs [eSubExp gtid] + i <- letExp "i" =<< eIndex is [eSubExp gtid] + pure (i, x) -- Write back the irregular elements of a branch to a (partially) blank space -- The `offsets` variable is the offsets of the final result, -- whereas `irregRep` is the irregular representation of the result of a single branch. @@ -955,16 +950,11 @@ transformDistStm segments env (DistStm inps res stm) = do (_, _, ii1) <- doRepIota segs (_, _, ii2) <- doSegIota segs ~(Array _ (Shape [size]) _) <- lookupType elems - letExp "irregular_scatter" - =<< genScatter - space - size - ( \gtid -> do - x <- letSubExp "x" =<< eIndex elems [eSubExp gtid] - offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp gtid]]] - i <- letExp "i" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp gtid]) - pure (i, x) - ) + letExp "irregular_scatter" <=< genScatter space size $ \gtid -> do + x <- letSubExp "x" =<< eIndex elems [eSubExp gtid] + offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp gtid]]] + i <- letExp "i" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp gtid]) + pure (i, x) -- Given a single result from each branch as well the *unlifted* -- result type, merge the results of all branches into a single result. let mergeResult iss branchesRep resType = From 8c224e4c1a76e81dffa6a521ac97b6c1ba9b356d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 19:40:14 +0100 Subject: [PATCH 62/64] More nomenclature. --- src/Futhark/Pass/Flatten.hs | 122 ++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 61 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index c1b2253c5f..2ce9e12533 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -363,15 +363,15 @@ replicateIrreg segments env ns desc rep = do letSubExp "new_segment" =<< toExp (pe64 n * pe64 old_segment) pure $ subExpsRes [full_segment] - (ns_full_flags, ns_full_offsets, ns_full_elems) <- doRepIota ns_full + (ns_full_F, ns_full_O, ns_full_D) <- doRepIota ns_full (_, _, flat_to_segs) <- doSegIota ns_full - w <- arraySize 0 <$> lookupType ns_full_elems + w <- arraySize 0 <$> lookupType ns_full_D - elems <- letExp (desc <> "_rep_elems") <=< segMap (Solo w) $ \(Solo i) -> do + elems <- letExp (desc <> "_rep_D") <=< segMap (Solo w) $ \(Solo i) -> do -- Which segment we are in. segment_i <- - letSubExp "segment_i" =<< eIndex ns_full_elems [eSubExp i] + letSubExp "segment_i" =<< eIndex ns_full_D [eSubExp i] -- Size of original segment. old_segment <- letSubExp "old_segment" =<< eIndex (irregularS rep) [eSubExp segment_i] @@ -392,8 +392,8 @@ replicateIrreg segments env ns desc rep = do pure $ IrregularRep { irregularS = ns_full, - irregularF = ns_full_flags, - irregularO = ns_full_offsets, + irregularF = ns_full_F, + irregularO = ns_full_O, irregularD = elems } @@ -475,10 +475,10 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = slice_ns <- mapM (readInput segments env (toList is) inps) $ sliceDims slice fmap varsRes . letTupExp "n" <=< toExp $ product $ map pe64 slice_ns (_n, offsets, m) <- exScanAndSum ns - (_, _, repiota_elems) <- doRepIota ns + (_, _, repiota_D) <- doRepIota ns flags <- genFlags m offsets elems <- letExp "elems" <=< renameExp <=< segMap (NE.singleton m) $ \is -> do - segment <- letSubExp "segment" =<< eIndex repiota_elems (toList $ fmap eSubExp is) + segment <- letSubExp "segment" =<< eIndex repiota_D (toList $ fmap eSubExp is) segment_start <- letSubExp "segment_start" =<< eIndex offsets [eSubExp segment] readInputs segments env [segment] inps -- TODO: multidimensional segments @@ -501,19 +501,19 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = ns <- dataArr segments env inps n xs <- dataArr segments env inps x ss <- dataArr segments env inps s - (flags, offsets, elems) <- certifying (distCerts inps aux env) $ doSegIota ns - (_, _, repiota_elems) <- doRepIota ns - m <- arraySize 0 <$> lookupType elems - elems' <- letExp "iota_elems_fixed" <=< segMap (Solo m) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex repiota_elems [eSubExp i] - v' <- letSubExp "v" =<< eIndex elems [eSubExp i] + (res_F, res_O, res_D) <- certifying (distCerts inps aux env) $ doSegIota ns + (_, _, repiota_D) <- doRepIota ns + m <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "iota_D_fixed" <=< segMap (Solo m) $ \(Solo i) -> do + segment <- letSubExp "segment" =<< eIndex repiota_D [eSubExp i] + v' <- letSubExp "v" =<< eIndex res_D [eSubExp i] x' <- letSubExp "x" =<< eIndex xs [eSubExp segment] s' <- letSubExp "s" =<< eIndex ss [eSubExp segment] fmap (subExpsRes . pure) . letSubExp "v" <=< toExp $ primExpFromSubExp (IntType it) x' ~+~ sExt it (untyped (pe64 v')) ~*~ primExpFromSubExp (IntType it) s' - pure $ insertIrregular ns flags offsets (distResTag res) elems' env + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' env Replicate (Shape [n]) (Var v) -> do ns <- dataArr segments env inps n rep <- getIrregRep segments env inps v @@ -521,11 +521,11 @@ transformDistBasicOp segments env (inps, res, pe, aux, e) = pure $ insertRep (distResTag res) (Irregular rep') env Replicate (Shape [n]) (Constant v) -> do ns <- dataArr segments env inps n - (flags, offsets, elems) <- + (res_F, res_O, res_D) <- certifying (distCerts inps aux env) $ doSegIota ns - w <- arraySize 0 <$> lookupType elems - elems' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) - pure $ insertIrregular ns flags offsets (distResTag res) elems' env + w <- arraySize 0 <$> lookupType res_D + res_D' <- letExp "rep_const" $ BasicOp $ Replicate (Shape [w]) (Constant v) + pure $ insertIrregular ns res_F res_O (distResTag res) res_D' env Replicate (Shape []) (Var v) -> case lookup v inps of Just (DistInputFree v' _) -> do @@ -631,11 +631,11 @@ onMapFreeVar :: (VName, VName, VName) -> VName -> Maybe (Builder GPU (VName, MapArray IrregularRep)) -onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do - let segments_per_elem = ws_elems +onMapFreeVar segments env inps ws (ws_F, ws_O, ws_data) v = do + let segments_per_elem = ws_data v_inp <- lookup v inps pure $ do - ws_prod <- arraySize 0 <$> lookupType ws_elems + ws_prod <- arraySize 0 <$> lookupType ws_data fmap (v,) $ case v_inp of DistInputFree v' t -> do fmap (`MapArray` t) @@ -646,10 +646,10 @@ onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems) v = do subExpsRes . pure <$> (letSubExp "v" =<< eIndex v' [eSubExp segment]) DistInput rt t -> case resVar rt env of Irregular rep -> do - offsets <- letExp (baseString v <> "_rep_free_irreg_offsets") + offsets <- letExp (baseString v <> "_rep_free_irreg_O") <=< segMap (Solo ws_prod) $ \(Solo i) -> do - segment <- letSubExp "segment" =<< eIndex ws_elems [eSubExp i] + segment <- letSubExp "segment" =<< eIndex ws_data [eSubExp i] subExpsRes . pure <$> (letSubExp "v" =<< eIndex (irregularO rep) [eSubExp segment]) let rep' = IrregularRep @@ -693,7 +693,7 @@ onMapInputArr segments env inps ii2 p arr = do then pure $ MapArray (irregularD rep) elems_t else do -- Otherwise we need to perform surgery on the metadata. - ~[p_segments, p_offsets] <- letTupExp + ~[p_segments, p_O] <- letTupExp (baseString (paramName p) <> "_rep_inp_irreg") <=< segMap (Solo ws_prod) $ \(Solo i) -> do @@ -709,7 +709,7 @@ onMapInputArr segments env inps ii2 p arr = do { irregularD = irregularD rep, irregularF = irregularF rep, irregularS = p_segments, - irregularO = p_offsets + irregularO = p_O } pure $ MapOther rep' elems_t Regular vs -> @@ -740,11 +740,11 @@ transformInnerMap :: Builder GPU (VName, VName, VName) transformInnerMap segments env inps pat w arrs map_lam = do ws <- dataArr segments env inps w - (ws_flags, ws_offsets, ws_elems) <- doRepIota ws - new_segment <- arraySize 0 <$> lookupType ws_elems + (ws_F, ws_O, ws_data) <- doRepIota ws + new_segment <- arraySize 0 <$> lookupType ws_data arrs' <- zipWithM - (onMapInputArr segments env inps ws_elems) + (onMapInputArr segments env inps ws_data) (lambdaParams map_lam) arrs let free = freeIn map_lam @@ -755,7 +755,7 @@ transformInnerMap segments env inps pat w arrs map_lam = do (free_replicated, replicated) <- fmap unzip . sequence $ mapMaybe - (onMapFreeVar segments env inps ws (ws_flags, ws_offsets, ws_elems)) + (onMapFreeVar segments env inps ws (ws_F, ws_O, ws_data)) free_and_sizes free_ps <- zipWithM @@ -777,7 +777,7 @@ transformInnerMap segments env inps pat w arrs map_lam = do transformDistributed arrmap (NE.singleton new_segment) distributed traceM $ unlines ["inner map distributed", prettyString distributed] addStms =<< runReaderT (runBuilder_ m) scope - pure (ws_flags, ws_offsets, ws) + pure (ws_F, ws_O, ws) transformDistStm :: Segments -> DistEnv -> DistStm -> Builder GPU DistEnv transformDistStm segments env (DistStm inps res stm) = do @@ -795,20 +795,20 @@ transformDistStm segments env (DistStm inps res stm) = do | Just (reds, map_lam) <- isRedomapSOAC form -> do map_pat <- fmap Pat $ forM (lambdaReturnType map_lam) $ \t -> PatElem <$> newVName "map" <*> pure (t `arrayOfRow` w) - (ws_flags, ws_offsets, ws) <- + (ws_F, ws_O, ws) <- transformInnerMap segments env inps map_pat w arrs map_lam let (redout_names, mapout_names) = splitAt (redResults reds) (patNames map_pat) elems' <- - genSegRed ws ws_flags ws_offsets redout_names $ + genSegRed ws ws_F ws_O redout_names $ singleReduce reds let (red_tags, map_tags) = splitAt (redResults reds) $ map distResTag res pure $ insertRegulars red_tags elems' $ - insertIrregulars ws ws_flags ws_offsets (zip map_tags mapout_names) env + insertIrregulars ws ws_F ws_O (zip map_tags mapout_names) env | Just map_lam <- isMapSOAC form -> do - (ws_flags, ws_offsets, ws) <- transformInnerMap segments env inps pat w arrs map_lam - pure $ insertIrregulars ws ws_flags ws_offsets (zip (map distResTag res) $ patNames pat) env + (ws_F, ws_O, ws) <- transformInnerMap segments env inps pat w arrs map_lam + pure $ insertIrregulars ws ws_F ws_O (zip (map distResTag res) $ patNames pat) env Let _ _ (Match scrutinees cases defaultCase _) -> do let [w] = NE.toList segments @@ -845,7 +845,7 @@ transformDistStm segments env (DistStm inps res stm) = do inds_t <- lookupType partition_inds -- Get the indices of each scrutinee by equivalence class inds <- forM [0 .. num_cases - 1] $ \i -> do - num_elems <- + num_data <- letSubExp ("size" ++ show i) =<< eIndex partition_sizes [toExp $ intConst Int64 i] begin <- @@ -854,7 +854,7 @@ transformDistStm segments env (DistStm inps res stm) = do letExp ("inds_branch" ++ show i) $ BasicOp $ Index partition_inds $ - fullSlice inds_t [DimSlice begin num_elems (intConst Int64 1)] + fullSlice inds_t [DimSlice begin num_data (intConst Int64 1)] -- Take the elements at index `is` from an input `v`. let splitInput is v = do @@ -876,7 +876,7 @@ transformDistStm segments env (DistStm inps res stm) = do idx <- letExp "idx" =<< eIndex is [eSubExp i] subExpsRes . pure <$> (letSubExp "segs" =<< eIndex segs [toExp idx]) -- From this we calculate the offsets and number of elements - (_, offsets', num_elems) <- exScanAndSum segs' + (_, offsets', num_data) <- exScanAndSum segs' (_, _, ii1) <- doRepIota segs' (_, _, ii2) <- doSegIota segs' -- We then take the elements we need from `elems` and `flags` @@ -886,7 +886,7 @@ transformDistStm segments env (DistStm inps res stm) = do -- `offset = offsets[is[ii1[i]]]` -- We then add `ii2[i]` to `offset` -- and use that to index into `elems` and `flags`. - ~[flags', elems'] <- letTupExp "split_flags_elems" <=< segMap (Solo num_elems) $ \(Solo i) -> do + ~[flags', elems'] <- letTupExp "split_F_data" <=< segMap (Solo num_data) $ \(Solo i) -> do offset <- letExp "offset" =<< eIndex offsets [eIndex is [eIndex ii1 [eSubExp i]]] idx <- letExp "idx" =<< eBinOp (Add Int64 OverflowUndef) (toExp offset) (eIndex ii2 [eSubExp i]) flags_split <- letSubExp "flags" =<< eIndex flags [toExp idx] @@ -976,13 +976,13 @@ transformDistStm segments env (DistStm inps res stm) = do segsSpace <- letExp "blank_segs" =<< eBlank segsType -- Write back the segs of each branch to the blank space segs <- foldM scatterRegular segsSpace $ zip iss (irregularS <$> branchesIrregRep) - (_, offsets, num_elems) <- exScanAndSum segs - let resultType = Array pt (Shape [num_elems]) NoUniqueness + (_, offsets, num_data) <- exScanAndSum segs + let resultType = Array pt (Shape [num_data]) NoUniqueness -- Create the blank space for the result resultSpace <- letExp "blank_res" =<< eBlank resultType -- Write back the values of each branch to the blank space elems <- foldM (scatterIrregular offsets) resultSpace $ zip iss branchesIrregRep - flags <- genFlags num_elems offsets + flags <- genFlags num_data offsets pure $ Irregular $ IrregularRep @@ -1132,22 +1132,22 @@ liftParam w fparam = (arrayOf (Prim pt) (Shape [w]) Nonunique) pure ([p], Regular $ paramName p) Array pt _ u -> do - num_elems <- - newParam (desc <> "_num_elems") $ Prim int64 + num_data <- + newParam (desc <> "_num_data") $ Prim int64 segments <- newParam (desc <> "_segments") $ arrayOf (Prim int64) (Shape [w]) Nonunique flags <- - newParam (desc <> "_flags") $ - arrayOf (Prim Bool) (Shape [Var (paramName num_elems)]) Nonunique + newParam (desc <> "_F") $ + arrayOf (Prim Bool) (Shape [Var (paramName num_data)]) Nonunique offsets <- - newParam (desc <> "_offsets") $ + newParam (desc <> "_O") $ arrayOf (Prim int64) (Shape [w]) Nonunique elems <- - newParam (desc <> "_elems") $ - arrayOf (Prim pt) (Shape [Var (paramName num_elems)]) u + newParam (desc <> "_data") $ + arrayOf (Prim pt) (Shape [Var (paramName num_data)]) u pure - ( [num_elems, segments, flags, offsets, elems], + ( [num_data, segments, flags, offsets, elems], Irregular $ IrregularRep { irregularS = paramName segments, @@ -1179,12 +1179,12 @@ liftArg segments inps env (se, d) = do } ) = do t <- lookupType elems - num_elems <- letExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) - flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) flags - elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) elems + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) elems -- Only apply the original diet to the 'elems' array let diets = replicate 4 Observe ++ [d] - pure $ zipWith (curry (first Var)) [num_elems, segs, flags', offsets, elems'] diets + pure $ zipWith (curry (first Var)) [num_data, segs, flags', offsets, elems'] diets -- Lifts a functions return type such that it matches the lifted functions return type. liftRetType :: SubExp -> [RetType SOACS] -> [RetType GPU] @@ -1194,12 +1194,12 @@ liftRetType w = concat . snd . L.mapAccumL liftType 0 let lifted = case rettype of Prim pt -> pure $ arrayOf (Prim pt) (Shape [Free w]) Nonunique Array pt _ u -> - let num_elems = Prim int64 + let num_data = Prim int64 segs = arrayOf (Prim int64) (Shape [Free w]) Nonunique flags = arrayOf (Prim Bool) (Shape [Ext i]) Nonunique offsets = arrayOf (Prim int64) (Shape [Free w]) Nonunique elems = arrayOf (Prim pt) (Shape [Ext i]) u - in [num_elems, segs, flags, offsets, elems] + in [num_data, segs, flags, offsets, elems] Acc {} -> error "liftRetType: Acc" Mem {} -> error "liftRetType: Mem" in (i + length lifted, lifted) @@ -1222,10 +1222,10 @@ liftResult segments inps env res = map (SubExpRes mempty . Var) <$> vs } ) = do t <- lookupType elems - num_elems <- letExp "num_elems" =<< toExp (product $ map pe64 $ arrayDims t) - flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) flags - elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_elems]) elems - pure [num_elems, segs, flags', offsets, elems'] + num_data <- letExp "num_data" =<< toExp (product $ map pe64 $ arrayDims t) + flags' <- letExp "flags" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) flags + elems' <- letExp "elems" $ BasicOp $ Reshape ReshapeArbitrary (Shape [Var num_data]) elems + pure [num_data, segs, flags', offsets, elems'] liftBody :: SubExp -> DistInputs -> DistEnv -> [DistStm] -> Result -> Builder GPU Result liftBody w inputs env dstms result = do From 98715167d99257bac9c2a24655898778b50b5415 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 19:40:20 +0100 Subject: [PATCH 63/64] Revert "This seems wrong." This reverts commit 647e4fe1ef73a889c7d7da8ac84c8f6965225c1f. --- src/Futhark/Pass/Flatten.hs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 2ce9e12533..3afa2e558a 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -680,8 +680,11 @@ onMapInputArr segments env inps ii2 p arr = do case lookup arr inps of Just v_inp -> case v_inp of - DistInputFree vs t -> - pure $ MapArray vs t + DistInputFree vs t -> do + v <- + letExp (baseString vs <> "_flat") . BasicOp $ + Reshape ReshapeArbitrary (Shape [ws_prod]) vs + pure $ MapArray v t DistInput rt _ -> case resVar rt env of Irregular rep -> do From e7ea05c160153cddcfe0386b6756c1d6e4107120 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Dec 2023 19:53:23 +0100 Subject: [PATCH 64/64] Maybe like this. --- src/Futhark/Pass/Flatten.hs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Futhark/Pass/Flatten.hs b/src/Futhark/Pass/Flatten.hs index 3afa2e558a..daf2dd3e84 100644 --- a/src/Futhark/Pass/Flatten.hs +++ b/src/Futhark/Pass/Flatten.hs @@ -681,9 +681,10 @@ onMapInputArr segments env inps ii2 p arr = do Just v_inp -> case v_inp of DistInputFree vs t -> do + let inner_shape = arrayShape $ paramType p v <- letExp (baseString vs <> "_flat") . BasicOp $ - Reshape ReshapeArbitrary (Shape [ws_prod]) vs + Reshape ReshapeArbitrary (Shape [ws_prod] <> inner_shape) vs pure $ MapArray v t DistInput rt _ -> case resVar rt env of