Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU] Add initial thunk serialization. #21262

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 108 additions & 121 deletions xla/backends/cpu/runtime/BUILD

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions xla/backends/cpu/runtime/all_gather_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ limitations under the License.
#include <utility>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/all_gather_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_

#include <memory>
#include <string>

#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/all_reduce_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/primitive_util.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
3 changes: 3 additions & 0 deletions xla/backends/cpu/runtime/all_reduce_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class AllReduceThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

ReductionKind reduction_kind() const { return reduction_kind_; }
bool single_replica() const { return single_replica_; }

private:
AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params,
OpBuffers op_buffers, OpResources op_resources,
Expand Down
3 changes: 2 additions & 1 deletion xla/backends/cpu/runtime/all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
2 changes: 1 addition & 1 deletion xla/backends/cpu/runtime/call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
2 changes: 2 additions & 0 deletions xla/backends/cpu/runtime/call_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class CallThunk final : public Thunk {
BufferUses buffer_uses() const final;
ResourceUses resource_uses() const final;

const ThunkExecutor& called_executor() const { return called_executor_; }

private:
CallThunk(Info info, ThunkExecutor called_executor);

Expand Down
7 changes: 4 additions & 3 deletions xla/backends/cpu/runtime/collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "xla/backends/cpu/collectives/cpu_collectives.h"
#include "xla/backends/cpu/runtime/collective_thunk.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/collective_ops_utils.h"
Expand All @@ -41,9 +43,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/collective_permute_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class CollectivePermuteThunk final : public CollectiveThunk {

tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;

const std::vector<SourceTargetPair>& source_target_pairs() const {
return source_target_pairs_;
}

private:
CollectivePermuteThunk(
Info info, OpParams op_params, OpBuffers op_buffers,
Expand Down
1 change: 1 addition & 0 deletions xla/backends/cpu/runtime/collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/service/collective_ops_utils.h"
#include "xla/service/computation_placer.h"
#include "xla/service/global_device_id.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
Expand Down
5 changes: 5 additions & 0 deletions xla/backends/cpu/runtime/collective_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/inlined_vector.h"
Expand Down Expand Up @@ -77,6 +78,10 @@ class CollectiveThunk : public Thunk {

const OpParams& op_params() const { return op_params_; }

const OpBuffers& op_buffers() const { return op_buffers_; }

const OpResources& op_resources() const { return op_resources_; }

// Resolves operation's device memory from the buffers and buffer allocations.
absl::StatusOr<OpDeviceMemory> GetOpDeviceMemory(const ExecuteParams& params);

Expand Down
5 changes: 3 additions & 2 deletions xla/backends/cpu/runtime/conditional_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xla/backends/cpu/runtime/thunk.h"
#include "xla/backends/cpu/runtime/thunk_executor.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla::cpu {

Expand Down
8 changes: 8 additions & 0 deletions xla/backends/cpu/runtime/conditional_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ class ConditionalThunk final : public Thunk {
BufferUses buffer_uses() const final;
ResourceUses resource_uses() const final;

const std::vector<ThunkExecutor>& branch_executors() const {
return branch_executors_;
}

const BufferAllocation::Slice& branch_index_buffer() const {
return branch_index_buffer_;
}

private:
ConditionalThunk(Info info, BufferAllocation::Slice branch_index_buffer,
std::vector<ThunkExecutor> branch_executors);
Expand Down
14 changes: 8 additions & 6 deletions xla/backends/cpu/runtime/convolution_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/runtime/convolution_thunk.h"

#define EIGEN_USE_THREADS

#include <cstdint>
Expand All @@ -39,9 +38,9 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down Expand Up @@ -213,7 +212,7 @@ absl::StatusOr<std::unique_ptr<ConvolutionThunk>> ConvolutionThunk::Create(
output_shape, input_batch, input_dims, input_channels, kernel_dims,
kernel_channels, kernel_filters, output_dims, strides, padding_before,
padding_after, base_dilation, window_dilation, feature_group_count,
options));
options, dnums, window));
}

ConvolutionThunk::ConvolutionThunk(
Expand All @@ -229,7 +228,8 @@ ConvolutionThunk::ConvolutionThunk(
const absl::InlinedVector<int64_t, 2>& padding_after,
const absl::InlinedVector<int64_t, 2>& base_dilation,
const absl::InlinedVector<int64_t, 2>& window_dilation,
int64_t feature_group_count, Options options)
int64_t feature_group_count, Options options,
const ConvolutionDimensionNumbers& dnums, const Window& window)
: Thunk(Kind::kConvolution, std::move(info)),
input_buffer_(input_buffer),
input_shape_(input_shape),
Expand All @@ -251,7 +251,9 @@ ConvolutionThunk::ConvolutionThunk(
window_dilation_(window_dilation),
feature_group_count_(feature_group_count),
convolution_rank_(input_dims.size()),
options_(options) {}
options_(options),
dnums_(dnums),
window_(window) {}

tsl::AsyncValueRef<Thunk::ExecuteEvent> ConvolutionThunk::Execute(
const ExecuteParams& params) {
Expand Down
47 changes: 30 additions & 17 deletions xla/backends/cpu/runtime/convolution_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/xla_data.pb.h"

namespace xla::cpu {

Expand All @@ -54,24 +55,34 @@ class ConvolutionThunk final : public Thunk {
{output_buffer_, BufferUse::kWrite}};
}

ConvolutionDimensionNumbers dnums() const { return dnums_; }
Window window() const { return window_; }
int64_t feature_group_count() const { return feature_group_count_; }
const Options& options() const { return options_; }
BufferAllocation::Slice input_buffer() const { return input_buffer_; }
Shape input_shape() const { return input_shape_; }
BufferAllocation::Slice kernel_buffer() const { return kernel_buffer_; }
Shape kernel_shape() const { return kernel_shape_; }
BufferAllocation::Slice output_buffer() const { return output_buffer_; }
Shape output_shape() const { return output_shape_; }

private:
ConvolutionThunk(Info info, BufferAllocation::Slice input_buffer,
const Shape& input_shape,
BufferAllocation::Slice kernel_buffer,
const Shape& kernel_shape,
BufferAllocation::Slice output_buffer,
const Shape& output_shape, int64_t input_batch,
const absl::InlinedVector<int64_t, 2>& input_dims,
int64_t input_channels,
const absl::InlinedVector<int64_t, 2>& kernel_dims,
int64_t kernel_channels, int64_t kernel_filters,
const absl::InlinedVector<int64_t, 2>& output_dims,
const absl::InlinedVector<int64_t, 2>& strides,
const absl::InlinedVector<int64_t, 2>& padding_before,
const absl::InlinedVector<int64_t, 2>& padding_after,
const absl::InlinedVector<int64_t, 2>& base_dilation,
const absl::InlinedVector<int64_t, 2>& window_dilation,
int64_t feature_group_count, Options options);
ConvolutionThunk(
Info info, BufferAllocation::Slice input_buffer, const Shape& input_shape,
BufferAllocation::Slice kernel_buffer, const Shape& kernel_shape,
BufferAllocation::Slice output_buffer, const Shape& output_shape,
int64_t input_batch, const absl::InlinedVector<int64_t, 2>& input_dims,
int64_t input_channels,
const absl::InlinedVector<int64_t, 2>& kernel_dims,
int64_t kernel_channels, int64_t kernel_filters,
const absl::InlinedVector<int64_t, 2>& output_dims,
const absl::InlinedVector<int64_t, 2>& strides,
const absl::InlinedVector<int64_t, 2>& padding_before,
const absl::InlinedVector<int64_t, 2>& padding_after,
const absl::InlinedVector<int64_t, 2>& base_dilation,
const absl::InlinedVector<int64_t, 2>& window_dilation,
int64_t feature_group_count, Options options,
const ConvolutionDimensionNumbers& dnums, const Window& window);

void HandleACLConvolution(const ExecuteParams& params,
se::DeviceMemoryBase input,
Expand Down Expand Up @@ -137,6 +148,8 @@ class ConvolutionThunk final : public Thunk {
int64_t feature_group_count_;
int convolution_rank_;
Options options_;
ConvolutionDimensionNumbers dnums_;
Window window_;
};

} // namespace xla::cpu
Expand Down
5 changes: 3 additions & 2 deletions xla/backends/cpu/runtime/copy_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
Expand All @@ -42,9 +44,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
6 changes: 6 additions & 0 deletions xla/backends/cpu/runtime/copy_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class CopyThunk final : public Thunk {
return {{src_buffer_, BufferUse::kRead}, {dst_buffer_, BufferUse::kWrite}};
}

const Shape& src_shape() const { return src_shape_; }
const Shape& dst_shape() const { return dst_shape_; }

const BufferAllocation::Slice& src_buffer() const { return src_buffer_; }
const BufferAllocation::Slice& dst_buffer() const { return dst_buffer_; }

private:
CopyThunk(Info info, BufferAllocation::Slice src_buffer,
const Shape& src_shape, BufferAllocation::Slice dst_buffer,
Expand Down
6 changes: 3 additions & 3 deletions xla/backends/cpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/dynamic_annotations.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand All @@ -51,10 +52,9 @@ limitations under the License.
#include "xla/service/custom_call_target_registry.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/profiler/lib/traceme.h"

namespace xla::cpu {
Expand Down
5 changes: 5 additions & 0 deletions xla/backends/cpu/runtime/custom_call_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ class CustomCallThunk final : public Thunk {

BufferUses buffer_uses() const final;

const std::string& target_name() const { return target_name_; }
const OpBuffers& op_buffers() const { return op_buffers_; }
const CustomCallApiVersion& api_version() const { return api_version_; }
const std::string& backend_config() const { return backend_config_; }

private:
CustomCallThunk(Info info, absl::string_view target_name,
OpBuffers op_buffers, CustomCallApiVersion api_version,
Expand Down
Loading
Loading