Skip to content

Commit

Permalink
Remove test use of StreamExecutor::GetAllocator.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627053448
  • Loading branch information
klucke authored and copybara-github committed Apr 22, 2024
1 parent 2db109a commit db1c1be
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 63 deletions.
10 changes: 7 additions & 3 deletions xla/service/generic_transfer_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/host/host_platform_id.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream_executor.h"
Expand Down Expand Up @@ -62,18 +63,21 @@ class GenericTransferManagerTest : public ::testing::Test {
se::PlatformManager::PlatformWithId(se::host::kHostPlatformId));
TF_ASSERT_OK_AND_ASSIGN(stream_executor_, platform->ExecutorForDevice(0));
TF_ASSERT_OK_AND_ASSIGN(stream_, stream_executor_->CreateStream());
allocator_ =
std::make_unique<se::StreamExecutorMemoryAllocator>(stream_executor_);
}

ScopedShapedBuffer AllocateBuffer(const Shape& shape) {
auto buffer = transfer_manager_.AllocateScopedShapedBuffer(
shape, stream_executor_->GetAllocator(),
/*device_ordinal=*/0);
auto buffer =
transfer_manager_.AllocateScopedShapedBuffer(shape, allocator_.get(),
/*device_ordinal=*/0);
return std::move(buffer.value());
}

PackingTransferManager transfer_manager_;
se::StreamExecutor* stream_executor_;
std::unique_ptr<se::Stream> stream_;
std::unique_ptr<se::DeviceMemoryAllocator> allocator_;
};

TEST_F(GenericTransferManagerTest, TransferLiteralToDevice) {
Expand Down
12 changes: 6 additions & 6 deletions xla/service/gpu/fusions/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,13 @@ ENTRY e {
backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}}
})";

se::StreamExecutorMemoryAllocator allocator(
backend().default_stream_executor());
// Verify that a command buffer is applied.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Executable> executable,
backend().compiler()->RunBackend(
GetOptimizedModule(kHloText).value(),
backend().default_stream_executor(),
backend().default_stream_executor()->GetAllocator()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Executable> executable,
backend().compiler()->RunBackend(
GetOptimizedModule(kHloText).value(),
backend().default_stream_executor(), &allocator));
absl::StatusOr<bool> filecheck_result =
RunFileCheck(executable->module().ToString(), R"(
; CHECK: ENTRY
Expand Down
33 changes: 20 additions & 13 deletions xla/service/gpu/runtime/address_computation_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ TEST(AddressComputationThunkTest, SlicedGemm) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0,
executor->GetAllocator());
{lhs, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -358,9 +358,10 @@ TEST(AddressComputationThunkTest, SlicedNonContiguousGemm) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0,
lhs_offset_1, rhs_offset_0, rhs_offset_1},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -527,9 +528,10 @@ TEST(AddressComputationThunkTest, MulipleSlicedOperandsGemm) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({lhs, rhs, out, workspace, lhs_offset_0,
lhs_offset_1, rhs_offset_0, rhs_offset_1},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -673,9 +675,9 @@ TEST(AddressComputationThunkTest, SlicedMemcpy) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{src, dst, offset_0, offset_1, offset_2, offset_3}, 0,
executor->GetAllocator());
{src, dst, offset_0, offset_1, offset_2, offset_3}, 0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -861,10 +863,11 @@ TEST(AddressComputationThunkTest, SlicedOutputMemcpy) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3,
dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -1024,9 +1027,9 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryArgumentOrder) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0,
executor->GetAllocator());
{workspace, lhs, out, rhs, lhs_offset_0, lhs_offset_1}, 0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -1174,10 +1177,11 @@ TEST(AddressComputationThunkTest, SlicedGemmArbitraryNumberOfArguments) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{workspace, /*garbage, to be ignored*/ se::DeviceMemoryBase(), out, rhs,
lhs_offset_0, lhs_offset_1, /*garbage, to be ignored*/ rhs, lhs},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -1323,9 +1327,10 @@ TEST(AddressComputationThunkTest, SlicedTupledOperandGemm) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{lhs_whole_buffer, rhs, out, workspace, lhs_offset_0, lhs_offset_1}, 0,
executor->GetAllocator());
&allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -1506,10 +1511,11 @@ TEST(AddressComputationThunkTest, SlicedMemcpyOOB) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations(
{src, dst, src_offset_0, src_offset_1, src_offset_2, src_offset_3,
dst_offset_0, dst_offset_1, dst_offset_2, dst_offset_3},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -1675,8 +1681,9 @@ TEST(AddressComputationThunkTest, SlicedOperandsSameBufferGemm) {

// Preparing parameters for thunk execution.
ServiceExecutableRunOptions run_options;
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({buffer, workspace, lhs_offset_0, lhs_offset_1},
0, executor->GetAllocator());
0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down
29 changes: 17 additions & 12 deletions xla/service/gpu/runtime/command_buffer_cmd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ TEST(CommandBufferCmdTest, MemcpyCmd) {
commands.Emplace<MemcpyDeviceToDeviceCmd>(s0, slice_b, slice_a, byte_length);

ServiceExecutableRunOptions run_options;
BufferAllocations allocations({a, b}, 0, executor->GetAllocator());
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({a, b}, 0, &allocator);

CommandBufferCmd::StateManager state;

Expand Down Expand Up @@ -272,7 +273,8 @@ TEST(CommandBufferCmdTest, BarrierCmd) {
commands.Emplace<MemcpyDeviceToDeviceCmd>(s1, slice_e, slice_d, byte_length);

ServiceExecutableRunOptions run_options;
BufferAllocations allocations({a, b, c, d, e}, 0, executor->GetAllocator());
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({a, b, c, d, e}, 0, &allocator);

CommandBufferCmd::StateManager state;

Expand Down Expand Up @@ -341,7 +343,8 @@ TEST(CommandBufferCmdTest, LaunchCmd) {
TF_ASSERT_OK(commands.Initialize({executor, source}, state));

ServiceExecutableRunOptions run_options;
BufferAllocations allocations({a, b}, 0, executor->GetAllocator());
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({a, b}, 0, &allocator);

Thunk::ExecuteParams params =
Thunk::ExecuteParams::Create(run_options, allocations, stream.get(),
Expand Down Expand Up @@ -400,7 +403,8 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
se::DeviceMemoryBase mem0(reinterpret_cast<void*>(0x01234567));
se::DeviceMemoryBase mem1(reinterpret_cast<void*>(0x12345670));

BufferAllocations allocations({mem0, mem1}, 0, executor->GetAllocator());
se::StreamExecutorMemoryAllocator allocator(executor);
BufferAllocations allocations({mem0, mem1}, 0, &allocator);

// No-op trace callback to count how many times it was called.
int64_t num_calls = 0;
Expand All @@ -423,7 +427,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {

// Check that when memory address changes we re-trace the command buffer.
se::DeviceMemoryBase mem2(reinterpret_cast<void*>(0x23456701));
allocations = BufferAllocations({mem0, mem2}, 0, executor->GetAllocator());
allocations = BufferAllocations({mem0, mem2}, 0, &allocator);

TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer2,
traced_cmd_buffer.GetOrTraceCommandBuffer(
Expand All @@ -433,7 +437,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
EXPECT_EQ(num_calls, 2);

// Check that we keep first command buffer in cache.
allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator());
allocations = BufferAllocations({mem0, mem1}, 0, &allocator);

TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer3,
traced_cmd_buffer.GetOrTraceCommandBuffer(
Expand All @@ -442,7 +446,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
EXPECT_EQ(num_calls, 2);

// Check that we trace a new graph when buffer allocation pattern is new.
allocations = BufferAllocations({mem0, mem0}, 0, executor->GetAllocator());
allocations = BufferAllocations({mem0, mem0}, 0, &allocator);

TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer4,
traced_cmd_buffer.GetOrTraceCommandBuffer(
Expand All @@ -452,7 +456,7 @@ TEST(TracedCommandBuffer, GetOrUpdateCommandBuffer) {
EXPECT_EQ(num_calls, 3);

// Check that we still keep the previous graph in cache.
allocations = BufferAllocations({mem0, mem1}, 0, executor->GetAllocator());
allocations = BufferAllocations({mem0, mem1}, 0, &allocator);

TF_ASSERT_OK_AND_ASSIGN(auto* command_buffer5,
traced_cmd_buffer.GetOrTraceCommandBuffer(
Expand All @@ -479,12 +483,13 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) {

se::DeviceMemoryBase mem0(reinterpret_cast<void*>(0x01234567));
se::DeviceMemoryBase mem1(reinterpret_cast<void*>(0x12345670));
se::StreamExecutorMemoryAllocator allocator(executor);

std::array<BufferAllocations, 4> allocations = {
BufferAllocations({mem0, mem1}, 0, executor->GetAllocator()),
BufferAllocations({mem1, mem0}, 0, executor->GetAllocator()),
BufferAllocations({mem0, mem0}, 0, executor->GetAllocator()),
BufferAllocations({mem1, mem1}, 0, executor->GetAllocator()),
BufferAllocations({mem0, mem1}, 0, &allocator),
BufferAllocations({mem1, mem0}, 0, &allocator),
BufferAllocations({mem0, mem0}, 0, &allocator),
BufferAllocations({mem1, mem1}, 0, &allocator),
};

int32_t index = 0;
Expand Down
Loading

0 comments on commit db1c1be

Please sign in to comment.