Skip to content

Commit

Permalink
ImpGen should only care about LMADs.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Aug 8, 2023
1 parent 13ee8c5 commit 584ac4d
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 64 deletions.
57 changes: 28 additions & 29 deletions src/Futhark/CodeGen/ImpGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,17 @@ defaultOperations opc =
data MemLoc = MemLoc
{ memLocName :: VName,
memLocShape :: [Imp.DimSize],
memLocIxFun :: IxFun.IxFun (Imp.TExp Int64)
memLocLMAD :: LMAD.LMAD (Imp.TExp Int64)
}
deriving (Eq, Show)

sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc (MemLoc mem shape ixfun) slice =
MemLoc mem shape $ IxFun.slice ixfun slice
sliceMemLoc (MemLoc mem shape lmad) slice =
MemLoc mem shape $ LMAD.slice lmad slice

flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc mem shape ixfun) slice =
MemLoc mem shape $ IxFun.flatSlice ixfun slice
flatSliceMemLoc (MemLoc mem shape lmad) slice =
MemLoc mem shape $ LMAD.flatSlice lmad slice

data ArrayEntry = ArrayEntry
{ entryArrayLoc :: MemLoc,
Expand Down Expand Up @@ -514,8 +514,8 @@ compileInParam fparam = case paramDec fparam of
pure $ Left $ Imp.ScalarParam name bt
MemMem space ->
pure $ Left $ Imp.MemParam name space
MemArray bt shape _ (ArrayIn mem ixfun) ->
pure $ Right $ ArrayDecl name bt $ MemLoc mem (shapeDims shape) ixfun
MemArray bt shape _ (ArrayIn mem lmad) ->
pure $ Right $ ArrayDecl name bt $ MemLoc mem (shapeDims shape) $ IxFun.ixfunLMAD lmad
MemAcc {} ->
error "Functions may not have accumulator parameters."
where
Expand Down Expand Up @@ -616,9 +616,9 @@ compileExternalValues types orig_rts orig_epts maybe_params = do
mkValueDesc _ signedness (MemArray t shape _ ret) = do
(mem, space) <-
case ret of
ReturnsNewBlock space j _ixfun ->
ReturnsNewBlock space j _lmad ->
pure (nthOut j, space)
ReturnsInBlock mem _ixfun -> do
ReturnsInBlock mem _lmad -> do
space <- entryMemSpace <$> lookupMemory mem
pure (mem, space)
pure $ Imp.ArrayValue mem space t signedness $ map f $ shapeDims shape
Expand Down Expand Up @@ -971,7 +971,7 @@ defCompileBasicOp (Pat [pe]) (ArrayLit es _)
emit $ Imp.DeclareArray static_array t $ Imp.ArrayValues vs
let static_src =
MemLoc static_array [intConst Int64 $ fromIntegral $ length es] $
IxFun.iota [fromIntegral $ length es]
LMAD.iota 0 [fromIntegral $ length es]
addVar static_array $ MemVar Nothing $ MemEntry DefaultSpace
copy t dest_mem static_src
| otherwise =
Expand Down Expand Up @@ -1118,8 +1118,8 @@ memBoundToVarEntry e (MemMem space) =
MemVar e $ MemEntry space
memBoundToVarEntry e (MemAcc acc ispace ts _) =
AccVar e (acc, ispace, ts)
memBoundToVarEntry e (MemArray bt shape _ (ArrayIn mem ixfun)) =
let location = MemLoc mem (shapeDims shape) ixfun
memBoundToVarEntry e (MemArray bt shape _ (ArrayIn mem lmad)) =
let location = MemLoc mem (shapeDims shape) $ IxFun.ixfunLMAD lmad
in ArrayVar
e
ArrayEntry
Expand Down Expand Up @@ -1162,12 +1162,11 @@ dScope ::
ImpM rep r op ()
dScope e = mapM_ (uncurry $ dInfo e) . M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op ()
dArray name pt shape mem ixfun =
dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op ()
dArray name pt shape mem lmad =
addVar name $ ArrayVar Nothing $ ArrayEntry location pt
where
location =
MemLoc mem (shapeDims shape) ixfun
location = MemLoc mem (shapeDims shape) lmad

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile = local $ \env -> env {envVolatility = Imp.Volatile}
Expand Down Expand Up @@ -1383,30 +1382,30 @@ fullyIndexArray' ::
MemLoc ->
[Imp.TExp Int64] ->
ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' (MemLoc mem _ ixfun) indices = do
fullyIndexArray' (MemLoc mem _ lmad) indices = do
space <- entryMemSpace <$> lookupMemory mem
pure
( mem,
space,
elements $ IxFun.index ixfun indices
elements $ LMAD.index lmad indices
)

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy
bt
dst@(MemLoc dst_name _ dst_ixfn@(IxFun.IxFun dst_lmad _))
src@(MemLoc src_name _ src_ixfn@(IxFun.IxFun src_lmad _)) = do
dst@(MemLoc dst_name _ dst_ixfn@dst_lmad)
src@(MemLoc src_name _ src_ixfn@src_lmad) = do
-- If we can statically determine that the two index-functions
-- are equivalent, don't do anything
unless (dst_name == src_name && dst_ixfn `IxFun.equivalent` src_ixfn)
unless (dst_name == src_name && dst_ixfn `LMAD.equivalent` src_ixfn)
$
-- It's also possible that we can dynamically determine that the two
-- index-functions are equivalent.
sUnless
( fromBool (dst_name == src_name)
.&&. IxFun.dynamicEqualsLMAD dst_lmad src_lmad
.&&. LMAD.dynamicEqualsLMAD dst_lmad src_lmad
)
$ do
-- If none of the above is true, actually do the copy
Expand All @@ -1417,8 +1416,8 @@ lmadCopy :: CopyCompiler rep r op
lmadCopy t dstloc srcloc = do
let dstmem = memLocName dstloc
srcmem = memLocName srcloc
dstlmad = IxFun.ixfunLMAD $ memLocIxFun dstloc
srclmad = IxFun.ixfunLMAD $ memLocIxFun srcloc
dstlmad = memLocLMAD dstloc
srclmad = memLocLMAD srcloc
srcspace <- entryMemSpace <$> lookupMemory srcmem
dstspace <- entryMemSpace <$> lookupMemory dstmem
emit $
Expand Down Expand Up @@ -1722,7 +1721,7 @@ sAlloc name size space = do
sAlloc_ name' size space
pure name'

sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op VName
sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op VName
sArray name bt shape mem ixfun = do
name' <- newVName name
dArray name' bt shape mem ixfun
Expand All @@ -1732,7 +1731,7 @@ sArray name bt shape mem ixfun = do
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem name pt shape mem =
sArray name pt shape mem $
IxFun.iota $
LMAD.iota 0 $
map (isInt64 . primExpFromSubExp int64) $
shapeDims shape

Expand All @@ -1741,9 +1740,9 @@ sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> I
sAllocArrayPerm name pt shape space perm = do
let permuted_dims = rearrangeShape perm $ shapeDims shape
mem <- sAlloc (name ++ "_mem") (typeSize (Array pt shape NoUniqueness)) space
let iota_ixfun = IxFun.iota $ map (isInt64 . primExpFromSubExp int64) permuted_dims
let iota_ixfun = LMAD.iota 0 $ map (isInt64 . primExpFromSubExp int64) permuted_dims
sArray name pt shape mem $
IxFun.permute iota_ixfun $
LMAD.permute iota_ixfun $
rearrangeInverse perm

-- | Uses linear/iota index function.
Expand All @@ -1761,7 +1760,7 @@ sStaticArray name pt vs = do
mem <- newVNameForFun $ name ++ "_mem"
emit $ Imp.DeclareArray mem pt vs
addVar mem $ MemVar Nothing $ MemEntry DefaultSpace
sArray name pt shape mem $ IxFun.iota [fromIntegral num_elems]
sArray name pt shape mem $ LMAD.iota 0 [fromIntegral num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite arr is v = do
Expand Down
12 changes: 5 additions & 7 deletions src/Futhark/CodeGen/ImpGen/GPU/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (dropLast, nubOrd, splitFromEnd)
Expand Down Expand Up @@ -1251,7 +1251,7 @@ replicateForType bt = do
shape = Shape [Var num_elems]
function fname [] params $ do
arr <-
sArray "arr" bt shape mem $ IxFun.iota $ map pe64 $ shapeDims shape
sArray "arr" bt shape mem $ LMAD.iota 0 $ map pe64 $ shapeDims shape
sReplicateKernel arr $ Var val

pure fname
Expand All @@ -1262,7 +1262,7 @@ replicateIsFill arr v = do
v_t <- subExpType v
case v_t of
Prim v_t'
| IxFun.isDirect arr_ixfun -> pure $
| LMAD.isDirect arr_ixfun -> pure $
Just $ do
fname <- replicateForType v_t'
emit $
Expand Down Expand Up @@ -1347,9 +1347,7 @@ iotaForType bt = do
function fname [] params $ do
arr <-
sArray "arr" (IntType bt) shape mem $
IxFun.iota $
map pe64 $
shapeDims shape
LMAD.iota 0 (map pe64 (shapeDims shape))
sIotaKernel arr (sExt64 n') x' s' bt

pure fname
Expand All @@ -1364,7 +1362,7 @@ sIota ::
CallKernelGen ()
sIota arr n x s et = do
ArrayEntry (MemLoc arr_mem _ arr_ixfun) _ <- lookupArray arr
if IxFun.isDirect arr_ixfun
if LMAD.isDirect arr_ixfun
then do
fname <- iotaForType et
emit $
Expand Down
14 changes: 6 additions & 8 deletions src/Futhark/CodeGen/ImpGen/GPU/Group.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
Expand All @@ -42,9 +42,7 @@ flattenArray k flat arr = do
let flat_shape = Shape $ Var (tvVar flat) : drop k (memLocShape arr_loc)
sArray (baseString arr ++ "_flat") pt flat_shape (memLocName arr_loc) $
fromMaybe (error "flattenArray") $
IxFun.reshape (memLocIxFun arr_loc) $
map pe64 $
shapeDims flat_shape
LMAD.reshape (memLocLMAD arr_loc) (map pe64 $ shapeDims flat_shape)

sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray start size arr = do
Expand All @@ -59,7 +57,7 @@ sliceArray start size arr = do
(elemType arr_t)
(arrayShape arr_t `setOuterDim` Var (tvVar size))
mem
$ IxFun.slice ixfun slice
$ LMAD.slice ixfun slice

-- | @applyLambda lam dests args@ emits code that:
--
Expand Down Expand Up @@ -154,8 +152,8 @@ copyInGroup pt destloc srcloc = do
dest_space <- entryMemSpace <$> lookupMemory (memLocName destloc)
src_space <- entryMemSpace <$> lookupMemory (memLocName srcloc)

let src_ixfun = memLocIxFun srcloc
dims = IxFun.shape src_ixfun
let src_lmad = memLocLMAD srcloc
dims = LMAD.shape src_lmad
rank = length dims

case (dest_space, src_space) of
Expand Down Expand Up @@ -229,7 +227,7 @@ prepareIntraGroupSegHist group_size =

locks_mem <- sAlloc "locks_mem" (typeSize locks_t) $ Space "local"
dArray locks int32 (arrayShape locks_t) locks_mem $
IxFun.iota . map pe64 . arrayDims $
LMAD.iota 0 . map pe64 . arrayDims $
locks_t

sComment "All locks start out unlocked" $
Expand Down
4 changes: 2 additions & 2 deletions src/Futhark/CodeGen/ImpGen/GPU/SegHist.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.SegRed (compileSegRed')
import Futhark.Construct (fullSliceNum)
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Pass.ExplicitAllocations ()
import Futhark.Util (chunks, mapAccumLM, maxinum, splitFromEnd, takeLast)
Expand Down Expand Up @@ -114,7 +114,7 @@ computeHistoUsage space op = do
(elemType dest_t)
subhistos_shape
subhistos_mem
$ IxFun.iota
$ LMAD.iota 0
$ map pe64
$ shapeDims subhistos_shape

Expand Down
6 changes: 2 additions & 4 deletions src/Futhark/CodeGen/ImpGen/GPU/SegRed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (chunks)
import Futhark.Util.IntegralExp (divUp, quot, rem)
Expand Down Expand Up @@ -142,9 +142,7 @@ intermediateArrays (Count group_size) num_threads (SegBinOp _ red_op nes _) = do
MemArray pt shape _ (ArrayIn mem _) -> do
let shape' = Shape [num_threads] <> shape
sArray "red_arr" pt shape' mem $
IxFun.iota $
map pe64 $
shapeDims shape'
LMAD.iota 0 (map pe64 $ shapeDims shape')
_ -> do
let pt = elemType $ paramType p
shape = Shape [group_size]
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/CodeGen/ImpGen/GPU/SegScan/SinglePass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
Expand Down Expand Up @@ -85,7 +85,7 @@ createLocalArrays (Count groupSize) m types = do
ty
(Shape [groupSize])
localMem
$ IxFun.iotaOffset off' [pe64 groupSize]
$ LMAD.iota off' [pe64 groupSize]

warpscan <- sArrayInMem "warpscan" int8 (Shape [constant (warpSize :: Int64)]) localMem
warpExchanges <-
Expand All @@ -96,7 +96,7 @@ createLocalArrays (Count groupSize) m types = do
ty
(Shape [constant (warpSize :: Int64)])
localMem
$ IxFun.iotaOffset off' [warpSize]
$ LMAD.iota off' [warpSize]

pure (sharedId, transposedArrays, prefixArrays, warpscan, warpExchanges)

Expand Down
6 changes: 2 additions & 4 deletions src/Futhark/CodeGen/ImpGen/GPU/SegScan/TwoPass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Transform.Rename
import Futhark.Util (takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
Expand Down Expand Up @@ -42,9 +42,7 @@ makeLocalArrays (Count group_size) num_threads scans = do
let shape' = Shape [num_threads] <> shape
arr <-
lift . sArray "scan_arr" pt shape' mem $
IxFun.iota $
map pe64 $
shapeDims shape'
LMAD.iota 0 (map pe64 $ shapeDims shape')
pure (arr, [])
_ -> do
let pt = elemType $ paramType p
Expand Down
4 changes: 4 additions & 0 deletions src/Futhark/IR/Mem.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ module Futhark.IR.Mem
MemReturn (..),
IxFun,
ExtIxFun,
LMAD,
isStaticIxFun,
ExpReturns,
BodyReturns,
Expand Down Expand Up @@ -250,6 +251,9 @@ instance ST.IndexOp (inner rep) => ST.IndexOp (MemOp inner rep) where
-- | The index function representation used for memory annotations.
type IxFun = IxFun.IxFun (TPrimExp Int64 VName)

-- | The LMAD representation used for memory annotations.
type LMAD = IxFun.LMAD (TPrimExp Int64 VName)

-- | An index function that may contain existential variables.
type ExtIxFun = IxFun.IxFun (TPrimExp Int64 (Ext VName))

Expand Down
11 changes: 4 additions & 7 deletions src/Futhark/IR/Mem/IxFun.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
( flatSlice,
( equivalent,
flatSlice,
index,
iota,
isDirect,
mkExistential,
permute,
rank,
Expand Down Expand Up @@ -280,9 +282,4 @@ closeEnough ixf1 ixf2 =
-- each pair of LMADs matching in permutation, offsets, and strides.
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent ixf1 ixf2 =
equivalentLMADs (ixfunLMAD ixf1) (ixfunLMAD ixf2)
where
equivalentLMADs lmad1 lmad2 =
length (LMAD.dims lmad1) == length (LMAD.dims lmad2)
&& LMAD.offset lmad1 == LMAD.offset lmad2
&& map ldStride (LMAD.dims lmad1) == map ldStride (LMAD.dims lmad2)
LMAD.equivalent (ixfunLMAD ixf1) (ixfunLMAD ixf2)
Loading

0 comments on commit 584ac4d

Please sign in to comment.