Skip to content

Commit

Permalink
Merge pull request #36 from GunnarFarneback/cuda_provider_options
Browse files Browse the repository at this point in the history
Support CUDA provider options.
  • Loading branch information
jw3126 authored Apr 13, 2024
2 parents c931b8a + 2a3b74d commit 11e6c3b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ONNXRunTime"
uuid = "e034b28e-924e-41b2-b98f-d2bbeb830c6a"
authors = ["Jan Weidner <[email protected]> and contributors"]
version = "1.0.0"
version = "1.1.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ julia> import CUDA, cuDNN
julia> ORT.load_inference(path, execution_provider=:cuda)
```

CUDA provider options can be specified
```
julia> ORT.load_inference(path, execution_provider=:cuda,
provider_options=(;cudnn_conv_algo_search=:HEURISTIC))
```

Memory allocated by a model is eventually automatically released after
it goes out of scope, when the model object is deleted by the garbage
collector. It can also be immediately released with `release(model)`.
Expand Down
20 changes: 19 additions & 1 deletion src/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ArgCheck
using LazyArtifacts
using DataStructures: OrderedDict
using DocStringExtensions
import CEnum
################################################################################
##### testdatapath
################################################################################
Expand Down Expand Up @@ -59,10 +60,16 @@ end
"""
function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
envname::AbstractString="defaultenv",
provider_options::NamedTuple=(;)
)::InferenceSession
api = GetApi(;execution_provider)
env = CreateEnv(api, name=envname)
if execution_provider === :cpu
if !isempty(provider_options)
error("""
No provider options are supported for the CPU execution provider.
""")
end
session_options = CreateSessionOptions(api)
elseif execution_provider === :cuda
CUDAExt = Base.get_extension(@__MODULE__, :CUDAExt)
Expand All @@ -83,7 +90,18 @@ function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
end
end
session_options = CreateSessionOptions(api)
cuda_options = OrtCUDAProviderOptions()
cuda_options_dict = Dict{Symbol, Any}(pairs(provider_options))
if haskey(cuda_options_dict, :cudnn_conv_algo_search)
# Look up enum values.
value = cuda_options_dict[:cudnn_conv_algo_search]
if value first.(CEnum.name_value_pairs(CAPI.OrtCudnnConvAlgoSearch))
error("""
$(value) is not a valid value for :cudnn_conv_algo_search.
""")
end
cuda_options_dict[:cudnn_conv_algo_search] = getfield(CAPI, value)
end
cuda_options = OrtCUDAProviderOptions(; cuda_options_dict...)
SessionOptionsAppendExecutionProvider_CUDA(api, session_options, cuda_options)
else
error("Unsupported execution_provider $execution_provider")
Expand Down
27 changes: 27 additions & 0 deletions test/test_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@ using ONNXRunTime: SessionOptionsAppendExecutionProvider_CUDA
y = model((;input=input,), ["output"])
@test y == (output=input .+ 1f0,)
end

@testset "provider options" begin
path = ORT.testdatapath("Conv1d2.onnx")
input = Array{Float32,3}(undef, (1,2,3))
input[1,1,1] = 1
input[1,1,2] = 2
input[1,1,3] = 3
input[1,2,1] = 4
input[1,2,2] = 5
input[1,2,3] = 6
inputs=(;input)
for conv_search in (:DEFAULT, :HEURISTIC, :EXHAUSTIVE)
cuda_options = (;cudnn_conv_algo_search=conv_search)
model = ORT.load_inference(path, execution_provider=:cuda,
provider_options=cuda_options)
out = model(inputs).output
@test out[1,1,1] == 1
@test out[1,1,2] == 3
@test out[1,1,3] == 5
@test out[1,2,1] == 0
@test out[1,2,2] == 0
@test out[1,2,3] == 0
end
cuda_options = (;cudnn_conv_algo_search=:NEVER_HEARD_OF_THIS)
@test_throws ErrorException ORT.load_inference(path, execution_provider=:cuda,
provider_options=cuda_options)
end
end

using ONNXRunTime.CAPI
Expand Down

2 comments on commit 11e6c3b

@jw3126
Copy link
Owner Author

@jw3126 jw3126 commented on 11e6c3b Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/104844

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" 11e6c3bdedbf93af2fdebed97f6bcfb8e2e22681
git push origin v1.1.0

Please sign in to comment.