Skip to content

Commit

Permalink
TCP init method race condition fix (pytorch#15684)
Browse files Browse the repository at this point in the history
Summary:
This PR fixes a race condition for TCP init method, when master rank can exit earlier than slave ranks and thus the TCP daemon thread gets shutdown before other slaves are able to access it.

This will let every rank (process) write a special key to the store to mark that they are completed (and thus about to exit).  The master rank (who is the server) will always wait until all the ranks to complete before complete itself.

This should fix: pytorch#15638

Tested using the repro of pytorch#15638 and works fine. Also test_distributed and test_c10d should have already had this coverage.

I had to make rendezvous test in c10d the world size of 1, since it is a single process code.
Pull Request resolved: pytorch#15684

Differential Revision: D13570904

Pulled By: teng-li

fbshipit-source-id: 34f3bc471204bbd29320df359347ad5561c6b589
  • Loading branch information
teng-li authored and facebook-github-bot committed Jan 18, 2019
1 parent aaff2fe commit b4bc55b
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 50 deletions.
64 changes: 24 additions & 40 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def create_tcp_store(addr):
try:
port = common.find_free_port()
ports.append(port)
return c10d.TCPStore(addr, port, True)
return c10d.TCPStore(addr, port, 1, True)
except RuntimeError as error:
if str(error) == "Address already in use":
continue
Expand All @@ -226,8 +226,8 @@ def test_address_already_in_use(self):
# Use noqa to silence flake8.
# Need to store in an unused variable here to ensure the first
# object is not destroyed before the second object is created.
store1 = c10d.TCPStore(addr, port, True) # noqa: F841
store2 = c10d.TCPStore(addr, port, True) # noqa: F841
store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841
store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841


class PrefixTCPStoreTest(TestCase, StoreTestBase):
Expand All @@ -254,7 +254,7 @@ def test_common_errors(self):
raise unittest.SkipTest("C10D is not built with NCCL process group,"
" skipping test")
vars = {
"WORLD_SIZE": "2",
"WORLD_SIZE": "1",
"RANK": "0",
"MASTER_ADDR": "127.0.0.1",
"MASTER_PORT": common.find_free_port(),
Expand Down Expand Up @@ -287,9 +287,9 @@ def withouts(d, keys):
with self.assertRaisesRegex(ValueError, 'WORLD_SIZE expected'):
gen = c10d.rendezvous('env://')
next(gen)
c10d.init_process_group(backend='nccl', world_size=2)
c10d.init_process_group(backend='nccl', world_size=1)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
self.assertEqual(c10d.get_world_size(), 1)
c10d.destroy_process_group()

with Env(without(vars, 'RANK')):
Expand All @@ -298,19 +298,19 @@ def withouts(d, keys):
next(gen)
c10d.init_process_group(backend='nccl', rank=0)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
self.assertEqual(c10d.get_world_size(), 1)
c10d.destroy_process_group()

with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])):
c10d.init_process_group(backend='nccl', rank=0, world_size=2)
c10d.init_process_group(backend='nccl', rank=0, world_size=1)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
self.assertEqual(c10d.get_world_size(), 1)
c10d.destroy_process_group()

with Env(vars):
c10d.init_process_group(backend='nccl')
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
self.assertEqual(c10d.get_world_size(), 1)
c10d.destroy_process_group()

with Env(without(vars, 'MASTER_ADDR')):
Expand All @@ -324,48 +324,38 @@ def withouts(d, keys):
next(gen)

with Env(without(vars, 'WORLD_SIZE')):
gen = c10d.rendezvous('env://?world_size={}'.format(2))
gen = c10d.rendezvous('env://?world_size={}'.format(1))
_, _, size = next(gen)
self.assertEqual(size, 2)
self.assertEqual(size, 1)

with Env(without(vars, 'RANK')):
gen = c10d.rendezvous('env://?rank={}'.format(0))
_, rank, _ = next(gen)
self.assertEqual(rank, 0)

with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])):
gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 2))
gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 1))
_, rank, size = next(gen)
self.assertEqual(rank, 0)
self.assertEqual(size, 2)
self.assertEqual(size, 1)

@retry_on_address_already_in_use_error
def test_nominal(self):
os.environ['WORLD_SIZE'] = '2'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(common.find_free_port())

# First rank
# Single rank
os.environ['RANK'] = '0'
gen0 = c10d.rendezvous('env://')
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(2, size0)
self.assertEqual(1, size0)

# Second rank
os.environ['RANK'] = '1'
gen1 = c10d.rendezvous('env://')
store1, rank1, size1 = next(gen1)
self.assertEqual(1, rank1)
self.assertEqual(2, size1)

# Set value on both stores
store0.set("key0", "value0")
store1.set("key1", "value1")

# Cross check with get
self.assertEqual(b"value0", store1.get("key0"))
self.assertEqual(b"value1", store0.get("key1"))
# check with get
self.assertEqual(b"value0", store0.get("key0"))


class RendezvousFileTest(TestCase):
Expand Down Expand Up @@ -417,23 +407,17 @@ def test_common_errors(self):
def test_nominal(self):
addr = 'localhost'
port = common.find_free_port()
url = 'tcp://%s:%d?world_size=%d' % (addr, port, 2)
url = 'tcp://%s:%d?world_size=%d' % (addr, port, 1)
gen0 = c10d.rendezvous(url + "&rank=0")
store0, rank0, size0 = next(gen0)
self.assertEqual(0, rank0)
self.assertEqual(2, size0)
gen1 = c10d.rendezvous(url + "&rank=1")
store1, rank1, size1 = next(gen1)
self.assertEqual(1, rank1)
self.assertEqual(2, size1)
self.assertEqual(1, size0)

# Set value on both stores
# Set value on the single store
store0.set("key0", "value0")
store1.set("key1", "value1")

# Cross check with get
self.assertEqual(b"value0", store1.get("key0"))
self.assertEqual(b"value1", store0.get("key1"))
# check with get
self.assertEqual(b"value0", store0.get("key0"))


class MultiProcessTestCase(TestCase):
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ They are used in specifying strategies for reduction collectives, e.g.,
.def(py::init<const std::string&, int>());

shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store)
.def(py::init<const std::string&, int, bool>());
.def(py::init<const std::string&, int, int, bool>());

shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store)
.def(py::init<const std::string&, ::c10d::Store&>());
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _error(msg):
rank = int(query["rank"])
world_size = int(query["world_size"])
start_daemon = rank == 0
store = TCPStore(result.hostname, result.port, start_daemon)
store = TCPStore(result.hostname, result.port, world_size, start_daemon)
yield (store, rank, world_size)

# If this configuration is invalidated, there is nothing we can do about it
Expand Down Expand Up @@ -140,7 +140,7 @@ def _env_error(var):

# Now start the TCP store daemon on the rank 0
start_daemon = rank == 0
store = TCPStore(master_addr, master_port, start_daemon)
store = TCPStore(master_addr, master_port, world_size, start_daemon)
yield (store, rank, world_size)

# If this configuration is invalidated, there is nothing we can do about it
Expand Down
62 changes: 58 additions & 4 deletions torch/lib/c10d/TCPStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,14 @@ bool TCPStoreDaemon::checkKeys(const std::vector<std::string>& keys) const {
TCPStore::TCPStore(
const std::string& masterAddr,
PortType masterPort,
int numWorkers,
bool isServer)
: isServer_(isServer),
tcpStoreAddr_(masterAddr),
tcpStorePort_(masterPort) {
tcpStorePort_(masterPort),
numWorkers_(numWorkers),
initKey_("init/"),
regularPrefix_("/") {
if (isServer_) {
// Opening up the listening socket
std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort);
Expand All @@ -291,6 +295,8 @@ TCPStore::TCPStore(
}
// Connect to the daemon
storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_);

waitForWorkers_();
}

TCPStore::~TCPStore() {
Expand All @@ -303,20 +309,56 @@ TCPStore::~TCPStore() {
}
}

void TCPStore::waitForWorkers_() {
addHelper_(initKey_, 1);
// Let server block until all workers have completed, this ensures that
// the server daemon thread is always running until the very end
if (isServer_) {
const auto start = std::chrono::steady_clock::now();
while (true) {
std::vector<uint8_t> value = getHelper_(initKey_);
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
int numWorkersCompleted = std::stoi(std::string(buf, len));
if (numWorkersCompleted >= numWorkers_) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout_ != kNoTimeout && elapsed > timeout_) {
break;
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}

void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
std::string regKey = regularPrefix_ + key;
tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendString(storeSocket_, regKey, true);
tcputil::sendVector<uint8_t>(storeSocket_, data);
}

std::vector<uint8_t> TCPStore::get(const std::string& key) {
wait({key});
std::string regKey = regularPrefix_ + key;
return getHelper_(regKey);
}

std::vector<uint8_t> TCPStore::getHelper_(const std::string& key) {
waitHelper_({key}, timeout_);
tcputil::sendValue<QueryType>(storeSocket_, QueryType::GET);
tcputil::sendString(storeSocket_, key);
return tcputil::recvVector<uint8_t>(storeSocket_);
}

int64_t TCPStore::add(const std::string& key, int64_t value) {
std::string regKey = regularPrefix_ + key;
return addHelper_(regKey, value);
}

int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
tcputil::sendString(storeSocket_, key, true);
tcputil::sendValue<int64_t>(storeSocket_, value);
Expand All @@ -328,7 +370,8 @@ bool TCPStore::check(const std::vector<std::string>& keys) {
SizeType nkeys = keys.size();
tcputil::sendBytes<SizeType>(storeSocket_, &nkeys, 1, (nkeys > 0));
for (size_t i = 0; i < nkeys; i++) {
tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1)));
std::string regKey = regularPrefix_ + keys[i];
tcputil::sendString(storeSocket_, regKey, (i != (nkeys - 1)));
}
auto checkResponse = tcputil::recvValue<CheckResponseType>(storeSocket_);
if (checkResponse == CheckResponseType::READY) {
Expand All @@ -347,6 +390,17 @@ void TCPStore::wait(const std::vector<std::string>& keys) {
void TCPStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
std::vector<std::string> regKeys;
regKeys.resize(keys.size());
for (size_t i = 0; i < keys.size(); ++i) {
regKeys[i] = regularPrefix_ + keys[i];
}
waitHelper_(regKeys, timeout);
}

void TCPStore::waitHelper_(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
// Set the socket timeout if there is a wait timeout
if (timeout != kNoTimeout) {
struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000,
Expand Down
12 changes: 12 additions & 0 deletions torch/lib/c10d/TCPStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TCPStore : public Store {
explicit TCPStore(
const std::string& masterAddr,
PortType masterPort,
int numWorkers,
bool isServer = false);

virtual ~TCPStore();
Expand All @@ -67,13 +68,24 @@ class TCPStore : public Store {
const std::chrono::milliseconds& timeout) override;

protected:
int64_t addHelper_(const std::string& key, int64_t value);
std::vector<uint8_t> getHelper_(const std::string& key);
void waitHelper_(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout);
void waitForWorkers_();

bool isServer_;
int storeSocket_ = -1;
int masterListenSocket_ = -1;

std::string tcpStoreAddr_;
PortType tcpStorePort_;

int numWorkers_;
const std::string initKey_;
const std::string regularPrefix_;

// Only needs to be launched as the server
std::unique_ptr<TCPStoreDaemon> tcpStoreDaemon_ = nullptr;
};
Expand Down
7 changes: 4 additions & 3 deletions torch/lib/c10d/test/TCPStoreTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
#include <c10d/TCPStore.hpp>

void testHelper(const std::string& prefix = "") {
const auto numThreads = 16;
const auto numWorkers = numThreads + 1;
// server store
c10d::TCPStore serverTCPStore("127.0.0.1", 29500, true);
c10d::TCPStore serverTCPStore("127.0.0.1", 29500, numWorkers, true);
c10d::PrefixStore serverStore(prefix, serverTCPStore);

// Basic set/get on the server store
Expand All @@ -22,7 +24,6 @@ void testHelper(const std::string& prefix = "") {

// Hammer on TCPStore
std::vector<std::thread> threads;
const auto numThreads = 16;
const auto numIterations = 1000;
c10d::test::Semaphore sem1, sem2;

Expand All @@ -31,7 +32,7 @@ void testHelper(const std::string& prefix = "") {
std::vector<std::unique_ptr<c10d::PrefixStore>> clientStores;
for (auto i = 0; i < numThreads; i++) {
clientTCPStores.push_back(std::unique_ptr<c10d::TCPStore>(
new c10d::TCPStore("127.0.0.1", 29500, false)));
new c10d::TCPStore("127.0.0.1", 29500, numWorkers, false)));
clientStores.push_back(std::unique_ptr<c10d::PrefixStore>(
new c10d::PrefixStore(prefix, *clientTCPStores[i])));
}
Expand Down

0 comments on commit b4bc55b

Please sign in to comment.