diff --git a/src/GraphBLAS-sharp.Backend/Common/Common.fs b/src/GraphBLAS-sharp.Backend/Common/Common.fs index ae9839f2..afa9f871 100644 --- a/src/GraphBLAS-sharp.Backend/Common/Common.fs +++ b/src/GraphBLAS-sharp.Backend/Common/Common.fs @@ -24,8 +24,7 @@ module Common = /// > val values = [| 1.9; 2.8; 6.4; 5.5; 4.6; 3.7; 7.3 |] /// /// - let sortKeyValuesInplace<'n, 'a when 'n: comparison> = - Sort.Bitonic.sortKeyValuesInplace<'n, 'a> + let sortKeyValuesInplace<'a> = Sort.Bitonic.sortKeyValuesInplace<'a> module Radix = /// diff --git a/src/GraphBLAS-sharp.Backend/Common/Sort/Bitonic.fs b/src/GraphBLAS-sharp.Backend/Common/Sort/Bitonic.fs index f51e94ce..5aba906b 100644 --- a/src/GraphBLAS-sharp.Backend/Common/Sort/Bitonic.fs +++ b/src/GraphBLAS-sharp.Backend/Common/Sort/Bitonic.fs @@ -1,317 +1,255 @@ namespace GraphBLAS.FSharp.Backend.Common.Sort open Brahma.FSharp -open GraphBLAS.FSharp.Backend.Common +open GraphBLAS.FSharp.Backend module Bitonic = - let private localBegin (clContext: ClContext) workGroupSize = - let processedSize = workGroupSize * 2 + let sortKeyValuesInplace<'a> (clContext: ClContext) (workGroupSize: int) = - let localBegin = - <@ fun (range: Range1D) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) (length: int) -> + let localSize = + Common.Utils.floorToPower2 ( + int (clContext.ClDevice.LocalMemSize) + / (sizeof + sizeof<'a>) + ) + + let maxThreadsPerBlock = + min (clContext.ClDevice.MaxWorkGroupSize) (localSize / 2) - let lid = range.LocalID0 - let gid = range.GlobalID0 + let waveSize = 32 + let maxWorkGroupSize = clContext.ClDevice.MaxWorkGroupSize + + let localStep = + <@ fun (ndRange: Range1D) (rows: ClArray) (cols: ClArray) (vals: ClArray<'a>) (length: int) -> + let gid = ndRange.GlobalID0 + let lid = ndRange.LocalID0 + let workGroupSize = ndRange.LocalWorkSize let groupId = gid / workGroupSize - // 1 рабочая группа обрабатывает 2 * workGroupSize элементов - let localRows = localArray<'n> processedSize - let localCols = localArray<'n> processedSize - let localValues = localArray<'a> processedSize + let offset = groupId * localSize + let border = min (offset + localSize) length + let n = border - offset - let mutable readIdx = processedSize * groupId + lid - let mutable localLength = local () - localLength <- processedSize + let nAligned = + (%Quotes.ArithmeticOperations.ceilToPowerOfTwo) n - // копируем элементы из глобальной памяти в локальную - if readIdx < length then - localRows.[lid] <- rows.[readIdx] - localCols.[lid] <- cols.[readIdx] - localValues.[lid] <- values.[readIdx] + let numberOfThreads = nAligned / 2 - if readIdx = length then - localLength <- lid + let sortedKeys = localArray localSize + let sortedVals = localArray<'a> localSize - readIdx <- readIdx + workGroupSize + let mutable i = lid - if readIdx < length then - localRows.[lid + workGroupSize] <- rows.[readIdx] - localCols.[lid + workGroupSize] <- cols.[readIdx] - localValues.[lid + workGroupSize] <- values.[readIdx] + while i + offset < border do + let key: uint64 = + ((uint64 rows.[i + offset]) <<< 32) + ||| (uint64 cols.[i + offset]) - if readIdx = length then - localLength <- lid + workGroupSize + sortedKeys.[i] <- key + sortedVals.[i] <- vals.[i + offset] + i <- i + workGroupSize barrierLocal () - let mutable segmentLength = 1 - - while segmentLength < processedSize do - segmentLength <- segmentLength <<< 1 - let localLineId = lid % (segmentLength >>> 1) - let localTwinId = segmentLength - localLineId - 1 - let groupLineId = lid / (segmentLength >>> 1) + let mutable segmentSize = 2 - let lineId = - segmentLength * groupLineId + localLineId + while segmentSize <= nAligned do + let segmentSizeHalf = segmentSize / 2 - let twinId = - segmentLength * groupLineId + localTwinId + let mutable tid = lid - if twinId < localLength - && (localRows.[lineId] > localRows.[twinId] - || localRows.[lineId] = localRows.[twinId] - && localCols.[lineId] > localCols.[twinId]) then - let tmpRow = localRows.[lineId] - localRows.[lineId] <- localRows.[twinId] - localRows.[twinId] <- tmpRow + while tid < numberOfThreads do + let segmentId = tid / segmentSizeHalf + let innerId = tid % segmentSizeHalf + let innerIdSibling = segmentSize - innerId - 1 + let i = segmentId * segmentSize + innerId + let j = segmentId * segmentSize + innerIdSibling - let tmpCol = localCols.[lineId] - localCols.[lineId] <- localCols.[twinId] - localCols.[twinId] <- tmpCol + if (i < n && j < n && sortedKeys.[i] > sortedKeys.[j]) then + let tempK = sortedKeys.[i] + sortedKeys.[i] <- sortedKeys.[j] + sortedKeys.[j] <- tempK + let tempV = sortedVals.[i] + sortedVals.[i] <- sortedVals.[j] + sortedVals.[j] <- tempV - let tmpValue = localValues.[lineId] - localValues.[lineId] <- localValues.[twinId] - localValues.[twinId] <- tmpValue + tid <- tid + workGroupSize barrierLocal () - let mutable j = segmentLength >>> 1 - - while j > 1 do - let localLineId = lid % (j >>> 1) - let localTwinId = localLineId + (j >>> 1) - let groupLineId = lid / (j >>> 1) - let lineId = j * groupLineId + localLineId - let twinId = j * groupLineId + localTwinId - - if twinId < localLength - && (localRows.[lineId] > localRows.[twinId] - || localRows.[lineId] = localRows.[twinId] - && localCols.[lineId] > localCols.[twinId]) then - let tmpRow = localRows.[lineId] - localRows.[lineId] <- localRows.[twinId] - localRows.[twinId] <- tmpRow - - let tmpCol = localCols.[lineId] - localCols.[lineId] <- localCols.[twinId] - localCols.[twinId] <- tmpCol - - let tmpValue = localValues.[lineId] - localValues.[lineId] <- localValues.[twinId] - localValues.[twinId] <- tmpValue + let mutable k = segmentSizeHalf / 2 - barrierLocal () - - j <- j >>> 1 - - let mutable writeIdx = processedSize * groupId + lid - - if writeIdx < length then - rows.[writeIdx] <- localRows.[lid] - cols.[writeIdx] <- localCols.[lid] - values.[writeIdx] <- localValues.[lid] + while k > 0 do - writeIdx <- writeIdx + workGroupSize + let mutable tid = lid - if writeIdx < length then - rows.[writeIdx] <- localRows.[lid + workGroupSize] - cols.[writeIdx] <- localCols.[lid + workGroupSize] - values.[writeIdx] <- localValues.[lid + workGroupSize] @> + while tid < numberOfThreads do + let segmentSizeInner = k * 2 + let segmentId = tid / k + let innerId = tid % k + let innerIdSibling = innerId + k + let i = segmentId * segmentSizeInner + innerId - let program = clContext.Compile(localBegin) + let j = + segmentId * segmentSizeInner + innerIdSibling - fun (queue: MailboxProcessor<_>) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) -> + if (i < n && j < n && sortedKeys.[i] > sortedKeys.[j]) then + let tempK = sortedKeys.[i] + sortedKeys.[i] <- sortedKeys.[j] + sortedKeys.[j] <- tempK + let tempV = sortedVals.[i] + sortedVals.[i] <- sortedVals.[j] + sortedVals.[j] <- tempV - let ndRange = - Range1D.CreateValid(Utils.floorToPower2 values.Length, workGroupSize) + tid <- tid + workGroupSize - let kernel = program.GetKernel() + k <- k / 2 + barrierLocal () - queue.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange rows cols values values.Length)) - queue.Post(Msg.CreateRunMsg<_, _>(kernel)) + segmentSize <- segmentSize * 2 + let mutable i = lid - let private globalStep (clContext: ClContext) workGroupSize = + while i + offset < border do + let key = sortedKeys.[i] + rows.[i + offset] <- int (key >>> 32) + cols.[i + offset] <- int key + vals.[i + offset] <- sortedVals.[i] + i <- i + workGroupSize @> let globalStep = - <@ fun (range: Range1D) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) (length: int) (segmentLength: int) (mirror: ClCell) -> - - let mirror = mirror.Value - - let gid = range.GlobalID0 - - let localLineId = gid % (segmentLength >>> 1) - let mutable localTwinId = 0 + <@ fun (ndRange: Range1D) (rows: ClArray) (cols: ClArray) (vals: ClArray<'a>) (length: int) (segmentStart: int) -> + let lid = ndRange.LocalID0 + let workGroupSize = ndRange.LocalWorkSize - if mirror then - localTwinId <- segmentLength - localLineId - 1 - else - localTwinId <- localLineId + (segmentLength >>> 1) + let n = length - let groupLineId = gid / (segmentLength >>> 1) + let nAligned = + (%Quotes.ArithmeticOperations.ceilToPowerOfTwo) n - let lineId = - segmentLength * groupLineId + localLineId + let numberOfThreads = nAligned / 2 - let twinId = - segmentLength * groupLineId + localTwinId + let mutable segmentSize = segmentStart - if twinId < length - && (rows.[lineId] > rows.[twinId] - || rows.[lineId] = rows.[twinId] - && cols.[lineId] > cols.[twinId]) then - let tmpRow = rows.[lineId] - rows.[lineId] <- rows.[twinId] - rows.[twinId] <- tmpRow + while segmentSize <= nAligned do + let segmentSizeHalf = segmentSize / 2 - let tmpCol = cols.[lineId] - cols.[lineId] <- cols.[twinId] - cols.[twinId] <- tmpCol + let mutable tid = lid - let tmpV = values.[lineId] - values.[lineId] <- values.[twinId] - values.[twinId] <- tmpV @> + while tid < numberOfThreads do + let segmentId = tid / segmentSizeHalf + let innerId = tid % segmentSizeHalf + let innerIdSibling = segmentSize - innerId - 1 + let i = segmentId * segmentSize + innerId + let j = segmentId * segmentSize + innerIdSibling - let program = clContext.Compile(globalStep) + if (i < n && j < n) then + let keyI = + ((uint64 rows.[i]) <<< 32) ||| (uint64 cols.[i]) - fun (queue: MailboxProcessor<_>) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) (segmentLength: int) (mirror: bool) -> + let keyJ = + ((uint64 rows.[j]) <<< 32) ||| (uint64 cols.[j]) - let ndRange = - Range1D.CreateValid(Utils.floorToPower2 values.Length, workGroupSize) + if (keyI > keyJ) then + let tempR = rows.[i] + rows.[i] <- rows.[j] + rows.[j] <- tempR + let tempC = cols.[i] + cols.[i] <- cols.[j] + cols.[j] <- tempC + let tempV = vals.[i] + vals.[i] <- vals.[j] + vals.[j] <- tempV - let mirror = clContext.CreateClCell mirror + tid <- tid + workGroupSize - let kernel = program.GetKernel() + barrierGlobal () - queue.Post( - Msg.MsgSetArguments - (fun () -> kernel.KernelFunc ndRange rows cols values values.Length segmentLength mirror) - ) - - queue.Post(Msg.CreateRunMsg<_, _>(kernel)) - queue.Post(Msg.CreateFreeMsg(mirror)) - - - let private localEnd (clContext: ClContext) workGroupSize = - - let processedSize = workGroupSize * 2 - - let localEnd = - <@ fun (range: Range1D) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) (length: int) -> - - let lid = range.LocalID0 - let gid = range.GlobalID0 - let groupId = gid / workGroupSize + let mutable k = segmentSizeHalf / 2 - // 1 рабочая группа обрабатывает 2 * wgSize элементов - let localRows = localArray<'n> processedSize - let localCols = localArray<'n> processedSize - let localValues = localArray<'a> processedSize + while k > 0 do - let mutable readIdx = processedSize * groupId + lid - let mutable localLength = local () - localLength <- processedSize + let mutable tid = lid - // копируем элементы из глобальной памяти в локальную - if readIdx < length then - localRows.[lid] <- rows.[readIdx] - localCols.[lid] <- cols.[readIdx] - localValues.[lid] <- values.[readIdx] + while tid < numberOfThreads do + let segmentSizeInner = k * 2 + let segmentId = tid / k + let innerId = tid % k + let innerIdSibling = innerId + k + let i = segmentId * segmentSizeInner + innerId - if readIdx = length then - localLength <- lid + let j = + segmentId * segmentSizeInner + innerIdSibling - readIdx <- readIdx + workGroupSize + if (i < n && j < n) then + let keyI = + ((uint64 rows.[i]) <<< 32) ||| (uint64 cols.[i]) - if readIdx < length then - localRows.[lid + workGroupSize] <- rows.[readIdx] - localCols.[lid + workGroupSize] <- cols.[readIdx] - localValues.[lid + workGroupSize] <- values.[readIdx] + let keyJ = + ((uint64 rows.[j]) <<< 32) ||| (uint64 cols.[j]) - if readIdx = length then - localLength <- lid + workGroupSize + if (keyI > keyJ) then + let tempR = rows.[i] + rows.[i] <- rows.[j] + rows.[j] <- tempR + let tempC = cols.[i] + cols.[i] <- cols.[j] + cols.[j] <- tempC + let tempV = vals.[i] + vals.[i] <- vals.[j] + vals.[j] <- tempV - barrierLocal () - - let mutable segmentLength = processedSize - let mutable j = segmentLength - - while j > 1 do - let localLineId = lid % (j / 2) - let localTwinId = localLineId + (j / 2) - let groupLineId = lid / (j / 2) - let lineId = j * groupLineId + localLineId - let twinId = j * groupLineId + localTwinId - - if twinId < localLength - && (localRows.[lineId] > localRows.[twinId] - || localRows.[lineId] = localRows.[twinId] - && localCols.[lineId] > localCols.[twinId]) then - let tmpRow = localRows.[lineId] - localRows.[lineId] <- localRows.[twinId] - localRows.[twinId] <- tmpRow - - let tmpCol = localCols.[lineId] - localCols.[lineId] <- localCols.[twinId] - localCols.[twinId] <- tmpCol - - let tmpValue = localValues.[lineId] - localValues.[lineId] <- localValues.[twinId] - localValues.[twinId] <- tmpValue - - barrierLocal () + tid <- tid + workGroupSize - j <- j >>> 1 + k <- k / 2 + barrierGlobal () - let mutable writeIdx = processedSize * groupId + lid + segmentSize <- segmentSize * 2 @> - if writeIdx < length then - rows.[writeIdx] <- localRows.[lid] - cols.[writeIdx] <- localCols.[lid] - values.[writeIdx] <- localValues.[lid] + let localStep = clContext.Compile(localStep) + let globalStep = clContext.Compile(globalStep) - writeIdx <- writeIdx + workGroupSize + fun (queue: MailboxProcessor<_>) (rows: ClArray) (cols: ClArray) (values: ClArray<'a>) -> - if writeIdx < length then - rows.[writeIdx] <- localRows.[lid + workGroupSize] - cols.[writeIdx] <- localCols.[lid + workGroupSize] - values.[writeIdx] <- localValues.[lid + workGroupSize] @> + let size = values.Length - let program = clContext.Compile(localEnd) + if (size = 1) then + () + else if (size <= localSize) then + let numberOfThreads = + Common.Utils.ceilToMultiple waveSize (min size maxThreadsPerBlock) - fun (queue: MailboxProcessor<_>) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) -> + let ndRangeLocal = + Range1D.CreateValid(numberOfThreads, numberOfThreads) - let ndRange = - Range1D.CreateValid(Utils.floorToPower2 values.Length, workGroupSize) + let kernel = localStep.GetKernel() - let kernel = program.GetKernel() + queue.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRangeLocal rows cols values values.Length)) + queue.Post(Msg.CreateRunMsg<_, _>(kernel)) + else + let numberOfGroups = + size / localSize + + (if size % localSize = 0 then 0 else 1) - queue.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange rows cols values values.Length)) - queue.Post(Msg.CreateRunMsg<_, _>(kernel)) + let ndRangeLocal = + Range1D.CreateValid(maxThreadsPerBlock * numberOfGroups, maxThreadsPerBlock) - let sortKeyValuesInplace<'n, 'a when 'n: comparison> (clContext: ClContext) workGroupSize = + let kernelLocal = localStep.GetKernel() - let localBegin = localBegin clContext workGroupSize - let globalStep = globalStep clContext workGroupSize - let localEnd = localEnd clContext workGroupSize + queue.Post( + Msg.MsgSetArguments(fun () -> kernelLocal.KernelFunc ndRangeLocal rows cols values values.Length) + ) - fun (queue: MailboxProcessor<_>) (rows: ClArray<'n>) (cols: ClArray<'n>) (values: ClArray<'a>) -> + queue.Post(Msg.CreateRunMsg<_, _>(kernelLocal)) - let lengthCeiled = Utils.ceilToPower2 values.Length + let ndRangeGlobal = + Range1D.CreateValid(maxWorkGroupSize, maxWorkGroupSize) - let rec loopNested i = - if i > workGroupSize * 2 then - globalStep queue rows cols values i false - loopNested (i >>> 1) + let kernelGlobal = globalStep.GetKernel() - let rec mainLoop segmentLength = - if segmentLength <= lengthCeiled then - globalStep queue rows cols values segmentLength true - loopNested (segmentLength >>> 1) - localEnd queue rows cols values - mainLoop (segmentLength <<< 1) + queue.Post( + Msg.MsgSetArguments + (fun () -> kernelGlobal.KernelFunc ndRangeGlobal rows cols values values.Length (localSize * 2)) + ) - localBegin queue rows cols values - mainLoop (workGroupSize <<< 2) + queue.Post(Msg.CreateRunMsg<_, _>(kernelGlobal)) diff --git a/src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs b/src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs index 8dc37dec..ec8ad537 100644 --- a/src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs +++ b/src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs @@ -258,6 +258,15 @@ module ArithmeticOperations = let fst<'a> = <@ fun (x: 'a) (_: 'a) -> Some x @> + let ceilToPowerOfTwo = + <@ fun (x: int) -> + let mutable i = 1 + + while i < x do + i <- i * 2 + + i @> + //PageRank specific let squareOfDifference = <@ fun (x: float32 option) (y: float32 option) -> diff --git a/tests/GraphBLAS-sharp.Tests/Backend/Common/Sort/Bitonic.fs b/tests/GraphBLAS-sharp.Tests/Backend/Common/Sort/Bitonic.fs index d76053a8..c2477bd4 100644 --- a/tests/GraphBLAS-sharp.Tests/Backend/Common/Sort/Bitonic.fs +++ b/tests/GraphBLAS-sharp.Tests/Backend/Common/Sort/Bitonic.fs @@ -22,7 +22,7 @@ module Bitonic = let q = defaultContext.Queue - let makeTest sort (array: ('n * 'n * 'a) []) = + let makeTest sort (array: (int * int * 'a) []) = if array.Length > 0 then let projection (row: 'n) (col: 'n) (_: 'a) = row, col @@ -32,6 +32,8 @@ module Bitonic = ) let rows, cols, vals = Array.unzip3 array + let rows = Array.map abs rows + let cols = Array.map abs cols let clRows = context.CreateClArray rows let clColumns = context.CreateClArray cols @@ -58,11 +60,42 @@ module Bitonic = $"Column arrays should be equal. Actual is \n%A{actualCols}, expected \n%A{expectedCols}, input is \n%A{cols}" |> Utils.compareArrays (=) actualCols expectedCols - $"Value arrays should be equal. Actual is \n%A{actualValues}, expected \n%A{expectedValues}, input is \n%A{vals}" - |> Utils.compareArrays (=) actualValues expectedValues + // Check that for each pair of equal keys values are the same + let mutable i = 1 + + let expected, actual = + new ResizeArray<'a>(), new ResizeArray<'a>() + + expected.Add expectedValues.[0] + actual.Add actualValues.[0] + + while i < expectedValues.Size do + if + not + ( + actualRows.[i - 1] = actualRows.[i] + && actualCols.[i - 1] = actualCols.[i] + ) + then + Expect.sequenceEqual + (actual |> Seq.countBy id) + (actual |> Seq.countBy id) + $"Values for keys %A{actualRows.[i - 1]}, %A{actualCols.[i - 1]} are not the same" + + expected.Clear() + actual.Clear() + + expected.Add expectedValues.[i] + actual.Add actualValues.[i] + i <- i + 1 + + Expect.sequenceEqual + actual + expected + $"Values for keys %A{actualRows.[i - 1]}, %A{actualCols.[i - 1]} are not the same" let testFixtures<'a when 'a: equality> = - Sort.Bitonic.sortKeyValuesInplace context wgSize + Sort.Bitonic.sortKeyValuesInplace<'a> context wgSize |> makeTest |> testPropertyWithConfig config $"Correctness on %A{typeof<'a>}"