-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c05f031
commit 2f2d6c0
Showing
4 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[deps] | ||
CUDA_SDK_jll = "6cbf2f2e-7e60-5632-ac76-dca2274e0be0" | ||
CUDSS_jll = "4889d778-9329-5762-9fec-0578a5d30366" | ||
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31" | ||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" | ||
|
||
[compat] | ||
CUDA_SDK_jll = "12.5.1" | ||
CUDSS_jll = "0.3.0" | ||
julia = "1.6" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Wrapping headers | ||
|
||
This directory contains a script `wrapper.jl` that can be used to automatically | ||
generate wrappers from C headers of NVIDIA cuDSS. This is done using Clang.jl. | ||
|
||
In CUDSS.jl, the wrappers need to know whether pointers passed into the | ||
library point to CPU or GPU memory (i.e. `Ptr` or `CuPtr`). This information is | ||
not available from the headers, and instead should be provided by the developer. | ||
The specific information is embedded in the TOML file `cudss.toml`. | ||
|
||
# Usage | ||
|
||
Either run `julia wrapper.jl` directly, or include it and call the `main()` function. | ||
Be sure to activate the project environment in this folder (`julia --project`), which will install `Clang.jl` and `JuliaFormatter.jl`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
[general] | ||
library_name = "libcudss" | ||
output_file_path = "../src/libcudss2.jl" | ||
|
||
output_ignorelist = [ | ||
# generates bad code | ||
"CUSPARSE_CPP_VERSION", | ||
"CUSPARSE_DEPRECATED_REPLACE_WITH", | ||
"CUSPARSE_DEPRECATED_ENUM_REPLACE_WITH", | ||
# these change often | ||
"CUSPARSE_VERSION", | ||
"CUSPARSE_VER_.*", | ||
] | ||
|
||
[codegen] | ||
use_ccall_macro = true | ||
always_NUL_terminated_string = true | ||
|
||
[api] | ||
checked_rettypes = [ "cudssStatus_t" ] | ||
|
||
[api.cusparseGetVersion] | ||
needs_context = false | ||
|
||
[api.cusparseGetProperty] | ||
needs_context = false | ||
|
||
[api.cusparseGetErrorName] | ||
needs_context = false | ||
|
||
[api.cusparseGetErrorString] | ||
needs_context = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
using Clang | ||
using Clang.Generators | ||
|
||
using JuliaFormatter | ||
|
||
using CUDA_SDK_jll, CUDSS_jll | ||
|
||
# a pass that removes macro definitions that are also function definitions. | ||
# | ||
# this sometimes happens with NVIDIA's headers, either because of typos, or because they are | ||
# reserving identifiers for future use: | ||
# #define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) | ||
mutable struct AvoidDuplicates <: Clang.AbstractPass end | ||
function (x::AvoidDuplicates)(dag::ExprDAG, options::Dict) | ||
# collect macro definitions | ||
macro_definitions = Dict() | ||
for (i, node) in enumerate(dag.nodes) | ||
if node isa ExprNode{<:AbstractMacroNodeType} | ||
macro_definitions[node.id] = (i, node) | ||
end | ||
end | ||
|
||
# scan function definitions | ||
for (i, node) in enumerate(dag.nodes) | ||
if Generators.is_function(node) && !Generators.is_variadic_function(node) | ||
if haskey(macro_definitions, node.id) | ||
@info "Removing macro definition for $(node.id)" | ||
j, duplicate_node = macro_definitions[node.id] | ||
dag.nodes[j] = ExprNode(node.id, Clang.Generators.Skip(), duplicate_node.cursor, duplicate_node.exprs, duplicate_node.adj) | ||
end | ||
end | ||
end | ||
|
||
return dag | ||
end | ||
|
||
function rewriter!(ctx, options) | ||
for node in get_nodes(ctx.dag) | ||
# remove aliases for function names | ||
# | ||
# when NVIDIA changes the behavior of an API, they version the function | ||
# (`cuFunction_v2`), and sometimes even change function names. To maintain backwards | ||
# compatibility, they ship aliases with their headers such that compiled binaries | ||
# will keep using the old version, and newly-compiled ones will use the developer's | ||
# CUDA version. remove those, since we target multiple CUDA versions. | ||
# | ||
# remove this if we ever decide to support a single supported version of CUDA. | ||
if node isa ExprNode{<:AbstractMacroNodeType} | ||
isempty(node.exprs) && continue | ||
expr = node.exprs[1] | ||
if Meta.isexpr(expr, :const) | ||
expr = expr.args[1] | ||
end | ||
if Meta.isexpr(expr, :(=)) | ||
lhs, rhs = expr.args | ||
if rhs isa Expr && rhs.head == :call | ||
name = string(rhs.args[1]) | ||
if endswith(name, "STRUCT_SIZE") | ||
rhs.head = :macrocall | ||
rhs.args[1] = Symbol("@", name) | ||
insert!(rhs.args, 2, nothing) | ||
end | ||
end | ||
isa(lhs, Symbol) || continue | ||
if Meta.isexpr(rhs, :call) && rhs.args[1] in (:__CUDA_API_PTDS, :__CUDA_API_PTSZ) | ||
rhs = rhs.args[2] | ||
end | ||
isa(rhs, Symbol) || continue | ||
lhs, rhs = String(lhs), String(rhs) | ||
function get_prefix(str) | ||
# cuFooBar -> cu | ||
isempty(str) && return nothing | ||
islowercase(str[1]) || return nothing | ||
for i in 2:length(str) | ||
if isuppercase(str[i]) | ||
return str[1:i-1] | ||
end | ||
end | ||
return nothing | ||
end | ||
lhs_prefix = get_prefix(lhs) | ||
lhs_prefix === nothing && continue | ||
rhs_prefix = get_prefix(rhs) | ||
if lhs_prefix == rhs_prefix | ||
@debug "Removing function alias: `$expr`" | ||
empty!(node.exprs) | ||
end | ||
end | ||
end | ||
|
||
if Generators.is_function(node) && !Generators.is_variadic_function(node) | ||
expr = node.exprs[1] | ||
call_expr = expr.args[2].args[1].args[3] # assumes `use_ccall_macro` is true | ||
|
||
# replace `@ccall` with `@gcsafe_ccall` | ||
expr.args[2].args[1].args[1] = Symbol("@gcsafe_ccall") | ||
|
||
target_expr = call_expr.args[1].args[1] | ||
fn = String(target_expr.args[2].value) | ||
|
||
# look up API options for this function | ||
fn_options = Dict{String,Any}() | ||
templates = Dict{String,Any}() | ||
template_types = nothing | ||
if haskey(options, "api") | ||
names = [fn] | ||
|
||
# _64 aliases are used by CUBLAS with Int64 arguments. they otherwise have | ||
# an idential signature, so we can reuse the same type rewrites. | ||
if endswith(fn, "_64") | ||
push!(names, fn[1:end-3]) | ||
end | ||
|
||
# look for a template rewrite: many libraries have very similar functions, | ||
# e.g., `cublas[SDHCZ]gemm`, for which we can use the same type rewrites | ||
# registered as `cublas𝕏gemm` template with `T` and `S` placeholders. | ||
for name in copy(names), (typcode,(T,S)) in ["S"=>("Cfloat","Cfloat"), | ||
"D"=>("Cdouble","Cdouble"), | ||
"H"=>("Float16","Float16"), | ||
"C"=>("cuComplex","Cfloat"), | ||
"Z"=>("cuDoubleComplex","Cdouble")] | ||
idx = findfirst(typcode, name) | ||
while idx !== nothing | ||
template_name = name[1:idx.start-1] * "𝕏" * name[idx.stop+1:end] | ||
if haskey(options["api"], template_name) | ||
templates[template_name] = ["T" => T, "S" => S] | ||
push!(names, template_name) | ||
end | ||
idx = findnext(typcode, name, idx.stop+1) | ||
end | ||
end | ||
|
||
# the exact name is always checked first, so it's always possible to | ||
# override the type rewrites for a specific function | ||
# (e.g. if a _64 function ever passes a `Ptr{Cint}` index). | ||
for name in names | ||
template_types = get(templates, name, nothing) | ||
if haskey(options["api"], name) | ||
fn_options = options["api"][name] | ||
break | ||
end | ||
end | ||
end | ||
|
||
# rewrite pointer argument types | ||
arg_exprs = call_expr.args[1].args[2:end] | ||
argtypes = get(fn_options, "argtypes", Dict()) | ||
for (arg, typ) in argtypes | ||
i = parse(Int, arg) | ||
i in 1:length(arg_exprs) || error("invalid argtypes for $fn: index $arg is out of bounds") | ||
|
||
# _64 aliases should use Int64 instead of Int32/Cint | ||
if endswith(fn, "_64") | ||
typ = replace(typ, "Cint" => "Int64", "Int32" => "Int64") | ||
end | ||
|
||
# expand type templates | ||
if template_types !== nothing | ||
typ = replace(typ, template_types...) | ||
end | ||
|
||
arg_exprs[i].args[2] = Meta.parse(typ) | ||
end | ||
|
||
# insert `initialize_context()` before each function with a `ccall` | ||
if get(fn_options, "needs_context", true) | ||
pushfirst!(expr.args[2].args, :(initialize_context())) | ||
end | ||
|
||
# insert `@checked` before each function with a `ccall` returning a checked type` | ||
rettyp = call_expr.args[2] | ||
checked_types = if haskey(options, "api") | ||
get(options["api"], "checked_rettypes", Dict()) | ||
else | ||
String[] | ||
end | ||
if rettyp isa Symbol && String(rettyp) in checked_types | ||
node.exprs[1] = Expr(:macrocall, Symbol("@checked"), nothing, expr) | ||
end | ||
end | ||
end | ||
end | ||
|
||
function main() | ||
cuda = joinpath(CUDA_SDK_jll.artifact_dir, "cuda", "include") | ||
@assert CUDA_SDK_jll.is_available() | ||
|
||
cudss = joinpath(CUDSS_jll.artifact_dir, "include") | ||
@assert CUDSS_jll.is_available() | ||
|
||
# function wrap(name, headers; targets=headers, defines=[], include_dirs=[]) | ||
# wrap("cudss", ["$cudss/cudss.h"]; include_dirs=[cuda, cudss]) | ||
args = get_default_args() | ||
append!(args, "-I$cuda", "-I$cudss") | ||
# for define in defines | ||
# if isa(define, Pair) | ||
# append!(args, ["-D", "$(first(define))=$(last(define))"]) | ||
# else | ||
# append!(args, ["-D", "$define"]) | ||
# end | ||
# end | ||
|
||
options = load_options(joinpath(@__DIR__, "$(name).toml")) | ||
|
||
# create context | ||
headers = ["$cudss/cudss.h"] | ||
targets = headers | ||
ctx = create_context(headers, args, options) | ||
|
||
insert!(ctx.passes, 2, AvoidDuplicates()) | ||
|
||
# run generator | ||
build!(ctx, BUILDSTAGE_NO_PRINTING) | ||
|
||
# Only keep the wrapped headers | ||
replace!(get_nodes(ctx.dag)) do node | ||
path = normpath(Clang.get_filename(node.cursor)) | ||
should_wrap = any(targets) do target | ||
occursin(target, path) | ||
end | ||
if !should_wrap | ||
return ExprNode(node.id, Generators.Skip(), node.cursor, Expr[], node.adj) | ||
end | ||
return node | ||
end | ||
|
||
rewriter!(ctx, options) | ||
build!(ctx, BUILDSTAGE_PRINTING_ONLY) | ||
output_file = options["general"]["output_file_path"] | ||
format_file(output_file, YASStyle()) | ||
|
||
return | ||
end | ||
|
||
if abspath(PROGRAM_FILE) == @__FILE__ | ||
main() | ||
end |