diff --git a/src/GraphBLAS-sharp.Backend/Common/ClArray.fs b/src/GraphBLAS-sharp.Backend/Common/ClArray.fs index ad6b3caf..260a3728 100644 --- a/src/GraphBLAS-sharp.Backend/Common/ClArray.fs +++ b/src/GraphBLAS-sharp.Backend/Common/ClArray.fs @@ -202,7 +202,7 @@ module ClArray = Bitmap.lastOccurrence clContext workGroupSize let prefixSumExclude = - PrefixSum.runExcludeInPlace <@ (+) @> clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) -> @@ -210,7 +210,7 @@ module ClArray = getUniqueBitmap processor DeviceOnly inputArray let resultLength = - (prefixSumExclude processor bitmap 0) + (prefixSumExclude processor bitmap) .ToHostAndFree(processor) let outputArray = @@ -314,7 +314,7 @@ module ClArray = Map.map<'a, int> (Map.chooseBitmap predicate) clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let assignValues = assignOption predicate clContext workGroupSize @@ -410,7 +410,7 @@ module ClArray = Map.map2<'a, 'b, int> (Map.choose2Bitmap predicate) clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let assignValues = assignOption2 predicate clContext workGroupSize @@ -878,7 +878,7 @@ module ClArray = mapInPlace ArithmeticOperations.intNotQ clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let scatter = Scatter.lastOccurrence clContext workGroupSize diff --git a/src/GraphBLAS-sharp.Backend/Common/Common.fs b/src/GraphBLAS-sharp.Backend/Common/Common.fs index ae9839f2..bd99652c 100644 --- a/src/GraphBLAS-sharp.Backend/Common/Common.fs +++ b/src/GraphBLAS-sharp.Backend/Common/Common.fs @@ -211,7 +211,7 @@ module Common = /// Should be a power of 2 and greater than 1. /// Associative binary operation. /// Zero element for binary operation. - let runExcludeInPlace plus = PrefixSum.runExcludeInPlace plus + let runExcludeInPlace plus = ScanInternal.runExcludeInPlace plus /// /// Include in-place prefix sum. @@ -231,7 +231,8 @@ module Common = /// ClContext. /// Should be a power of 2 and greater than 1. /// Zero element for binary operation. - let runIncludeInPlace plus = PrefixSum.runIncludeInPlace plus + let runIncludeInPlace plus = + PrefixSumInternal.runIncludeInPlace plus /// /// Exclude in-place prefix sum. Array is scanned starting from the end. @@ -241,7 +242,7 @@ module Common = /// Should be a power of 2 and greater than 1. /// Zero element for binary operation. let runBackwardsExcludeInPlace plus = - PrefixSum.runBackwardsExcludeInPlace plus + PrefixSumInternal.runBackwardsExcludeInPlace plus /// /// Include in-place prefix sum. Array is scanned starting from the end. @@ -251,7 +252,7 @@ module Common = /// Should be a power of 2 and greater than 1. /// Zero element for binary operation. let runBackwardsIncludeInPlace plus = - PrefixSum.runBackwardsIncludeInPlace plus + PrefixSumInternal.runBackwardsIncludeInPlace plus /// /// Exclude in-place prefix sum of integer array with addition operation and start value that is equal to 0. @@ -267,7 +268,7 @@ module Common = /// > val sum = [| 4 |] /// /// - let standardExcludeInPlace = PrefixSum.standardExcludeInPlace + let standardExcludeInPlace = ScanInternal.standardExcludeInPlace /// /// Include in-place prefix sum of integer array with addition operation and start value that is equal to 0. @@ -285,7 +286,7 @@ module Common = /// /// ClContext. /// Should be a power of 2 and greater than 1. - let standardIncludeInPlace = PrefixSum.standardIncludeInPlace + let standardIncludeInPlace = PrefixSumInternal.standardIncludeInPlace module ByKey = /// @@ -299,7 +300,8 @@ module Common = /// > val result = [| 0; 0; 1; 2; 0; 1 |] /// /// - let sequentialExclude op = PrefixSum.ByKey.sequentialExclude op + let sequentialExclude op = + PrefixSumInternal.ByKey.sequentialExclude op /// /// Include scan by key. @@ -312,7 +314,8 @@ module Common = /// > val result = [| 1; 1; 2; 3; 1; 2 |] /// /// - let sequentialInclude op = PrefixSum.ByKey.sequentialInclude op + let sequentialInclude op = + PrefixSumInternal.ByKey.sequentialInclude op module Reduce = /// diff --git a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs index a56af91f..c606a087 100644 --- a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs +++ b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs @@ -6,7 +6,7 @@ open GraphBLAS.FSharp.Backend.Quotes open GraphBLAS.FSharp.Objects.ArraysExtensions open GraphBLAS.FSharp.Objects.ClCellExtensions -module PrefixSum = +module internal PrefixSumInternal = let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize = let update = @@ -224,6 +224,8 @@ module PrefixSum = /// /// ClContext. /// Should be a power of 2 and greater than 1. + [] let standardExcludeInPlace (clContext: ClContext) workGroupSize = let scan = diff --git a/src/GraphBLAS-sharp.Backend/Common/Scan.fs b/src/GraphBLAS-sharp.Backend/Common/Scan.fs new file mode 100644 index 00000000..636d89c1 --- /dev/null +++ b/src/GraphBLAS-sharp.Backend/Common/Scan.fs @@ -0,0 +1,270 @@ +namespace GraphBLAS.FSharp.Backend.Common + +open Brahma.FSharp +open FSharp.Quotations +open GraphBLAS.FSharp.Objects.ArraysExtensions +open GraphBLAS.FSharp.Objects.ClContextExtensions + +module internal ScanInternal = + + let private preScan + (opAdd: Expr<'a -> 'a -> 'a>) + (zero: 'a) + (saveSum: bool) + (clContext: ClContext) + (workGroupSize: int) + = + + let blockSize = + min clContext.ClDevice.MaxWorkGroupSize 256 + + let valuesPerBlock = 2 * blockSize + let numberOfMemBanks = 32 + + let localArraySize = + valuesPerBlock + + (valuesPerBlock / numberOfMemBanks) + + let getIndex = + <@ fun index -> index + (index / numberOfMemBanks) @> + + let preScan = + <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (carryBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) -> + let gid = ndRange.GlobalID0 / blockSize + let lid = ndRange.LocalID0 + let gstart = gid * blockSize * 2 + + let sumValues = localArray<'a> localArraySize + + //Load values + if (gstart + lid + blockSize * 0) < valuesLength then + sumValues.[(%getIndex) (lid + blockSize * 0)] <- valuesBuffer.[gstart + lid + blockSize * 0] + else + sumValues.[(%getIndex) (lid + blockSize * 0)] <- zero + + + if (gstart + lid + blockSize * 1) < valuesLength then + sumValues.[(%getIndex) (lid + blockSize * 1)] <- valuesBuffer.[gstart + lid + blockSize * 1] + else + sumValues.[(%getIndex) (lid + blockSize * 1)] <- zero + + //Sweep up + let mutable offset = 1 + let mutable d = blockSize + + while d > 0 do + barrierLocal () + + if lid < d then + let ai = (%getIndex) (offset * (2 * lid + 1) - 1) + let bi = (%getIndex) (offset * (2 * lid + 2) - 1) + sumValues.[bi] <- (%opAdd) sumValues.[bi] sumValues.[ai] + + offset <- offset * 2 + d <- d / 2 + + barrierLocal () + + if lid = 0 then + let ai = (%getIndex) (2 * blockSize - 1) + carryBuffer.[gid] <- sumValues.[ai] + sumValues.[ai] <- zero + + // This condition means this thread will rewrite last element in array + // Saving it here for totalSum + if saveSum + && (gstart + lid + blockSize * 1 = valuesLength - 1 + || gstart + lid + blockSize * 0 = valuesLength - 1) then + totalSumCell.Value <- valuesBuffer.[valuesLength - 1] + + //Sweep down + d <- 1 + + while d <= blockSize do + barrierLocal () + + offset <- offset / 2 + + if lid < d then + let ai = (%getIndex) (offset * (2 * lid + 1) - 1) + let bi = (%getIndex) (offset * (2 * lid + 2) - 1) + + let tmp = sumValues.[ai] + sumValues.[ai] <- sumValues.[bi] + sumValues.[bi] <- (%opAdd) sumValues.[bi] tmp + + d <- d * 2 + + barrierLocal () + + if (gstart + lid + blockSize * 0) < valuesLength then + valuesBuffer.[gstart + lid + blockSize * 0] <- sumValues.[(%getIndex) (lid + blockSize * 0)] + + if (gstart + lid + blockSize * 1) < valuesLength then + valuesBuffer.[gstart + lid + blockSize * 1] <- sumValues.[(%getIndex) (lid + blockSize * 1)] @> + + let preScan = clContext.Compile(preScan) + + fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (totalSum: ClCell<'a>) -> + let numberOfGroups = + inputArray.Length / valuesPerBlock + + (if inputArray.Length % valuesPerBlock = 0 then + 0 + else + 1) + + let carry = + clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, numberOfGroups) + + let ndRangePreScan = + Range1D.CreateValid(numberOfGroups * blockSize, blockSize) + + let preScanKernel = preScan.GetKernel() + + processor.Post( + Msg.MsgSetArguments + (fun () -> preScanKernel.KernelFunc ndRangePreScan inputArray.Length inputArray carry totalSum) + ) + + processor.Post(Msg.CreateRunMsg<_, _>(preScanKernel)) + + carry, numberOfGroups > 1 + + let private scan (opAdd: Expr<'a -> 'a -> 'a>) (saveSum: bool) (clContext: ClContext) (workGroupSize: int) = + + let blockSize = + min clContext.ClDevice.MaxWorkGroupSize 256 + + let valuesPerBlock = 2 * blockSize + + let scan = + <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (carryBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) -> + let gid = ndRange.GlobalID0 + 2 * blockSize + let cid = gid / (2 * blockSize) + + if gid < valuesLength then + valuesBuffer.[gid] <- (%opAdd) valuesBuffer.[gid] carryBuffer.[cid] + + if saveSum && gid = valuesLength - 1 then + totalSumCell.Value <- (%opAdd) totalSumCell.Value valuesBuffer.[gid] @> + + let scan = clContext.Compile(scan) + + fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (carry: ClArray<'a>) (totalSum: ClCell<'a>) -> + let numberOfGroups = + inputArray.Length / valuesPerBlock + + (if inputArray.Length % valuesPerBlock = 0 then + 0 + else + 1) + + let ndRangeScan = + Range1D.CreateValid((numberOfGroups - 1) * valuesPerBlock, blockSize) + + let scan = scan.GetKernel() + + processor.Post( + Msg.MsgSetArguments(fun () -> scan.KernelFunc ndRangeScan inputArray.Length inputArray carry totalSum) + ) + + processor.Post(Msg.CreateRunMsg<_, _>(scan)) + + let runExcludeInPlace plus zero (clContext: ClContext) workGroupSize = + + let blockSize = + min clContext.ClDevice.MaxWorkGroupSize 256 + + let valuesPerBlock = 2 * blockSize + + let getTotalSum = + <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) -> + totalSumCell.Value <- (%plus) valuesBuffer.[valuesLength - 1] totalSumCell.Value @> + + let preScanSaveSum = + preScan plus zero true clContext workGroupSize + + let preScan = + preScan plus zero false clContext workGroupSize + + let scanSaveSum = scan plus true clContext workGroupSize + let scan = scan plus false clContext workGroupSize + let getTotalSum = clContext.Compile(getTotalSum) + + fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) -> + + let totalSum = clContext.CreateClCell<'a>() + + let carry, needRecursion = + preScanSaveSum processor inputArray totalSum + + if not needRecursion then + carry.Free processor + + let ndRangeTotalSum = Range1D.CreateValid(1, 1) + let getTotalSum = getTotalSum.GetKernel() + + processor.Post( + Msg.MsgSetArguments + (fun () -> getTotalSum.KernelFunc ndRangeTotalSum inputArray.Length inputArray totalSum) + ) + + processor.Post(Msg.CreateRunMsg<_, _>(getTotalSum)) + else + let mutable carryStack = [ carry; inputArray ] + let mutable stop = not needRecursion + + // Run preScan for carry until we get fully scanned carry + // If during preScan numberOfGroups = 1 means input is fully scanned + while not stop do + let input = carryStack.Head + let carry, needRecursion = preScan processor input totalSum + + if needRecursion then + carryStack <- carry :: carryStack + else + stop <- true + carry.Free processor + + stop <- false + + // Run scan for each not fully scanned carry until we get inputArray scanned + while not stop do + match carryStack with + | carry :: inputCarry :: tail -> + if tail.IsEmpty then + scanSaveSum processor inputCarry carry totalSum + stop <- true + else + scan processor inputCarry carry totalSum + + carry.Free processor + carryStack <- carryStack.Tail + | _ -> failwith "carryStack always has at least 2 elements" + + totalSum + + /// + /// Exclude in-place prefix sum of integer array with addition operation and start value that is equal to 0. + /// + /// + /// + /// let arr = [| 1; 1; 1; 1 |] + /// let sum = [| 0 |] + /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0 + /// |> ignore + /// ... + /// > val arr = [| 0; 1; 2; 3 |] + /// > val sum = [| 4 |] + /// + /// + /// ClContext. + /// Should be a power of 2 and greater than 1. + /// Note that maximum possible workGroupSize is used for better perfomance + let standardExcludeInPlace (clContext: ClContext) workGroupSize = + + let scan = + runExcludeInPlace <@ (+) @> 0 clContext workGroupSize + + fun (processor: MailboxProcessor<_>) (inputArray: ClArray) -> + + scan processor inputArray diff --git a/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs b/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs index b7db92d0..abace8a4 100644 --- a/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs +++ b/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs @@ -156,7 +156,7 @@ module internal Radix = let count = count clContext workGroupSize mask let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let scatter = scatter clContext workGroupSize mask @@ -259,7 +259,7 @@ module internal Radix = let count = count clContext workGroupSize mask let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let scatterByKey = scatterByKey clContext workGroupSize mask diff --git a/src/GraphBLAS-sharp.Backend/Common/Sum.fs b/src/GraphBLAS-sharp.Backend/Common/Sum.fs index 94a745c2..73aad03b 100644 --- a/src/GraphBLAS-sharp.Backend/Common/Sum.fs +++ b/src/GraphBLAS-sharp.Backend/Common/Sum.fs @@ -529,7 +529,7 @@ module Reduce = Scatter.lastOccurrence clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize fun (processor: MailboxProcessor<_>) allocationMode (keys: ClArray) (values: ClArray<'a option>) -> @@ -661,7 +661,7 @@ module Reduce = Scatter.lastOccurrence clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray) (keys: ClArray) (values: ClArray<'a>) -> @@ -940,7 +940,7 @@ module Reduce = Scatter.lastOccurrence clContext workGroupSize let prefixSum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray) (firstKeys: ClArray) (secondKeys: ClArray) (values: ClArray<'a>) -> diff --git a/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj b/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj index cac6456d..4300c9ec 100644 --- a/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj +++ b/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj @@ -32,6 +32,7 @@ + diff --git a/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs b/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs index bca119a1..2ec56f1e 100644 --- a/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs +++ b/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs @@ -65,7 +65,7 @@ module SpMSpV = inputArray.[i] <- 0 @> let sum = - PrefixSum.standardExcludeInPlace clContext workGroupSize + ScanInternal.standardExcludeInPlace clContext workGroupSize let prepareOffsets = clContext.Compile prepareOffsets diff --git a/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs b/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs index f94b0564..3773eefd 100644 --- a/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs +++ b/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs @@ -13,33 +13,51 @@ let logger = Log.create "ClArray.PrefixSum.Tests" let context = defaultContext.ClContext -let config = Tests.Utils.defaultConfig +let config = + { Tests.Utils.defaultConfig with + maxTest = 20 + startSize = 1 + endSize = 1000000 } let wgSize = 128 let q = defaultContext.Queue -let makeTest plus zero isEqual scan (array: 'a []) = +let makeTest plus zero isEqual scanInclude scanExclude (array: 'a []) = if array.Length > 0 then + // Exclude + let actual, actualSum = + let clArray = context.CreateClArray array + let (total: ClCell<_>) = scanExclude q clArray + + let actual = clArray.ToHostAndFree q + let actualSum = total.ToHostAndFree q - logger.debug ( - eventX $"Array is %A{array}\n" - >> setField "array" (sprintf "%A" array) - ) + actual, actualSum + let expected, expectedSum = + array + |> Array.mapFold + (fun s t -> + let a = plus s t + s, a) + zero + + "Arrays for exclude should be the same" + |> Tests.Utils.compareArrays isEqual actual expected + + "Total sums for exclude should be equal" + |> Expect.equal actualSum expectedSum + + // Include let actual, actualSum = let clArray = context.CreateClArray array - let (total: ClCell<_>) = scan q clArray zero + let (total: ClCell<_>) = scanInclude q clArray zero let actual = clArray.ToHostAndFree q let actualSum = total.ToHostAndFree q actual, actualSum - logger.debug ( - eventX "Actual is {actual}\n" - >> setField "actual" (sprintf "%A" actual) - ) - let expected, expectedSum = array |> Array.mapFold @@ -48,20 +66,15 @@ let makeTest plus zero isEqual scan (array: 'a []) = a, a) zero - logger.debug ( - eventX "Expected is {expected}\n" - >> setField "expected" (sprintf "%A" expected) - ) - - "Total sums should be equal" + "Total sums for include should be equal" |> Expect.equal actualSum expectedSum - "Arrays should be the same" + "Arrays for include should be the same" |> Tests.Utils.compareArrays isEqual actual expected let testFixtures plus plusQ zero isEqual name = - Common.PrefixSum.runIncludeInPlace plusQ context wgSize - |> makeTest plus zero isEqual + (PrefixSum.runIncludeInPlace plusQ context wgSize, PrefixSum.runExcludeInPlace plusQ zero context wgSize) + ||> makeTest plus zero isEqual |> testPropertyWithConfig config $"Correctness on %s{name}" let tests =