Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
shraiysh committed Jan 28, 2025
1 parent 9565742 commit 9a0b89b
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 106 deletions.
39 changes: 32 additions & 7 deletions xla/backends/gpu/codegen/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,13 +662,21 @@ absl::StatusOr<FusionEmissionResult> EmitGemm(
std::vector<std::optional<BufferAllocation::Slice>> arguments{
lhs_slice, rhs_slice, output, workspace};

std::optional<DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata>
offset_modules_metadata = std::nullopt;
if (can_compute_indvar_on_host) {
offset_modules_metadata =
DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata{
/*indvar_init=*/std::move(init_module),
/*indvar_update=*/std::move(update_module),
/*extracted_offset_modules=*/std::move(extracted_offset_modules)};
}
thunk = std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
std::move(arguments), std::move(fake_allocations),
std::move(offset_buffer_indices), std::move(orig_shapes),
std::move(sliced_shapes), std::move(offset_byte_sizes),
std::move(extracted_offset_modules), std::move(init_module),
std::move(update_module));
std::move(offset_modules_metadata));
} else {
thunk = std::make_unique<GemmThunk>(thunk_info, std::move(config),
lhs_slice, rhs_slice, output, workspace,
Expand Down Expand Up @@ -956,12 +964,20 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
? ffi_thunk(std::move(fake_operands), std::move(fake_results))
: legacy_thunk(std::move(fake_operands), std::move(fake_results)));

std::optional<DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata>
offset_modules_metadata = std::nullopt;
if (can_compute_indvar_on_host) {
offset_modules_metadata =
DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata{
/*indvar_init=*/std::move(init_module),
/*indvar_update=*/std::move(update_module),
/*extracted_offset_modules=*/std::move(extracted_offset_modules)};
}
thunk = std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
std::move(arguments), std::move(fake_allocations), std::move(offsets),
std::move(orig_shapes), std::move(sliced_shapes),
std::move(offset_byte_sizes), std::move(extracted_offset_modules),
std::move(init_module), std::move(update_module));
std::move(offset_byte_sizes), std::move(offset_modules_metadata));
} else {
TF_ASSIGN_OR_RETURN(
thunk, found_ffi_handler
Expand Down Expand Up @@ -1144,13 +1160,22 @@ absl::StatusOr<FusionEmissionResult> EmitCollective(
// Depending on whether this is a dynamic fusion or not, we wrap the thunk(s)
// within a dynamic-slice thunk.
if (isDynamic) {
std::optional<DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata>
offset_modules_metadata = std::nullopt;
if (can_compute_indvar_on_host) {
offset_modules_metadata =
DynamicSliceThunk::OffsetAsFunctionOfIndvarModulesMetadata(
/*indvar_init=*/std::move(init_module),
/*indvar_update=*/std::move(update_module),
/*extracted_offset_modules=*/std::move(extracted_offset_modules));
}
std::unique_ptr<Thunk> thunk = std::make_unique<DynamicSliceThunk>(
thunk_info, std::make_unique<ThunkSequence>(std::move(seq)),
thunk_info,
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
std::move(arguments), std::move(fake_allocations),
std::move(offset_buffer_indices), std::move(orig_shapes),
std::move(sliced_shapes), std::move(offset_byte_sizes),
std::move(extracted_offset_modules), std::move(init_module),
std::move(update_module));
std::move(offset_modules_metadata));
result.thunks.push_back(std::move(thunk));
} else {
for (auto& thunk : seq) {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/gpu/runtime/dynamic_slice_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "xla/backends/gpu/runtime/sequential_thunk.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/hlo/evaluator/hlo_evaluator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
Expand Down
13 changes: 7 additions & 6 deletions xla/backends/gpu/runtime/dynamic_slice_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,8 @@ class DynamicSliceThunk : public Thunk {
std::vector<std::optional<Shape>> orig_shapes,
std::vector<std::optional<Shape>> sliced_shapes,
std::vector<std::optional<uint64_t>> offset_byte_sizes,
std::vector<std::unique_ptr<HloModule>> temp_modules,
std::unique_ptr<HloModule> indvar_init,
std::unique_ptr<HloModule> indvar_update);

std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata = std::nullopt);
DynamicSliceThunk(const DynamicSliceThunk&) = delete;
DynamicSliceThunk& operator=(const DynamicSliceThunk&) = delete;

Expand Down Expand Up @@ -188,8 +186,11 @@ class DynamicSliceThunk : public Thunk {
// A mapping from argument index to the base offset in the `offsets_allocs_`.
std::vector<int64_t> offsets_allocs_base_;

std::vector<std::unique_ptr<HloModule>> temp_modules_;
std::unique_ptr<HloModule> indvar_init_, indvar_update_;
// This structure holds the metadata for offset computations on host. It
// stores a single induction variable initialization module, its update module
// and the offsets that are a function of the induction variable.
std::optional<OffsetAsFunctionOfIndvarModulesMetadata>
offset_as_function_of_indvar_metadata_;
};

} // namespace gpu
Expand Down
127 changes: 42 additions & 85 deletions xla/backends/gpu/runtime/dynamic_slice_thunk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,15 @@ TEST_F(DynamicSliceThunkTest, SlicedGemm) {
std::vector<DynamicSliceThunk::Offset> lhs_offsets{slice_lhs_offset_0,
slice_lhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt,
std::nullopt, std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt,
std::nullopt, std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -311,22 +306,17 @@ TEST_F(DynamicSliceThunkTest, MulipleSlicedOperandsGemm) {
std::vector<DynamicSliceThunk::Offset> rhs_offsets{slice_rhs_offset_0,
slice_rhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, rhs_offsets, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, rhs_offsets, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}),
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 1}), std::nullopt,
std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}),
ShapeUtil::MakeShape(PrimitiveType::F32, {3, 1}), std::nullopt,
std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), sizeof(int64_t), std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), sizeof(int64_t), std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -493,18 +483,14 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpy) {
std::vector<DynamicSliceThunk::Offset> slice_offsets{
slice_offset_0, slice_offset_1, slice_offset_2, slice_offset_3};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_src, slice_dst}, std::move(fake_allocations),
/*offsets=*/{slice_offsets, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_src, slice_dst}, std::move(fake_allocations),
{slice_offsets, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 8}), std::nullopt},
// Make sure to pass a dst shape with the same rank as src shape (i.e.
// original slice result and not bitcasted one)
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 8, 8}), std::nullopt},
/*offset_byte_sizes=*/{sizeof(int64_t), std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -661,20 +647,16 @@ TEST_F(DynamicSliceThunkTest, SlicedOutputMemcpy) {
slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2,
slice_dst_offset_3};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_src, slice_dst}, std::move(fake_allocations),
/*offsets=*/{slice_src_offsets, slice_dst_offsets},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_src, slice_dst}, std::move(fake_allocations),
{slice_src_offsets, slice_dst_offsets},
{ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}),
ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})},
// Make sure to pass a dst shape with the same rank as src shape (i.e.
// original slice result and not bitcasted one)
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2}),
ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})},
/*offset_byte_sizes=*/{sizeof(int64_t), sizeof(int64_t)},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), sizeof(int64_t)});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -842,20 +824,15 @@ TEST_F(DynamicSliceThunkTest, SlicedGemmArbitraryArgumentOrder) {
std::vector<DynamicSliceThunk::Offset> lhs_offsets{slice_lhs_offset_0,
slice_lhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt,
std::nullopt, std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt,
std::nullopt, std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -996,20 +973,15 @@ TEST_F(DynamicSliceThunkTest, SlicedGemmArbitraryNumberOfArguments) {
std::vector<DynamicSliceThunk::Offset> lhs_offsets{slice_lhs_offset_0,
slice_lhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt,
std::nullopt, std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt,
std::nullopt, std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -1143,20 +1115,15 @@ TEST_F(DynamicSliceThunkTest, SlicedTupledOperandGemm) {
std::vector<DynamicSliceThunk::Offset> lhs_offsets{slice_lhs_offset_0,
slice_lhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt,
std::nullopt, std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt,
std::nullopt, std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -1324,21 +1291,16 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpyOOB) {
slice_dst_offset_0, slice_dst_offset_1, slice_dst_offset_2,
slice_dst_offset_3};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_src, slice_dst}, std::move(fake_allocations),
/*offsets=*/{slice_src_offsets, slice_dst_offsets},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_src, slice_dst}, std::move(fake_allocations),
{slice_src_offsets, slice_dst_offsets},
{ShapeUtil::MakeShape(PrimitiveType::S32, {8, 8, 10, 2}),
ShapeUtil::MakeShape(PrimitiveType::S32, {2, 2, 2, 2})},
// Make sure to pass a dst shape with the same rank as src shape (i.e.
// original slice result and not bitcasted one)
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2}),
ShapeUtil::MakeShape(PrimitiveType::S32, {1, 1, 2, 2})},
/*offset_byte_sizes=*/{sizeof(int64_t), sizeof(int64_t)},
/*temp_modules=*/{}, /*indvar_init=*/nullptr,
/*indvar_update=*/nullptr);
{sizeof(int64_t), sizeof(int64_t)});

// Step 2:
// Execute address computation thunk.
Expand Down Expand Up @@ -1507,20 +1469,15 @@ TEST_F(DynamicSliceThunkTest, SlicedOperandsSameBufferGemm) {
std::vector<DynamicSliceThunk::Offset> lhs_offsets{slice_lhs_offset_0,
slice_lhs_offset_1};
DynamicSliceThunk thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::make_unique<ThunkSequence>(std::move(seq)),
/*arguments=*/{slice_lhs, slice_rhs, slice_out, slice_workspace},
/*fake_allocations=*/std::move(fake_allocations),
/*offsets=*/{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
/*orig_shapes=*/
Thunk::ThunkInfo(), std::make_unique<ThunkSequence>(std::move(seq)),
{slice_lhs, slice_rhs, slice_out, slice_workspace},
std::move(fake_allocations),
{lhs_offsets, std::nullopt, std::nullopt, std::nullopt},
{ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt,
std::nullopt, std::nullopt},
/*sliced_shapes=*/
{ShapeUtil::MakeShape(PrimitiveType::F32, {1, 3}), std::nullopt,
std::nullopt, std::nullopt},
/*offset_byte_sizes=*/
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt},
/*temp_modules=*/{}, /*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
{sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt});

// Step 2:
// Execute address computation thunk.
Expand Down
6 changes: 1 addition & 5 deletions xla/backends/gpu/runtime/for_all_thunks_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@ TEST(ForAllThunksTest, DynamicSliceThunk) {
thunk_sequence->push_back(std::move(thunk));

DynamicSliceThunk dynamic_slice_thunk(
/*thunk_info=*/Thunk::ThunkInfo(),
/*embedded_thunk=*/std::move(thunk_sequence), /*arguments=*/{},
/*fake_allocations=*/{}, /*offsets=*/{}, /*orig_shapes=*/{},
/*sliced_shapes=*/{}, /*offset_byte_sizes=*/{}, /*temp_modules=*/{},
/*indvar_init=*/nullptr, /*indvar_update=*/nullptr);
Thunk::ThunkInfo(), std::move(thunk_sequence), {}, {}, {}, {}, {}, {});
EXPECT_THAT(GetAllThunks(&dynamic_slice_thunk),
// `DynamicSliceThunk` wraps the `embedded_thunk` in a
// `SequentialThunk`, which is why iterate over more than the
Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1543,8 +1543,11 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) {
const HloFusionAnalysis fusion_analysis =
HloFusionAnalysis::Create(*instr, device_info);
VLOG(3) << "IrEmitterUnnested::EmitFusion:start";
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(HloFusionInfo(
fusion_analysis, instr, &ir_emitter_context_->buffer_assignment(), *call_graph_));
std::unique_ptr<FusionInterface> emitter = GetFusionEmitter(
/*fusion_info=*/HloFusionInfo(
/*analysis=*/fusion_analysis, instr,
/*buffer_assignment=*/&ir_emitter_context_->buffer_assignment(),
/*call_graph=*/*call_graph_));
TF_ASSIGN_OR_RETURN(auto result, emitter->Emit(*ir_emitter_context_, *instr));

const ExecutionStreamAssignment& stream_assignment =
Expand Down

0 comments on commit 9a0b89b

Please sign in to comment.