Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add internal wrapper for cuda driver APIs #2070

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions cudax/include/cuda/experimental/__utility/driver_api.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDAX__UTILITY_DRIVER_API
#define _CUDAX__UTILITY_DRIVER_API

#include <cuda.h>

#include <cuda/std/__exception/cuda_error.h>

// Get the driver function by name using this macro
#define CUDAX_GET_DRIVER_FUNCTION(function_name) \
reinterpret_cast<decltype(function_name)*>(get_driver_entry_point(#function_name))

namespace cuda::experimental::detail::driver
miscco marked this conversation as resolved.
Show resolved Hide resolved
{
inline void* get_driver_entry_point(const char* name)
{
void* fn;
cudaDriverEntryPointQueryResult result;
cudaGetDriverEntryPoint(name, &fn, cudaEnableDefault, &result);
if (result != cudaDriverEntryPointSuccess)
{
if (result == cudaDriverEntryPointVersionNotSufficent)
{
::cuda::__throw_cuda_error(cudaErrorNotSupported, "Driver does not support this API");
}
else
{
::cuda::__throw_cuda_error(cudaErrorUnknown, "Failed to access driver API");
}
}
return fn;
}

template <typename Fn, typename... Args>
inline void call_driver_fn(Fn fn, const char* err_msg, Args... args)
{
CUresult status = fn(args...);
if (status != CUDA_SUCCESS)
{
::cuda::__throw_cuda_error(static_cast<cudaError_t>(status), err_msg);
}
Comment on lines +47 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want something like _CCCL_TRY_CUDA_API

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could also be a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be more or less equivalent to _CCCL_TRY_CUDA_API, am I missing some key difference here? I would have no issues turning it into a macro instead if its preffered

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe a function is "cleaner" than a macro, but the macro cannot go as we cannot depend on cudax.

Otherwise we would need to move the function into libcu++

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need two separate functions/macros because driver API returns CUresult and runtime returns cudaError_t.

But these have the same values, so maybe we can add a cast to _CCCL_TRY_CUDA_API and remove this function 🤔

}

inline void ctxPush(CUcontext ctx)
{
static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuCtxPushCurrent);
call_driver_fn(driver_fn, "Failed to push context", ctx);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we dynamically loading these functions instead of including <cuda.h> and linking to libcuda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need to require -lcuda compilation flag otherwise. This is more in line with the current CUDA runtime which does not require the compilation flag. There are compatibility reasons why current CUDA runtime does that and we probably want the same thing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

directly linking to libcuda.so means that any consuming library would only run on machines with the CUDA driver installed. This would mean that any application with runtime logic to dispatch to CUDA vs CPU based on HW support would fail to load when launched on a machine without the CUDA driver.

From a build engineer standpoint linking to libcuda.so should never happen

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool thanks. i knew there must be a reason. TIL


inline void ctxPop()
{
static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuCtxPopCurrent);
CUcontext dummy;
call_driver_fn(driver_fn, "Failed to pop context", &dummy);
}

inline CUcontext ctxGetCurrent()
{
static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuCtxGetCurrent);
CUcontext result;
call_driver_fn(driver_fn, "Failed to get current context", &result);
return result;
}

inline CUcontext streamGetCtx(CUstream stream)
{
static auto driver_fn = CUDAX_GET_DRIVER_FUNCTION(cuStreamGetCtx);
CUcontext result;
call_driver_fn(driver_fn, "Failed to get context from a stream", stream, &result);
return result;
}
} // namespace cuda::experimental::detail::driver

#undef CUDAX_GET_DRIVER_FUNCTION
#endif
4 changes: 4 additions & 0 deletions cudax/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ foreach(cn_target IN LISTS cudax_TARGETS)
launch/configuration.cu
)
target_compile_options(${test_target} PRIVATE $<$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>:--extended-lambda>)

Cudax_add_catch2_test(test_target misc_tests ${cn_target}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Cudax_add_catch2_test(test_target misc_tests ${cn_target}
cudax_add_catch2_test(test_target misc_tests ${cn_target}

utility/driver_api.cu
)
endforeach()
44 changes: 44 additions & 0 deletions cudax/test/utility/driver_api.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA C++ Core Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//
#define LIBCUDACXX_ENABLE_EXCEPTIONS

#include <cuda/experimental/__utility/driver_api.cuh>

#include "../hierarchy/testing_common.cuh"

TEST_CASE("Call each one", "[driver api]")
{
cudaStream_t stream;
// Assumes the ctx stack was empty or had one ctx, should be the case unless some other
// test leaves 2+ ctxs on the stack

// Pushes the primary context if the stack is empty
CUDART(cudaStreamCreate(&stream));

auto ctx = cuda::experimental::detail::driver::ctxGetCurrent();
CUDAX_REQUIRE(ctx != nullptr);

cuda::experimental::detail::driver::ctxPop();
CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == nullptr);

cuda::experimental::detail::driver::ctxPush(ctx);
CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx);

cuda::experimental::detail::driver::ctxPush(ctx);
CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx);

cuda::experimental::detail::driver::ctxPop();
CUDAX_REQUIRE(cuda::experimental::detail::driver::ctxGetCurrent() == ctx);

auto stream_ctx = cuda::experimental::detail::driver::streamGetCtx(stream);
CUDAX_REQUIRE(ctx == stream_ctx);

CUDART(cudaStreamDestroy(stream));
}
Loading