diff --git a/src/GraphBLAS-sharp.Backend/Common/ClArray.fs b/src/GraphBLAS-sharp.Backend/Common/ClArray.fs
index ad6b3caf..260a3728 100644
--- a/src/GraphBLAS-sharp.Backend/Common/ClArray.fs
+++ b/src/GraphBLAS-sharp.Backend/Common/ClArray.fs
@@ -202,7 +202,7 @@ module ClArray =
Bitmap.lastOccurrence clContext workGroupSize
let prefixSumExclude =
- PrefixSum.runExcludeInPlace <@ (+) @> clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
@@ -210,7 +210,7 @@ module ClArray =
getUniqueBitmap processor DeviceOnly inputArray
let resultLength =
- (prefixSumExclude processor bitmap 0)
+ (prefixSumExclude processor bitmap)
.ToHostAndFree(processor)
let outputArray =
@@ -314,7 +314,7 @@ module ClArray =
Map.map<'a, int> (Map.chooseBitmap predicate) clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let assignValues =
assignOption predicate clContext workGroupSize
@@ -410,7 +410,7 @@ module ClArray =
Map.map2<'a, 'b, int> (Map.choose2Bitmap predicate) clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let assignValues =
assignOption2 predicate clContext workGroupSize
@@ -878,7 +878,7 @@ module ClArray =
mapInPlace ArithmeticOperations.intNotQ clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let scatter =
Scatter.lastOccurrence clContext workGroupSize
diff --git a/src/GraphBLAS-sharp.Backend/Common/Common.fs b/src/GraphBLAS-sharp.Backend/Common/Common.fs
index ae9839f2..bd99652c 100644
--- a/src/GraphBLAS-sharp.Backend/Common/Common.fs
+++ b/src/GraphBLAS-sharp.Backend/Common/Common.fs
@@ -211,7 +211,7 @@ module Common =
/// Should be a power of 2 and greater than 1.
/// Associative binary operation.
/// Zero element for binary operation.
- let runExcludeInPlace plus = PrefixSum.runExcludeInPlace plus
+ let runExcludeInPlace plus = ScanInternal.runExcludeInPlace plus
///
/// Include in-place prefix sum.
@@ -231,7 +231,8 @@ module Common =
/// ClContext.
/// Should be a power of 2 and greater than 1.
/// Zero element for binary operation.
- let runIncludeInPlace plus = PrefixSum.runIncludeInPlace plus
+ let runIncludeInPlace plus =
+ PrefixSumInternal.runIncludeInPlace plus
///
/// Exclude in-place prefix sum. Array is scanned starting from the end.
@@ -241,7 +242,7 @@ module Common =
/// Should be a power of 2 and greater than 1.
/// Zero element for binary operation.
let runBackwardsExcludeInPlace plus =
- PrefixSum.runBackwardsExcludeInPlace plus
+ PrefixSumInternal.runBackwardsExcludeInPlace plus
///
/// Include in-place prefix sum. Array is scanned starting from the end.
@@ -251,7 +252,7 @@ module Common =
/// Should be a power of 2 and greater than 1.
/// Zero element for binary operation.
let runBackwardsIncludeInPlace plus =
- PrefixSum.runBackwardsIncludeInPlace plus
+ PrefixSumInternal.runBackwardsIncludeInPlace plus
///
/// Exclude in-place prefix sum of integer array with addition operation and start value that is equal to 0.
@@ -267,7 +268,7 @@ module Common =
/// > val sum = [| 4 |]
///
///
- let standardExcludeInPlace = PrefixSum.standardExcludeInPlace
+ let standardExcludeInPlace = ScanInternal.standardExcludeInPlace
///
/// Include in-place prefix sum of integer array with addition operation and start value that is equal to 0.
@@ -285,7 +286,7 @@ module Common =
///
/// ClContext.
/// Should be a power of 2 and greater than 1.
- let standardIncludeInPlace = PrefixSum.standardIncludeInPlace
+ let standardIncludeInPlace = PrefixSumInternal.standardIncludeInPlace
module ByKey =
///
@@ -299,7 +300,8 @@ module Common =
/// > val result = [| 0; 0; 1; 2; 0; 1 |]
///
///
- let sequentialExclude op = PrefixSum.ByKey.sequentialExclude op
+ let sequentialExclude op =
+ PrefixSumInternal.ByKey.sequentialExclude op
///
/// Include scan by key.
@@ -312,7 +314,8 @@ module Common =
/// > val result = [| 1; 1; 2; 3; 1; 2 |]
///
///
- let sequentialInclude op = PrefixSum.ByKey.sequentialInclude op
+ let sequentialInclude op =
+ PrefixSumInternal.ByKey.sequentialInclude op
module Reduce =
///
diff --git a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs
index a56af91f..c606a087 100644
--- a/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs
+++ b/src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs
@@ -6,7 +6,7 @@ open GraphBLAS.FSharp.Backend.Quotes
open GraphBLAS.FSharp.Objects.ArraysExtensions
open GraphBLAS.FSharp.Objects.ClCellExtensions
-module PrefixSum =
+module internal PrefixSumInternal =
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
let update =
@@ -224,6 +224,8 @@ module PrefixSum =
///
/// ClContext.
/// Should be a power of 2 and greater than 1.
+ []
let standardExcludeInPlace (clContext: ClContext) workGroupSize =
let scan =
diff --git a/src/GraphBLAS-sharp.Backend/Common/Scan.fs b/src/GraphBLAS-sharp.Backend/Common/Scan.fs
new file mode 100644
index 00000000..636d89c1
--- /dev/null
+++ b/src/GraphBLAS-sharp.Backend/Common/Scan.fs
@@ -0,0 +1,270 @@
+namespace GraphBLAS.FSharp.Backend.Common
+
+open Brahma.FSharp
+open FSharp.Quotations
+open GraphBLAS.FSharp.Objects.ArraysExtensions
+open GraphBLAS.FSharp.Objects.ClContextExtensions
+
+module internal ScanInternal =
+
+ let private preScan
+ (opAdd: Expr<'a -> 'a -> 'a>)
+ (zero: 'a)
+ (saveSum: bool)
+ (clContext: ClContext)
+ (workGroupSize: int)
+ =
+
+ let blockSize =
+ min clContext.ClDevice.MaxWorkGroupSize 256
+
+ let valuesPerBlock = 2 * blockSize
+ let numberOfMemBanks = 32
+
+ let localArraySize =
+ valuesPerBlock
+ + (valuesPerBlock / numberOfMemBanks)
+
+ let getIndex =
+ <@ fun index -> index + (index / numberOfMemBanks) @>
+
+ let preScan =
+ <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (carryBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) ->
+ let gid = ndRange.GlobalID0 / blockSize
+ let lid = ndRange.LocalID0
+ let gstart = gid * blockSize * 2
+
+ let sumValues = localArray<'a> localArraySize
+
+ //Load values
+ if (gstart + lid + blockSize * 0) < valuesLength then
+ sumValues.[(%getIndex) (lid + blockSize * 0)] <- valuesBuffer.[gstart + lid + blockSize * 0]
+ else
+ sumValues.[(%getIndex) (lid + blockSize * 0)] <- zero
+
+
+ if (gstart + lid + blockSize * 1) < valuesLength then
+ sumValues.[(%getIndex) (lid + blockSize * 1)] <- valuesBuffer.[gstart + lid + blockSize * 1]
+ else
+ sumValues.[(%getIndex) (lid + blockSize * 1)] <- zero
+
+ //Sweep up
+ let mutable offset = 1
+ let mutable d = blockSize
+
+ while d > 0 do
+ barrierLocal ()
+
+ if lid < d then
+ let ai = (%getIndex) (offset * (2 * lid + 1) - 1)
+ let bi = (%getIndex) (offset * (2 * lid + 2) - 1)
+ sumValues.[bi] <- (%opAdd) sumValues.[bi] sumValues.[ai]
+
+ offset <- offset * 2
+ d <- d / 2
+
+ barrierLocal ()
+
+ if lid = 0 then
+ let ai = (%getIndex) (2 * blockSize - 1)
+ carryBuffer.[gid] <- sumValues.[ai]
+ sumValues.[ai] <- zero
+
+ // This condition means this thread will rewrite last element in array
+ // Saving it here for totalSum
+ if saveSum
+ && (gstart + lid + blockSize * 1 = valuesLength - 1
+ || gstart + lid + blockSize * 0 = valuesLength - 1) then
+ totalSumCell.Value <- valuesBuffer.[valuesLength - 1]
+
+ //Sweep down
+ d <- 1
+
+ while d <= blockSize do
+ barrierLocal ()
+
+ offset <- offset / 2
+
+ if lid < d then
+ let ai = (%getIndex) (offset * (2 * lid + 1) - 1)
+ let bi = (%getIndex) (offset * (2 * lid + 2) - 1)
+
+ let tmp = sumValues.[ai]
+ sumValues.[ai] <- sumValues.[bi]
+ sumValues.[bi] <- (%opAdd) sumValues.[bi] tmp
+
+ d <- d * 2
+
+ barrierLocal ()
+
+ if (gstart + lid + blockSize * 0) < valuesLength then
+ valuesBuffer.[gstart + lid + blockSize * 0] <- sumValues.[(%getIndex) (lid + blockSize * 0)]
+
+ if (gstart + lid + blockSize * 1) < valuesLength then
+ valuesBuffer.[gstart + lid + blockSize * 1] <- sumValues.[(%getIndex) (lid + blockSize * 1)] @>
+
+ let preScan = clContext.Compile(preScan)
+
+ fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (totalSum: ClCell<'a>) ->
+ let numberOfGroups =
+ inputArray.Length / valuesPerBlock
+ + (if inputArray.Length % valuesPerBlock = 0 then
+ 0
+ else
+ 1)
+
+ let carry =
+ clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, numberOfGroups)
+
+ let ndRangePreScan =
+ Range1D.CreateValid(numberOfGroups * blockSize, blockSize)
+
+ let preScanKernel = preScan.GetKernel()
+
+ processor.Post(
+ Msg.MsgSetArguments
+ (fun () -> preScanKernel.KernelFunc ndRangePreScan inputArray.Length inputArray carry totalSum)
+ )
+
+ processor.Post(Msg.CreateRunMsg<_, _>(preScanKernel))
+
+ carry, numberOfGroups > 1
+
+ let private scan (opAdd: Expr<'a -> 'a -> 'a>) (saveSum: bool) (clContext: ClContext) (workGroupSize: int) =
+
+ let blockSize =
+ min clContext.ClDevice.MaxWorkGroupSize 256
+
+ let valuesPerBlock = 2 * blockSize
+
+ let scan =
+ <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (carryBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) ->
+ let gid = ndRange.GlobalID0 + 2 * blockSize
+ let cid = gid / (2 * blockSize)
+
+ if gid < valuesLength then
+ valuesBuffer.[gid] <- (%opAdd) valuesBuffer.[gid] carryBuffer.[cid]
+
+ if saveSum && gid = valuesLength - 1 then
+ totalSumCell.Value <- (%opAdd) totalSumCell.Value valuesBuffer.[gid] @>
+
+ let scan = clContext.Compile(scan)
+
+ fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (carry: ClArray<'a>) (totalSum: ClCell<'a>) ->
+ let numberOfGroups =
+ inputArray.Length / valuesPerBlock
+ + (if inputArray.Length % valuesPerBlock = 0 then
+ 0
+ else
+ 1)
+
+ let ndRangeScan =
+ Range1D.CreateValid((numberOfGroups - 1) * valuesPerBlock, blockSize)
+
+ let scan = scan.GetKernel()
+
+ processor.Post(
+ Msg.MsgSetArguments(fun () -> scan.KernelFunc ndRangeScan inputArray.Length inputArray carry totalSum)
+ )
+
+ processor.Post(Msg.CreateRunMsg<_, _>(scan))
+
+ let runExcludeInPlace plus zero (clContext: ClContext) workGroupSize =
+
+ let blockSize =
+ min clContext.ClDevice.MaxWorkGroupSize 256
+
+ let valuesPerBlock = 2 * blockSize
+
+ let getTotalSum =
+ <@ fun (ndRange: Range1D) (valuesLength: int) (valuesBuffer: ClArray<'a>) (totalSumCell: ClCell<'a>) ->
+ totalSumCell.Value <- (%plus) valuesBuffer.[valuesLength - 1] totalSumCell.Value @>
+
+ let preScanSaveSum =
+ preScan plus zero true clContext workGroupSize
+
+ let preScan =
+ preScan plus zero false clContext workGroupSize
+
+ let scanSaveSum = scan plus true clContext workGroupSize
+ let scan = scan plus false clContext workGroupSize
+ let getTotalSum = clContext.Compile(getTotalSum)
+
+ fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
+
+ let totalSum = clContext.CreateClCell<'a>()
+
+ let carry, needRecursion =
+ preScanSaveSum processor inputArray totalSum
+
+ if not needRecursion then
+ carry.Free processor
+
+ let ndRangeTotalSum = Range1D.CreateValid(1, 1)
+ let getTotalSum = getTotalSum.GetKernel()
+
+ processor.Post(
+ Msg.MsgSetArguments
+ (fun () -> getTotalSum.KernelFunc ndRangeTotalSum inputArray.Length inputArray totalSum)
+ )
+
+ processor.Post(Msg.CreateRunMsg<_, _>(getTotalSum))
+ else
+ let mutable carryStack = [ carry; inputArray ]
+ let mutable stop = not needRecursion
+
+ // Run preScan for carry until we get fully scanned carry
+ // If during preScan numberOfGroups = 1 means input is fully scanned
+ while not stop do
+ let input = carryStack.Head
+ let carry, needRecursion = preScan processor input totalSum
+
+ if needRecursion then
+ carryStack <- carry :: carryStack
+ else
+ stop <- true
+ carry.Free processor
+
+ stop <- false
+
+ // Run scan for each not fully scanned carry until we get inputArray scanned
+ while not stop do
+ match carryStack with
+ | carry :: inputCarry :: tail ->
+ if tail.IsEmpty then
+ scanSaveSum processor inputCarry carry totalSum
+ stop <- true
+ else
+ scan processor inputCarry carry totalSum
+
+ carry.Free processor
+ carryStack <- carryStack.Tail
+ | _ -> failwith "carryStack always has at least 2 elements"
+
+ totalSum
+
+ ///
+ /// Exclude in-place prefix sum of integer array with addition operation and start value that is equal to 0.
+ ///
+ ///
+ ///
+ /// let arr = [| 1; 1; 1; 1 |]
+ /// let sum = [| 0 |]
+ /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
+ /// |> ignore
+ /// ...
+ /// > val arr = [| 0; 1; 2; 3 |]
+ /// > val sum = [| 4 |]
+ ///
+ ///
+ /// ClContext.
+ /// Should be a power of 2 and greater than 1.
+ /// Note that maximum possible workGroupSize is used for better perfomance
+ let standardExcludeInPlace (clContext: ClContext) workGroupSize =
+
+ let scan =
+ runExcludeInPlace <@ (+) @> 0 clContext workGroupSize
+
+ fun (processor: MailboxProcessor<_>) (inputArray: ClArray) ->
+
+ scan processor inputArray
diff --git a/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs b/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs
index b7db92d0..abace8a4 100644
--- a/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs
+++ b/src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs
@@ -156,7 +156,7 @@ module internal Radix =
let count = count clContext workGroupSize mask
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let scatter = scatter clContext workGroupSize mask
@@ -259,7 +259,7 @@ module internal Radix =
let count = count clContext workGroupSize mask
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let scatterByKey =
scatterByKey clContext workGroupSize mask
diff --git a/src/GraphBLAS-sharp.Backend/Common/Sum.fs b/src/GraphBLAS-sharp.Backend/Common/Sum.fs
index 94a745c2..73aad03b 100644
--- a/src/GraphBLAS-sharp.Backend/Common/Sum.fs
+++ b/src/GraphBLAS-sharp.Backend/Common/Sum.fs
@@ -529,7 +529,7 @@ module Reduce =
Scatter.lastOccurrence clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
fun (processor: MailboxProcessor<_>) allocationMode (keys: ClArray) (values: ClArray<'a option>) ->
@@ -661,7 +661,7 @@ module Reduce =
Scatter.lastOccurrence clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray) (keys: ClArray) (values: ClArray<'a>) ->
@@ -940,7 +940,7 @@ module Reduce =
Scatter.lastOccurrence clContext workGroupSize
let prefixSum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray) (firstKeys: ClArray) (secondKeys: ClArray) (values: ClArray<'a>) ->
diff --git a/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj b/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj
index cac6456d..4300c9ec 100644
--- a/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj
+++ b/src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj
@@ -32,6 +32,7 @@
+
diff --git a/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs b/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs
index bca119a1..2ec56f1e 100644
--- a/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs
+++ b/src/GraphBLAS-sharp.Backend/Operations/SpMSpV.fs
@@ -65,7 +65,7 @@ module SpMSpV =
inputArray.[i] <- 0 @>
let sum =
- PrefixSum.standardExcludeInPlace clContext workGroupSize
+ ScanInternal.standardExcludeInPlace clContext workGroupSize
let prepareOffsets = clContext.Compile prepareOffsets
diff --git a/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs b/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs
index f94b0564..3773eefd 100644
--- a/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs
+++ b/tests/GraphBLAS-sharp.Tests/Backend/Common/Scan/PrefixSum.fs
@@ -13,33 +13,51 @@ let logger = Log.create "ClArray.PrefixSum.Tests"
let context = defaultContext.ClContext
-let config = Tests.Utils.defaultConfig
+let config =
+ { Tests.Utils.defaultConfig with
+ maxTest = 20
+ startSize = 1
+ endSize = 1000000 }
let wgSize = 128
let q = defaultContext.Queue
-let makeTest plus zero isEqual scan (array: 'a []) =
+let makeTest plus zero isEqual scanInclude scanExclude (array: 'a []) =
if array.Length > 0 then
+ // Exclude
+ let actual, actualSum =
+ let clArray = context.CreateClArray array
+ let (total: ClCell<_>) = scanExclude q clArray
+
+ let actual = clArray.ToHostAndFree q
+ let actualSum = total.ToHostAndFree q
- logger.debug (
- eventX $"Array is %A{array}\n"
- >> setField "array" (sprintf "%A" array)
- )
+ actual, actualSum
+ let expected, expectedSum =
+ array
+ |> Array.mapFold
+ (fun s t ->
+ let a = plus s t
+ s, a)
+ zero
+
+ "Arrays for exclude should be the same"
+ |> Tests.Utils.compareArrays isEqual actual expected
+
+ "Total sums for exclude should be equal"
+ |> Expect.equal actualSum expectedSum
+
+ // Include
let actual, actualSum =
let clArray = context.CreateClArray array
- let (total: ClCell<_>) = scan q clArray zero
+ let (total: ClCell<_>) = scanInclude q clArray zero
let actual = clArray.ToHostAndFree q
let actualSum = total.ToHostAndFree q
actual, actualSum
- logger.debug (
- eventX "Actual is {actual}\n"
- >> setField "actual" (sprintf "%A" actual)
- )
-
let expected, expectedSum =
array
|> Array.mapFold
@@ -48,20 +66,15 @@ let makeTest plus zero isEqual scan (array: 'a []) =
a, a)
zero
- logger.debug (
- eventX "Expected is {expected}\n"
- >> setField "expected" (sprintf "%A" expected)
- )
-
- "Total sums should be equal"
+ "Total sums for include should be equal"
|> Expect.equal actualSum expectedSum
- "Arrays should be the same"
+ "Arrays for include should be the same"
|> Tests.Utils.compareArrays isEqual actual expected
let testFixtures plus plusQ zero isEqual name =
- Common.PrefixSum.runIncludeInPlace plusQ context wgSize
- |> makeTest plus zero isEqual
+ (PrefixSum.runIncludeInPlace plusQ context wgSize, PrefixSum.runExcludeInPlace plusQ zero context wgSize)
+ ||> makeTest plus zero isEqual
|> testPropertyWithConfig config $"Correctness on %s{name}"
let tests =