Skip to content

Commit

Permalink
Fixes #669: Can now obtain the number of kernels in a module and a co…
Browse files Browse the repository at this point in the history
…ntainer of `kernel_t` wrappers for them
  • Loading branch information
eyalroz committed Sep 6, 2024
1 parent 0f445ac commit b2e023c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/cuda/api/detail/unique_span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,17 @@ unique_span<T> make_unique_span(size_t size)
return unique_span<T>{ new T[size], size };
}

template <typename T, typename Generator>
unique_span <T> generate_unique_span(size_t size, Generator generator_by_index) noexcept
{
// Q: Do I need to check the alignment here? Perhaps allocate more to ensure alignment?
auto result_data = static_cast<T *>(::operator new(sizeof(T) * size));
for (size_t i = 0; i < size; i++) {
new(&result_data[i]) T(generator_by_index(i));
}
return unique_span<T>{result_data, size};
}

} // namespace cuda

#endif // CUDA_API_WRAPPERS_UNIQUE_SPAN_HPP_
21 changes: 21 additions & 0 deletions src/cuda/api/module.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,27 @@ class module_t {
return { memory::as_pointer(dptr), size };
}

#if CUDA_VERSION >= 12040
size_t get_num_kernels() const
{
unsigned result;
auto status = cuModuleGetFunctionCount(&result, handle_);
throw_if_error_lazy(status, "Failed determining function count for " + module::detail_::identify(*this));
return result;
}

unique_span<kernel_t> get_kernels() const
{
auto num_kernels = get_num_kernels();
if (num_kernels == 0) { return {}; }
auto handles = make_unique_span<kernel::handle_t>(num_kernels);
auto status = cuModuleEnumerateFunctions(handles.data(), num_kernels, handle_);
throw_if_error_lazy(status, "Failed enumerating the kernels in " + module::detail_::identify(*this));
auto gen = [&](size_t i) { return kernel::wrap(device_id_, context_handle_, handles[i]); };
return generate_unique_span<kernel_t>(handles.size(), gen);
}
#endif // CUDA_VERSION >= 12040

// TODO: Implement a surface reference and texture reference class rather than these raw pointers.

#if CUDA_VERSION < 12000
Expand Down

0 comments on commit b2e023c

Please sign in to comment.