From 686ef1f3ec0a6074c9e202e53c87ef0641670ad1 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 29 May 2024 11:15:08 -0400 Subject: [PATCH 1/2] Upgrade handle support --- Project.toml | 2 +- src/management.jl | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index dd449e8..cac3e9e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] CEnum = "0.4, 0.5" -CUDA = "5.3.0 - 5.3.5" +CUDA = "5.4.0" CUDSS_jll = "0.2.1" julia = "1.6" LinearAlgebra = "1.6" diff --git a/src/management.jl b/src/management.jl index 375a5ac..bcf492b 100644 --- a/src/management.jl +++ b/src/management.jl @@ -16,8 +16,20 @@ version() = VersionNumber(cudssGetProperty(CUDA.MAJOR_VERSION), cudssGetProperty(CUDA.MINOR_VERSION), cudssGetProperty(CUDA.PATCH_LEVEL)) -# cache for created, but unused handles -const idle_handles = CUDA.HandleCache{CuContext,cudssHandle_t}() +## handles + +function handle_ctor(ctx) + context!(ctx) do + cudssCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cudssDestroy(handle) + end +end + +const idle_handles = HandleCache{CuContext,cudssHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -30,16 +42,10 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cudssCreate() - end + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cudssDestroy(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end cudssSetStream(new_handle, cuda.stream) From 6b21cee14236e6bcf8fa7043270f1ac58a2b99f9 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 29 May 2024 11:27:32 -0400 Subject: [PATCH 2/2] Import CUDA.APIUtils --- src/CUDSS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CUDSS.jl b/src/CUDSS.jl index a6a6a68..05f0e43 100644 --- a/src/CUDSS.jl +++ b/src/CUDSS.jl @@ -1,6 +1,6 @@ module CUDSS -using CUDA, CUDA.CUSPARSE +using CUDA, CUDA.APIUtils, CUDA.CUSPARSE using CUDSS_jll using LinearAlgebra using SparseArrays