From db1c1bea9f183af4ca07da760d203ebb932afe4a Mon Sep 17 00:00:00 2001 From: Kyle Lucke Date: Mon, 22 Apr 2024 08:58:22 -0700 Subject: [PATCH] Remove test use of StreamExecutor::GetAllocator. PiperOrigin-RevId: 627053448 --- xla/service/generic_transfer_manager_test.cc | 10 +++- xla/service/gpu/fusions/cudnn_test.cc | 12 ++-- .../runtime/address_computation_thunk_test.cc | 33 ++++++----- .../gpu/runtime/command_buffer_cmd_test.cc | 29 ++++++---- .../gpu/runtime/command_buffer_thunk_test.cc | 55 ++++++++++++------- xla/service/gpu/tests/BUILD | 1 + xla/service/gpu/tests/gemm_rewrite_test.cc | 19 +++++-- .../gpu/tests/gpu_too_many_blocks_test.cc | 8 ++- 8 files changed, 104 insertions(+), 63 deletions(-) diff --git a/xla/service/generic_transfer_manager_test.cc b/xla/service/generic_transfer_manager_test.cc index 05eda50a0cdcd..d0235816488e6 100644 --- a/xla/service/generic_transfer_manager_test.cc +++ b/xla/service/generic_transfer_manager_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" @@ -62,18 +63,21 @@ class GenericTransferManagerTest : public ::testing::Test { se::PlatformManager::PlatformWithId(se::host::kHostPlatformId)); TF_ASSERT_OK_AND_ASSIGN(stream_executor_, platform->ExecutorForDevice(0)); TF_ASSERT_OK_AND_ASSIGN(stream_, stream_executor_->CreateStream()); + allocator_ = + std::make_unique(stream_executor_); } ScopedShapedBuffer AllocateBuffer(const Shape& shape) { - auto buffer = transfer_manager_.AllocateScopedShapedBuffer( - shape, stream_executor_->GetAllocator(), - /*device_ordinal=*/0); + auto buffer = + transfer_manager_.AllocateScopedShapedBuffer(shape, allocator_.get(), + /*device_ordinal=*/0); return std::move(buffer.value()); } PackingTransferManager transfer_manager_; se::StreamExecutor* stream_executor_; std::unique_ptr stream_; + std::unique_ptr allocator_; }; TEST_F(GenericTransferManagerTest, TransferLiteralToDevice) { diff --git a/xla/service/gpu/fusions/cudnn_test.cc b/xla/service/gpu/fusions/cudnn_test.cc index ca7d2c2b8d59b..05deee92233df 100644 --- a/xla/service/gpu/fusions/cudnn_test.cc +++ b/xla/service/gpu/fusions/cudnn_test.cc @@ -335,13 +335,13 @@ ENTRY e { backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} })"; + se::StreamExecutorMemoryAllocator allocator( + backend().default_stream_executor()); // Verify that a command buffer is applied. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr executable, - backend().compiler()->RunBackend( - GetOptimizedModule(kHloText).value(), - backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator())); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + backend().compiler()->RunBackend( + GetOptimizedModule(kHloText).value(), + backend().default_stream_executor(), &allocator)); absl::StatusOr filecheck_result = RunFileCheck(executable->module().ToString(), R"( ; CHECK: ENTRY diff --git a/xla/service/gpu/runtime/address_computation_thunk_test.cc b/xla/service/gpu/runtime/address_computation_thunk_test.cc index 887f23d8e5575..66028a9a06aea 100644 --- a/xla/service/gpu/runtime/address_computation_thunk_test.cc +++ b/xla/service/gpu/runtime/address_computation_thunk_test.cc @@ -186,9 +186,9 @@ TEST(AddressComputationThunkTest, SlicedGemm) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( - {lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, - executor->GetAllocator()); + {lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -358,9 +358,10 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1, rhs_offset_0, rhs_offset_1}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -527,9 +528,10 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1, rhs_offset_0, rhs_offset_1}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -673,9 +675,9 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( - {src, dst, offset_0, offset_1, offset_2, offset_3}, 0, - executor->GetAllocator()); + {src, dst, offset_0, offset_1, offset_2, offset_3}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -861,10 +863,11 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( {src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3, dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1024,9 +1027,9 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( - {workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0, - executor->GetAllocator()); + {workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1174,10 +1177,11 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( {workspace, /*garbage, to be ignored*/ se::DeviceMemoryBase(), out, rhs, lhs_offset_0, lhs_offset_1, /*garbage, to be ignored*/ rhs, lhs}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1323,9 +1327,10 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( {lhs_whole_buffer, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, - executor->GetAllocator()); + &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1506,10 +1511,11 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations( {src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3, dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1675,8 +1681,9 @@ TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) { // Preparing parameters for thunk execution. ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); BufferAllocations allocations({buffer, workspace, lhs_offset_0, lhs_offset_1}, - 0, executor->GetAllocator()); + 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), diff --git a/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/xla/service/gpu/runtime/command_buffer_cmd_test.cc index ff689d8441b2c..6c0cb7a674f22 100644 --- a/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -204,7 +204,8 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { commands.Emplace(s0, slice_b, slice_a, byte_length); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b}, 0, &allocator); CommandBufferCmd::StateManager state; @@ -272,7 +273,8 @@ TEST(CommandBufferCmdTest, BarrierCmd) { commands.Emplace(s1, slice_e, slice_d, byte_length); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b, c, d, e}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b, c, d, e}, 0, &allocator); CommandBufferCmd::StateManager state; @@ -341,7 +343,8 @@ TEST(CommandBufferCmdTest, LaunchCmd) { TF_ASSERT_OK(commands.Initialize({executor, source}, state)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -400,7 +403,8 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); - BufferAllocations allocations({mem0, mem1}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({mem0, mem1}, 0, &allocator); // No-op trace callback to count how many times it was called. int64_t num_calls = 0; @@ -423,7 +427,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { // Check that when memory address changes we re-trace the command buffer. se::DeviceMemoryBase mem2(reinterpret_cast(0x23456701)); - allocations = BufferAllocations({mem0, mem2}, 0, executor->GetAllocator()); + allocations = BufferAllocations({mem0, mem2}, 0, &allocator); TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2, traced_cmd_buffer.GetOrTraceCommandBuffer( @@ -433,7 +437,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { EXPECT_EQ(num_calls, 2); // Check that we keep first command buffer in cache. - allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + allocations = BufferAllocations({mem0, mem1}, 0, &allocator); TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3, traced_cmd_buffer.GetOrTraceCommandBuffer( @@ -442,7 +446,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { EXPECT_EQ(num_calls, 2); // Check that we trace a new graph when buffer allocation pattern is new. - allocations = BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()); + allocations = BufferAllocations({mem0, mem0}, 0, &allocator); TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4, traced_cmd_buffer.GetOrTraceCommandBuffer( @@ -452,7 +456,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) { EXPECT_EQ(num_calls, 3); // Check that we still keep the previous graph in cache. - allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()); + allocations = BufferAllocations({mem0, mem1}, 0, &allocator); TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5, traced_cmd_buffer.GetOrTraceCommandBuffer( @@ -479,12 +483,13 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { se::DeviceMemoryBase mem0(reinterpret_cast(0x01234567)); se::DeviceMemoryBase mem1(reinterpret_cast(0x12345670)); + se::StreamExecutorMemoryAllocator allocator(executor); std::array allocations = { - BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()), - BufferAllocations({mem1, mem0}, 0, executor->GetAllocator()), - BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()), - BufferAllocations({mem1, mem1}, 0, executor->GetAllocator()), + BufferAllocations({mem0, mem1}, 0, &allocator), + BufferAllocations({mem1, mem0}, 0, &allocator), + BufferAllocations({mem0, mem0}, 0, &allocator), + BufferAllocations({mem1, mem1}, 0, &allocator), }; int32_t index = 0; diff --git a/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 4a78b621b3dd4..a9fb8e55fdb5e 100644 --- a/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -134,8 +134,9 @@ TEST(CommandBufferThunkTest, MemcpyCmd) { // Construct a thunk with command sequence. CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); + se::StreamExecutorMemoryAllocator allocator(executor); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -189,7 +190,8 @@ TEST(CommandBufferThunkTest, MemzeroCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -231,7 +233,8 @@ TEST(CommandBufferThunkTest, Memset32Cmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -270,7 +273,8 @@ TEST(CommandBufferThunkTest, Memset32CmdOnDifferentStreams) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -332,7 +336,8 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdSameThunk) { auto external_allocation = std::make_unique(); - BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b, c}, 0, &allocator, external_allocation.get()); ServiceExecutableRunOptions run_options; @@ -395,7 +400,8 @@ TEST(CommandBufferThunkTest, MemallocFreeCmdAcrossThunk) { auto external_allocation = std::make_unique(); - BufferAllocations allocations({a, b, c}, 0, executor->GetAllocator(), + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b, c}, 0, &allocator, external_allocation.get()); ServiceExecutableRunOptions run_options; @@ -462,7 +468,8 @@ TEST(CommandBufferThunkTest, LaunchCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -487,7 +494,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. - allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); + allocations = BufferAllocations({a, c}, 0, &allocator); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -558,7 +565,8 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -583,7 +591,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // Update buffer allocation #1 to buffer `c`. - allocations = BufferAllocations({a, c}, 0, executor->GetAllocator()); + allocations = BufferAllocations({a, c}, 0, &allocator); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -673,8 +681,8 @@ TEST(CommandBufferThunkTest, GemmCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({lhs, rhs, out, workspace}, 0, - executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({lhs, rhs, out, workspace}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -699,8 +707,8 @@ TEST(CommandBufferThunkTest, GemmCmd) { TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); // Update buffer allocation to updated `out` buffer. - allocations = BufferAllocations({lhs, rhs, updated_out, workspace}, 0, - executor->GetAllocator()); + allocations = + BufferAllocations({lhs, rhs, updated_out, workspace}, 0, &allocator); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -774,7 +782,8 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({a, b, c, d}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({a, b, c, d}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -806,7 +815,7 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { TF_ASSERT_OK(stream->MemZero(&e, byte_length)); // Update buffer allocation #1 to buffer `c`. - allocations = BufferAllocations({a, b, c, e}, 0, executor->GetAllocator()); + allocations = BufferAllocations({a, b, c, e}, 0, &allocator); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -889,7 +898,8 @@ TEST(CommandBufferThunkTest, IfCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({pred, a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -914,7 +924,7 @@ TEST(CommandBufferThunkTest, IfCmd) { TF_ASSERT_OK(stream->MemZero(&c, byte_length)); // Update buffer allocation #2 to buffer `c`. - allocations = BufferAllocations({pred, a, c}, 0, executor->GetAllocator()); + allocations = BufferAllocations({pred, a, c}, 0, &allocator); // Thunk execution should automatically update underlying command buffer. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -988,7 +998,8 @@ TEST(CommandBufferThunkTest, IfElseCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({pred, a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({pred, a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1077,7 +1088,8 @@ TEST(CommandBufferThunkTest, CaseCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({index, a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({index, a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), @@ -1156,7 +1168,8 @@ TEST(CommandBufferThunkTest, ForCmd) { CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo(nullptr)); ServiceExecutableRunOptions run_options; - BufferAllocations allocations({loop_cnt, a, b}, 0, executor->GetAllocator()); + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({loop_cnt, a, b}, 0, &allocator); Thunk::ExecuteParams params = Thunk::ExecuteParams::Create(run_options, allocations, stream.get(), diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 79b6f1e9f14fa..67f4e9479535b 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -183,6 +183,7 @@ xla_test( "//xla/service/gpu:gpu_executable", "//xla/service/gpu:variant_visitor", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", "//xla/tests:filecheck", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", diff --git a/xla/service/gpu/tests/gemm_rewrite_test.cc b/xla/service/gpu/tests/gemm_rewrite_test.cc index 69e956f62deac..8a8e58923f264 100644 --- a/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/test.h" #include "xla/tests/filecheck.h" #include "xla/xla.pb.h" @@ -230,11 +231,12 @@ ENTRY AddDotsFunc { return ParseAndReturnVerifiedModule(hlo_text, config); }; + se::StreamExecutorMemoryAllocator allocator( + backend().default_stream_executor()); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr optimized_module, backend().compiler()->RunHloPasses( - *get_module(), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator())); + *get_module(), backend().default_stream_executor(), &allocator)); absl::StatusOr filecheck_result = RunFileCheck(optimized_module->ToString(), @@ -7632,11 +7634,15 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { int expected_number_of_allocations) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo)); + if (allocator_ == nullptr) { + allocator_ = std::make_unique( + backend().default_stream_executor()); + } TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr executable, - backend().compiler()->RunBackend( - std::move(optimized_module), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator())); + backend().compiler()->RunBackend(std::move(optimized_module), + backend().default_stream_executor(), + allocator_.get())); GpuExecutable* gpu_executable = static_cast(executable.get()); absl::Span allocations = @@ -7650,6 +7656,9 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); return debug_options; } + + private: + std::unique_ptr allocator_; }; TEST_F(GemmRewriteAllocationTest, SharedBufferAssignment) { diff --git a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index 87b2c79723268..9832fb8895a13 100644 --- a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -51,10 +51,12 @@ ENTRY primitive_computation_mul.8 { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, GetOptimizedModule(hlo_text)); + se::StreamExecutorMemoryAllocator allocator( + backend().default_stream_executor()); absl::StatusOr> failed_executable = - backend().compiler()->RunBackend( - std::move(optimized_module), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator()); + backend().compiler()->RunBackend(std::move(optimized_module), + backend().default_stream_executor(), + &allocator); EXPECT_FALSE(failed_executable.ok()); EXPECT_THAT(