Skip to content

Commit

Permalink
[XLA:CPU] Add initial thunk serialization.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713322136
  • Loading branch information
Google-ML-Automation committed Jan 16, 2025
1 parent 79ffbf1 commit b68b247
Show file tree
Hide file tree
Showing 56 changed files with 1,390 additions and 152 deletions.
199 changes: 121 additions & 78 deletions xla/backends/cpu/runtime/BUILD

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions xla/backends/cpu/runtime/all_gather_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ limitations under the License.
#include <utility>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/all_gather_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_

#include <memory>
#include <string>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/all_reduce_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class AllReduceThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

ReductionKind reduction_kind() const { return reduction_kind_; }
bool single_replica() const { return single_replica_; }

private:
AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params,
OpBuffers op_buffers, OpResources op_resources,
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/call_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class CallThunk final : public Thunk {
BufferUses buffer_uses() const final;
ResourceUses resource_uses() const final;

const ThunkExecutor& called_executor() const { return called_executor_; }

private:
CallThunk(Info info, ThunkExecutor called_executor);

Expand Down
7 changes: 4 additions & 3 deletions xla/backends/cpu/runtime/collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
Expand All @@ -41,9 +43,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class CollectivePermuteThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

const std::vector<SourceTargetPair>& source_target_pairs() const {
return source_target_pairs_;
}

private:
CollectivePermuteThunk(
Info info, OpParams op_params, OpBuffers op_buffers,
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/global_device_id.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
Expand Down
5 changes: 5 additions & 0 deletions xla/backends/cpu/runtime/collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/inlined_vector.h"
Expand Down Expand Up @@ -77,6 +78,10 @@ class CollectiveThunk : public Thunk {

const OpParams& op_params() const { return op_params_; }

const OpBuffers& op_buffers() const { return op_buffers_; }

const OpResources& op_resources() const { return op_resources_; }

// Resolves operation's device memory from the buffers and buffer allocations.
absl::StatusOr<OpDeviceMemory> GetOpDeviceMemory(const ExecuteParams& params);

Expand Down
93 changes: 93 additions & 0 deletions xla/backends/cpu/runtime/collective_thunk.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package xla.cpu;

import "xla/service/hlo.proto";
import "xla/xla_data.proto";

option cc_enable_arenas = true;

message ResourceProto {
enum Kind {
UNKNOWN = 0;
TOKEN = 1;
COLLECTIVE_COMMUNICATOR = 2;
}
Kind kind = 1;
}

message BufferAllocationSliceProto {
int64 offset = 1;
int64 size = 2;
int64 buffer_allocation_index = 3;
}

message ShapeBufferAllocationSliceProto {
xla.ShapeProto shape = 1;
BufferAllocationSliceProto slice = 2;
}

message OpParamsProto {
int64 op_id = 1;
bool has_channel_id = 2;
bool use_global_device_ids = 3; // TODO(basioli) optional
repeated ReplicaGroup replica_group = 4;
}

message OpBuffersProto {
repeated ShapeBufferAllocationSliceProto source_shapes_buffer_slices = 1;
repeated ShapeBufferAllocationSliceProto destination_shapes_buffer_slices = 2;
}

message OpResourcesProto {
ResourceProto communicator_resource = 1; // TODO(basioli) optional
}

message AllGatherThunkProto {} // NOTE(basioli) empty for now

message AllReduceThunkProto {
string reduction_kind = 1;
bool single_replica = 2;
}

message AllToAllThunkProto {} // NOTE(basioli) empty for now

message ReduceScatterThunkProto {
string reduction_kind = 1;
}

message CollectivePermuteThunkProto {
message SourceTargetPairProto {
int64 source = 1;
int64 target = 2;
}
repeated SourceTargetPairProto source_target_pairs = 1;
}

message CollectiveThunkProto {
OpParamsProto op_params = 1;
OpBuffersProto op_buffers = 2;
OpResourcesProto op_resources = 3;
oneof impl {
AllGatherThunkProto all_gather_thunk = 4;
AllReduceThunkProto all_reduce_thunk = 5;
AllToAllThunkProto all_to_all_thunk = 6;
ReduceScatterThunkProto reduce_scatter_thunk = 7;
CollectivePermuteThunkProto collective_permute_thunk = 8;
}
}
5 changes: 3 additions & 2 deletions xla/backends/cpu/runtime/conditional_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla::cpu {

Expand Down
8 changes: 8 additions & 0 deletions xla/backends/cpu/runtime/conditional_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ class ConditionalThunk final : public Thunk {
BufferUses buffer_uses() const final;
ResourceUses resource_uses() const final;

const std::vector<ThunkExecutor>& branch_executors() const {
return branch_executors_;
}

const BufferAllocation::Slice& branch_index_buffer() const {
return branch_index_buffer_;
}

private:
ConditionalThunk(Info info, BufferAllocation::Slice branch_index_buffer,
std::vector<ThunkExecutor> branch_executors);
Expand Down
14 changes: 8 additions & 6 deletions xla/backends/cpu/runtime/convolution_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/convolution_thunk.h"

#define EIGEN_USE_THREADS

#include <cstdint>
Expand All @@ -39,9 +38,9 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -213,7 +212,7 @@ absl::StatusOr<std::unique_ptr<ConvolutionThunk>> ConvolutionThunk::Create(
output_shape, input_batch, input_dims, input_channels, kernel_dims,
kernel_channels, kernel_filters, output_dims, strides, padding_before,
padding_after, base_dilation, window_dilation, feature_group_count,
options));
options, dnums, window));
}

ConvolutionThunk::ConvolutionThunk(
Expand All @@ -229,7 +228,8 @@ ConvolutionThunk::ConvolutionThunk(
const absl::InlinedVector<int64_t, 2>& padding_after,
const absl::InlinedVector<int64_t, 2>& base_dilation,
const absl::InlinedVector<int64_t, 2>& window_dilation,
int64_t feature_group_count, Options options)
int64_t feature_group_count, Options options,
const ConvolutionDimensionNumbers& dnums, const Window& window)
: Thunk(Kind::kConvolution, std::move(info)),
input_buffer_(input_buffer),
input_shape_(input_shape),
Expand All @@ -251,7 +251,9 @@ ConvolutionThunk::ConvolutionThunk(
window_dilation_(window_dilation),
feature_group_count_(feature_group_count),
convolution_rank_(input_dims.size()),
options_(options) {}
options_(options),
dnums_(dnums),
window_(window) {}

tsl::AsyncValueRef<Thunk::ExecuteEvent> ConvolutionThunk::Execute(
const ExecuteParams& params) {
Expand Down
Loading

0 comments on commit b68b247

Please sign in to comment.