Skip to content

Commit

Permalink
update all places that use parse_decl_from_source
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Feb 7, 2025
1 parent ee172f9 commit 33dd101
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 25 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ import numpy as np
source = os.path.join(os.path.dirname(__file__), "demo.cuh")
# Assume your machine has a GPU that supports "sm_80" compute capability,
# parse the header with sm_80 compute capability.
structs, functions, *_ = parse_declarations_from_source(source, [source], "sm_80")
decls = parse_declarations_from_source(source, [source], "sm_80")
shim_writer = MemoryShimWriter(f'#include "{source}"')
# Make Numba bindings from the declarations.
# New type "myfloat16" is a Number type, data model is PrimitiveModel.
myfloat16 = bind_cxx_struct(shim_writer, structs[0], types.Number, PrimitiveModel)
bind_cxx_function(shim_writer, functions[0])
hsqrt = bind_cxx_function(shim_writer, functions[1])
myfloat16 = bind_cxx_struct(shim_writer, decls.structs[0], types.Number, PrimitiveModel)
bind_cxx_function(shim_writer, decls.functions[0])
hsqrt = bind_cxx_function(shim_writer, decls.functions[1])
```

`myfloat16` struct can now be used within Numba:
Expand Down
6 changes: 2 additions & 4 deletions numbast/benchmarks/test_rtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ def simulate_header(request):
tmp.flush()

major, minor = cuda.get_current_device().compute_capability
_, functions, *_ = parse_declarations_from_source(
tmp.name, [tmp.name], f"sm_{major}{minor}"
)
decls = parse_declarations_from_source(tmp.name, [tmp.name], f"sm_{major}{minor}")
shim_writer = MemoryShimWriter(f'#include "{tmp.name}"')
adds = bind_cxx_functions(shim_writer, functions)
adds = bind_cxx_functions(shim_writer, decls.functions)

yield adds, shim_writer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def cuda_enum(data_folder):

header = data_folder("enum.cuh")

*_, enums = parse_declarations_from_source(header, [header], "sm_50")
decls = parse_declarations_from_source(header, [header], "sm_50")
enums = decls.enums

assert len(enums) == 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def cuda_function(data_folder):

header = data_folder("function.cuh")

_, functions, *_ = parse_declarations_from_source(header, [header], "sm_50")
decls = parse_declarations_from_source(header, [header], "sm_50")
functions = decls.functions

assert len(functions) == 3

Expand Down
4 changes: 3 additions & 1 deletion numbast/src/numbast/static/tests/test_operator_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def cuda_decls(data_folder):

specs = {"Foo": (Type, StructModel, header)}

structs, functions, *_ = parse_declarations_from_source(header, [header], "sm_50")
decls = parse_declarations_from_source(header, [header], "sm_50")
structs = decls.structs
functions = decls.functions

assert len(structs) == 1

Expand Down
3 changes: 2 additions & 1 deletion numbast/src/numbast/static/tests/test_static_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def cuda_struct(data_folder):

specs = {"__myfloat16": (Number, PrimitiveModel, header)}

structs, *_ = parse_declarations_from_source(header, [header], "sm_50")
decls = parse_declarations_from_source(header, [header], "sm_50")
structs = decls.structs

assert len(structs) == 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def cuda_struct(data_folder):
"MyInt": (Number, PrimitiveModel, header),
}

structs, *_ = parse_declarations_from_source(header, [header], "sm_50")
decls = parse_declarations_from_source(header, [header], "sm_50")
structs = decls.structs

assert len(structs) == 3

Expand Down
16 changes: 9 additions & 7 deletions numbast/src/numbast/tools/static_binding_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,16 @@ def _static_binding_generator(

# TODO: we don't have tests on different compute capabilities for the static binding generator yet.
# This will be added in future PRs.
structs, functions, function_templates, class_templates, typedefs, enums = (
parse_declarations_from_source(
entry_point,
retain_list,
compute_capability=compute_capability,
cudatoolkit_include_dir=CUDA_INCLUDE_PATH,
)
decls = parse_declarations_from_source(
entry_point,
retain_list,
compute_capability=compute_capability,
cudatoolkit_include_dir=CUDA_INCLUDE_PATH,
)
structs = decls.structs
functions = decls.functions
enums = decls.enums
typedefs = decls.typedefs

if log_generates:
log_files_to_generate(functions, structs, enums, typedefs)
Expand Down
3 changes: 2 additions & 1 deletion numbast/tests/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
source = os.path.join(os.path.dirname(__file__), "demo.cuh")
# Assume your machine has a GPU that supports "sm_80" compute capability,
# parse the header with sm_80 compute capability.
structs, functions, *_ = parse_declarations_from_source(source, [source], "sm_80")
decls = parse_declarations_from_source(source, [source], "sm_80")
structs, functions = decls.structs, decls.functions

shim_writer = MemoryShimWriter(f'#include "{source}"')

Expand Down
3 changes: 2 additions & 1 deletion numbast/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def test_struct_binding_has_correct_LLVM_type():
# This test checks if the bindings of type Foo correctly lowered into
# LLVM type { i32, i32, i32 }.
p = os.path.join(DATA_FOLDER, "sample_struct.cuh")
structs, *_ = parse_declarations_from_source(p, [p], "sm_80")
decls = parse_declarations_from_source(p, [p], "sm_80")
structs = decls.structs
shim_writer = MemoryShimWriter(f"#include {p}")
s = bind_cxx_struct(shim_writer, structs[0], types.Type, StructModel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,18 @@
curand_files = [h for h in curand_files if os.path.exists(h)]


structs, functions, _, _, typedefs, enums = parse_declarations_from_source(
decls = parse_declarations_from_source(
curand_kernel_h,
curand_files,
f"sm_{COMPUTE_CAPABILITY[0]}{COMPUTE_CAPABILITY[1]}",
cudatoolkit_include_dir=CUDA_INCLUDE_PATH,
)
structs, functions, typedefs, enums = (
decls.structs,
decls.functions,
decls.typedefs,
decls.enums,
)


TYPE_SPECIALIZATION = {
Expand Down
8 changes: 7 additions & 1 deletion numbast_extensions/fp16/src/fp16/fp16_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@

cuda_bf16 = numbast_extensions.bf16.bf16_bindings.cuda_bf16

structs, functions, _, _, typedefs, enums = parse_declarations_from_source(
decls = parse_declarations_from_source(
cuda_fp16,
[cuda_fp16, cuda_fp16_hpp],
f"sm_{COMPUTE_CAPABILITY[0]}{COMPUTE_CAPABILITY[1]}",
cudatoolkit_include_dir=CUDA_INCLUDE_PATH,
)
structs, functions, typedefs, enums = (
decls.structs,
decls.functions,
decls.typedefs,
decls.enums,
)

TYPE_SPECIALIZATION = {
"__half_raw": types.Number,
Expand Down
8 changes: 7 additions & 1 deletion numbast_extensions/numbast_extensions/bf16/bf16_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@
cuda_bf16_hpp = os.path.join(CUDA_INCLUDE_PATH, "cuda_bf16.hpp")


structs, functions, _, _, typedefs, _ = parse_declarations_from_source(
decls = parse_declarations_from_source(
cuda_bf16,
[cuda_bf16, cuda_bf16_hpp],
f"sm_{COMPUTE_CAPABILITY[0]}{COMPUTE_CAPABILITY[1]}",
cudatoolkit_include_dir=CUDA_INCLUDE_PATH,
)
structs, functions, typedefs, enums = (
decls.structs,
decls.functions,
decls.typedefs,
decls.enums,
)

TYPE_SPECIALIZATION = {
"__nv_bfloat16_raw": types.Number,
Expand Down

0 comments on commit 33dd101

Please sign in to comment.