Skip to content

Commit

Permalink
[Coordination Service]Allow restartable tasks to connect back to clus…
Browse files Browse the repository at this point in the history
…ter, as long as they have the same local topology as before.

PiperOrigin-RevId: 712613151
  • Loading branch information
ishark authored and Google-ML-Automation committed Jan 10, 2025
1 parent caef8ec commit adf61c5
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 23 deletions.
3 changes: 1 addition & 2 deletions xla/pjrt/distributed/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ xla_cc_test(
":in_memory_key_value_store",
":protocol_proto_cc",
":topology_util",
"//xla:test_helpers",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/status",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:env",
Expand Down Expand Up @@ -115,7 +115,6 @@ cc_library(
":key_value_store_interface",
":protocol_proto_cc",
"//xla:util",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:utils",
"//xla/pjrt/gpu:gpu_topology_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
88 changes: 77 additions & 11 deletions xla/pjrt/distributed/topology_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#include "xla/pjrt/distributed/topology_util.h"

#include <algorithm>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <map>
#include <set>
Expand All @@ -28,13 +30,13 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/blocking_counter.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/utils.h"
#include "xla/util.h"
#include "tsl/platform/env.h"
Expand All @@ -45,6 +47,34 @@ limitations under the License.

namespace xla {

namespace {
bool SameDevice(const DeviceProto& a, const DeviceProto& b) {
return (a.name() == b.name() && a.vendor() == b.vendor() &&
a.local_device_ordinal() == b.local_device_ordinal() &&
a.core_count() == b.core_count() &&
a.device_kind() == b.device_kind() &&
a.slice_index() == b.slice_index() &&
// Global device ID Might not be set for LocalTopologyProto, still
// check it for default value.
a.global_device_id() == b.global_device_id() &&
a.compute_capability() == b.compute_capability());
}

bool SameLocalTopology(const LocalTopologyProto& a,
const LocalTopologyProto& b) {
if (a.node_id() != b.node_id() || a.devices_size() != b.devices_size()) {
return false;
}
for (int i = 0; i < a.devices_size(); ++i) {
if (!SameDevice(a.devices(i), b.devices(i))) {
return false;
}
}
return true;
}

} // namespace

// Exists on Linux systems. Unique per OS kernel restart.
static constexpr char kBootIdPath[] = "/proc/sys/kernel/random/boot_id";

Expand Down Expand Up @@ -160,14 +190,13 @@ GlobalTopologyProto BuildGlobalTopology(
return global_topology;
}

absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
int num_nodes,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout,
KeyValueStoreInterface* kv_store,
const LocalTopologyProto& local_topology,
GlobalTopologyProto* global_topology,
bool assign_global_device_ids) {
absl::Status ExchangeTopologies(
absl::string_view platform, int node_id, int num_nodes,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout,
KeyValueStoreInterface* kv_store, const LocalTopologyProto& local_topology,
GlobalTopologyProto* global_topology, bool assign_global_device_ids,
int64_t pjrt_major_version, int64_t pjrt_minor_version) {
VLOG(3) << "Local Topology for platform" << platform << ":\n"
<< local_topology.DebugString();
if (num_nodes == 1) {
Expand All @@ -179,8 +208,45 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
return absl::OkStatus();
}
CHECK(kv_store != nullptr);
TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id),
local_topology.SerializeAsString()));
const std::string local_topology_key = GetLocalTopologyKey(platform, node_id);
const std::string serialized_local_topology =
local_topology.SerializeAsString();

bool use_try_get_api = true;
if (pjrt_major_version == 0 && pjrt_minor_version < 61) {
use_try_get_api = false;
}
if (use_try_get_api) {
absl::StatusOr<std::string> existing_local_topology =
kv_store->TryGet(local_topology_key);
printf("existing_local_topology status: %s\n",
existing_local_topology.status().ToString().c_str());

if (existing_local_topology.ok()) {
printf("existing topology found");
// Local topology has been set previously from the same node before
// restart.
LocalTopologyProto existing_local_topology_proto;
existing_local_topology_proto.ParseFromString(*existing_local_topology);
if (!SameLocalTopology(existing_local_topology_proto, local_topology)) {
return absl::InternalError(absl::Substitute(
"Different local topology for node $0 has been set previously, "
"possibly before a restart.\nBefore: $1\nAfter: $2",
node_id, existing_local_topology_proto.DebugString(),
local_topology.DebugString()));
}
} else if (absl::IsNotFound(existing_local_topology.status()) ||
absl::IsDeadlineExceeded(existing_local_topology.status())) {
TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id),
serialized_local_topology));
} else {
return existing_local_topology.status();
}
} else {
// Fallback to Set API as previous behavior.
TF_RETURN_IF_ERROR(kv_store->Set(GetLocalTopologyKey(platform, node_id),
serialized_local_topology));
}

// The lead node gets all local topologies, builds the global topology and
// puts it to the key-value store.
Expand Down
18 changes: 10 additions & 8 deletions xla/pjrt/distributed/topology_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_PJRT_DISTRIBUTED_TOPOLOGY_UTIL_H_
#define XLA_PJRT_DISTRIBUTED_TOPOLOGY_UTIL_H_

#include <cstdint>
#include <string>

#include "absl/status/status.h"
Expand All @@ -39,14 +40,15 @@ absl::StatusOr<std::string> GetBootIdString();
// topology in the order they appear in the input. Otherwise leaves the global
// IDs as they were in the local topologies..
// TODO(phawkins): deprecate and remove assign_global_device_ids.
absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
int num_nodes,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout,
KeyValueStoreInterface* kv_store,
const LocalTopologyProto& local_topology,
GlobalTopologyProto* global_topology,
bool assign_global_device_ids);
// TODO(ishark): deprecate and remove pjrt_version once TryGet API is available
// on OSS GPU.
absl::Status ExchangeTopologies(
absl::string_view platform, int node_id, int num_nodes,
absl::Duration get_local_topology_timeout,
absl::Duration get_global_topology_timeout,
KeyValueStoreInterface* kv_store, const LocalTopologyProto& local_topology,
GlobalTopologyProto* global_topology, bool assign_global_device_ids,
int64_t pjrt_major_version = 0, int64_t pjrt_minor_version = 61);

// Functions below this point are public only for testing.

Expand Down
91 changes: 90 additions & 1 deletion xla/pjrt/distributed/topology_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ limitations under the License.
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "xla/pjrt/distributed/in_memory_key_value_store.h"
#include "xla/pjrt/distributed/protocol.pb.h"
#include "xla/test_helpers.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/statusor.h"
Expand All @@ -31,6 +31,7 @@ limitations under the License.

namespace xla {
namespace {
using tsl::testing::StatusIs;

TEST(TopologyTest, BuildGlobalTopology) {
std::vector<LocalTopologyProto> locals(2);
Expand Down Expand Up @@ -86,6 +87,94 @@ TEST(TopologyTest, ExchangeTopology) {
}
}

TEST(TopologyTest, ExchangeTopology_Twice_Succeeds) {
int num_nodes = 2;
std::vector<LocalTopologyProto> locals(num_nodes);
DeviceProto* d0 = locals[0].add_devices();
d0->set_local_device_ordinal(0);
DeviceProto* d1 = locals[0].add_devices();
d1->set_local_device_ordinal(0);
DeviceProto* d2 = locals[1].add_devices();
d2->set_local_device_ordinal(0);
DeviceProto* d3 = locals[1].add_devices();
d3->set_local_device_ordinal(1);

InMemoryKeyValueStore kv_store;
std::vector<GlobalTopologyProto> globals(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool",
num_nodes);
for (int i = 0; i < num_nodes; i++) {
thread_pool.Schedule([&, i] {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
// Simulate node 1 restarting and exchanging topologies again.
if (i == 1) {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
}
});
}
}
for (const GlobalTopologyProto& global : globals) {
EXPECT_EQ(global.nodes_size(), 2);
EXPECT_EQ(global.nodes()[0].devices_size(), 2);
EXPECT_EQ(global.nodes()[1].devices_size(), 2);
}
}

TEST(TopologyTest, ExchangeTopology_TwiceWithDifferentLocalTopology_Fails) {
int num_nodes = 2;
std::vector<LocalTopologyProto> locals(num_nodes);
DeviceProto* d0 = locals[0].add_devices();
d0->set_local_device_ordinal(0);
DeviceProto* d1 = locals[0].add_devices();
d1->set_local_device_ordinal(0);
DeviceProto* d2 = locals[1].add_devices();
d2->set_local_device_ordinal(0);
DeviceProto* d3 = locals[1].add_devices();
d3->set_local_device_ordinal(1);

InMemoryKeyValueStore kv_store;
std::vector<GlobalTopologyProto> globals(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "TestPool",
num_nodes);
for (int i = 0; i < num_nodes; i++) {
thread_pool.Schedule([&, i] {
TF_ASSERT_OK(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true));
// Simulate node 1 restarting with different devices.
if (i == 1) {
DeviceProto* d4 = locals[1].add_devices();
d4->set_local_device_ordinal(2);
// This should fail because the local topology is unexpectedly
// different.
EXPECT_THAT(ExchangeTopologies(
/*platform=*/"cuda", /*node_id=*/i, num_nodes,
/*get_local_topology_timeout=*/
absl::Seconds(10), /*get_global_topology_timeout=*/
absl::Seconds(10), &kv_store, locals[i], &globals[i],
/*assign_global_device_ids=*/true),
StatusIs(absl::StatusCode::kInternal));
}
});
}
}
}

TEST(TopologyTest, BuildGpuTopology) {
std::string slice_0_boot_id = "foo";
std::string slice_1_boot_id = "bar";
Expand Down
10 changes: 9 additions & 1 deletion xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,19 @@ absl::StatusOr<GlobalTopology> MakeGlobalTopologyWithLocalTopology(
xla::PjRtClient* pjrt_client, const PjRtClient::CreateOptions& options,
const LocalTopologyProto& local_topology_proto) {
GlobalTopologyProto global_topology_proto;
auto plugin_attributes = pjrt_client->plugin_attributes();
int64_t pjrt_major_version = 0;
int64_t pjrt_minor_version = 0;
if (plugin_attributes.has_value()) {
pjrt_major_version = plugin_attributes->pjrt_c_api_major_version;
pjrt_minor_version = plugin_attributes->pjrt_c_api_minor_version;
}
TF_RETURN_IF_ERROR(ExchangeTopologies(
pjrt_client->platform_name(), options.process_id, options.num_processes,
options.get_local_topology_timeout, options.get_global_topology_timeout,
options.kv_store.get(), local_topology_proto, &global_topology_proto,
/*assign_global_device_ids=*/false));
/*assign_global_device_ids=*/false, pjrt_major_version,
pjrt_minor_version));

std::optional<int> my_process_index;
absl::flat_hash_map<DeviceId, xla::PjRtGlobalDeviceId>
Expand Down

0 comments on commit adf61c5

Please sign in to comment.