diff --git a/Project.toml b/Project.toml index dd449e8..cac3e9e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] CEnum = "0.4, 0.5" -CUDA = "5.3.0 - 5.3.5" +CUDA = "5.4.0" CUDSS_jll = "0.2.1" julia = "1.6" LinearAlgebra = "1.6" diff --git a/src/CUDSS.jl b/src/CUDSS.jl index a6a6a68..05f0e43 100644 --- a/src/CUDSS.jl +++ b/src/CUDSS.jl @@ -1,6 +1,6 @@ module CUDSS -using CUDA, CUDA.CUSPARSE +using CUDA, CUDA.APIUtils, CUDA.CUSPARSE using CUDSS_jll using LinearAlgebra using SparseArrays diff --git a/src/management.jl b/src/management.jl index 375a5ac..bcf492b 100644 --- a/src/management.jl +++ b/src/management.jl @@ -16,8 +16,20 @@ version() = VersionNumber(cudssGetProperty(CUDA.MAJOR_VERSION), cudssGetProperty(CUDA.MINOR_VERSION), cudssGetProperty(CUDA.PATCH_LEVEL)) -# cache for created, but unused handles -const idle_handles = CUDA.HandleCache{CuContext,cudssHandle_t}() +## handles + +function handle_ctor(ctx) + context!(ctx) do + cudssCreate() + end +end +function handle_dtor(ctx, handle) + context!(ctx; skip_destroyed=true) do + cudssDestroy(handle) + end +end + +const idle_handles = HandleCache{CuContext,cudssHandle_t}(handle_ctor, handle_dtor) function handle() cuda = CUDA.active_state() @@ -30,16 +42,10 @@ function handle() # get library state @noinline function new_state(cuda) - new_handle = pop!(idle_handles, cuda.context) do - cudssCreate() - end + new_handle = pop!(idle_handles, cuda.context) finalizer(current_task()) do task - push!(idle_handles, cuda.context, new_handle) do - context!(cuda.context; skip_destroyed=true) do - cudssDestroy(new_handle) - end - end + push!(idle_handles, cuda.context, new_handle) end cudssSetStream(new_handle, cuda.stream)