Skip to content

Commit

Permalink
[XLA:CPU] Add initial thunk serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713322136
  • Loading branch information
Google-ML-Automation committed Jan 10, 2025
1 parent 092b8dd commit 0b8fa9f
Show file tree
Hide file tree
Showing 59 changed files with 1,131 additions and 151 deletions.
182 changes: 105 additions & 77 deletions xla/backends/cpu/runtime/BUILD

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions xla/backends/cpu/runtime/all_gather_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@ limitations under the License.

#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/collective_thunk.pb.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -89,4 +92,10 @@ tsl::AsyncValueRef<AllGatherThunk::ExecuteEvent> AllGatherThunk::Execute(
});
}

absl::StatusOr<std::string> AllGatherThunk::SerializeAsStringCollectiveImpl()
const {
AllGatherThunkProto proto;
return proto.SerializeAsString();
}

} // namespace xla::cpu
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/all_gather_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_

#include <memory>
#include <string>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
Expand All @@ -33,6 +34,9 @@ class AllGatherThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

protected:
absl::StatusOr<std::string> SerializeAsStringCollectiveImpl() const final;

private:
AllGatherThunk(Info info, OpParams op_params, OpBuffers op_buffers,
OpResources op_resources);
Expand Down
19 changes: 18 additions & 1 deletion xla/backends/cpu/runtime/all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>

#include "absl/container/inlined_vector.h"
Expand All @@ -27,18 +28,21 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/collective_thunk.pb.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -115,4 +119,17 @@ tsl::AsyncValueRef<AllReduceThunk::ExecuteEvent> AllReduceThunk::Execute(
return OkExecuteEvent();
}

absl::StatusOr<std::string> AllReduceThunk::SerializeAsStringCollectiveImpl()
const {
AllReduceThunkProto proto;
absl::string_view reduction_kind_as_string_view =
ReductionKindToString(reduction_kind_);
std::string reduction_kind_as_string(reduction_kind_as_string_view.begin(),
reduction_kind_as_string_view.end());
proto.set_reduction_kind(reduction_kind_as_string);
proto.set_single_replica(single_replica_);

return proto.SerializeAsString();
}

} // namespace xla::cpu
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/all_reduce_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class AllReduceThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

protected:
absl::StatusOr<std::string> SerializeAsStringCollectiveImpl() const final;

private:
AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params,
OpBuffers op_buffers, OpResources op_resources,
Expand Down
11 changes: 10 additions & 1 deletion xla/backends/cpu/runtime/all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/all_to_all_thunk.h"

#include <memory>
#include <string>
#include <utility>

#include "absl/container/inlined_vector.h"
Expand All @@ -25,15 +26,17 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/collective_thunk.pb.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -87,4 +90,10 @@ tsl::AsyncValueRef<AllToAllThunk::ExecuteEvent> AllToAllThunk::Execute(
});
}

absl::StatusOr<std::string> AllToAllThunk::SerializeAsStringCollectiveImpl()
const {
AllToAllThunkProto proto;
return proto.SerializeAsString();
}

} // namespace xla::cpu
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/all_to_all_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class AllToAllThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

protected:
absl::StatusOr<std::string> SerializeAsStringCollectiveImpl() const final;

private:
AllToAllThunk(Info info, OpParams op_params, OpBuffers op_buffers,
OpResources op_resources);
Expand Down
13 changes: 12 additions & 1 deletion xla/backends/cpu/runtime/call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ limitations under the License.
#include "xla/backends/cpu/runtime/call_thunk.h"

#include <memory>
#include <string>
#include <utility>

#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk.pb.h"
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand All @@ -46,6 +48,15 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> CallThunk::Execute(
return called_executor_.Execute(params);
}

absl::StatusOr<std::string> CallThunk::SerializeAsStringImpl() const {
CallThunkProto proto;
TF_ASSIGN_OR_RETURN(std::string called_sequence_str,
called_executor_.thunk_sequence().SerializeAsString());

proto.mutable_called_sequence()->ParseFromString(called_sequence_str);
return proto.SerializeAsString();
}

CallThunk::BufferUses CallThunk::buffer_uses() const {
return called_executor_.buffer_uses();
}
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/call_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class CallThunk final : public Thunk {
BufferUses buffer_uses() const final;
ResourceUses resource_uses() const final;

protected:
absl::StatusOr<std::string> SerializeAsStringImpl() const final;

private:
CallThunk(Info info, ThunkExecutor called_executor);

Expand Down
22 changes: 19 additions & 3 deletions xla/backends/cpu/runtime/collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -32,7 +33,9 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/collective_thunk.pb.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
Expand All @@ -41,9 +44,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -144,4 +146,18 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) {
});
}

absl::StatusOr<std::string>
CollectivePermuteThunk::SerializeAsStringCollectiveImpl() const {
CollectivePermuteThunkProto proto;

for (const auto& source_target_pair : source_target_pairs_) {
CollectivePermuteThunkProto::SourceTargetPairProto*
source_target_pair_proto = proto.add_source_target_pairs();
source_target_pair_proto->set_source(source_target_pair.first);
source_target_pair_proto->set_target(source_target_pair.second);
}

return proto.SerializeAsString();
}

} // namespace xla::cpu
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class CollectivePermuteThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

protected:
absl::StatusOr<std::string> SerializeAsStringCollectiveImpl() const final;

private:
CollectivePermuteThunk(
Info info, OpParams op_params, OpBuffers op_buffers,
Expand Down
78 changes: 78 additions & 0 deletions xla/backends/cpu/runtime/collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
Expand All @@ -35,6 +36,7 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_clique_key.h"
#include "xla/backends/cpu/collectives/cpu_cliques.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.pb.h"
#include "xla/backends/cpu/runtime/resource_use.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
Expand All @@ -44,6 +46,7 @@ limitations under the License.
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/global_device_id.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
Expand Down Expand Up @@ -76,6 +79,81 @@ Thunk::BufferUses CollectiveThunk::buffer_uses() const {
return uses;
}

absl::StatusOr<std::string> CollectiveThunk::OpParams::SerializeAsString()
const {
OpParamsProto proto;
proto.set_has_channel_id(has_channel_id);
proto.set_use_global_device_ids(
use_global_device_ids.value()); // TODO(basioli) optional
proto.set_op_id(op_id);
for (const auto& group : group) {
ReplicaGroup* replica_group = proto.add_replica_group();
for (const auto& device : group.replica_ids()) {
replica_group->add_replica_ids(device);
}
}
return proto.SerializeAsString();
}

absl::StatusOr<std::string> CollectiveThunk::OpResources::SerializeAsString()
const {
OpResourcesProto proto;
// TODO(basioli) pointer -> optional?
const auto& communicator_resource_str =
communicator_resource->ToProto().SerializeAsString();
proto.mutable_communicator_resource()->ParseFromString(
communicator_resource_str);
return proto.SerializeAsString();
}

absl::StatusOr<std::string> CollectiveThunk::SerializeAsStringImpl() const {
CollectiveThunkProto proto;

TF_ASSIGN_OR_RETURN(const std::string op_params_str,
op_params_.SerializeAsString());
proto.mutable_op_params()->ParseFromString(op_params_str);

TF_ASSIGN_OR_RETURN(const std::string op_resources_str,
op_resources_.SerializeAsString());
proto.mutable_op_resources()->ParseFromString(op_resources_str);

TF_ASSIGN_OR_RETURN(const std::string impl_string,
SerializeAsStringCollectiveImpl());

for (size_t i = 0; i < op_buffers_.source_buffers.size(); ++i) {
TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto(
op_buffers_.source_buffers[i], op_buffers_.source_shapes[i],
proto.mutable_op_buffers()->add_source_shapes_buffer_slices()));
}

for (size_t i = 0; i < op_buffers_.destination_buffers.size(); ++i) {
TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto(
op_buffers_.destination_buffers[i], op_buffers_.destination_shapes[i],
proto.mutable_op_buffers()->add_destination_shapes_buffer_slices()));
}

switch (proto.impl_case()) {
case CollectiveThunkProto::ImplCase::kAllGatherThunk:
proto.mutable_all_gather_thunk()->ParseFromString(impl_string);
break;
case CollectiveThunkProto::ImplCase::kAllReduceThunk:
proto.mutable_all_reduce_thunk()->ParseFromString(impl_string);
break;
case CollectiveThunkProto::ImplCase::kAllToAllThunk:
proto.mutable_all_to_all_thunk()->ParseFromString(impl_string);
break;
case CollectiveThunkProto::ImplCase::kReduceScatterThunk:
proto.mutable_reduce_scatter_thunk()->ParseFromString(impl_string);
break;
case CollectiveThunkProto::ImplCase::kCollectivePermuteThunk:
proto.mutable_collective_permute_thunk()->ParseFromString(impl_string);
break;
default:
return absl::UnimplementedError("SerializeAsStringImpl not implemented");
}
return proto.SerializeAsString();
}

Thunk::ResourceUses CollectiveThunk::resource_uses() const {
return {ResourceUse::Write(op_resources_.communicator_resource)};
}
Expand Down
Loading

0 comments on commit 0b8fa9f

Please sign in to comment.