diff --git a/xla/backends/cpu/collectives/in_process_communicator.cc b/xla/backends/cpu/collectives/in_process_communicator.cc index b5ab1396e3847..2d4dc88f9ef27 100644 --- a/xla/backends/cpu/collectives/in_process_communicator.cc +++ b/xla/backends/cpu/collectives/in_process_communicator.cc @@ -344,7 +344,7 @@ absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, std::string name = absl::StrCat("all reduce ", key.ToString()); AllReduceParticipant partiticipant{rank_, send_buffer, recv_buffer}; - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllReduceOp, dtype, count, reduction_kind, std::placeholders::_1)); @@ -362,7 +362,7 @@ absl::Status InProcessCommunicator::CollectivePermute( recv_buffer}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(CollectivePermuteOp, num_bytes, std::placeholders::_1)); } @@ -380,7 +380,7 @@ absl::Status InProcessCommunicator::AllToAll( {recv_buffers.begin(), recv_buffers.end()}}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllToAllOp, num_bytes, std::placeholders::_1)); } @@ -396,7 +396,7 @@ absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, AllGatherParticipant partiticipant{rank_, send_buffer, recv_buffer}; size_t num_bytes = count * primitive_util::ByteWidth(dtype); - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(AllGatherOp, num_bytes, std::placeholders::_1)); } @@ -411,7 +411,7 @@ absl::Status InProcessCommunicator::ReduceScatter( std::string name = absl::StrCat("reduce scatter ", key.ToString()); ReduceScatterParticipant partiticipant{rank_, send_buffer, recv_buffer}; - return RendezvousSingle( + return Rendezvous( name, key, partiticipant, key.num_local_participants, std::bind(ReduceScatterOp, dtype, count, reduction_kind, std::placeholders::_1)); diff --git a/xla/backends/gpu/collectives/gpu_cliques.cc b/xla/backends/gpu/collectives/gpu_cliques.cc index a06cb864d556f..77398835588d8 100644 --- a/xla/backends/gpu/collectives/gpu_cliques.cc +++ b/xla/backends/gpu/collectives/gpu_cliques.cc @@ -292,7 +292,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, // processes are not able to synchronize device activity. RendezvousArg rendezvous_arg = std::make_pair(device_rank, synchronized); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rendezvous_arg, num_local_participants, initialize, WarnStuckTimeout(), TerminateTimeout()); @@ -431,7 +431,7 @@ InitializeGpuClique(GpuCollectives* collectives, se::StreamExecutor* device, rank.value(), clique_key.ToString(), run_id.ToInt(), parent_clique_key.ToString()); - return RendezvousSingle>( + return Rendezvous>( initialization_rendezvous_name, rendezvous_key, rank_pair, num_local_participants, split, WarnStuckTimeout(), TerminateTimeout()); } @@ -466,7 +466,7 @@ absl::StatusOr> AcquireGpuClique( TF_ASSIGN_OR_RETURN( std::shared_ptr clique, - RendezvousSingle>( + Rendezvous>( rendezvous_name, rendezvous_key, num_local_participants, [&] { tsl::profiler::TraceMe trace("LockGpuClique"); diff --git a/xla/debug_options_flags.cc b/xla/debug_options_flags.cc index b78b87c8a15df..b4ea4e0a4a04b 100644 --- a/xla/debug_options_flags.cc +++ b/xla/debug_options_flags.cc @@ -2124,13 +2124,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for( &DebugOptions::set_xla_gpu_executable_warn_stuck_timeout_seconds), debug_options->xla_gpu_executable_warn_stuck_timeout_seconds(), - "Set timeout for RendezvousSingle stuck warning")); + "Set timeout for Rendezvous stuck warning")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_terminate_timeout", int32_setter_for( &DebugOptions::set_xla_gpu_executable_terminate_timeout_seconds), debug_options->xla_gpu_executable_terminate_timeout_seconds(), - "Set timeout for RendezvousSingle termination")); + "Set timeout for Rendezvous termination")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_disable_binary_libraries", bool_setter_for( diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc index 9e0a5bebcc4a2..f656e9691a0d3 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc @@ -626,7 +626,7 @@ absl::Status RendezvousAfterInitialization( run_options->device_ordinal(), run_options->run_options().run_id().ToInt()); - RendezvousSingle( + Rendezvous( rendezvous_name, rendezvous_key, num_local_participants, absl::Seconds( debug_options diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 1c839dcb18c9b..47211e6f437c9 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -479,10 +479,10 @@ absl::Status NcclCollectiveThunk::ExecuteOnStream(const ExecuteParams& params) { "first call to collective operation %d; run_id=%d", config().op_id, params.collective_params->run_id.ToInt()); - RendezvousSingle(first_call_rendezvous_flag_, rendezvous_name, - rendezvous_key, num_local_participants, - /*warn_stuck_timeout=*/absl::Seconds(20), - /*terminate_timeout=*/absl::Seconds(40)); + Rendezvous(first_call_rendezvous_flag_, rendezvous_name, rendezvous_key, + num_local_participants, + /*warn_stuck_timeout=*/absl::Seconds(20), + /*terminate_timeout=*/absl::Seconds(40)); } return absl::OkStatus(); diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.h b/xla/service/gpu/runtime/nccl_collective_thunk.h index acdb18d68a3fc..5b5ba1fcf2699 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -210,7 +210,7 @@ class NcclCollectiveThunk : public Thunk { // // TODO(ezhulenev): Try to move this flag to NCCL clique as we need to make // sure that all NCCL resources are allocated just once. - RendezvousSingleFlag first_call_rendezvous_flag_; + RendezvousFlag first_call_rendezvous_flag_; }; //===----------------------------------------------------------------------===// diff --git a/xla/service/rendezvous.cc b/xla/service/rendezvous.cc index a22c5537a4d45..e9c88dbca2e5a 100644 --- a/xla/service/rendezvous.cc +++ b/xla/service/rendezvous.cc @@ -137,13 +137,12 @@ inline constexpr int32_t kPending = 0; inline constexpr int32_t kCompleted = std::numeric_limits::max(); } // namespace -RendezvousSingleFlag::RendezvousSingleFlag() : state_(kPending) {} +RendezvousFlag::RendezvousFlag() : state_(kPending) {} -RendezvousSingleFlag::InFlightRendezvous::InFlightRendezvous( - RendezvousSingleFlag* flag) +RendezvousFlag::InFlightRendezvous::InFlightRendezvous(RendezvousFlag* flag) : flag_(flag) {} -RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { +RendezvousFlag::InFlightRendezvous::~InFlightRendezvous() { if (flag_ == nullptr) return; // Reload state and use CAS to decide if we are the one who @@ -162,11 +161,11 @@ RendezvousSingleFlag::InFlightRendezvous::~InFlightRendezvous() { } } -RendezvousSingleFlag::InFlightRendezvous::operator bool() const { +RendezvousFlag::InFlightRendezvous::operator bool() const { return flag_ != nullptr; } -RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { +RendezvousFlag::InFlightRendezvous RendezvousFlag::TryJoin() { // If `state_` is `kCompleted` it means that we have at least one completed // rendezvous for this flag and can skip it. if (state_.load() == kCompleted) return InFlightRendezvous(nullptr); @@ -184,8 +183,6 @@ RendezvousSingleFlag::InFlightRendezvous RendezvousSingleFlag::TryJoin() { return InFlightRendezvous(this); } -bool RendezvousSingleFlag::IsCompleted() const { - return state_.load() == kCompleted; -} +bool RendezvousFlag::IsCompleted() const { return state_.load() == kCompleted; } } // namespace xla diff --git a/xla/service/rendezvous.h b/xla/service/rendezvous.h index b19776601e894..ffd4c43100372 100644 --- a/xla/service/rendezvous.h +++ b/xla/service/rendezvous.h @@ -85,14 +85,14 @@ using RendezvousResultType = typename RendezvousResult::Type; // all threads receive the result. Rendezvous must have a human readable name to // make easy to debug stuck and timed out attempts. template -RendezvousResultType RendezvousSingle( +RendezvousResultType Rendezvous( absl::string_view name, const K& key, const V& value, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); // A rendezvous for a group of threads that do not have any value arguments. template -RendezvousResultType RendezvousSingle( +RendezvousResultType Rendezvous( absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -100,14 +100,13 @@ RendezvousResultType RendezvousSingle( // A rendezvous for a group of threads that do not have any computation to run // and simply acts as a barrier for a group of thread. template -void RendezvousSingle( - absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); -// An `std::once_flag`-like primitive for executing RendezvousSingle operations. +// An `std::once_flag`-like primitive for executing Rendezvous operations. // -// RendezvousSingleFlag guarantees that all or none participants in a rendezvous +// RendezvousFlag guarantees that all or none participants in a rendezvous // join the rendezvous process and once rendezvous is completed flag marked as // `completed` and all further rendezvous using this flag will be skipped. It // has a weaker than exactly-once guarantee and multiple racing rendezvous can @@ -119,17 +118,17 @@ void RendezvousSingle( // and prefer simpler implementation with weaker guarantees. // // See: https://en.cppreference.com/w/cpp/thread/once_flag -class RendezvousSingleFlag { +class RendezvousFlag { public: - RendezvousSingleFlag(); + RendezvousFlag(); - RendezvousSingleFlag(const RendezvousSingleFlag&) = delete; - RendezvousSingleFlag& operator=(const RendezvousSingleFlag&) = delete; + RendezvousFlag(const RendezvousFlag&) = delete; + RendezvousFlag& operator=(const RendezvousFlag&) = delete; // RAII wrapper to exit from in-flight rendezvous when destructed. class InFlightRendezvous { public: - explicit InFlightRendezvous(RendezvousSingleFlag* flag); + explicit InFlightRendezvous(RendezvousFlag* flag); ~InFlightRendezvous(); InFlightRendezvous(const InFlightRendezvous&) = delete; @@ -138,7 +137,7 @@ class RendezvousSingleFlag { operator bool() const; // NOLINT private: - RendezvousSingleFlag* flag_; + RendezvousFlag* flag_; }; // Returns InFlightRendezvous convertible to `true` if the caller should join @@ -159,8 +158,8 @@ class RendezvousSingleFlag { // rendezvous. If rendezvous will not be executed it will return empty shared // pointer result. template -RendezvousResultType RendezvousSingle( - RendezvousSingleFlag& flag, absl::string_view name, const K& key, +RendezvousResultType Rendezvous( + RendezvousFlag& flag, absl::string_view name, const K& key, size_t num_threads, Fn fn, absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), absl::Duration terminate_timeout = absl::InfiniteDuration()); @@ -169,11 +168,10 @@ RendezvousResultType RendezvousSingle( // not in `completed` state and will switch it to `completed` after finishing a // rendezvous. template -void RendezvousSingle( - RendezvousSingleFlag& flag, absl::string_view name, const K& key, - size_t num_threads, - absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), - absl::Duration terminate_timeout = absl::InfiniteDuration()); +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, + absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), + absl::Duration terminate_timeout = absl::InfiniteDuration()); //===----------------------------------------------------------------------===// // Internal implementation details. @@ -291,11 +289,10 @@ void AwaitAndLogIfStuck(RendezvousStateSynchronization& state, int32_t id, //===----------------------------------------------------------------------===// template -RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, - const V& value, size_t num_threads, - Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + const V& value, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { // Check that `fn` is callable with a span of values and returns `R`. static_assert(std::is_invocable_r_v>, "invalid rendezvous function signature"); @@ -319,7 +316,7 @@ RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, tsl::profiler::TraceMe trace([&] { return tsl::profiler::TraceMeEncode( - "RendezvousSingle", + "Rendezvous", {{"num_threads", num_threads}, {"name", name}, {"id", id}}); }); @@ -355,46 +352,44 @@ RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, } template -RendezvousResultType RendezvousSingle(absl::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - return RendezvousSingle( +RendezvousResultType Rendezvous(absl::string_view name, const K& key, + size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + return Rendezvous( name, key, std::nullopt, num_threads, [fn](auto) { return fn(); }, warn_stuck_timeout, terminate_timeout); } template -void RendezvousSingle(absl::string_view name, const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { - RendezvousSingle( +void Rendezvous(absl::string_view name, const K& key, size_t num_threads, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { + Rendezvous( name, key, std::nullopt, num_threads, [](auto) { return std::nullopt; }, warn_stuck_timeout, terminate_timeout); } template -RendezvousResultType RendezvousSingle(RendezvousSingleFlag& flag, - absl::string_view name, const K& key, - size_t num_threads, Fn fn, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +RendezvousResultType Rendezvous(RendezvousFlag& flag, absl::string_view name, + const K& key, size_t num_threads, Fn fn, + absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - return RendezvousSingle(name, key, num_threads, std::move(fn), - warn_stuck_timeout, terminate_timeout); + return Rendezvous(name, key, num_threads, std::move(fn), + warn_stuck_timeout, terminate_timeout); } else { return RendezvousResult::Empty(); } } template -void RendezvousSingle(RendezvousSingleFlag& flag, absl::string_view name, - const K& key, size_t num_threads, - absl::Duration warn_stuck_timeout, - absl::Duration terminate_timeout) { +void Rendezvous(RendezvousFlag& flag, absl::string_view name, const K& key, + size_t num_threads, absl::Duration warn_stuck_timeout, + absl::Duration terminate_timeout) { if (auto in_flight_rendezvous = flag.TryJoin()) { - RendezvousSingle(name, key, num_threads, warn_stuck_timeout, - terminate_timeout); + Rendezvous(name, key, num_threads, warn_stuck_timeout, + terminate_timeout); } } diff --git a/xla/service/rendezvous_test.cc b/xla/service/rendezvous_test.cc index 867d24971f078..c47550a63de17 100644 --- a/xla/service/rendezvous_test.cc +++ b/xla/service/rendezvous_test.cc @@ -41,8 +41,7 @@ tsl::thread::ThreadPool CreateThreadPool(int32_t size) { } TEST(RendezvousTest, OneParticipant) { - auto result = - RendezvousSingle("rendezvous_test", 0, 1, [] { return 42; }); + auto result = Rendezvous("rendezvous_test", 0, 1, [] { return 42; }); ASSERT_EQ(*result, 42); } @@ -53,7 +52,7 @@ TEST(RendezvousTest, TwoParticipants) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, 2, [] { return 42; }); + Rendezvous("rendezvous_test", 0, 2, [] { return 42; }); counter.DecrementCount(); }; }; @@ -81,7 +80,7 @@ TEST(RendezvousTest, TwoParticipantsWithValues) { auto task = [&](int32_t id) { return [&, id] { results[id] = - RendezvousSingle("rendezvous_test", 0, id, 2, accumulate); + Rendezvous("rendezvous_test", 0, id, 2, accumulate); counter.DecrementCount(); }; }; @@ -103,7 +102,7 @@ TEST(RendezvousTest, RepeatRendezvous) { absl::BlockingCounter counter(2); auto task = [&] { - RendezvousSingle("rendezvous_test", i, 2, [] { return 42; }); + Rendezvous("rendezvous_test", i, 2, [] { return 42; }); counter.DecrementCount(); }; @@ -119,8 +118,8 @@ TEST(RendezvousTest, ReturningStatusOr) { auto task = [&](int32_t id) { return [&, id] { - results[id] = RendezvousSingle>( - "rendezvous_test", 0, 2, [] { return 42; }); + results[id] = Rendezvous>("rendezvous_test", 0, 2, + [] { return 42; }); counter.DecrementCount(); }; }; @@ -135,8 +134,8 @@ TEST(RendezvousTest, ReturningStatusOr) { ASSERT_EQ(**results[1], 42); } -TEST(RendezvousTest, RendezvousSingleFlag) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlag) { + RendezvousFlag flag; auto thread_pool = CreateThreadPool(2); int32_t num_executed = 0; @@ -146,7 +145,7 @@ TEST(RendezvousTest, RendezvousSingleFlag) { auto task = [&](absl::BlockingCounter& counter) { return [&] { - RendezvousSingle( + Rendezvous( flag, "rendezvous_test", 0, 2, [&] { return ++num_executed; }, Timeout(), Terminate()); counter.DecrementCount(); @@ -169,8 +168,8 @@ TEST(RendezvousTest, RendezvousSingleFlag) { ASSERT_EQ(num_executed, 1); } -TEST(RendezvousTest, RendezvousSingleFlagRace) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRace) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -179,8 +178,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { auto task = [&](int32_t key) { return [&, key] { - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); }; }; @@ -191,8 +190,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRace) { } } -TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { - RendezvousSingleFlag flag; +TEST(RendezvousTest, RendezvousFlagRaceWithBarriers) { + RendezvousFlag flag; static constexpr int32_t kNumRendezvous = 16; static constexpr int32_t kNumThreads = 8; @@ -209,8 +208,8 @@ TEST(RendezvousTest, RendezvousSingleFlagRaceWithBarriers) { return [&, key] { participants_ready.DecrementCount(); participants_notification.WaitForNotification(); - RendezvousSingle(flag, "key: " + std::to_string(key), key, kNumThreads, - Timeout(), Terminate()); + Rendezvous(flag, "key: " + std::to_string(key), key, kNumThreads, + Timeout(), Terminate()); participants_done.DecrementCount(); }; }; @@ -238,8 +237,8 @@ static void BM_Rendezvous(benchmark::State& state) { absl::BlockingCounter counter(num_threads); for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { - RendezvousSingle("rendezvous_test", 0, num_threads, - [] { return 42; }); + Rendezvous("rendezvous_test", 0, num_threads, + [] { return 42; }); counter.DecrementCount(); }); } @@ -256,8 +255,8 @@ static void BM_RendezvousWithValues(benchmark::State& state) { for (int64_t i = 0; i < num_threads; ++i) { thread_pool.Schedule([&] { int32_t value = i; - RendezvousSingle("rendezvous_test", 0, value, num_threads, - [](auto) { return 42; }); + Rendezvous("rendezvous_test", 0, value, num_threads, + [](auto) { return 42; }); counter.DecrementCount(); }); } diff --git a/xla/stream_executor/host/BUILD b/xla/stream_executor/host/BUILD index fc895e3b002a7..cad11a5cbc1a2 100644 --- a/xla/stream_executor/host/BUILD +++ b/xla/stream_executor/host/BUILD @@ -4,14 +4,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/stream_executor:build_defs.bzl", "stream_executor_friends") load("//xla/tsl:tsl.bzl", "internal_visibility") -load( - "//xla/tsl/platform:build_config_root.bzl", - "if_llvm_aarch32_available", - "if_llvm_aarch64_available", - "if_llvm_powerpc_available", - "if_llvm_system_z_available", - "if_llvm_x86_available", -) load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -83,7 +75,6 @@ cc_library( ], deps = [ ":host_event", - ":host_kernel", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:kernel", @@ -91,6 +82,7 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", @@ -103,66 +95,6 @@ cc_library( ], ) -cc_library( - name = "host_kernel_c_api", - hdrs = ["host_kernel_c_api.h"], -) - -cc_library( - name = "host_kernel", - srcs = ["host_kernel.cc"], - hdrs = ["host_kernel.h"], - deps = [ - ":host_kernel_c_api", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:stream", - "//xla/tsl/concurrency:async_value", - "//xla/tsl/platform:env", - "//xla/tsl/platform:logging", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:env", - ], -) - -xla_cc_test( - name = "host_kernel_test", - srcs = ["host_kernel_test.cc"], - deps = [ - ":host_kernel", - ":host_kernel_c_api", - ":host_platform", - ":jit_host_kernel_function", - ":ptr_host_kernel_function", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor:stream_executor_h", - "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_benchmark", - "@tsl//tsl/platform:test_main", - ], -) - cc_library( name = "host_executor", srcs = [ @@ -173,7 +105,6 @@ cc_library( ], deps = [ ":host_event", - ":host_kernel", ":host_stream", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", @@ -185,15 +116,14 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_common", + "//xla/tsl/platform:env", "//xla/tsl/platform/profile_utils:profile_utils_cpu_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@tsl//tsl/platform:env", "@tsl//tsl/platform:platform_port", - "@tsl//tsl/platform:statusor", ], alwayslink = True, ) @@ -216,69 +146,3 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) - -cc_library( - name = "ptr_host_kernel_function", - srcs = ["ptr_host_kernel_function.cc"], - hdrs = ["ptr_host_kernel_function.h"], - deps = [ - ":host_executor", - ":host_kernel", - ":host_kernel_c_api", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor/platform:initialize", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], - alwayslink = True, # static kernel function loader registration -) - -cc_library( - name = "jit_host_kernel_function", - srcs = ["jit_host_kernel_function.cc"], - hdrs = ["jit_host_kernel_function.h"], - deps = [ - ":host_executor", - ":host_kernel", - ":host_kernel_c_api", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor/platform:initialize", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Analysis", - "@llvm-project//llvm:AsmParser", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:ExecutionEngine", - "@llvm-project//llvm:JITLink", - "@llvm-project//llvm:OrcJIT", - "@llvm-project//llvm:OrcShared", - "@llvm-project//llvm:Passes", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:Target", - "@llvm-project//llvm:TargetParser", - "@llvm-project//llvm:TransformUtils", - "@llvm-project//llvm:ir_headers", - "@tsl//tsl/platform:statusor", - ] + if_llvm_aarch32_available([ - "@llvm-project//llvm:ARMAsmParser", - "@llvm-project//llvm:ARMCodeGen", - ]) + if_llvm_aarch64_available([ - "@llvm-project//llvm:AArch64AsmParser", - "@llvm-project//llvm:AArch64CodeGen", - ]) + if_llvm_powerpc_available([ - "@llvm-project//llvm:PowerPCAsmParser", - "@llvm-project//llvm:PowerPCCodeGen", - ]) + if_llvm_system_z_available([ - "@llvm-project//llvm:SystemZAsmParser", - "@llvm-project//llvm:SystemZCodeGen", - ]) + if_llvm_x86_available([ - "@llvm-project//llvm:X86AsmParser", - "@llvm-project//llvm:X86CodeGen", - ]), - alwayslink = 1, # static kernel function loader registration -) diff --git a/xla/stream_executor/host/host_executor.cc b/xla/stream_executor/host/host_executor.cc index 5beb6ea547579..182166a83e2ca 100644 --- a/xla/stream_executor/host/host_executor.cc +++ b/xla/stream_executor/host/host_executor.cc @@ -22,12 +22,10 @@ limitations under the License. #include #include -#include #include #include #include #include -#include #include "absl/log/check.h" #include "absl/log/log.h" @@ -38,18 +36,15 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" -#include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host/host_stream.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/tsl/platform/profile_utils/cpu_utils.h" +#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/cpu_info.h" -#include "tsl/platform/env.h" #include "tsl/platform/mem.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" namespace stream_executor { namespace host { @@ -59,16 +54,6 @@ HostStream* AsHostStream(Stream* stream) { return dynamic_cast(stream); } -static std::vector& -KernelFunctionLoaderRegistry() { - static auto* registry = new std::vector(); - return *registry; -} - -void HostExecutor::RegisterKernelFunctionLoader(KernelFunctionLoader loader) { - KernelFunctionLoaderRegistry().push_back(std::move(loader)); -} - absl::Status HostExecutor::Init() { thread_pool_ = std::make_shared( tsl::Env::Default(), "host-executor", tsl::port::MaxParallelism()); @@ -77,18 +62,6 @@ absl::Status HostExecutor::Init() { absl::StatusOr> HostExecutor::LoadKernel( const MultiKernelLoaderSpec& spec) { - auto host_kernel = std::make_unique(thread_pool_); - host_kernel->SetArity(spec.arity()); - - for (auto& loader : KernelFunctionLoaderRegistry()) { - auto loaded = loader(spec); - if (!loaded.has_value()) continue; - - TF_ASSIGN_OR_RETURN(auto kernel_function, *std::move(loaded)); - host_kernel->SetKernelFunction(std::move(kernel_function)); - return std::move(host_kernel); - } - return absl::InternalError("No method of loading host kernel provided"); } diff --git a/xla/stream_executor/host/host_executor.h b/xla/stream_executor/host/host_executor.h index 831cf27727b3a..e38d273946fc4 100644 --- a/xla/stream_executor/host/host_executor.h +++ b/xla/stream_executor/host/host_executor.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ #include -#include #include #include #include @@ -27,7 +26,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -35,7 +33,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_common.h" -#include "tsl/platform/threadpool.h" +#include "xla/tsl/platform/threadpool.h" namespace stream_executor { namespace host { @@ -48,15 +46,6 @@ namespace host { // routines executed under the context of a GPU executor. class HostExecutor : public StreamExecutorCommon { public: - // A function that loads a kernel function from a given spec. If spec is not - // supported it returns an empty optional. - using KernelFunctionLoader = std::function>>( - const MultiKernelLoaderSpec& spec)>; - - // Registers a kernel function loader in a static registry. - static void RegisterKernelFunctionLoader(KernelFunctionLoader loader); - HostExecutor(Platform* platform, int device_ordinal) : StreamExecutorCommon(platform), device_ordinal_(device_ordinal) {} diff --git a/xla/stream_executor/host/host_kernel.cc b/xla/stream_executor/host/host_kernel.cc deleted file mode 100644 index 6ed9ba7a2e0a1..0000000000000 --- a/xla/stream_executor/host/host_kernel.cc +++ /dev/null @@ -1,266 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/host/host_kernel.h" - -#include -#include -#include -#include -#include - -#include "absl/base/optimization.h" -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/types/span.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "xla/tsl/platform/logging.h" -#include "xla/tsl/platform/threadpool.h" - -namespace stream_executor::host { - -using LaunchEvent = HostKernel::LaunchEvent; - -// Non-reference-counted async value ref for host kernels executed inline. -static tsl::AsyncValueRef OkLaunchEvent() { - static tsl::AsyncValueOwningRef* event = [] { - auto* storage = new tsl::internal::AsyncValueStorage(); - return new tsl::AsyncValueOwningRef( - tsl::MakeAvailableAsyncValueRef(*storage)); - }(); - return event->AsRef(); -} - -static absl::InlinedVector ConvertBuffersToKernelArgs( - absl::Span buffers) { - absl::InlinedVector args(buffers.size()); - for (size_t i = 0; i < buffers.size(); ++i) { - args[i].data = const_cast(buffers[i].opaque()); - args[i].size = buffers[i].size(); - } - return args; -} - -namespace { -// Keep a state of an in-flight asynchronous kernel execution on a heap to keep -// it alive until the last task is done. -class HostKernelExecuteState { - public: - HostKernelExecuteState(HostKernel::TaskRunner task_runner, - SE_HOST_Kernel* kernel, ThreadDim thread_dims, - absl::Span args); - ~HostKernelExecuteState(); - - // Calls a task with index `task_index` synchronously. - void CallSync(uint64_t task_index); - - // Calls tasks in the [start_index, end_index) range asynchronously using task - // runner to schedule work. Executes a single task in the caller thread. - void CallAsync(uint64_t start_index, uint64_t end_index); - - tsl::AsyncValueRef event() const { return event_.AsRef(); } - - private: - // Converts linear task index in [0, num_tasks) to (x, y, z) coordinate. We - // assume that `x` is the fastest iterating dimension. - SE_HOST_KernelThread Delinearize(uint64_t task_index); - - HostKernel::TaskRunner task_runner_; - size_t num_tasks_; - - SE_HOST_Kernel* kernel_; - SE_HOST_KernelThreadDim thread_dims_; - absl::InlinedVector args_; - - tsl::CountDownAsyncValueRef event_; -}; -} // namespace - -HostKernel::HostKernel(std::shared_ptr thread_pool) - : thread_pool_(thread_pool) { - // Kernel and arity will be set separately -} - -HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, - std::shared_ptr thread_pool) - : function_(std::make_unique(kernel)), - kernel_(function_->kernel()), - arity_(arity), - thread_pool_(thread_pool) {} - -absl::Status HostKernel::Launch( - const ThreadDim& thread_dims, - absl::Span buffers) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers)); -} - -absl::Status HostKernel::Launch( - const ThreadDim& thread_dims, - absl::Span args) const { - SE_HOST_KernelThreadDim kernel_thread_dims = { - thread_dims.x, - thread_dims.y, - thread_dims.z, - }; - - for (uint64_t z = 0; z < thread_dims.z; ++z) { - for (uint64_t y = 0; y < thread_dims.y; ++y) { - for (uint64_t x = 0; x < thread_dims.x; ++x) { - SE_HOST_KernelThread kernel_thread = {x, y, z}; - - SE_HOST_KernelCallFrame call_frame = { - &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; - - SE_HOST_KernelError* error = (*kernel_)(&call_frame); - - if (ABSL_PREDICT_FALSE(error != nullptr)) { - return absl::InternalError("Failed to call host kernel"); - } - } - } - } - - return absl::OkStatus(); -} - -absl::Status HostKernel::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - Stream* stream, const KernelArgs& args) { - if (cluster_dims.has_value()) { - if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) { - return absl::UnimplementedError("Not implemented for Host"); - } - } - const KernelArgsDeviceMemoryArray* device_mem = - DynCast(&args); - - if (device_mem != nullptr) { - return Launch(thread_dims, device_mem->device_memory_args()); - } - return absl::UnimplementedError( - "Host kernel implements Launch method only for DeviceMemoryArray " - "arguments."); -} - -tsl::AsyncValueRef HostKernel::Launch( - const ThreadDim& thread_dims, absl::Span buffers, - TaskRunner task_runner) const { - return Launch(thread_dims, ConvertBuffersToKernelArgs(buffers), - std::move(task_runner)); -} - -tsl::AsyncValueRef HostKernel::Launch( - const ThreadDim& thread_dims, absl::Span args, - TaskRunner task_runner) const { - size_t num_tasks = thread_dims.x * thread_dims.y * thread_dims.z; - CHECK_GT(num_tasks, 0) << "Number of tasks must be positive"; // Crash Ok - - // Short-circuit launch with a single task and run it in the caller thread. - if (ABSL_PREDICT_TRUE(num_tasks == 1)) { - absl::Status launched = Launch(thread_dims, args); - return ABSL_PREDICT_TRUE(launched.ok()) - ? OkLaunchEvent() - : tsl::MakeErrorAsyncValueRef(std::move(launched)); - } - - // Create host kernel execute state on heap and kick-off execution. - auto state = std::make_unique( - std::move(task_runner), kernel_, thread_dims, args); - state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); - - // Move execute state to the execute event callback to ensure that it is kept - // alive while host kernel has pending tasks. - auto execute_event = state->event(); - execute_event.AndThen([state = std::move(state)] {}); - - return execute_event; -} - -HostKernelExecuteState::HostKernelExecuteState( - HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel, - ThreadDim thread_dims, absl::Span args) - : task_runner_(std::move(task_runner)), - num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), - kernel_(kernel), - thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), - args_(args.begin(), args.end()), - event_(num_tasks_) {} - -HostKernelExecuteState::~HostKernelExecuteState() { - auto cnt = event_.count(); - DCHECK_EQ(cnt, 0) << "Host kernel execute state is destroyed before all " - "tasks are completed"; -} - -void HostKernelExecuteState::CallSync(uint64_t task_index) { - CHECK_LT(task_index, num_tasks_) << "Task index out of range"; // Crash OK - - // Do not execute the task if the kernel execution has already failed. - if (ABSL_PREDICT_FALSE(event_.is_error())) { - event_.CountDown(absl::OkStatus()); - return; - } - - SE_HOST_KernelThread kernel_thread = Delinearize(task_index); - SE_HOST_KernelCallFrame call_frame = {&thread_dims_, &kernel_thread, - args_.size(), args_.data()}; - - SE_HOST_KernelError* error = (*kernel_)(&call_frame); - - if (ABSL_PREDICT_TRUE(error == nullptr)) { - event_.CountDown(absl::OkStatus()); - } else { - event_.CountDown(absl::InternalError( - absl::StrFormat("Failed to call host kernel: x=%d, y=%d, z=%d", - kernel_thread.x, kernel_thread.y, kernel_thread.z))); - } -} - -void HostKernelExecuteState::CallAsync(uint64_t start_index, - uint64_t end_index) { - CHECK_LT(start_index, end_index) << "Invalid task index range"; // Crash OK - while (end_index - start_index > 1) { - uint64_t mid_index = (start_index + end_index) / 2; - task_runner_([self = this, mid_index, end_index] { - self->CallAsync(mid_index, end_index); - }); - end_index = mid_index; - } - CallSync(start_index); -} - -SE_HOST_KernelThread HostKernelExecuteState::Delinearize(uint64_t task_index) { - uint64_t stride_z = thread_dims_.y * thread_dims_.x; - uint64_t stride_y = thread_dims_.x; - - uint64_t z = task_index / stride_z; - task_index = task_index % stride_z; - - uint64_t y = task_index / stride_y; - task_index = task_index % stride_y; - - uint64_t x = task_index; - - return SE_HOST_KernelThread{x, y, z}; -} - -} // namespace stream_executor::host diff --git a/xla/stream_executor/host/host_kernel.h b/xla/stream_executor/host/host_kernel.h deleted file mode 100644 index fe62b9071934d..0000000000000 --- a/xla/stream_executor/host/host_kernel.h +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ -#define XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/base/optimization.h" -#include "absl/functional/any_invocable.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "xla/tsl/concurrency/chain.h" -#include "tsl/platform/threadpool.h" - -namespace stream_executor::host { - -class HostExecutor; - -class HostKernel : public Kernel { - public: - using Task = std::function; - using TaskRunner = absl::AnyInvocable; - - // A struct to report completion of the kernel execution. - using LaunchEvent = tsl::Chain; - - // Virtual base class that owns the function behind the host kernel. It can be - // a function in a jit-compiled LLVM module or simply a pointer to the - // in-process function written in C++. HostKernel is responsible for launching - // the kernel function owned by the KernelFunction with given user-provided - // arguments potentially on a thread pool. - class KernelFunction { - public: - virtual ~KernelFunction() = default; - virtual SE_HOST_Kernel* kernel() const = 0; - }; - - // A wrapper around function pointer that implements SE_HOST_Kernel API. - class KernelFunctionPtr final : public KernelFunction { - public: - explicit KernelFunctionPtr(SE_HOST_Kernel* ptr) : ptr_(ptr) {} - SE_HOST_Kernel* kernel() const override { return ptr_; } - - private: - SE_HOST_Kernel* ptr_; // not owned - }; - - // TODO(ezhulenev): Remove this constructor as we prefer to rely on task - // runner as it gives us more flexibility. - explicit HostKernel(std::shared_ptr thread_pool); - - // TODO(tsilytskyi): make this implementation detail private - HostKernel(unsigned arity, SE_HOST_Kernel* kernel, - std::shared_ptr thread_pool = nullptr); - - // Calls the kernel once in the caller thread for a thread dim (0,0,0). - // This is a fast path for small host kernels that have just one thread. - absl::Status CallOnce(absl::Span args) const; - - // Launches the kernel on the current thread by iterating over all threads in - // `thread_dims` and calling the kernel function. - absl::Status Launch(const ThreadDim& thread_dims, - absl::Span buffers) const; - absl::Status Launch(const ThreadDim& thread_dims, - absl::Span args) const; - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const std::optional& cluster_dims, - Stream* stream, const KernelArgs& args) override; - - // Launches the kernel by iterating over all threads in `thread_dims` and - // calling `task_runner` to run individual task (implementation might decide - // to run some of the tasks in the caller thread to save on scheduling - // overheads). It's up to the caller to define where task runner will execute - // the task, i.e., a common case is to launch them on a thread pool. - // - // The returned async value becomes available after all tasks are completed. - // Async value returned in constructed state and the caller can access it to - // get the number of tasks that are expected to be completed. - tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span buffers, - TaskRunner task_runner) const; - tsl::AsyncValueRef Launch( - const ThreadDim& thread_dims, absl::Span args, - TaskRunner task_runner) const; - - // For host platform, we assume that a core is a thread, and we can run at - // most one instance of a kernel on a given thread. - absl::StatusOr GetMaxOccupiedBlocksPerCore(ThreadDim, - size_t) const override { - return 1; - }; - - void SetArity(unsigned arity) { arity_ = arity; }; - unsigned Arity() const override { return arity_; }; - - template >* = nullptr> - void SetKernelFunction(std::unique_ptr function) { - function_ = std::move(function); - kernel_ = function_->kernel(); - } - - private: - std::unique_ptr function_; - SE_HOST_Kernel* kernel_; // pointer to the kernel owned by `function_` - - unsigned arity_; - std::shared_ptr thread_pool_; -}; - -inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::Status HostKernel::CallOnce( - absl::Span args) const { - constexpr SE_HOST_KernelThreadDim kernel_thread_dims = {1, 1, 1}; - constexpr SE_HOST_KernelThread kernel_thread = {1, 1, 1}; - - SE_HOST_KernelCallFrame call_frame = {&kernel_thread_dims, &kernel_thread, - args.size(), args.data()}; - - SE_HOST_KernelError* error = (*kernel_)(&call_frame); - - if (ABSL_PREDICT_FALSE(error != nullptr)) { - return absl::InternalError("Failed to call host kernel"); - } - - return absl::OkStatus(); -} - -inline const HostKernel* AsHostKernel(const Kernel* kernel) { - return static_cast(kernel); -} - -inline HostKernel* AsHostKernel(Kernel* kernel) { - return static_cast(kernel); -} - -} // namespace stream_executor::host - -#endif // XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_H_ diff --git a/xla/stream_executor/host/host_kernel_c_api.h b/xla/stream_executor/host/host_kernel_c_api.h deleted file mode 100644 index c51a6a9434800..0000000000000 --- a/xla/stream_executor/host/host_kernel_c_api.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ -#define XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ - -#include -#include - -//===----------------------------------------------------------------------===// -// StreamExecutor Host Kernel API -//===----------------------------------------------------------------------===// - -#ifdef __cplusplus -extern "C" { -#endif - -// StreamExecutor host kernel API is an integration point between a codegen -// backend and a runtime. XLA:CPU backend compiles fusion regions to native -// functions (via LLVM backend) that are compatible with a kernel API (and ABI), -// and the runtime is simply invoking them with user buffers and orchestrates -// multi-threaded execution. - -// WARNING: This API does not provide any backward compatibility guarantees as -// today XLA:CPU backend is statically linked and we do not plan to load -// kernels from dynamic libraries. It's defined as C API because we have to -// match it in the codegen backend (built on top of LLVM) and C structs have -// trivial layout that can be expressed as llvm stuct (*). -// -// (*) https://llvm.org/docs/LangRef.html#structure-types - -// Similar to a Gpu backend an XLA:CPU compiler generates a tiled function from -// an HLO fusion where each tile is responsible for computing a part of the -// output. It's up to compiler to chose the tiling strategy, from StreamExecutor -// perspective it's simply an iteration space where each task is independent and -// can be executed concurrently. -typedef struct SE_HOST_KernelDim3 { - uint64_t x; - uint64_t y; - uint64_t z; -} SE_HOST_KernelDim3; - -// Kernel grid size roughly corresponds to a CUDA block size. -typedef struct SE_HOST_KernelDim3 SE_HOST_KernelThreadDim; - -// Kernel grid coordinate roughly corresponds to a CUDA block, with an -// assumption that all kernel invocations can run concurrently. -typedef struct SE_HOST_KernelDim3 SE_HOST_KernelThread; - -// A CPU kernel argument that corresponds to se::DeviceMemoryBase. -typedef struct SE_HOST_KernelArg { - void* data; - size_t size; -} SE_HOST_KernelArg; - -// A CPU kernel call frame. -typedef struct SE_HOST_KernelCallFrame { - const SE_HOST_KernelThreadDim* thread_dims; - const SE_HOST_KernelThread* thread; - - size_t num_args; - const SE_HOST_KernelArg* args; -} SE_HOST_KernelCallFrame; - -// Error reporting for host kernels. NULL means success. -typedef struct SE_HOST_KernelError SE_HOST_KernelError; - -// Host kernel API. -typedef SE_HOST_KernelError* SE_HOST_Kernel( - const SE_HOST_KernelCallFrame* call_frame); - -#ifdef __cplusplus -} -#endif - -#endif // XLA_STREAM_EXECUTOR_HOST_HOST_KERNEL_C_API_H_ diff --git a/xla/stream_executor/host/host_kernel_test.cc b/xla/stream_executor/host/host_kernel_test.cc deleted file mode 100644 index aabcbc185aeb1..0000000000000 --- a/xla/stream_executor/host/host_kernel_test.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/host/host_kernel.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "absl/types/span.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform_manager.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/cpu_info.h" -#include "tsl/platform/env.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/platform/threadpool.h" - -namespace stream_executor::host { - -static SE_HOST_KernelError* AddI32(const SE_HOST_KernelCallFrame* call_frame) { - const SE_HOST_KernelArg& lhs = call_frame->args[0]; - const SE_HOST_KernelArg& rhs = call_frame->args[1]; - const SE_HOST_KernelArg& out = call_frame->args[2]; - - int32_t* lhs_ptr = reinterpret_cast(lhs.data); - int32_t* rhs_ptr = reinterpret_cast(rhs.data); - int32_t* out_ptr = reinterpret_cast(out.data); - - const auto zstep = call_frame->thread_dims->x * call_frame->thread_dims->y; - const auto ystep = call_frame->thread_dims->x; - - uint64_t i = call_frame->thread->x + call_frame->thread->y * ystep + - call_frame->thread->z * zstep; - *(out_ptr + i) = *(lhs_ptr + i) + *(rhs_ptr + i); - - return nullptr; -} - -static const char* llvm_kernel_add = R"( -%SE_HOST_KernelCallFrame = type { ptr, ptr, i64, ptr } -%struct.SE_HOST_KernelArg = type { ptr, i64 } - -define ptr @LlvmAddI32(ptr noundef %0) { - %2 = getelementptr inbounds %SE_HOST_KernelCallFrame, ptr %0, i32 0, i32 3 - %3 = load ptr, ptr %2, align 8 - %4 = getelementptr inbounds %struct.SE_HOST_KernelArg, ptr %3, i64 1 - %5 = getelementptr inbounds %struct.SE_HOST_KernelArg, ptr %3, i64 2 - %6 = load ptr, ptr %3, align 8 - %7 = load ptr, ptr %4, align 8 - %8 = load ptr, ptr %5, align 8 - %9 = getelementptr inbounds %SE_HOST_KernelCallFrame, ptr %0, i32 0, i32 1 - %10 = load ptr, ptr %9, align 8 - %11 = load i64, ptr %10, align 8 - %12 = getelementptr inbounds i32, ptr %6, i64 %11 - %13 = load i32, ptr %12, align 4 - %14 = getelementptr inbounds i32, ptr %7, i64 %11 - %15 = load i32, ptr %14, align 4 - %16 = add nsw i32 %13, %15 - %17 = getelementptr inbounds i32, ptr %8, i64 %11 - store i32 %16, ptr %17, align 4 - ret ptr null -} -)"; - -static absl::StatusOr NewStreamExecutor() { - TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host")); - TF_ASSIGN_OR_RETURN(auto stream_exec, - platform->ExecutorForDevice(/*ordinal=*/0)); - return stream_exec; -} - -TEST(HostKernelTest, InternalAddition1D) { - auto tp = std::make_shared(tsl::Env::Default(), - "XLAEigen", 2); - - HostKernel kernel(/*arity=*/3, AddI32, tp); - - std::vector lhs = {1, 2, 3, 4}; - std::vector rhs = {5, 6, 7, 8}; - std::vector out = {0, 0, 0, 0}; - - DeviceMemoryBase lhs_mem(lhs.data(), lhs.size() * sizeof(int32_t)); - DeviceMemoryBase rhs_mem(rhs.data(), rhs.size() * sizeof(int32_t)); - DeviceMemoryBase out_mem(out.data(), out.size() * sizeof(int32_t)); - std::vector args = {lhs_mem, rhs_mem, out_mem}; - - TF_ASSERT_OK(kernel.Launch(ThreadDim(4), args)); - - std::vector expected = {6, 8, 10, 12}; - EXPECT_EQ(out, expected); -} - -TEST(HostKernelTest, InternalAddition3D) { - auto tp = std::make_shared(tsl::Env::Default(), - "XLAEigen", 2); - - HostKernel kernel(/*arity=*/3, AddI32, tp); - - // Lets pretend there is a 3-dimensional 2x2x3 data - std::vector lhs = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - std::vector rhs = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}; - std::vector out = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - - DeviceMemoryBase lhs_mem(lhs.data(), lhs.size() * sizeof(int32_t)); - DeviceMemoryBase rhs_mem(rhs.data(), rhs.size() * sizeof(int32_t)); - DeviceMemoryBase out_mem(out.data(), out.size() * sizeof(int32_t)); - std::vector args = {lhs_mem, rhs_mem, out_mem}; - - TF_ASSERT_OK(kernel.Launch(ThreadDim(2, 2, 3), args)); - - std::vector expected = {11, 13, 15, 17, 19, 21, - 23, 25, 27, 29, 31, 33}; - EXPECT_EQ(out, expected); -} - -TEST(HostKernelTest, Addition3D) { - // Lets pretend there is a 3-dimensional 2x2x3 data - std::vector lhs = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - std::vector rhs = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}; - std::vector out = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - - DeviceMemoryBase lhs_mem(lhs.data(), lhs.size() * sizeof(int32_t)); - DeviceMemoryBase rhs_mem(rhs.data(), rhs.size() * sizeof(int32_t)); - DeviceMemoryBase out_mem(out.data(), out.size() * sizeof(int32_t)); - std::vector args = {lhs_mem, rhs_mem, out_mem}; - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(reinterpret_cast(AddI32), "Addition_kernel"); - - TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); - - const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK( - add->Launch(ThreadDim(2, 2, 3), BlockDim(1), stream.get(), kargs)); - - std::vector expected = {11, 13, 15, 17, 19, 21, - 23, 25, 27, 29, 31, 33}; - EXPECT_EQ(out, expected); -} - -TEST(HostKernelTest, JitAddition) { - std::vector lhs = {1, 2, 3, 4}; - std::vector rhs = {5, 6, 7, 8}; - std::vector out = {0, 0, 0, 0}; - - DeviceMemoryBase lhs_mem(lhs.data(), lhs.size() * sizeof(int32_t)); - DeviceMemoryBase rhs_mem(rhs.data(), rhs.size() * sizeof(int32_t)); - DeviceMemoryBase out_mem(out.data(), out.size() * sizeof(int32_t)); - std::vector args = {lhs_mem, rhs_mem, out_mem}; - - MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddLlvmHostKernel(llvm_kernel_add, "LlvmAddI32", "LlvmAddI32", - absl::Span()); - - TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); - - const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; - TF_ASSERT_OK(add->Launch(ThreadDim(4), BlockDim(1), stream.get(), kargs)); - - std::vector expected = {6, 8, 10, 12}; - EXPECT_EQ(out, expected); -} - -TEST(HostKernelTest, LaunchAsync) { - auto* no_op = +[](const SE_HOST_KernelCallFrame*) { - return static_cast(nullptr); - }; - - auto thread_pool = std::make_shared( - tsl::Env::Default(), "benchmark", tsl::port::MaxParallelism()); - - std::atomic num_tasks = 0; - - HostKernel::TaskRunner runner = [&](HostKernel::Task task) { - num_tasks.fetch_add(1, std::memory_order_relaxed); - thread_pool->Schedule(std::move(task)); - }; - - HostKernel host_kernel(/*arity=*/0, no_op); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), - absl::Span(), - std::move(runner)); - - tsl::BlockUntilReady(event); - EXPECT_TRUE(event.IsConcrete()); - EXPECT_EQ(num_tasks.load(std::memory_order_relaxed), 4 * 4 * 4 - 1); -} - -TEST(HostKernelTest, LaunchAsyncError) { - // SE_HOST_KernelError type is not defined so we simply return a non-nullptr - // pointer to signal error to the runtime. - auto* maybe_error = +[](const SE_HOST_KernelCallFrame* call_frame) { - if (call_frame->thread->x == 2 && call_frame->thread->z == 2) { - return reinterpret_cast(0xDEADBEEF); - } - return static_cast(nullptr); - }; - - auto thread_pool = std::make_shared( - tsl::Env::Default(), "benchmark", tsl::port::MaxParallelism()); - - std::atomic num_tasks = 0; - - HostKernel::TaskRunner runner = [&](HostKernel::Task task) { - num_tasks.fetch_add(1, std::memory_order_relaxed); - thread_pool->Schedule(std::move(task)); - }; - - HostKernel host_kernel(/*arity=*/0, maybe_error); - auto event = host_kernel.Launch(ThreadDim(4, 4, 4), - absl::Span(), - std::move(runner)); - - tsl::BlockUntilReady(event); - ASSERT_TRUE(event.IsError()); - EXPECT_TRUE(absl::StrContains(event.GetError().message(), - "Failed to call host kernel:")); - EXPECT_EQ(num_tasks.load(std::memory_order_relaxed), 4 * 4 * 4 - 1); -} - -//===----------------------------------------------------------------------===// -// Performance benchmarks below -//===----------------------------------------------------------------------===// - -// We benchmark HostKernel launch overheads so we use a noop kernel as we are -// only interested on how fast we can launch kernel tasks. -static SE_HOST_KernelError* NoOp(const SE_HOST_KernelCallFrame*) { - return nullptr; -} - -static void BM_HostKernelSyncLaunch(benchmark::State& state) { - int32_t tdim_x = state.range(0); - - HostKernel kernel(/*arity=*/0, NoOp); - absl::Span args; - - for (auto _ : state) { - benchmark::DoNotOptimize(kernel.Launch(ThreadDim(tdim_x), args)); - } -} - -static void BM_HostKernelAsyncLaunch(benchmark::State& state) { - int32_t tdim_x = state.range(0); - - auto thread_pool = std::make_shared( - tsl::Env::Default(), "benchmark", tsl::port::MaxParallelism()); - - auto task_runner = [&thread_pool](HostKernel::Task task) { - thread_pool->Schedule(std::move(task)); - }; - - HostKernel kernel(/*arity=*/0, NoOp); - absl::Span args; - - for (auto _ : state) { - auto event = kernel.Launch(ThreadDim(tdim_x), args, task_runner); - tsl::BlockUntilReady(event); - } -} - -BENCHMARK(BM_HostKernelSyncLaunch) - ->MeasureProcessCPUTime() - ->Arg(1) - ->Arg(4) - ->Arg(8) - ->Arg(16) - ->Arg(32) - ->Arg(64); - -BENCHMARK(BM_HostKernelAsyncLaunch) - ->MeasureProcessCPUTime() - ->Arg(1) - ->Arg(4) - ->Arg(8) - ->Arg(16) - ->Arg(32) - ->Arg(64); - -} // namespace stream_executor::host diff --git a/xla/stream_executor/host/host_stream.cc b/xla/stream_executor/host/host_stream.cc index ee812daad8d97..1ee375dbe5946 100644 --- a/xla/stream_executor/host/host_stream.cc +++ b/xla/stream_executor/host/host_stream.cc @@ -22,7 +22,6 @@ limitations under the License. #include // NOLINT #include #include -#include #include #include @@ -34,13 +33,9 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" #include "tsl/platform/denormal.h" -#include "tsl/platform/env.h" #include "tsl/platform/setround.h" namespace stream_executor { diff --git a/xla/stream_executor/host/jit_host_kernel_function.cc b/xla/stream_executor/host/jit_host_kernel_function.cc deleted file mode 100644 index 3291af042faad..0000000000000 --- a/xla/stream_executor/host/jit_host_kernel_function.cc +++ /dev/null @@ -1,456 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/host/jit_host_kernel_function.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/CGSCCPassManager.h" -#include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/ExecutionEngine/JITEventListener.h" -#include "llvm/ExecutionEngine/ObjectCache.h" -#include "llvm/ExecutionEngine/Orc/CompileUtils.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/CoreContainers.h" -#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/LLJIT.h" -#include "llvm/ExecutionEngine/Orc/Mangling.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" -#include "llvm/ExecutionEngine/Orc/TaskDispatch.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Passes/OptimizationLevel.h" -#include "llvm/Passes/PassBuilder.h" -#include "llvm/Support/CodeGen.h" -#include "llvm/Support/Error.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Target/TargetMachine.h" -#include "llvm/TargetParser/Triple.h" -#include "llvm/Transforms/Utils/Cloning.h" -#include "xla/stream_executor/host/host_executor.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/platform/initialize.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor::host { - -using llvm::Expected; -using llvm::MemoryBuffer; -using llvm::SectionMemoryManager; -using llvm::Triple; - -using llvm::orc::ExecutionSession; -using llvm::orc::ExecutorAddr; -using llvm::orc::InPlaceTaskDispatcher; -using llvm::orc::IRCompileLayer; -using llvm::orc::JITTargetMachineBuilder; -using llvm::orc::RTDyldObjectLinkingLayer; -using llvm::orc::SelfExecutorProcessControl; -using llvm::orc::SimpleCompiler; -using llvm::orc::SymbolMap; -using llvm::orc::ThreadSafeModule; - -namespace { - -// This compiler keeps weak pointers to the TargetMachine and the ObjectCache. -// -// This allows releasing the memory of those objects, even though the LLJIT -// keeps the compiler alive. -// -// We wrote this class based on the code of llvm::orc::ConcurrentIRCompiler. -class WeakCompiler : public IRCompileLayer::IRCompiler { - public: - static llvm::orc::IRSymbolMapper::ManglingOptions - IrManglingOptionsForWeakTargetMachine( - std::weak_ptr weak_target_machine) { - std::shared_ptr target_machine = - weak_target_machine.lock(); - CHECK(target_machine != nullptr) - << "Compiler should not be used after the TargetMachine is destroyed."; - - return llvm::orc::irManglingOptionsFromTargetOptions( - target_machine->Options); - } - - // It's not recommended to allocate the parameters with std::make_shared, - // because that would allocate the object and the control block in one - // allocation, so the weak_ptr would keep alive the memory of the object as - // well. - explicit WeakCompiler(std::weak_ptr weak_target_machine) - : IRCompiler(IrManglingOptionsForWeakTargetMachine(weak_target_machine)), - weak_target_machine_(std::move(weak_target_machine)) {} - - Expected> operator()( - llvm::Module &module) override { - std::shared_ptr target_machine = - weak_target_machine_.lock(); - CHECK(target_machine != nullptr) - << "Compiler should not be used after the TargetMachine is destroyed."; - - SimpleCompiler compiler(*target_machine); - return compiler(module); - } - - private: - std::weak_ptr weak_target_machine_; -}; - -} // namespace - -namespace internal { -// A minimal LLVM ORC JIT compilation stack to jit-compile LLVM modules to -// executable functions. -class ExecutionEngine { - public: - using ExportedFunctionPtr = const SE_HOST_Kernel *; - - // Callback to run optimization passes on the compiled LLVM module. - using OptimizingTransformer = std::function; - - // Callback to construct an optimizing transformer for the given options. - using MakeOptimizingTransformer = - std::function; - - // Options for creating execution engine from an LLVM module. - struct Options { - // User-provided codegen optimization level. - llvm::CodeGenOptLevel opt_level = llvm::CodeGenOptLevel::Default; - - // User-provided target machine specification. - std::shared_ptr target_machine = nullptr; - - // User-provided builder for the optimizing transformer. - MakeOptimizingTransformer make_optimizing_transformer; - - // User-provided memory mapper for allocating memory for executables. - llvm::SectionMemoryManager::MemoryMapper *section_memory_mapper = nullptr; - - // Notify the llvm's global GDB notifications listener. - bool enable_gdb_listener = false; - - // Notify the llvm's global Perf notifications listener. - bool enable_perf_listener = false; - }; - - // Creates a new execution engine by compiling the provided LLVM module to - // a native executable using LLVM ORC stack. - static absl::StatusOr> CreateFromModule( - std::unique_ptr ctx, - std::unique_ptr module, Options options, - absl::Span exported); - - // Returns a pointer to the exported function. - absl::Span exported() const { return exported_; } - - ExportedFunctionPtr exported(unsigned ordinal) const { - return exported_[ordinal]; - } - - // Return a memory buffer with a object file behind this execution engine. Can - // be null if execution engine didn't save the compiled object file. - std::unique_ptr obj_file() const; - - private: - ExecutionEngine(bool enable_gdb_listener, bool enable_perf_listener); - - // We build execution engine on top of the ORC LLJIT API, which owns all - // compiled/loaded object files and does the linking at run time. - // - // TODO(ezhulenev): Instead of keeping LLJIT alive we should be able to keep - // only llvm::orc::JITDylibSP owning main dylib and the object layer owning - // memory-mapped regions holding object files. Once we are done with - // executable compilation this jit is defunct because it holds an expired - // weak_ptr to an llvm::orc::TargetMachine instance. - std::unique_ptr jit_; - - // Pointers to resolved exported functions. Indexed by function ordinal. - std::vector exported_; - - // Object file behind the compiled executable. Can be null. - std::unique_ptr obj_file_; - - llvm::JITEventListener *gdb_listener_ = nullptr; - llvm::JITEventListener *perf_listener_ = nullptr; -}; - -ExecutionEngine::ExecutionEngine(bool enable_gdb_listener, - bool enable_perf_listener) { - if (enable_gdb_listener) - gdb_listener_ = llvm::JITEventListener::createGDBRegistrationListener(); - if (enable_perf_listener) - perf_listener_ = llvm::JITEventListener::createPerfJITEventListener(); -} - -std::unique_ptr ExecutionEngine::obj_file() const { - return obj_file_ ? MemoryBuffer::getMemBuffer(obj_file_->getMemBufferRef()) - : nullptr; -} - -static std::string ToString(const llvm::Error &err) { - std::string str; - llvm::raw_string_ostream(str) << err; - return str; -} - -absl::StatusOr> -ExecutionEngine::CreateFromModule( - std::unique_ptr ctx, - std::unique_ptr module, Options options, - absl::Span exported) { - auto engine = std::unique_ptr(new ExecutionEngine( - options.enable_gdb_listener, options.enable_perf_listener)); - - // We'll need module pointer later to lookup object file in the cache. - llvm::Module *module_ptr = module.get(); - - // Set up the target machine details. - if (!options.target_machine) { - return absl::InternalError("target machine was not provided"); - } - module->setDataLayout(options.target_machine->createDataLayout()); - module->setTargetTriple(options.target_machine->getTargetTriple().str()); - - // Run an optimization pipeline over the LLVM module (alway run with default - // opt level independent of the options). - // - // TODO(ezhulenev): We should have out own optimizing transformer pipelines - // for different Xla backends, e.g. there is absolutely no need to run - // SLV vectorizer for Xla Gpi host side executable. - auto transformer = - options.make_optimizing_transformer(options.target_machine.get()); - if (auto err = transformer(module_ptr)) - return absl::InternalError(absl::StrFormat( - "failed to run optimization pipeline: %s", ToString(err))); - - // Callback to create the object layer with a user-provided section memory - // mapper and JIT event listeners. - auto obj_layer_creator = [&](ExecutionSession &session, const Triple &tt) { - auto obj_layer = std::make_unique( - session, [section_memory_mapper = options.section_memory_mapper]() { - return std::make_unique(section_memory_mapper); - }); - - // Register JIT event listeners if they are enabled. - if (engine->gdb_listener_) - obj_layer->registerJITEventListener(*engine->gdb_listener_); - if (engine->perf_listener_) - obj_layer->registerJITEventListener(*engine->perf_listener_); - - return obj_layer; - }; - - // Callback to compile IR module on demand. - auto compile_function_creator = - [weak_target_machine = std::weak_ptr( - options.target_machine)](JITTargetMachineBuilder) - -> Expected> { - return std::make_unique(weak_target_machine); - }; - - // Use in-process executor process control with in-place task dispatcher. - auto executorProcessControl = SelfExecutorProcessControl::Create( - nullptr, std::make_unique()); - - if (auto err = executorProcessControl.takeError()) { - return absl::InternalError(absl::StrFormat( - "failed to create executor process control: %s", ToString(err))); - } - - static auto *lljit_mu = new absl::Mutex(); - std::optional lljit_lock(lljit_mu); - - // Construct the LLJIT with the given compiler and object linking layers. - auto jit = llvm::orc::LLJITBuilder() - .setCompileFunctionCreator(std::move(compile_function_creator)) - .setObjectLinkingLayerCreator(obj_layer_creator) - .setExecutorProcessControl(std::move(*executorProcessControl)) - .setNumCompileThreads(0) // disable multi-threading - .create(); - - if (auto err = jit.takeError()) { - return absl::InternalError( - absl::StrFormat("failed to construct LLJIT: %s", ToString(err))); - } - - lljit_lock.reset(); - - // Register input module with the LLJIT. - ThreadSafeModule tsm(std::move(module), std::move(ctx)); - if (auto err = (*jit)->addIRModule(std::move(tsm))) { - return absl::InternalError( - absl::StrFormat("failed to add source module: %s", ToString(err))); - } - - llvm::DataLayout data_layout = (*jit)->getDataLayout(); - - // Resolve all exported functions to function pointers. - for (absl::string_view name : exported) { - // Trigger compilation by looking up the exported function. - // TODO(tsilytskyi): - // - Do we need to mangle function name? - // - Do we need to verify/adapt function proto to expected API? - Expected addr = (*jit)->lookup(name); - if (auto err = addr.takeError()) { - return absl::InternalError(absl::StrFormat( - "failed to compile exported function %s: %s", name, ToString(err))); - } - - // Check that we found an address of an exported function. - auto ptr = addr->toPtr(); - if (!ptr) { - return absl::InternalError( - absl::StrFormat("exported function %s resolved to null", name)); - } - - engine->exported_.push_back(ptr); - } - - // Fill remaining fields and return constructed ExecutionEngine to the caller. - engine->jit_ = std::move(*jit); - return std::move(engine); -} - -} // namespace internal - -JitHostKernelFunction::JitHostKernelFunction( - std::unique_ptr exec_engine) - : engine_(std::move(exec_engine)) { - kernel_ = reinterpret_cast(engine_->exported(0)); -}; - -static std::function -MakeOptimizingTransformerForJit(llvm::TargetMachine *targetMachine) { - return [targetMachine](llvm::Module *m) -> llvm::Error { - llvm::LoopAnalysisManager lam; - llvm::FunctionAnalysisManager fam; - llvm::CGSCCAnalysisManager cgam; - llvm::ModuleAnalysisManager mam; - - llvm::PipelineTuningOptions tuningOptions; - // LLVM's loop unrolling isn't well tuned for the loops we emit. Turn it off - // as it consumes compile time with little benefit. - tuningOptions.LoopUnrolling = false; - // Vectorization happens at the MLIR level. - tuningOptions.LoopVectorization = false; - llvm::PassBuilder pb(targetMachine, tuningOptions); - - pb.registerModuleAnalyses(mam); - pb.registerCGSCCAnalyses(cgam); - pb.registerFunctionAnalyses(fam); - pb.registerLoopAnalyses(lam); - pb.crossRegisterProxies(lam, fam, cgam, mam); - - llvm::ModulePassManager mpm; - mpm.addPass(pb.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O2)); - mpm.run(*m, mam); - return llvm::Error::success(); - }; -} - -absl::StatusOr> -JitHostKernelFunction::CreateFromLlvmIr(absl::string_view name, - absl::string_view entry, - absl::string_view ir, - absl::Span options) { - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - auto llvm_ctx = std::make_unique(); - llvm::SMDiagnostic diagnostic; - llvm::MemoryBufferRef ir_buffer(ir, name); - std::unique_ptr llvm_module = - llvm::parseAssembly(ir_buffer, diagnostic, *llvm_ctx, nullptr); - - // Prepare JIT target machine for code generation. - auto builder = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!builder) return absl::InternalError(toString(builder.takeError())); - - llvm::Expected> target_machine = - builder->createTargetMachine(); - if (!target_machine) - return absl::InternalError(toString(target_machine.takeError())); - - // Set target triple - llvm_module->setTargetTriple( - llvm::StringRef(target_machine.get()->getTargetTriple().getTriple())); - - // Construct options for the XLA runtime execution engine. - internal::ExecutionEngine::Options engine_options; - engine_options.target_machine = std::move(target_machine.get()); - engine_options.make_optimizing_transformer = MakeOptimizingTransformerForJit; - - std::vector exported = {entry}; - - // Compile input module to the native function. - TF_ASSIGN_OR_RETURN(auto engine, - internal::ExecutionEngine::CreateFromModule( - std::move(llvm_ctx), std::move(llvm_module), - std::move(engine_options), exported)); - - return std::unique_ptr( - new JitHostKernelFunction(std::move(engine))); -} - -static void RegisterJitKernelFunctionLoader() { - using CompiledFunction = std::optional< - absl::StatusOr>>; - - HostExecutor::RegisterKernelFunctionLoader( - [](const MultiKernelLoaderSpec &spec) -> CompiledFunction { - if (!spec.has_llvm_host_kernel()) return std::nullopt; - - const LlvmHostKernel &llvm_host_kernel = spec.llvm_host_kernel(); - absl::string_view name = llvm_host_kernel.kernel_name(); - absl::string_view entry = llvm_host_kernel.entrypoint(); - absl::string_view ir = llvm_host_kernel.ir(); - absl::Span options = llvm_host_kernel.options(); - - return JitHostKernelFunction::CreateFromLlvmIr(name, entry, ir, - options); - }); -} - -} // namespace stream_executor::host - -STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( - jit_kernel_function_loader, - stream_executor::host::RegisterJitKernelFunctionLoader()); diff --git a/xla/stream_executor/host/jit_host_kernel_function.h b/xla/stream_executor/host/jit_host_kernel_function.h deleted file mode 100644 index 73991761db010..0000000000000 --- a/xla/stream_executor/host/jit_host_kernel_function.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_HOST_JIT_HOST_KERNEL_FUNCTION_H_ -#define XLA_STREAM_EXECUTOR_HOST_JIT_HOST_KERNEL_FUNCTION_H_ - -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" - -namespace stream_executor::host { - -namespace internal { -class ExecutionEngine; -} - -// A host kernel function compiled from LLVM IR at run time -class JitHostKernelFunction : public HostKernel::KernelFunction { - public: - SE_HOST_Kernel *kernel() const override { return kernel_; } - - static absl::StatusOr> - CreateFromLlvmIr(absl::string_view name, absl::string_view entry, - absl::string_view ir, absl::Span options); - - private: - explicit JitHostKernelFunction( - std::unique_ptr exec_engine); - - std::unique_ptr engine_; - SE_HOST_Kernel *kernel_; -}; - -} // namespace stream_executor::host - -#endif // XLA_STREAM_EXECUTOR_HOST_JIT_HOST_KERNEL_FUNCTION_H_ diff --git a/xla/stream_executor/host/ptr_host_kernel_function.cc b/xla/stream_executor/host/ptr_host_kernel_function.cc deleted file mode 100644 index a327494f98715..0000000000000 --- a/xla/stream_executor/host/ptr_host_kernel_function.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/host/ptr_host_kernel_function.h" - -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/stream_executor/host/host_executor.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/platform/initialize.h" - -namespace stream_executor::host { - -absl::StatusOr> -PtrHostKernelFunction::CreateFromPtr(SE_HOST_Kernel *kernel, - absl::string_view kernel_name) { - return std::unique_ptr( - new PtrHostKernelFunction(kernel)); -} - -static void RegisterPtrKernelFunctionLoader() { - using CompiledFunction = std::optional< - absl::StatusOr>>; - - HostExecutor::RegisterKernelFunctionLoader( - [](const MultiKernelLoaderSpec &spec) -> CompiledFunction { - if (!spec.has_in_process_symbol()) return std::nullopt; - - return PtrHostKernelFunction::CreateFromPtr( - reinterpret_cast( - spec.in_process_symbol().symbol()), - spec.in_process_symbol().kernel_name()); - }); -} - -} // namespace stream_executor::host - -STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER( - ptr_kernel_function_loader, - stream_executor::host::RegisterPtrKernelFunctionLoader()); diff --git a/xla/stream_executor/host/ptr_host_kernel_function.h b/xla/stream_executor/host/ptr_host_kernel_function.h deleted file mode 100644 index 9a8ce110db1db..0000000000000 --- a/xla/stream_executor/host/ptr_host_kernel_function.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_HOST_PTR_HOST_KERNEL_FUNCTION_H_ -#define XLA_STREAM_EXECUTOR_HOST_PTR_HOST_KERNEL_FUNCTION_H_ - -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" - -namespace stream_executor::host { - -// A host kernel function compiled together with XLA by a regular C++ compiler. -class PtrHostKernelFunction : public HostKernel::KernelFunction { - public: - SE_HOST_Kernel *kernel() const override { return kernel_; } - - static absl::StatusOr> - CreateFromPtr(SE_HOST_Kernel *kernel, absl::string_view kernel_name); - - private: - explicit PtrHostKernelFunction(SE_HOST_Kernel *kernel) : kernel_(kernel) {} - - SE_HOST_Kernel *kernel_; -}; - -} // namespace stream_executor::host - -#endif // XLA_STREAM_EXECUTOR_HOST_PTR_HOST_KERNEL_FUNCTION_H_ diff --git a/xla/xla.proto b/xla/xla.proto index 448cc49c9d9e7..b13f2ca9b5462 100644 --- a/xla/xla.proto +++ b/xla/xla.proto @@ -1052,7 +1052,7 @@ message DebugOptions { AUTOTUNE_CACHE_MODE_READ = 2; } - // Timeouts for RendezvousSingle stuck warning and termination. + // Timeouts for Rendezvous stuck warning and termination. int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; int32 xla_gpu_executable_terminate_timeout_seconds = 328;