Skip to content

Commit

Permalink
Add a gen folder
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jul 4, 2024
1 parent c05f031 commit 2f2d6c0
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
10 changes: 10 additions & 0 deletions gen/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
CUDA_SDK_jll = "6cbf2f2e-7e60-5632-ac76-dca2274e0be0"
CUDSS_jll = "4889d778-9329-5762-9fec-0578a5d30366"
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
CUDA_SDK_jll = "12.5.1"
CUDSS_jll = "0.3.0"
julia = "1.6"
14 changes: 14 additions & 0 deletions gen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Wrapping headers

This directory contains a script `wrapper.jl` that can be used to automatically
generate wrappers from C headers of NVIDIA cuDSS. This is done using Clang.jl.

In CUDSS.jl, the wrappers need to know whether pointers passed into the
library point to CPU or GPU memory (i.e. `Ptr` or `CuPtr`). This information is
not available from the headers, and instead should be provided by the developer.
The specific information is embedded in the TOML file `cudss.toml`.

# Usage

Either run `julia wrapper.jl` directly, or include it and call the `main()` function.
Be sure to activate the project environment in this folder (`julia --project`), which will install `Clang.jl` and `JuliaFormatter.jl`.
32 changes: 32 additions & 0 deletions gen/cudss.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[general]
library_name = "libcudss"
output_file_path = "../src/libcudss2.jl"

output_ignorelist = [
# generates bad code
"CUSPARSE_CPP_VERSION",
"CUSPARSE_DEPRECATED_REPLACE_WITH",
"CUSPARSE_DEPRECATED_ENUM_REPLACE_WITH",
# these change often
"CUSPARSE_VERSION",
"CUSPARSE_VER_.*",
]

[codegen]
use_ccall_macro = true
always_NUL_terminated_string = true

[api]
checked_rettypes = [ "cudssStatus_t" ]

[api.cusparseGetVersion]
needs_context = false

[api.cusparseGetProperty]
needs_context = false

[api.cusparseGetErrorName]
needs_context = false

[api.cusparseGetErrorString]
needs_context = false
237 changes: 237 additions & 0 deletions gen/wrapper.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
using Clang
using Clang.Generators

using JuliaFormatter

using CUDA_SDK_jll, CUDSS_jll

# a pass that removes macro definitions that are also function definitions.
#
# this sometimes happens with NVIDIA's headers, either because of typos, or because they are
# reserving identifiers for future use:
# #define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2)
mutable struct AvoidDuplicates <: Clang.AbstractPass end
function (x::AvoidDuplicates)(dag::ExprDAG, options::Dict)
# collect macro definitions
macro_definitions = Dict()
for (i, node) in enumerate(dag.nodes)
if node isa ExprNode{<:AbstractMacroNodeType}
macro_definitions[node.id] = (i, node)
end
end

# scan function definitions
for (i, node) in enumerate(dag.nodes)
if Generators.is_function(node) && !Generators.is_variadic_function(node)
if haskey(macro_definitions, node.id)
@info "Removing macro definition for $(node.id)"
j, duplicate_node = macro_definitions[node.id]
dag.nodes[j] = ExprNode(node.id, Clang.Generators.Skip(), duplicate_node.cursor, duplicate_node.exprs, duplicate_node.adj)
end
end
end

return dag
end

function rewriter!(ctx, options)
for node in get_nodes(ctx.dag)
# remove aliases for function names
#
# when NVIDIA changes the behavior of an API, they version the function
# (`cuFunction_v2`), and sometimes even change function names. To maintain backwards
# compatibility, they ship aliases with their headers such that compiled binaries
# will keep using the old version, and newly-compiled ones will use the developer's
# CUDA version. remove those, since we target multiple CUDA versions.
#
# remove this if we ever decide to support a single supported version of CUDA.
if node isa ExprNode{<:AbstractMacroNodeType}
isempty(node.exprs) && continue
expr = node.exprs[1]
if Meta.isexpr(expr, :const)
expr = expr.args[1]
end
if Meta.isexpr(expr, :(=))
lhs, rhs = expr.args
if rhs isa Expr && rhs.head == :call
name = string(rhs.args[1])
if endswith(name, "STRUCT_SIZE")
rhs.head = :macrocall
rhs.args[1] = Symbol("@", name)
insert!(rhs.args, 2, nothing)
end
end
isa(lhs, Symbol) || continue
if Meta.isexpr(rhs, :call) && rhs.args[1] in (:__CUDA_API_PTDS, :__CUDA_API_PTSZ)
rhs = rhs.args[2]
end
isa(rhs, Symbol) || continue
lhs, rhs = String(lhs), String(rhs)
function get_prefix(str)
# cuFooBar -> cu
isempty(str) && return nothing
islowercase(str[1]) || return nothing
for i in 2:length(str)
if isuppercase(str[i])
return str[1:i-1]
end
end
return nothing
end
lhs_prefix = get_prefix(lhs)
lhs_prefix === nothing && continue
rhs_prefix = get_prefix(rhs)
if lhs_prefix == rhs_prefix
@debug "Removing function alias: `$expr`"
empty!(node.exprs)
end
end
end

if Generators.is_function(node) && !Generators.is_variadic_function(node)
expr = node.exprs[1]
call_expr = expr.args[2].args[1].args[3] # assumes `use_ccall_macro` is true

# replace `@ccall` with `@gcsafe_ccall`
expr.args[2].args[1].args[1] = Symbol("@gcsafe_ccall")

target_expr = call_expr.args[1].args[1]
fn = String(target_expr.args[2].value)

# look up API options for this function
fn_options = Dict{String,Any}()
templates = Dict{String,Any}()
template_types = nothing
if haskey(options, "api")
names = [fn]

# _64 aliases are used by CUBLAS with Int64 arguments. they otherwise have
# an idential signature, so we can reuse the same type rewrites.
if endswith(fn, "_64")
push!(names, fn[1:end-3])
end

# look for a template rewrite: many libraries have very similar functions,
# e.g., `cublas[SDHCZ]gemm`, for which we can use the same type rewrites
# registered as `cublas𝕏gemm` template with `T` and `S` placeholders.
for name in copy(names), (typcode,(T,S)) in ["S"=>("Cfloat","Cfloat"),
"D"=>("Cdouble","Cdouble"),
"H"=>("Float16","Float16"),
"C"=>("cuComplex","Cfloat"),
"Z"=>("cuDoubleComplex","Cdouble")]
idx = findfirst(typcode, name)
while idx !== nothing
template_name = name[1:idx.start-1] * "𝕏" * name[idx.stop+1:end]
if haskey(options["api"], template_name)
templates[template_name] = ["T" => T, "S" => S]
push!(names, template_name)
end
idx = findnext(typcode, name, idx.stop+1)
end
end

# the exact name is always checked first, so it's always possible to
# override the type rewrites for a specific function
# (e.g. if a _64 function ever passes a `Ptr{Cint}` index).
for name in names
template_types = get(templates, name, nothing)
if haskey(options["api"], name)
fn_options = options["api"][name]
break
end
end
end

# rewrite pointer argument types
arg_exprs = call_expr.args[1].args[2:end]
argtypes = get(fn_options, "argtypes", Dict())
for (arg, typ) in argtypes
i = parse(Int, arg)
i in 1:length(arg_exprs) || error("invalid argtypes for $fn: index $arg is out of bounds")

# _64 aliases should use Int64 instead of Int32/Cint
if endswith(fn, "_64")
typ = replace(typ, "Cint" => "Int64", "Int32" => "Int64")
end

# expand type templates
if template_types !== nothing
typ = replace(typ, template_types...)
end

arg_exprs[i].args[2] = Meta.parse(typ)
end

# insert `initialize_context()` before each function with a `ccall`
if get(fn_options, "needs_context", true)
pushfirst!(expr.args[2].args, :(initialize_context()))
end

# insert `@checked` before each function with a `ccall` returning a checked type`
rettyp = call_expr.args[2]
checked_types = if haskey(options, "api")
get(options["api"], "checked_rettypes", Dict())
else
String[]
end
if rettyp isa Symbol && String(rettyp) in checked_types
node.exprs[1] = Expr(:macrocall, Symbol("@checked"), nothing, expr)
end
end
end
end

function main()
cuda = joinpath(CUDA_SDK_jll.artifact_dir, "cuda", "include")
@assert CUDA_SDK_jll.is_available()

cudss = joinpath(CUDSS_jll.artifact_dir, "include")
@assert CUDSS_jll.is_available()

# function wrap(name, headers; targets=headers, defines=[], include_dirs=[])
# wrap("cudss", ["$cudss/cudss.h"]; include_dirs=[cuda, cudss])
args = get_default_args()
append!(args, "-I$cuda", "-I$cudss")
# for define in defines
# if isa(define, Pair)
# append!(args, ["-D", "$(first(define))=$(last(define))"])
# else
# append!(args, ["-D", "$define"])
# end
# end

options = load_options(joinpath(@__DIR__, "$(name).toml"))

# create context
headers = ["$cudss/cudss.h"]
targets = headers
ctx = create_context(headers, args, options)

insert!(ctx.passes, 2, AvoidDuplicates())

# run generator
build!(ctx, BUILDSTAGE_NO_PRINTING)

# Only keep the wrapped headers
replace!(get_nodes(ctx.dag)) do node
path = normpath(Clang.get_filename(node.cursor))
should_wrap = any(targets) do target
occursin(target, path)
end
if !should_wrap
return ExprNode(node.id, Generators.Skip(), node.cursor, Expr[], node.adj)
end
return node
end

rewriter!(ctx, options)
build!(ctx, BUILDSTAGE_PRINTING_ONLY)
output_file = options["general"]["output_file_path"]
format_file(output_file, YASStyle())

return
end

if abspath(PROGRAM_FILE) == @__FILE__
main()
end

0 comments on commit 2f2d6c0

Please sign in to comment.