diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index c80d349c2ae924..8378e612e147b7 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -2,6 +2,7 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/service/cpu:build_defs.bzl", "runtime_copts") load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl/platform:build_config.bzl", "tf_proto_library") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -133,6 +134,29 @@ cc_library( ], ) +tf_proto_library( + name = "collective_thunk_proto", + srcs = ["collective_thunk.proto"], + make_default_target_header_only = True, + protodeps = [ + "//xla:xla_data_proto", + "//xla/service:hlo_proto", + ], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "thunk_proto", + srcs = ["thunk.proto"], + make_default_target_header_only = True, + protodeps = [ + ":collective_thunk_proto", + "//xla:xla_data_proto", + "//xla/service:hlo_proto", + ], + visibility = ["//visibility:public"], +) + cc_library( name = "thunk", srcs = ["thunk.cc"], @@ -140,13 +164,11 @@ cc_library( deps = [ ":buffer_allocations", ":function_library", - ":kernel_c_api", ":resource_use", + ":thunk_proto_cc", "//xla:executable_run_options", - "//xla:util", "//xla/backends/cpu/collectives:cpu_collectives", "//xla/backends/cpu/collectives:in_process_collectives", - "//xla/core/collectives", "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", @@ -157,9 +179,10 @@ cc_library( "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:logging", "//xla/tsl/platform:statusor", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", @@ -217,14 +240,12 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:numbers", "@tsl//tsl/profiler/lib:traceme", ], @@ -246,7 +267,6 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "//xla/tsl/platform:env", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", @@ -269,11 +289,11 @@ cc_library( deps = [ ":thunk", ":thunk_executor", + ":thunk_proto_cc", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -285,16 +305,17 @@ cc_library( deps = [ ":thunk", ":thunk_executor", + ":thunk_proto_cc", "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) @@ -306,6 +327,7 @@ xla_cc_test( ":conditional_thunk", ":resource_use", ":thunk", + ":thunk_proto_cc", ":thunk_testlib", "//xla:shape_util", "//xla/runtime:buffer_use", @@ -326,27 +348,23 @@ cc_library( hdrs = ["all_gather_thunk.h"], deps = [ ":collective_thunk", + ":collective_thunk_proto_cc", ":thunk", "//xla:shape_util", - "//xla:status_macros", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/runtime:buffer_use", + "//xla/core/collectives:communicator", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -402,6 +420,7 @@ cc_library( deps = [ ":convolution_thunk_internal", ":thunk", + ":thunk_proto_cc", "//xla:executable_run_options", "//xla:shape_util", "//xla:status_macros", @@ -411,6 +430,8 @@ cc_library( "//xla/service/cpu:runtime_conv2d_acl", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -419,8 +440,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -432,6 +451,7 @@ xla_cc_test( ":buffer_allocations", ":convolution_thunk", ":thunk", + ":thunk_proto_cc", ":thunk_testlib", "//xla:literal", "//xla:literal_util", @@ -454,18 +474,18 @@ cc_library( hdrs = ["all_reduce_thunk.h"], deps = [ ":collective_thunk", + ":collective_thunk_proto_cc", ":thunk", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/runtime:buffer_use", + "//xla/core/collectives:communicator", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -473,11 +493,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -488,27 +504,23 @@ cc_library( hdrs = ["all_to_all_thunk.h"], deps = [ ":collective_thunk", + ":collective_thunk_proto_cc", ":thunk", "//xla:shape_util", - "//xla:status_macros", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/runtime:buffer_use", + "//xla/core/collectives:communicator", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -519,27 +531,25 @@ cc_library( hdrs = ["reduce_scatter_thunk.h"], deps = [ ":collective_thunk", + ":collective_thunk_proto_cc", ":thunk", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", - "//xla/runtime:buffer_use", + "//xla/core/collectives:communicator", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -550,30 +560,28 @@ cc_library( hdrs = ["collective_permute_thunk.h"], deps = [ ":collective_thunk", + ":collective_thunk_proto_cc", ":thunk", "//xla:shape_util", "//xla:status_macros", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", - "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -583,6 +591,7 @@ cc_library( srcs = ["collective_thunk.cc"], hdrs = ["collective_thunk.h"], deps = [ + ":collective_thunk_proto_cc", ":resource_use", ":thunk", "//xla:shape_util", @@ -599,6 +608,7 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", + "//xla/service:hlo_proto_cc", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", @@ -623,6 +633,7 @@ cc_library( deps = [ ":buffer_allocations", ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:util", "//xla/pjrt:transpose", @@ -630,15 +641,17 @@ cc_library( "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -650,6 +663,7 @@ xla_cc_test( ":buffer_allocations", ":copy_thunk", ":thunk", + ":thunk_proto_cc", ":thunk_testlib", "//xla:literal_util", "//xla:shape_util", @@ -668,6 +682,7 @@ cc_library( hdrs = ["custom_call_thunk.h"], deps = [ ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:util", "//xla/ffi:attribute_map", @@ -682,9 +697,12 @@ cc_library( "//xla/service:custom_call_target_registry", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -694,9 +712,6 @@ cc_library( "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -707,6 +722,7 @@ cc_library( hdrs = ["dot_lib.h"], deps = [ ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:status_macros", "//xla:types", @@ -748,17 +764,18 @@ cc_library( deps = [ ":dot_lib", ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/framework/contraction:eigen_contraction_kernel", + "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -766,10 +783,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -781,6 +795,7 @@ cc_library( deps = [ ":resource_use", ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:util", "//xla/runtime:buffer_use", @@ -788,12 +803,13 @@ cc_library( "//xla/service/cpu:cpu_runtime", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -806,6 +822,7 @@ xla_cc_test( ":outfeed_thunk", ":resource_use", ":thunk", + ":thunk_proto_cc", ":thunk_testlib", "//xla:shape_util", "//xla/runtime:buffer_use", @@ -833,11 +850,12 @@ cc_library( "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -869,6 +887,7 @@ cc_library( deps = [ ":resource_use", ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:util", "//xla/runtime:buffer_use", @@ -876,12 +895,13 @@ cc_library( "//xla/service/cpu:cpu_runtime", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -918,6 +938,7 @@ cc_library( ":kernel", ":kernel_c_api", ":thunk", + ":thunk_proto_cc", "//xla:util", "//xla/backends/cpu/codegen:llvm_ir_kernel_spec", "//xla/runtime:buffer_use", @@ -926,6 +947,7 @@ cc_library( "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -940,7 +962,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -973,6 +994,7 @@ cc_library( srcs = ["resource_use.cc"], hdrs = ["resource_use.h"], deps = [ + "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", @@ -997,21 +1019,21 @@ cc_library( hdrs = ["rng_state_thunk.h"], deps = [ ":thunk", + ":thunk_proto_cc", "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:int128", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -1023,6 +1045,7 @@ cc_library( deps = [ ":function_library", ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:util", "//xla/runtime:buffer_use", @@ -1043,7 +1066,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/profiler/lib:traceme", ], @@ -1083,17 +1105,20 @@ cc_library( ":buffer_allocations", ":thunk", ":thunk_executor", + ":thunk_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -1126,6 +1151,7 @@ cc_library( hdrs = ["fft_thunk.h"], deps = [ ":thunk", + ":thunk_proto_cc", "//xla:shape_util", "//xla:status_macros", "//xla/runtime:buffer_use", @@ -1135,11 +1161,12 @@ cc_library( "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:statusor", "@tsl//tsl/profiler/lib:traceme", ], ) @@ -1150,13 +1177,14 @@ cc_library( hdrs = ["topk_thunk.h"], deps = [ ":thunk", + ":thunk_proto_cc", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service/cpu:runtime_topk", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/backends/cpu/runtime/all_gather_thunk.cc b/xla/backends/cpu/runtime/all_gather_thunk.cc index 82847710d0b75f..8897e17b5d7264 100644 --- a/xla/backends/cpu/runtime/all_gather_thunk.cc +++ b/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -17,24 +17,27 @@ limitations under the License. #include #include +#include #include #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -89,4 +92,10 @@ tsl::AsyncValueRef AllGatherThunk::Execute( }); } +absl::StatusOr AllGatherThunk::SerializeAsStringCollectiveImpl() + const { + AllGatherThunkProto proto; + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/all_gather_thunk.h b/xla/backends/cpu/runtime/all_gather_thunk.h index 2d2dca9a7eac9d..5b04eb3d39b0b5 100644 --- a/xla/backends/cpu/runtime/all_gather_thunk.h +++ b/xla/backends/cpu/runtime/all_gather_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ #include +#include #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/collective_thunk.h" @@ -33,6 +34,9 @@ class AllGatherThunk final : public CollectiveThunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + protected: + absl::StatusOr SerializeAsStringCollectiveImpl() const final; + private: AllGatherThunk(Info info, OpParams op_params, OpBuffers op_buffers, OpResources op_resources); diff --git a/xla/backends/cpu/runtime/all_reduce_thunk.cc b/xla/backends/cpu/runtime/all_reduce_thunk.cc index 9c6ac2ead41620..45a8651dd00543 100644 --- a/xla/backends/cpu/runtime/all_reduce_thunk.cc +++ b/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/inlined_vector.h" @@ -27,9 +28,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" @@ -37,8 +41,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -115,4 +119,17 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return OkExecuteEvent(); } +absl::StatusOr AllReduceThunk::SerializeAsStringCollectiveImpl() + const { + AllReduceThunkProto proto; + absl::string_view reduction_kind_as_string_view = + ReductionKindToString(reduction_kind_); + std::string reduction_kind_as_string(reduction_kind_as_string_view.begin(), + reduction_kind_as_string_view.end()); + proto.set_reduction_kind(reduction_kind_as_string); + proto.set_single_replica(single_replica_); + + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/all_reduce_thunk.h b/xla/backends/cpu/runtime/all_reduce_thunk.h index 77866382353e02..25d336b6fb3a1c 100644 --- a/xla/backends/cpu/runtime/all_reduce_thunk.h +++ b/xla/backends/cpu/runtime/all_reduce_thunk.h @@ -34,6 +34,9 @@ class AllReduceThunk final : public CollectiveThunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + protected: + absl::StatusOr SerializeAsStringCollectiveImpl() const final; + private: AllReduceThunk(Info info, ReductionKind reduction_kind, OpParams op_params, OpBuffers op_buffers, OpResources op_resources, diff --git a/xla/backends/cpu/runtime/all_to_all_thunk.cc b/xla/backends/cpu/runtime/all_to_all_thunk.cc index b97ff3409deecc..8109f9e4a7096d 100644 --- a/xla/backends/cpu/runtime/all_to_all_thunk.cc +++ b/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/backends/cpu/runtime/all_to_all_thunk.h" #include +#include #include #include "absl/container/inlined_vector.h" @@ -25,7 +26,9 @@ limitations under the License. #include "absl/strings/str_format.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/shape.h" @@ -33,7 +36,7 @@ limitations under the License. #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -87,4 +90,10 @@ tsl::AsyncValueRef AllToAllThunk::Execute( }); } +absl::StatusOr AllToAllThunk::SerializeAsStringCollectiveImpl() + const { + AllToAllThunkProto proto; + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/all_to_all_thunk.h b/xla/backends/cpu/runtime/all_to_all_thunk.h index b58afe94394572..3779dcd9991852 100644 --- a/xla/backends/cpu/runtime/all_to_all_thunk.h +++ b/xla/backends/cpu/runtime/all_to_all_thunk.h @@ -33,6 +33,9 @@ class AllToAllThunk final : public CollectiveThunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + protected: + absl::StatusOr SerializeAsStringCollectiveImpl() const final; + private: AllToAllThunk(Info info, OpParams op_params, OpBuffers op_buffers, OpResources op_resources); diff --git a/xla/backends/cpu/runtime/call_thunk.cc b/xla/backends/cpu/runtime/call_thunk.cc index 0473ad78e40f49..0adf4e1c246797 100644 --- a/xla/backends/cpu/runtime/call_thunk.cc +++ b/xla/backends/cpu/runtime/call_thunk.cc @@ -16,14 +16,16 @@ limitations under the License. #include "xla/backends/cpu/runtime/call_thunk.h" #include +#include #include #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -46,6 +48,15 @@ tsl::AsyncValueRef CallThunk::Execute( return called_executor_.Execute(params); } +absl::StatusOr CallThunk::SerializeAsStringImpl() const { + CallThunkProto proto; + TF_ASSIGN_OR_RETURN(std::string called_sequence_str, + called_executor_.thunk_sequence().SerializeAsString()); + + proto.mutable_called_sequence()->ParseFromString(called_sequence_str); + return proto.SerializeAsString(); +} + CallThunk::BufferUses CallThunk::buffer_uses() const { return called_executor_.buffer_uses(); } diff --git a/xla/backends/cpu/runtime/call_thunk.h b/xla/backends/cpu/runtime/call_thunk.h index b7addf7297c392..9764f3c83f1f27 100644 --- a/xla/backends/cpu/runtime/call_thunk.h +++ b/xla/backends/cpu/runtime/call_thunk.h @@ -37,6 +37,9 @@ class CallThunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: CallThunk(Info info, ThunkExecutor called_executor); diff --git a/xla/backends/cpu/runtime/collective_permute_thunk.cc b/xla/backends/cpu/runtime/collective_permute_thunk.cc index 3e46d388a5f671..ca501252e93457 100644 --- a/xla/backends/cpu/runtime/collective_permute_thunk.cc +++ b/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -32,7 +33,9 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" @@ -41,9 +44,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -144,4 +146,18 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { }); } +absl::StatusOr +CollectivePermuteThunk::SerializeAsStringCollectiveImpl() const { + CollectivePermuteThunkProto proto; + + for (const auto& source_target_pair : source_target_pairs_) { + CollectivePermuteThunkProto::SourceTargetPairProto* + source_target_pair_proto = proto.add_source_target_pairs(); + source_target_pair_proto->set_source(source_target_pair.first); + source_target_pair_proto->set_target(source_target_pair.second); + } + + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/collective_permute_thunk.h b/xla/backends/cpu/runtime/collective_permute_thunk.h index 702b2f2b15f3dd..e7d1225f7761db 100644 --- a/xla/backends/cpu/runtime/collective_permute_thunk.h +++ b/xla/backends/cpu/runtime/collective_permute_thunk.h @@ -40,6 +40,9 @@ class CollectivePermuteThunk final : public CollectiveThunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + protected: + absl::StatusOr SerializeAsStringCollectiveImpl() const final; + private: CollectivePermuteThunk( Info info, OpParams op_params, OpBuffers op_buffers, diff --git a/xla/backends/cpu/runtime/collective_thunk.cc b/xla/backends/cpu/runtime/collective_thunk.cc index 35a6f72fb9671d..31e3a87a4a86d1 100644 --- a/xla/backends/cpu/runtime/collective_thunk.cc +++ b/xla/backends/cpu/runtime/collective_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -35,6 +36,7 @@ limitations under the License. #include "xla/backends/cpu/collectives/cpu_clique_key.h" #include "xla/backends/cpu/collectives/cpu_cliques.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/core/collectives/communicator.h" @@ -44,6 +46,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" @@ -76,6 +79,81 @@ Thunk::BufferUses CollectiveThunk::buffer_uses() const { return uses; } +absl::StatusOr CollectiveThunk::OpParams::SerializeAsString() + const { + OpParamsProto proto; + proto.set_has_channel_id(has_channel_id); + proto.set_use_global_device_ids( + use_global_device_ids.value()); // TODO(basioli) optional + proto.set_op_id(op_id); + for (const auto& group : group) { + ReplicaGroup* replica_group = proto.add_replica_group(); + for (const auto& device : group.replica_ids()) { + replica_group->add_replica_ids(device); + } + } + return proto.SerializeAsString(); +} + +absl::StatusOr CollectiveThunk::OpResources::SerializeAsString() + const { + OpResourcesProto proto; + // TODO(basioli) pointer -> optional? + const auto& communicator_resource_str = + communicator_resource->ToProto().SerializeAsString(); + proto.mutable_communicator_resource()->ParseFromString( + communicator_resource_str); + return proto.SerializeAsString(); +} + +absl::StatusOr CollectiveThunk::SerializeAsStringImpl() const { + CollectiveThunkProto proto; + + TF_ASSIGN_OR_RETURN(const std::string op_params_str, + op_params_.SerializeAsString()); + proto.mutable_op_params()->ParseFromString(op_params_str); + + TF_ASSIGN_OR_RETURN(const std::string op_resources_str, + op_resources_.SerializeAsString()); + proto.mutable_op_resources()->ParseFromString(op_resources_str); + + TF_ASSIGN_OR_RETURN(const std::string impl_string, + SerializeAsStringCollectiveImpl()); + + for (size_t i = 0; i < op_buffers_.source_buffers.size(); ++i) { + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + op_buffers_.source_buffers[i], op_buffers_.source_shapes[i], + proto.mutable_op_buffers()->add_source_shapes_buffer_slices())); + } + + for (size_t i = 0; i < op_buffers_.destination_buffers.size(); ++i) { + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + op_buffers_.destination_buffers[i], op_buffers_.destination_shapes[i], + proto.mutable_op_buffers()->add_destination_shapes_buffer_slices())); + } + + switch (proto.impl_case()) { + case CollectiveThunkProto::ImplCase::kAllGatherThunk: + proto.mutable_all_gather_thunk()->ParseFromString(impl_string); + break; + case CollectiveThunkProto::ImplCase::kAllReduceThunk: + proto.mutable_all_reduce_thunk()->ParseFromString(impl_string); + break; + case CollectiveThunkProto::ImplCase::kAllToAllThunk: + proto.mutable_all_to_all_thunk()->ParseFromString(impl_string); + break; + case CollectiveThunkProto::ImplCase::kReduceScatterThunk: + proto.mutable_reduce_scatter_thunk()->ParseFromString(impl_string); + break; + case CollectiveThunkProto::ImplCase::kCollectivePermuteThunk: + proto.mutable_collective_permute_thunk()->ParseFromString(impl_string); + break; + default: + return absl::UnimplementedError("SerializeAsStringImpl not implemented"); + } + return proto.SerializeAsString(); +} + Thunk::ResourceUses CollectiveThunk::resource_uses() const { return {ResourceUse::Write(op_resources_.communicator_resource)}; } diff --git a/xla/backends/cpu/runtime/collective_thunk.h b/xla/backends/cpu/runtime/collective_thunk.h index e226f7ab3834b6..0ec40e5b074fb5 100644 --- a/xla/backends/cpu/runtime/collective_thunk.h +++ b/xla/backends/cpu/runtime/collective_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/container/inlined_vector.h" @@ -50,6 +51,8 @@ class CollectiveThunk : public Thunk { bool has_channel_id; std::optional use_global_device_ids; std::vector group; + + absl::StatusOr SerializeAsString() const; }; // Source and destination buffers for the collective operation. @@ -64,6 +67,7 @@ class CollectiveThunk : public Thunk { // Resources used by the collective operation. struct OpResources { std::shared_ptr communicator_resource; + absl::StatusOr SerializeAsString() const; }; // Device memory resolved for the collective operation buffers. @@ -84,6 +88,12 @@ class CollectiveThunk : public Thunk { ResourceUses resource_uses() const final; protected: + absl::StatusOr SerializeAsStringImpl() const final; + + virtual absl::StatusOr SerializeAsStringCollectiveImpl() const { + return absl::UnimplementedError("SerializeAsStringImpl not implemented"); + } + // Callback for collective thunk implementations. using Callback = absl::AnyInvocable; diff --git a/xla/backends/cpu/runtime/collective_thunk.proto b/xla/backends/cpu/runtime/collective_thunk.proto new file mode 100644 index 00000000000000..cf4e8703524a07 --- /dev/null +++ b/xla/backends/cpu/runtime/collective_thunk.proto @@ -0,0 +1,71 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.cpu; + +import "xla/service/hlo.proto"; +import "xla/xla_data.proto"; + +message OpParamsProto { + int64 op_id = 1; + bool has_channel_id = 2; + bool use_global_device_ids = 3; // TODO(basioli) optional + repeated ReplicaGroup replica_group = 4; +} + +message OpBuffersProto { + repeated ShapeBufferAllocationSliceProto source_shapes_buffer_slices = 1; + repeated ShapeBufferAllocationSliceProto destination_shapes_buffer_slices = 2; +} + +message OpResourcesProto { + xla.ResourceProto communicator_resource = 1; // TODO(basioli) optional +} + +message AllGatherThunkProto {} // NOTE(basioli) empty for now + +message AllReduceThunkProto { + string reduction_kind = 1; + bool single_replica = 2; +} + +message AllToAllThunkProto {} // NOTE(basioli) empty for now + +message ReduceScatterThunkProto { + string reduction_kind = 1; +} + +message CollectivePermuteThunkProto { + message SourceTargetPairProto { + int64 source = 1; + int64 target = 2; + } + repeated SourceTargetPairProto source_target_pairs = 1; +} + +message CollectiveThunkProto { + OpParamsProto op_params = 1; + OpBuffersProto op_buffers = 2; + OpResourcesProto op_resources = 3; + oneof impl { + AllGatherThunkProto all_gather_thunk = 4; + AllReduceThunkProto all_reduce_thunk = 5; + AllToAllThunkProto all_to_all_thunk = 6; + ReduceScatterThunkProto reduce_scatter_thunk = 7; + CollectivePermuteThunkProto collective_permute_thunk = 8; + } +} diff --git a/xla/backends/cpu/runtime/conditional_thunk.cc b/xla/backends/cpu/runtime/conditional_thunk.cc index 42246dd1d3df51..02d010b19e2577 100644 --- a/xla/backends/cpu/runtime/conditional_thunk.cc +++ b/xla/backends/cpu/runtime/conditional_thunk.cc @@ -17,20 +17,23 @@ limitations under the License. #include #include +#include #include #include +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -96,6 +99,21 @@ tsl::AsyncValueRef ConditionalThunk::Execute( branch_index_buffer_.size()); } +absl::StatusOr ConditionalThunk::SerializeAsStringImpl() const { + ConditionalThunkProto proto; + proto.mutable_branch_sequences()->Reserve(branch_executors_.size()); + for (const auto& branch_executor : branch_executors_) { + TF_ASSIGN_OR_RETURN(std::string branch_sequence_str, + branch_executor.thunk_sequence().SerializeAsString()); + proto.add_branch_sequences()->ParseFromString(branch_sequence_str); + } + + TF_ASSIGN_OR_RETURN(std::string branch_index_buffer_str, + branch_index_buffer_.SerializeAsString()); + proto.mutable_branch_index_buffer()->ParseFromString(branch_index_buffer_str); + return proto.SerializeAsString(); +} + ConditionalThunk::BufferUses ConditionalThunk::buffer_uses() const { BufferUses buffer_uses = {BufferUse::Read(branch_index_buffer_)}; for (const auto& branch_executor : branch_executors_) { diff --git a/xla/backends/cpu/runtime/conditional_thunk.h b/xla/backends/cpu/runtime/conditional_thunk.h index 0b01d8517a6ff4..c174354e46e386 100644 --- a/xla/backends/cpu/runtime/conditional_thunk.h +++ b/xla/backends/cpu/runtime/conditional_thunk.h @@ -38,6 +38,9 @@ class ConditionalThunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: ConditionalThunk(Info info, BufferAllocation::Slice branch_index_buffer, std::vector branch_executors); diff --git a/xla/backends/cpu/runtime/convolution_thunk.cc b/xla/backends/cpu/runtime/convolution_thunk.cc index a157dcea226aa8..56288dadd27ae7 100644 --- a/xla/backends/cpu/runtime/convolution_thunk.cc +++ b/xla/backends/cpu/runtime/convolution_thunk.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "xla/backends/cpu/runtime/convolution_thunk.h" - #define EIGEN_USE_THREADS #include #include #include +#include #include #include "absl/container/inlined_vector.h" @@ -32,6 +32,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/backends/cpu/runtime/convolution_thunk_internal.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime_conv2d_acl.h" @@ -39,9 +40,9 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -213,7 +214,32 @@ absl::StatusOr> ConvolutionThunk::Create( output_shape, input_batch, input_dims, input_channels, kernel_dims, kernel_channels, kernel_filters, output_dims, strides, padding_before, padding_after, base_dilation, window_dilation, feature_group_count, - options)); + options, dnums, window)); +} + +absl::StatusOr ConvolutionThunk::SerializeAsStringImpl() const { + ConvolutionThunkProto proto; + + const std::string dnums_as_str = dnums_.SerializeAsString(); + proto.mutable_dimension_numbers()->ParseFromString(dnums_as_str); + + const std::string window_as_str = window_.SerializeAsString(); + proto.mutable_window()->ParseFromString(window_as_str); + + proto.set_feature_group_count(feature_group_count_); + + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + input_buffer_, input_shape_, proto.mutable_input_buffer_shape())); + + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + output_buffer_, output_shape_, proto.mutable_output_buffer_shape())); + + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + kernel_buffer_, kernel_shape_, proto.mutable_kernel_buffer_shape())); + + proto.mutable_options()->set_multi_threaded(options_.multi_threaded); + proto.mutable_options()->set_use_acl(options_.use_acl); + return proto.SerializeAsString(); } ConvolutionThunk::ConvolutionThunk( @@ -229,7 +255,8 @@ ConvolutionThunk::ConvolutionThunk( const absl::InlinedVector& padding_after, const absl::InlinedVector& base_dilation, const absl::InlinedVector& window_dilation, - int64_t feature_group_count, Options options) + int64_t feature_group_count, Options options, + const ConvolutionDimensionNumbers& dnums, const Window& window) : Thunk(Kind::kConvolution, std::move(info)), input_buffer_(input_buffer), input_shape_(input_shape), @@ -251,7 +278,9 @@ ConvolutionThunk::ConvolutionThunk( window_dilation_(window_dilation), feature_group_count_(feature_group_count), convolution_rank_(input_dims.size()), - options_(options) {} + options_(options), + dnums_(dnums), + window_(window) {} tsl::AsyncValueRef ConvolutionThunk::Execute( const ExecuteParams& params) { diff --git a/xla/backends/cpu/runtime/convolution_thunk.h b/xla/backends/cpu/runtime/convolution_thunk.h index 4ee732adb9595d..dac95b9212a1e9 100644 --- a/xla/backends/cpu/runtime/convolution_thunk.h +++ b/xla/backends/cpu/runtime/convolution_thunk.h @@ -54,24 +54,26 @@ class ConvolutionThunk final : public Thunk { {output_buffer_, BufferUse::kWrite}}; } + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: - ConvolutionThunk(Info info, BufferAllocation::Slice input_buffer, - const Shape& input_shape, - BufferAllocation::Slice kernel_buffer, - const Shape& kernel_shape, - BufferAllocation::Slice output_buffer, - const Shape& output_shape, int64_t input_batch, - const absl::InlinedVector& input_dims, - int64_t input_channels, - const absl::InlinedVector& kernel_dims, - int64_t kernel_channels, int64_t kernel_filters, - const absl::InlinedVector& output_dims, - const absl::InlinedVector& strides, - const absl::InlinedVector& padding_before, - const absl::InlinedVector& padding_after, - const absl::InlinedVector& base_dilation, - const absl::InlinedVector& window_dilation, - int64_t feature_group_count, Options options); + ConvolutionThunk( + Info info, BufferAllocation::Slice input_buffer, const Shape& input_shape, + BufferAllocation::Slice kernel_buffer, const Shape& kernel_shape, + BufferAllocation::Slice output_buffer, const Shape& output_shape, + int64_t input_batch, const absl::InlinedVector& input_dims, + int64_t input_channels, + const absl::InlinedVector& kernel_dims, + int64_t kernel_channels, int64_t kernel_filters, + const absl::InlinedVector& output_dims, + const absl::InlinedVector& strides, + const absl::InlinedVector& padding_before, + const absl::InlinedVector& padding_after, + const absl::InlinedVector& base_dilation, + const absl::InlinedVector& window_dilation, + int64_t feature_group_count, Options options, + const ConvolutionDimensionNumbers& dnums, const Window& window); void HandleACLConvolution(const ExecuteParams& params, se::DeviceMemoryBase input, @@ -137,6 +139,8 @@ class ConvolutionThunk final : public Thunk { int64_t feature_group_count_; int convolution_rank_; Options options_; + ConvolutionDimensionNumbers dnums_; + Window window_; }; } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/copy_thunk.cc b/xla/backends/cpu/runtime/copy_thunk.cc index 67b4d557256950..c68e237efebf75 100644 --- a/xla/backends/cpu/runtime/copy_thunk.cc +++ b/xla/backends/cpu/runtime/copy_thunk.cc @@ -24,27 +24,31 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/algorithm/container.h" #include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "unsupported/Eigen/CXX11/Tensor" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/pjrt/transpose.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -212,4 +216,13 @@ tsl::AsyncValueRef CopyThunk::Execute( return event; } +absl::StatusOr CopyThunk::SerializeAsStringImpl() const { + CopyThunkProto proto; + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + src_buffer_, src_shape_, proto.mutable_src_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dst_buffer_, dst_shape_, proto.mutable_dst_buffer_shape())); + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/copy_thunk.h b/xla/backends/cpu/runtime/copy_thunk.h index ed2cd68df5137a..e1e7bcaa865682 100644 --- a/xla/backends/cpu/runtime/copy_thunk.h +++ b/xla/backends/cpu/runtime/copy_thunk.h @@ -50,6 +50,9 @@ class CopyThunk final : public Thunk { return {{src_buffer_, BufferUse::kRead}, {dst_buffer_, BufferUse::kWrite}}; } + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: CopyThunk(Info info, BufferAllocation::Slice src_buffer, const Shape& src_shape, BufferAllocation::Slice dst_buffer, diff --git a/xla/backends/cpu/runtime/custom_call_thunk.cc b/xla/backends/cpu/runtime/custom_call_thunk.cc index 974a77522ac77d..a86d6b4af4f13e 100644 --- a/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -38,6 +39,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" @@ -51,10 +53,9 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -224,6 +225,26 @@ tsl::AsyncValueRef CustomCallThunk::Execute( return CallUntypedAPI(params); } +absl::StatusOr CustomCallThunk::SerializeAsStringImpl() const { + CustomCallThunkProto proto; + proto.set_target_name(target_name_); + proto.set_backend_config(backend_config_); + proto.set_api_version(api_version_); + + for (size_t i = 0; i < op_buffers_.arguments_buffers.size(); ++i) { + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + op_buffers_.arguments_buffers[i], op_buffers_.arguments_shapes[i], + proto.mutable_op_buffers()->add_arguments_shapes())); + } + + for (size_t i = 0; i < op_buffers_.results_buffers.size(); ++i) { + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + op_buffers_.results_buffers[i], op_buffers_.results_shapes[i], + proto.mutable_op_buffers()->add_results_shapes())); + } + return proto.SerializeAsString(); +} + tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( const ExecuteParams& params) { tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); diff --git a/xla/backends/cpu/runtime/custom_call_thunk.h b/xla/backends/cpu/runtime/custom_call_thunk.h index 19b387a8d19cdc..c025e9bf0e1d20 100644 --- a/xla/backends/cpu/runtime/custom_call_thunk.h +++ b/xla/backends/cpu/runtime/custom_call_thunk.h @@ -56,6 +56,9 @@ class CustomCallThunk final : public Thunk { BufferUses buffer_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: CustomCallThunk(Info info, absl::string_view target_name, OpBuffers op_buffers, CustomCallApiVersion api_version, diff --git a/xla/backends/cpu/runtime/dot_thunk.cc b/xla/backends/cpu/runtime/dot_thunk.cc index 00bcec6a2df83c..1811e268c82418 100644 --- a/xla/backends/cpu/runtime/dot_thunk.cc +++ b/xla/backends/cpu/runtime/dot_thunk.cc @@ -18,26 +18,27 @@ limitations under the License. #include #include #include +#include #include #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "absl/types/span.h" #include "xla/backends/cpu/runtime/dot_lib.h" #include "xla/backends/cpu/runtime/thunk.h" -#include "xla/layout_util.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -203,4 +204,19 @@ tsl::AsyncValueRef DotThunk::Execute( return state.AsRef(); } +absl::StatusOr DotThunk::SerializeAsStringImpl() const { + DotThunkProto proto; + *proto.mutable_dot_dimensions() = dot_dimensions_; + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.lhs_buffer, dot_slices_.lhs_shape, + proto.mutable_lhs_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.rhs_buffer, dot_slices_.rhs_shape, + proto.mutable_rhs_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.out_buffer, dot_slices_.out_shape, + proto.mutable_out_buffer_shape())); + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/dot_thunk.h b/xla/backends/cpu/runtime/dot_thunk.h index 15b5b97fd33c22..69846ba0d33a7c 100644 --- a/xla/backends/cpu/runtime/dot_thunk.h +++ b/xla/backends/cpu/runtime/dot_thunk.h @@ -50,6 +50,9 @@ class DotThunk final : public Thunk { BufferUses buffer_uses() const final { return DotBufferUses(dot_slices_); } + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: DotThunk(Info info, DotDimensionNumbers dot_dimensions, DotSlices dot_slices, DotShape dot_shape, DotCanonicalDims dot_canonical_dims); diff --git a/xla/backends/cpu/runtime/fft_thunk.cc b/xla/backends/cpu/runtime/fft_thunk.cc index b7c898b26d177c..be3b884fd09586 100644 --- a/xla/backends/cpu/runtime/fft_thunk.cc +++ b/xla/backends/cpu/runtime/fft_thunk.cc @@ -16,12 +16,14 @@ limitations under the License. #include #include +#include #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/layout_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" @@ -31,7 +33,8 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -107,6 +110,20 @@ tsl::AsyncValueRef FftThunk::Execute( return OkExecuteEvent(); } +absl::StatusOr FftThunk::SerializeAsStringImpl() const { + FftThunkProto proto; + + proto.set_is_multi_thread_eigen(is_multi_thread_eigen_); + proto.set_fft_type(fft_type_); + proto.mutable_fft_length()->Add(fft_length_.begin(), fft_length_.end()); + + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + input_buffer_, input_shape_, proto.mutable_input_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + output_buffer_, output_shape_, proto.mutable_output_buffer_shape())); + return proto.SerializeAsString(); +} + Thunk::BufferUses FftThunk::buffer_uses() const { return {{input_buffer_, BufferUse::kRead}, {output_buffer_, BufferUse::kWrite}}; diff --git a/xla/backends/cpu/runtime/fft_thunk.h b/xla/backends/cpu/runtime/fft_thunk.h index 64d4063d828cf7..e5a86f536118dd 100644 --- a/xla/backends/cpu/runtime/fft_thunk.h +++ b/xla/backends/cpu/runtime/fft_thunk.h @@ -47,6 +47,9 @@ class FftThunk final : public Thunk { BufferUses buffer_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: // Constructs a thunk for launching an FFT on a host. FftThunk(Info thunk_info, bool is_multi_thread_eigen, int32_t fft_type, diff --git a/xla/backends/cpu/runtime/infeed_thunk.cc b/xla/backends/cpu/runtime/infeed_thunk.cc index e1a601565c69d3..7e96704d015483 100644 --- a/xla/backends/cpu/runtime/infeed_thunk.cc +++ b/xla/backends/cpu/runtime/infeed_thunk.cc @@ -18,22 +18,25 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -98,6 +101,21 @@ tsl::AsyncValueRef InfeedThunk::Execute( return OkExecuteEvent(); } +absl::StatusOr InfeedThunk::SerializeAsStringImpl() const { + InfeedThunkProto proto; + *proto.mutable_infeed_resources()->mutable_consume_token() = + infeed_resources_.consume_token->ToProto(); + *proto.mutable_infeed_resources()->mutable_produce_token() = + infeed_resources_.produce_token->ToProto(); + + for (const InfeedBuffer& infeed_buffer : infeed_buffers_) { + TF_RETURN_IF_ERROR( + SerializeSliceShapeIntoProto(infeed_buffer.slice, infeed_buffer.shape, + proto.add_infeed_buffers_shapes())); + } + return proto.SerializeAsString(); +} + InfeedThunk::BufferUses InfeedThunk::buffer_uses() const { BufferUses buffer_uses; for (const InfeedBuffer& infeed_buffer : infeed_buffers_) { diff --git a/xla/backends/cpu/runtime/infeed_thunk.h b/xla/backends/cpu/runtime/infeed_thunk.h index 1d4225d1ddd008..f92cf93ab4275c 100644 --- a/xla/backends/cpu/runtime/infeed_thunk.h +++ b/xla/backends/cpu/runtime/infeed_thunk.h @@ -51,6 +51,9 @@ class InfeedThunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: InfeedThunk(Info info, absl::Span infeed_buffers, InfeedResources infeed_resources); diff --git a/xla/backends/cpu/runtime/kernel_thunk.cc b/xla/backends/cpu/runtime/kernel_thunk.cc index 2578dc1b7c85ac..8ced212999f210 100644 --- a/xla/backends/cpu/runtime/kernel_thunk.cc +++ b/xla/backends/cpu/runtime/kernel_thunk.cc @@ -43,14 +43,15 @@ limitations under the License. #include "xla/backends/cpu/runtime/kernel.h" #include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #define EIGEN_USE_THREADS @@ -379,4 +380,36 @@ absl::StatusOr> KernelThunk::Create( min_alignment); } +absl::StatusOr KernelThunk::SerializeAsStringImpl() const { + KernelThunkProto proto; + + proto.set_kernel_name(kernel_name_); + proto.mutable_thread_dim()->set_x(thread_dim_.x); + proto.mutable_thread_dim()->set_y(thread_dim_.y); + proto.mutable_thread_dim()->set_z(thread_dim_.z); + proto.set_min_alignment(min_alignment_.value()); + + for (const BufferAllocation::Slice& buffer : arguments_buffers_) { + TF_ASSIGN_OR_RETURN(const std::string slice_as_str, + buffer.SerializeAsString()); + proto.add_arguments_buffers()->ParseFromString(slice_as_str); + } + + for (const BufferAllocation::Slice& buffer : results_buffers_) { + TF_ASSIGN_OR_RETURN(const std::string slice_as_str, + buffer.SerializeAsString()); + proto.add_results_buffers()->ParseFromString(slice_as_str); + } + return proto.SerializeAsString(); +} + +template +absl::StatusOr +SmallKernelThunk::SerializeAsStringImpl() const { + KernelThunkProto proto; + // TODO(basioli): how is SmallKernelThunk different from KernelThunk and how + // is it used? + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/kernel_thunk.h b/xla/backends/cpu/runtime/kernel_thunk.h index 173f44420719ab..43890d6380f0c4 100644 --- a/xla/backends/cpu/runtime/kernel_thunk.h +++ b/xla/backends/cpu/runtime/kernel_thunk.h @@ -142,6 +142,9 @@ class SmallKernelThunk final tsl::AsyncValueRef Execute( const Thunk::ExecuteParams& params) final; + + protected: + absl::StatusOr SerializeAsStringImpl() const final; }; // Kernel thunk specialization for dynamic number of arguments and results. @@ -165,6 +168,9 @@ class KernelThunk final : public internal::KernelThunk<> { tsl::AsyncValueRef Execute( const Thunk::ExecuteParams& params) final; + + protected: + absl::StatusOr SerializeAsStringImpl() const final; }; } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/logical_id_thunk.cc b/xla/backends/cpu/runtime/logical_id_thunk.cc index ace52302dc953d..12c38b071228cd 100644 --- a/xla/backends/cpu/runtime/logical_id_thunk.cc +++ b/xla/backends/cpu/runtime/logical_id_thunk.cc @@ -17,9 +17,12 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "xla/backends/cpu/runtime/thunk.h" @@ -30,8 +33,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu::internal { @@ -53,6 +55,18 @@ LogicalIdThunk::Create( new LogicalIdThunk(std::move(info), logical_id_buffer)); } +template +absl::StatusOr +LogicalIdThunk::SerializeAsStringImpl() const { + return absl::UnimplementedError("Not implemented"); + // Maybe another layer of abstraction to take care of the classes that inherit + // this? LogicalIdThunkProto proto; + // // TODO(basioli): do we need these? + // // LogicalIdKind + // // BufferAllocation::Slice logical_id_buffer_ + // return proto.SerializeAsString(); +} + template LogicalIdThunk::LogicalIdThunk( Info info, BufferAllocation::Slice logical_id_buffer) diff --git a/xla/backends/cpu/runtime/logical_id_thunk.h b/xla/backends/cpu/runtime/logical_id_thunk.h index 6a42fe69963d1a..7f047646daa515 100644 --- a/xla/backends/cpu/runtime/logical_id_thunk.h +++ b/xla/backends/cpu/runtime/logical_id_thunk.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/thunk.h" @@ -44,6 +45,9 @@ class LogicalIdThunk : public Thunk { BufferUses buffer_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: LogicalIdThunk(Info info, BufferAllocation::Slice logical_id_buffer); diff --git a/xla/backends/cpu/runtime/outfeed_thunk.cc b/xla/backends/cpu/runtime/outfeed_thunk.cc index b541953a403dee..2fea4a58c84b78 100644 --- a/xla/backends/cpu/runtime/outfeed_thunk.cc +++ b/xla/backends/cpu/runtime/outfeed_thunk.cc @@ -18,21 +18,25 @@ limitations under the License. #include #include #include +#include #include +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -97,6 +101,22 @@ tsl::AsyncValueRef OutfeedThunk::Execute( return OkExecuteEvent(); } +absl::StatusOr OutfeedThunk::SerializeAsStringImpl() const { + OutfeedThunkProto proto; + + *proto.mutable_outfeed_resources()->mutable_consume_token() = + outfeed_resources_.consume_token->ToProto(); + *proto.mutable_outfeed_resources()->mutable_produce_token() = + outfeed_resources_.produce_token->ToProto(); + + for (const OutfeedBuffer& outfeed_buffer : outfeed_buffers_) { + TF_RETURN_IF_ERROR( + SerializeSliceShapeIntoProto(outfeed_buffer.slice, outfeed_buffer.shape, + proto.add_outfeed_buffers_shapes())); + } + return proto.SerializeAsString(); +} + OutfeedThunk::BufferUses OutfeedThunk::buffer_uses() const { BufferUses buffer_uses; for (const OutfeedBuffer& outfeed_buffer : outfeed_buffers_) { diff --git a/xla/backends/cpu/runtime/outfeed_thunk.h b/xla/backends/cpu/runtime/outfeed_thunk.h index 74920899255d46..6d5de24f42f098 100644 --- a/xla/backends/cpu/runtime/outfeed_thunk.h +++ b/xla/backends/cpu/runtime/outfeed_thunk.h @@ -50,6 +50,9 @@ class OutfeedThunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: OutfeedThunk(Info info, absl::Span outfeed_buffers, OutfeedResources outfeed_resources); diff --git a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index 570621d6c970eb..0b473caeafad79 100644 --- a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc +++ b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -17,26 +17,30 @@ limitations under the License. #include #include +#include #include #include "absl/container/inlined_vector.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.pb.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -103,4 +107,15 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) { }); } +absl::StatusOr +ReduceScatterThunk::SerializeAsStringCollectiveImpl() const { + ReduceScatterThunkProto proto; + absl::string_view reduction_kind_as_string_view = + ReductionKindToString(reduction_kind_); + std::string reduction_kind_as_string(reduction_kind_as_string_view.begin(), + reduction_kind_as_string_view.end()); + proto.set_reduction_kind(reduction_kind_as_string); + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/reduce_scatter_thunk.h b/xla/backends/cpu/runtime/reduce_scatter_thunk.h index 104d6c354dfa88..5a99fe1d7b0e52 100644 --- a/xla/backends/cpu/runtime/reduce_scatter_thunk.h +++ b/xla/backends/cpu/runtime/reduce_scatter_thunk.h @@ -34,6 +34,9 @@ class ReduceScatterThunk final : public CollectiveThunk { tsl::AsyncValueRef Execute(const ExecuteParams& params) final; + protected: + absl::StatusOr SerializeAsStringCollectiveImpl() const final; + private: ReduceScatterThunk(Info info, ReductionKind reduction_kind, OpParams op_params, OpBuffers op_buffers, diff --git a/xla/backends/cpu/runtime/resource_use.cc b/xla/backends/cpu/runtime/resource_use.cc index a3c03849b5178a..737b11aa7d2575 100644 --- a/xla/backends/cpu/runtime/resource_use.cc +++ b/xla/backends/cpu/runtime/resource_use.cc @@ -21,9 +21,36 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/types/span.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { +Resource::Resource(const ResourceProto& proto) { + if (proto.kind() == ResourceProto::TOKEN) { + kind_ = Kind::kToken; + } else if (proto.kind() == ResourceProto::COLLECTIVE_COMMUNICATOR) { + kind_ = Kind::kCollectiveCommunicator; + } else { + // TODO(basioli) what to do here? + } +} + +ResourceProto Resource::ToProto() const { + ResourceProto proto; + switch (kind_) { + case Kind::kToken: + proto.set_kind(ResourceProto::TOKEN); + break; + case Kind::kCollectiveCommunicator: + proto.set_kind(ResourceProto::COLLECTIVE_COMMUNICATOR); + break; + default: + // TODO(basioli) what to do here? + break; + } + return proto; +} + std::shared_ptr Resource::Create(Kind kind) { return absl::WrapUnique(new Resource(kind)); } diff --git a/xla/backends/cpu/runtime/resource_use.h b/xla/backends/cpu/runtime/resource_use.h index 1442a2895a02bf..c44a447932716f 100644 --- a/xla/backends/cpu/runtime/resource_use.h +++ b/xla/backends/cpu/runtime/resource_use.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" +#include "xla/xla_data.pb.h" namespace xla::cpu { @@ -43,9 +44,11 @@ class Resource { static constexpr Kind kToken = Kind::kToken; static constexpr Kind kCollectiveCommunicator = Kind::kCollectiveCommunicator; + explicit Resource(const ResourceProto& proto); static std::shared_ptr Create(Kind kind); Kind kind() const { return kind_; } + ResourceProto ToProto() const; private: explicit Resource(Kind kind); diff --git a/xla/backends/cpu/runtime/rng_state_thunk.cc b/xla/backends/cpu/runtime/rng_state_thunk.cc index 39a3de9b9429dc..ff93a1031d5413 100644 --- a/xla/backends/cpu/runtime/rng_state_thunk.cc +++ b/xla/backends/cpu/runtime/rng_state_thunk.cc @@ -18,21 +18,23 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/config.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/numeric/int128.h" -#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -50,6 +52,16 @@ RngGetAndUpdateStateThunk::Create(Info info, new RngGetAndUpdateStateThunk(std::move(info), state_buffer, delta)); } +absl::StatusOr RngGetAndUpdateStateThunk::SerializeAsStringImpl() + const { + RngGetAndUpdateStateThunkProto proto; + proto.set_delta(delta_); + TF_ASSIGN_OR_RETURN(const std::string slice_as_str, + state_buffer_.SerializeAsString()); + proto.mutable_state_buffer()->ParseFromString(slice_as_str); + return proto.SerializeAsString(); +} + RngGetAndUpdateStateThunk::RngGetAndUpdateStateThunk( Info info, BufferAllocation::Slice state_buffer, int64_t delta) : Thunk(Kind::kRngGetAndUpdateState, info), diff --git a/xla/backends/cpu/runtime/rng_state_thunk.h b/xla/backends/cpu/runtime/rng_state_thunk.h index d00bf4523e5dea..92bfc38f144ae2 100644 --- a/xla/backends/cpu/runtime/rng_state_thunk.h +++ b/xla/backends/cpu/runtime/rng_state_thunk.h @@ -43,6 +43,9 @@ class RngGetAndUpdateStateThunk final : public Thunk { return {{state_buffer_, BufferUse::kWrite}}; } + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: RngGetAndUpdateStateThunk(Info info, BufferAllocation::Slice state_buffer, int64_t delta); diff --git a/xla/backends/cpu/runtime/sort_thunk.cc b/xla/backends/cpu/runtime/sort_thunk.cc index 96534db43b1345..cf12ef0f9d5cca 100644 --- a/xla/backends/cpu/runtime/sort_thunk.cc +++ b/xla/backends/cpu/runtime/sort_thunk.cc @@ -44,6 +44,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/cpu/runtime/function_library.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" @@ -109,6 +110,30 @@ absl::StatusOr> SortThunk::Create( direction)); } +absl::StatusOr SortThunk::SerializeAsStringImpl() const { + SortThunkProto proto; + proto.set_dimension(dimension_); + proto.set_is_stable(is_stable_); + proto.set_comparator_name(comparator_name_); + // TODO(basioli): what about LessThan? + if (direction_.has_value()) { + switch (direction_.value()) { + case SortDirection::kAscending: + proto.set_direction(SortThunkProto::ASCENDING); + break; + case SortDirection::kDescending: + proto.set_direction(SortThunkProto::DESCENDING); + break; + } + } + + for (const Input& input : inputs_) { + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto(input.slice, input.shape, + proto.add_inputs_shapes())); + } + return proto.SerializeAsString(); +} + SortThunk::SortThunk(Info info, absl::Span inputs, int64_t dimension, bool is_stable, LessThan less_than, std::optional direction) diff --git a/xla/backends/cpu/runtime/sort_thunk.h b/xla/backends/cpu/runtime/sort_thunk.h index 6d32ab1ac3c5f6..8061cbb5ea737b 100644 --- a/xla/backends/cpu/runtime/sort_thunk.h +++ b/xla/backends/cpu/runtime/sort_thunk.h @@ -63,6 +63,9 @@ class SortThunk final : public Thunk { BufferUses buffer_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: SortThunk(Info info, absl::Span inputs, int64_t dimension, bool is_stable, LessThan less_than, diff --git a/xla/backends/cpu/runtime/thunk.cc b/xla/backends/cpu/runtime/thunk.cc index 96cf954095f20c..40a72cad65a75c 100644 --- a/xla/backends/cpu/runtime/thunk.cc +++ b/xla/backends/cpu/runtime/thunk.cc @@ -22,9 +22,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/collectives/in_process_collectives.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/global_device_id.h" @@ -32,11 +36,20 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" namespace xla::cpu { +absl::StatusOr Thunk::Info::SerializeAsString() const { + InfoProto proto; + proto.set_op_name(op_name); + proto.set_module_name(module_name); + proto.set_module_id(module_id); + return proto.SerializeAsString(); +} + absl::string_view Thunk::KindToString(Kind kind) { switch (kind) { case Kind::kAllGather: @@ -85,6 +98,33 @@ absl::string_view Thunk::KindToString(Kind kind) { return "xnn-fusion"; } } + +/*virtual*/ absl::StatusOr Thunk::SerializeAsString() const { + ThunkProto proto; + absl::string_view kind_as_string_view = KindToString(kind_); + std::string kind_as_string(kind_as_string_view.begin(), + kind_as_string_view.end()); + proto.set_kind(kind_as_string); + TF_ASSIGN_OR_RETURN(const std::string info_as_string, + info().SerializeAsString()); + proto.mutable_info()->ParseFromString(info_as_string); + TF_ASSIGN_OR_RETURN(const std::string impl_as_string, + SerializeAsStringImpl()); + + proto.mutable_info()->ParseFromString(info_as_string); + return proto.SerializeAsString(); +} + +absl::StatusOr Thunk::SerializeAsStringImpl() const { + return absl::UnimplementedError("SerializeAsStringImpl is not implemented"); +} + +// TODO(basioli): How are we going to deserialize this polymorphically? +/*static*/ absl::StatusOr> Thunk::FromString( + const std::string& serialized) { + return absl::UnimplementedError("Thunk::FromString is Not implemented yet"); +} + Thunk::Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)), @@ -207,4 +247,25 @@ ThunkSequence::ResourceUses ThunkSequence::resource_uses() const { return resource_uses; } +absl::StatusOr ThunkSequence::SerializeAsString() const { + ThunkSequenceProto proto; + proto.mutable_thunks()->Reserve(size()); + for (auto& thunk : *this) { + ThunkProto* thunk_proto = proto.add_thunks(); + + TF_ASSIGN_OR_RETURN(const std::string thunk_as_string, + thunk->SerializeAsString()); + if (!thunk_proto->ParseFromString(thunk_as_string)) { + return absl::InternalError(absl::StrFormat( + "Failed to parse thunk proto:\n %s", thunk_as_string)); + } + } + return proto.SerializeAsString(); +} + +/*static*/ absl::StatusOr> +ThunkSequence::FromString(const std::string& serialized) { + return absl::UnimplementedError("Not implemented yet"); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/thunk.h b/xla/backends/cpu/runtime/thunk.h index 2c86db92517745..91c36d53f5e24f 100644 --- a/xla/backends/cpu/runtime/thunk.h +++ b/xla/backends/cpu/runtime/thunk.h @@ -94,6 +94,8 @@ class Thunk { std::string op_name; std::string module_name; int64_t module_id; + + absl::StatusOr SerializeAsString() const; }; using Task = std::function; @@ -132,6 +134,12 @@ class Thunk { Kind kind() const { return kind_; } const Info& info() const { return info_; } + absl::StatusOr SerializeAsString() const; + + // TODO(basioli): How are we going to deserialize this polymorphically? + static absl::StatusOr> FromString( + const std::string& serialized); + static absl::string_view KindToString(Kind kind); // Returns the list of buffers used by a thunk. Thunk executor relies on this @@ -287,6 +295,7 @@ class Thunk { const ExecuteParams& params) = 0; protected: + virtual absl::StatusOr SerializeAsStringImpl() const; // Encodes thunk info into the TraceMe compatible format. std::string TraceMeEncode() const; @@ -337,6 +346,13 @@ class ThunkSequence : public std::vector> { void Append(ThunkSequence other); + absl::StatusOr SerializeAsString() const; + + // TODO(basioli): This will probably require more arguments to actually be + // able to deserialize. + static absl::StatusOr> FromString( + const std::string& serialized); + private: explicit ThunkSequence(std::unique_ptr thunk); }; diff --git a/xla/backends/cpu/runtime/thunk.proto b/xla/backends/cpu/runtime/thunk.proto new file mode 100644 index 00000000000000..caa8a4c67e9241 --- /dev/null +++ b/xla/backends/cpu/runtime/thunk.proto @@ -0,0 +1,192 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.cpu; + +import "xla/backends/cpu/runtime/collective_thunk.proto"; +import "xla/service/hlo.proto"; +import "xla/xla_data.proto"; + +message CallThunkProto { + ThunkSequenceProto called_sequence = 1; +} + +message ConditionalThunkProto { + repeated ThunkSequenceProto branch_sequences = 1; + BufferAllocationSliceProto branch_index_buffer = 2; +} + +message ConvolutionThunkProto { + message Options { + bool multi_threaded = 1; + bool use_acl = 2; + } + ConvolutionDimensionNumbers dimension_numbers = 1; + Window window = 2; + int64 feature_group_count = 3; + ShapeBufferAllocationSliceProto input_buffer_shape = 4; + ShapeBufferAllocationSliceProto kernel_buffer_shape = 5; + ShapeBufferAllocationSliceProto output_buffer_shape = 6; + Options options = 7; +} + +message SortThunkProto { + enum SortDirection { + UNKNOWN = 0; + ASCENDING = 1; + DESCENDING = 2; + } + int64 dimension = 1; + bool is_stable = 2; + SortDirection direction = 3; // TODO(basioli) optional + string comparator_name = 4; + repeated ShapeBufferAllocationSliceProto inputs_shapes = 5; +} + +message XnnFusionThunkProto {} + +message XnnDotThunkProto { + DotDimensionNumbers dot_dimensions = 1; + ShapeBufferAllocationSliceProto lhs_buffer_shape = 2; + ShapeBufferAllocationSliceProto rhs_buffer_shape = 3; + ShapeBufferAllocationSliceProto out_buffer_shape = 4; +} + +message DotThunkProto { + DotDimensionNumbers dot_dimensions = 1; + ShapeBufferAllocationSliceProto lhs_buffer_shape = 2; + ShapeBufferAllocationSliceProto rhs_buffer_shape = 3; + ShapeBufferAllocationSliceProto out_buffer_shape = 4; +} + +message RngGetAndUpdateStateThunkProto { + int64 delta = 1; + BufferAllocationSliceProto state_buffer = 2; +} + +message TopKThunkProto { + int64 batch_size = 1; + int64 input_size = 2; + int64 k = 3; + BufferAllocationSliceProto values_buffer = 4; + BufferAllocationSliceProto output_buffer = 5; + BufferAllocationSliceProto indices_buffer = 6; +} + +message WhileThunkProto { + ThunkSequenceProto cond_sequence = 1; + ThunkSequenceProto body_sequence = 2; + int64 trip_count = 3; // TODO(basioli) optional + BufferAllocationSliceProto cond_buffer = 4; +} + +message KernelThunkProto { + message ThreadDim { + int64 x = 1; + int64 y = 2; + int64 z = 3; + } + string kernel_name = 1; + ThreadDim thread_dim = 2; + // TODO(basioli) maybe optional? NOTE this is a set in C++ + repeated int64 invariant_arguments = 3; + // TODO(basioli) maybe optional? + int64 min_alignment = 4; + repeated BufferAllocationSliceProto arguments_buffers = 5; + repeated BufferAllocationSliceProto results_buffers = 6; +} + +message CopyThunkProto { + ShapeBufferAllocationSliceProto src_buffer_shape = 1; + ShapeBufferAllocationSliceProto dst_buffer_shape = 2; +} + +message FftThunkProto { + bool is_multi_thread_eigen = 1; + int32 fft_type = 2; + repeated int64 fft_length = 3; + ShapeBufferAllocationSliceProto input_buffer_shape = 4; + ShapeBufferAllocationSliceProto output_buffer_shape = 5; +} + +message InfeedThunkProto { + message InfeedResource { + // TODO(basioli) these are pointers in C++, maybe optional? + ResourceProto consume_token = 1; + ResourceProto produce_token = 2; + } + + InfeedResource infeed_resources = 1; + repeated ShapeBufferAllocationSliceProto infeed_buffers_shapes = 2; +} + +message OutfeedThunkProto { + message OutfeedResource { + // TODO(basioli) these are pointers in C++, maybe optional? + ResourceProto consume_token = 1; + ResourceProto produce_token = 2; + } + + OutfeedResource outfeed_resources = 1; + repeated ShapeBufferAllocationSliceProto outfeed_buffers_shapes = 2; +} + +message CustomCallThunkProto { + message OpBuffers { + repeated ShapeBufferAllocationSliceProto arguments_shapes = 1; + repeated ShapeBufferAllocationSliceProto results_shapes = 2; + } + CustomCallApiVersion api_version = 1; + string target_name = 2; + string backend_config = 3; + OpBuffers op_buffers = 4; +} + +message InfoProto { + string op_name = 1; + string module_name = 2; + int64 module_id = 3; +} + +// TODO(kbasioli) will probably need to add more fields here. +message ThunkProto { + string kind = 1; + InfoProto info = 2; + oneof impl { + CallThunkProto call_thunk = 3; + ConditionalThunkProto conditional_thunk = 4; + SortThunkProto sort_thunk = 5; + XnnFusionThunkProto xnn_fusion_thunk = 6; + XnnDotThunkProto xnn_dot_thunk = 7; + DotThunkProto dot_thunk = 8; + RngGetAndUpdateStateThunkProto rng_get_and_update_state_thunk = 9; + TopKThunkProto top_k_thunk = 10; + WhileThunkProto while_thunk = 11; + KernelThunkProto kernel_thunk = 12; + CopyThunkProto copy_thunk = 13; + FftThunkProto fft_thunk = 14; + InfeedThunkProto infeed_thunk = 15; + OutfeedThunkProto outfeed_thunk = 16; + CustomCallThunkProto custom_call_thunk = 17; + ConvolutionThunkProto convolution_thunk = 18; + CollectiveThunkProto collective_thunk = 19; + } +} + +message ThunkSequenceProto { + repeated ThunkProto thunks = 1; +} diff --git a/xla/backends/cpu/runtime/thunk_executor.h b/xla/backends/cpu/runtime/thunk_executor.h index 54b4a4be2ac0c6..d0437f2b14f4f7 100644 --- a/xla/backends/cpu/runtime/thunk_executor.h +++ b/xla/backends/cpu/runtime/thunk_executor.h @@ -95,6 +95,8 @@ class ThunkExecutor { // If any of the thunks failed, the event will be in error state. tsl::AsyncValueRef Execute(const Thunk::ExecuteParams& params); + const ThunkSequence& thunk_sequence() const { return thunk_sequence_; } + absl::Span nodes_defs() const { return nodes_defs_; } const NodeDef& node_def(NodeId id) const { return nodes_defs_[id]; } diff --git a/xla/backends/cpu/runtime/thunk_executor_test.cc b/xla/backends/cpu/runtime/thunk_executor_test.cc index dd315236916dd1..8a51426e309d76 100644 --- a/xla/backends/cpu/runtime/thunk_executor_test.cc +++ b/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -126,6 +126,9 @@ class AddI32Thunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: std::vector srcs_; std::vector dsts_; @@ -143,6 +146,11 @@ std::unique_ptr AddI32Thunk::Create( use_shared_resource, inject_error); } +absl::StatusOr AddI32Thunk::SerializeAsStringImpl() const { + // NOTE(basioli) no need for this as it is just a test thunk. + return absl::UnimplementedError("Not implemented"); +} + AddI32Thunk::AddI32Thunk(std::string name, std::vector srcs, std::vector dsts, diff --git a/xla/backends/cpu/runtime/topk_thunk.cc b/xla/backends/cpu/runtime/topk_thunk.cc index 0c72933dc1a3aa..abe481b5b6e3b1 100644 --- a/xla/backends/cpu/runtime/topk_thunk.cc +++ b/xla/backends/cpu/runtime/topk_thunk.cc @@ -17,16 +17,18 @@ limitations under the License. #include #include +#include #include #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/runtime_topk.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" namespace xla::cpu { @@ -69,4 +71,21 @@ tsl::AsyncValueRef TopKThunk::Execute( return OkExecuteEvent(); } +absl::StatusOr TopKThunk::SerializeAsStringImpl() const { + TopKThunkProto proto; + proto.set_batch_size(batch_size_); + proto.set_input_size(input_size_); + proto.set_k(k_); + TF_ASSIGN_OR_RETURN(const std::string values_as_str, + values_buffer_.SerializeAsString()); + proto.mutable_values_buffer()->ParseFromString(values_as_str); + TF_ASSIGN_OR_RETURN(const std::string output_as_str, + output_buffer_.SerializeAsString()); + proto.mutable_output_buffer()->ParseFromString(output_as_str); + TF_ASSIGN_OR_RETURN(const std::string indices_as_str, + indices_buffer_.SerializeAsString()); + proto.mutable_indices_buffer()->ParseFromString(indices_as_str); + return proto.SerializeAsString(); +} + } // namespace xla::cpu diff --git a/xla/backends/cpu/runtime/topk_thunk.h b/xla/backends/cpu/runtime/topk_thunk.h index 7e7fadb03852e7..53b23940349f7e 100644 --- a/xla/backends/cpu/runtime/topk_thunk.h +++ b/xla/backends/cpu/runtime/topk_thunk.h @@ -41,6 +41,9 @@ class TopKThunk final : public Thunk { BufferUse::Write(indices_buffer_)}; } + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: TopKThunk(Info info, BufferAllocation::Slice values, BufferAllocation::Slice output, BufferAllocation::Slice indices, diff --git a/xla/backends/cpu/runtime/while_thunk.cc b/xla/backends/cpu/runtime/while_thunk.cc index 6c1e81f5dee0d6..e4e4b7461cf858 100644 --- a/xla/backends/cpu/runtime/while_thunk.cc +++ b/xla/backends/cpu/runtime/while_thunk.cc @@ -19,22 +19,26 @@ limitations under the License. #include #include #include +#include #include #include "absl/base/optimization.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" namespace xla::cpu { @@ -86,6 +90,26 @@ tsl::AsyncValueRef WhileThunk::Execute( return ExecuteWhileLoop(params, condition); } +absl::StatusOr WhileThunk::SerializeAsStringImpl() const { + WhileThunkProto proto; + proto.set_trip_count(trip_count_.value()); + + TF_ASSIGN_OR_RETURN(std::string cond_sequence_str, + cond_executor_.thunk_sequence().SerializeAsString()); + + proto.mutable_cond_sequence()->ParseFromString(cond_sequence_str); + + TF_ASSIGN_OR_RETURN(std::string body_sequence_str, + body_executor_.thunk_sequence().SerializeAsString()); + + proto.mutable_body_sequence()->ParseFromString(body_sequence_str); + + TF_ASSIGN_OR_RETURN(std::string cond_buffer_str, + cond_buffer_.SerializeAsString()); + proto.mutable_cond_buffer()->ParseFromString(cond_buffer_str); + return proto.SerializeAsString(); +} + tsl::AsyncValueRef WhileThunk::ExecuteForLoop( const ExecuteParams& params, int64_t trip_count) { for (int64_t loop_counter = 0; loop_counter < trip_count; ++loop_counter) { diff --git a/xla/backends/cpu/runtime/while_thunk.h b/xla/backends/cpu/runtime/while_thunk.h index c1de07de86ad52..7a2669c95bbdce 100644 --- a/xla/backends/cpu/runtime/while_thunk.h +++ b/xla/backends/cpu/runtime/while_thunk.h @@ -47,6 +47,9 @@ class WhileThunk final : public Thunk { BufferUses buffer_uses() const final; ResourceUses resource_uses() const final; + protected: + absl::StatusOr SerializeAsStringImpl() const final; + private: WhileThunk(Info info, BufferAllocation::Slice cond_buffer, ThunkExecutor cond_executor, ThunkExecutor body_executor, diff --git a/xla/backends/cpu/runtime/xnnpack/BUILD b/xla/backends/cpu/runtime/xnnpack/BUILD index 4b01b7614131ff..9193dbf88aae4d 100644 --- a/xla/backends/cpu/runtime/xnnpack/BUILD +++ b/xla/backends/cpu/runtime/xnnpack/BUILD @@ -135,25 +135,20 @@ cc_library( ":xnn_fusion_thunk", ":xnn_interop", "//xla:shape_util", - "//xla:types", - "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/runtime:dot_lib", "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_proto_cc", "//xla/service:buffer_assignment", - "//xla/tsl/concurrency:async_value", "//xla/tsl/platform:errors", "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@XNNPACK", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/profiler/lib:traceme", ], ) diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc index 92d32d86e2461c..657b4395ff361b 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc +++ b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.cc @@ -31,14 +31,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/cpu/runtime/dot_lib.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" #include "xla/backends/cpu/runtime/xnnpack/xnn_fusion_thunk.h" #include "xla/backends/cpu/runtime/xnnpack/xnn_interop.h" #include "xla/service/buffer_assignment.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -122,6 +123,21 @@ absl::StatusOr> XnnDotThunk::Create( std::move(dot_shape), std::move(dot_canonical_dims))); } +absl::StatusOr XnnDotThunk::SerializeAsStringImpl() const { + XnnDotThunkProto proto; + *proto.mutable_dot_dimensions() = dot_dimensions_; + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.lhs_buffer, dot_slices_.lhs_shape, + proto.mutable_lhs_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.rhs_buffer, dot_slices_.rhs_shape, + proto.mutable_rhs_buffer_shape())); + TF_RETURN_IF_ERROR(SerializeSliceShapeIntoProto( + dot_slices_.out_buffer, dot_slices_.out_shape, + proto.mutable_out_buffer_shape())); + return proto.SerializeAsString(); +} + static std::vector DotArguments( const DotSlices& slices) { return {XnnFusionThunk::Argument{slices.lhs_buffer, slices.lhs_shape}, diff --git a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h index b3ae7e88b5e69e..1a96dc5bcbda25 100644 --- a/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h +++ b/xla/backends/cpu/runtime/xnnpack/xnn_dot_thunk.h @@ -47,6 +47,8 @@ class XnnDotThunk final : public XnnFusionThunk { BufferAllocation::Slice out_buffer, Shape out_shape); protected: + absl::StatusOr SerializeAsStringImpl() const final; + std::string fusion_kind() const final; std::string fusion_description() const final; diff --git a/xla/service/BUILD b/xla/service/BUILD index 7b0dd97c912265..7fbe62d672922d 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1704,6 +1704,9 @@ cc_library( "//xla/hlo/utils:hlo_live_range", "//xla/service/heap_simulator", "//xla/service/memory_space_assignment", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -1716,10 +1719,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:numbers", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/buffer_assignment.cc b/xla/service/buffer_assignment.cc index 12540a782a8dba..f7cb3feafb8625 100644 --- a/xla/service/buffer_assignment.cc +++ b/xla/service/buffer_assignment.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -61,11 +62,11 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -236,6 +237,14 @@ std::string BufferAllocation::Slice::ToString() const { ", offset:", offset_, ", size:", size_, "}"); } +absl::StatusOr BufferAllocation::Slice::SerializeAsString() const { + BufferAllocationSliceProto proto; + proto.set_offset(offset_); + proto.set_size(size_); + proto.set_buffer_allocation_index(allocation_ == nullptr ? -1 : index()); + return proto.SerializeAsString(); +} + BufferAllocation::Slice BufferAllocation::GetSlice( const HloValue& buffer) const { const OffsetSize os = FindOrDie(assigned_buffers_, &buffer); @@ -2297,4 +2306,17 @@ BufferAssigner::CreateAssignment( return std::move(assignment); } +absl::Status SerializeSliceShapeIntoProto( + const BufferAllocation::Slice& slice, const Shape& shape, + ShapeBufferAllocationSliceProto* proto) { + TF_ASSIGN_OR_RETURN(const std::string slice_buffer_str, + slice.SerializeAsString()); + + const std::string shape_str = shape.SerializeAsString(); + + proto->mutable_shape()->ParseFromString(shape_str); + proto->mutable_slice()->ParseFromString(slice_buffer_str); + return absl::OkStatus(); +} + } // namespace xla diff --git a/xla/service/buffer_assignment.h b/xla/service/buffer_assignment.h index 99c08a1f157dce..deeee8b768d480 100644 --- a/xla/service/buffer_assignment.h +++ b/xla/service/buffer_assignment.h @@ -180,7 +180,7 @@ class BufferAllocation { // to identify the memory range that a LogicalBuffer corresponds to. class Slice { public: - Slice() {} + Slice() = default; Slice(const BufferAllocation* allocation, int64_t offset, int64_t size) : allocation_(allocation), offset_(offset), size_(size) {} @@ -216,6 +216,8 @@ class BufferAllocation { std::string ToString() const; + absl::StatusOr SerializeAsString() const; + private: const BufferAllocation* allocation_ = nullptr; int64_t offset_ = 0; @@ -373,7 +375,7 @@ class BufferAllocation { }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. -std::ostream& operator<<(std::ostream& out, const BufferAllocation& s); +std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer); std::ostream& operator<<(std::ostream& out, const BufferAllocation::Slice& s); // This class encapsulates an assignment of the LogicalBuffers in an XLA @@ -821,6 +823,10 @@ class BufferAssigner { BufferAssigner& operator=(const BufferAssigner&) = delete; }; +absl::Status SerializeSliceShapeIntoProto( + const BufferAllocation::Slice& slice, const Shape& shape, + ShapeBufferAllocationSliceProto* proto); + } // namespace xla #endif // XLA_SERVICE_BUFFER_ASSIGNMENT_H_ diff --git a/xla/service/hlo.proto b/xla/service/hlo.proto index 4858f4153feff0..5721a2947e0826 100644 --- a/xla/service/hlo.proto +++ b/xla/service/hlo.proto @@ -677,6 +677,17 @@ message BufferAllocationProto { repeated Assigned assigned = 9; } +message BufferAllocationSliceProto { + int64 offset = 1; + int64 size = 2; + int64 buffer_allocation_index = 3; +} + +message ShapeBufferAllocationSliceProto { + xla.ShapeProto shape = 1; + BufferAllocationSliceProto slice = 2; +} + // A trace of a HeapSimulator run. message HeapSimulatorTrace { // The trace includes a list of events, where each event describes one action diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 01a6415549b584..398098cde88419 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -1156,3 +1156,12 @@ message OriginalArrayProto { message OriginalValueProto { repeated OriginalArrayProto leaves = 1; } + +message ResourceProto { + enum Kind { + UNKNOWN = 0; + TOKEN = 1; + COLLECTIVE_COMMUNICATOR = 2; + } + Kind kind = 1; +}