Skip to content

Commit

Permalink
Add nvrtc_sm_top_level::add_link_list() (#2781)
Browse files Browse the repository at this point in the history
  • Loading branch information
rwgk authored Nov 14, 2024
1 parent 61bc210 commit d53dd89
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 29 deletions.
14 changes: 14 additions & 0 deletions c/parallel/src/nvrtc/command_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string>
#include <string_view>
#include <tuple>
#include <vector>

#include <nvJitLink.h>
#include <nvrtc.h>
Expand Down Expand Up @@ -53,6 +54,7 @@ struct nvrtc_ltoir
const char* ltoir;
int ltsz;
};
using nvrtc_ltoir_list = std::vector<nvrtc_ltoir>;
struct nvrtc_jitlink_cleanup
{
nvrtc_cubin& cubin_ref;
Expand Down Expand Up @@ -132,6 +134,13 @@ struct nvrtc_command_list_visitor
check(nvJitLinkAddData(
jitlink.handle, NVJITLINK_INPUT_LTOIR, (const void*) lto.ltoir, (size_t) lto.ltsz, program_name.data()));
}
void execute(const nvrtc_ltoir_list& lto_list)
{
for (auto lto : lto_list)
{
execute(lto);
}
}
void execute(nvrtc_jitlink_cleanup cleanup)
{
auto jitlink_error = nvJitLinkComplete(jitlink.handle);
Expand Down Expand Up @@ -215,6 +224,11 @@ struct nvrtc_sm_top_level
{
return {nvrtc_command_list_append(std::move(cl), std::move(arg))};
}
// Add linkable units to whole program
nvrtc_sm_top_level<Tx..., nvrtc_ltoir_list> add_link_list(nvrtc_ltoir_list arg)
{
return {nvrtc_command_list_append(std::move(cl), std::move(arg))};
}

// Execute steps and link unit
nvrtc_cubin finalize_program(uint32_t numLtoOpts, const char** ltoOpts)
Expand Down
52 changes: 23 additions & 29 deletions c/parallel/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,27 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
constexpr size_t num_lto_args = 2;
const char* lopts[num_lto_args] = {"-lto", arch.c_str()};

auto cl =
// Collect all LTO-IRs to be linked.
nvrtc_ltoir_list ltoir_list;
auto ltoir_list_append = [&ltoir_list](nvrtc_ltoir lto) {
if (lto.ltsz)
{
ltoir_list.push_back(std::move(lto));
}
};
ltoir_list_append({op.ltoir, op.ltoir_size});
if (cccl_iterator_kind_t::iterator == input_it.type)
{
ltoir_list_append({input_it.advance.ltoir, input_it.advance.ltoir_size});
ltoir_list_append({input_it.dereference.ltoir, input_it.dereference.ltoir_size});
}
if (cccl_iterator_kind_t::iterator == output_it.type)
{
ltoir_list_append({output_it.advance.ltoir, output_it.advance.ltoir_size});
ltoir_list_append({output_it.dereference.ltoir, output_it.dereference.ltoir_size});
}

nvrtc_cubin result =
make_nvrtc_command_list()
.add_program(nvrtc_translation_unit{src.c_str(), name})
.add_expression({single_tile_kernel_name})
Expand All @@ -289,34 +309,8 @@ extern "C" CCCL_C_API CUresult cccl_device_reduce_build(
.get_name({single_tile_second_kernel_name, single_tile_second_kernel_lowered_name})
.get_name({reduction_kernel_name, reduction_kernel_lowered_name})
.cleanup_program()
.add_link({op.ltoir, op.ltoir_size});

nvrtc_cubin result{};

if (cccl_iterator_kind_t::iterator == input_it.type && cccl_iterator_kind_t::iterator == output_it.type)
{
result = cl.add_link({input_it.advance.ltoir, input_it.advance.ltoir_size})
.add_link({input_it.dereference.ltoir, input_it.dereference.ltoir_size})
.add_link({output_it.advance.ltoir, output_it.advance.ltoir_size})
.add_link({output_it.dereference.ltoir, output_it.dereference.ltoir_size})
.finalize_program(num_lto_args, lopts);
}
else if (cccl_iterator_kind_t::iterator == input_it.type)
{
result = cl.add_link({input_it.advance.ltoir, input_it.advance.ltoir_size})
.add_link({input_it.dereference.ltoir, input_it.dereference.ltoir_size})
.finalize_program(num_lto_args, lopts);
}
else if (cccl_iterator_kind_t::iterator == output_it.type)
{
result = cl.add_link({output_it.advance.ltoir, output_it.advance.ltoir_size})
.add_link({output_it.dereference.ltoir, output_it.dereference.ltoir_size})
.finalize_program(num_lto_args, lopts);
}
else
{
result = cl.finalize_program(num_lto_args, lopts);
}
.add_link_list(ltoir_list)
.finalize_program(num_lto_args, lopts);

cuLibraryLoadData(&build->library, result.cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0);
check(cuLibraryGetKernel(&build->single_tile_kernel, build->library, single_tile_kernel_lowered_name.c_str()));
Expand Down

0 comments on commit d53dd89

Please sign in to comment.