From b4bc55beefda3a0724b0fb83c04b6bbd8dd46c77 Mon Sep 17 00:00:00 2001 From: Teng Li Date: Fri, 18 Jan 2019 02:23:51 -0800 Subject: [PATCH] TCP init method race condition fix (#15684) Summary: This PR fixes a race condition for TCP init method, when master rank can exit earlier than slave ranks and thus the TCP daemon thread gets shutdown before other slaves are able to access it. This will let every rank (process) write a special key to the store to mark that they are completed (and thus about to exit). The master rank (who is the server) will always wait until all the ranks to complete before complete itself. This should fix: https://github.com/pytorch/pytorch/issues/15638 Tested using the repro of https://github.com/pytorch/pytorch/issues/15638 and works fine. Also test_distributed and test_c10d should have already had this coverage. I had to make rendezvous test in c10d the world size of 1, since it is a single process code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15684 Differential Revision: D13570904 Pulled By: teng-li fbshipit-source-id: 34f3bc471204bbd29320df359347ad5561c6b589 --- test/test_c10d.py | 64 +++++++++++----------------- torch/csrc/distributed/c10d/init.cpp | 2 +- torch/distributed/rendezvous.py | 4 +- torch/lib/c10d/TCPStore.cpp | 62 +++++++++++++++++++++++++-- torch/lib/c10d/TCPStore.hpp | 12 ++++++ torch/lib/c10d/test/TCPStoreTest.cpp | 7 +-- 6 files changed, 101 insertions(+), 50 deletions(-) diff --git a/test/test_c10d.py b/test/test_c10d.py index 329a7be8db2016..dc7efd91370663 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -204,7 +204,7 @@ def create_tcp_store(addr): try: port = common.find_free_port() ports.append(port) - return c10d.TCPStore(addr, port, True) + return c10d.TCPStore(addr, port, 1, True) except RuntimeError as error: if str(error) == "Address already in use": continue @@ -226,8 +226,8 @@ def test_address_already_in_use(self): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = c10d.TCPStore(addr, port, True) # noqa: F841 - store2 = c10d.TCPStore(addr, port, True) # noqa: F841 + store1 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 + store2 = c10d.TCPStore(addr, port, 1, True) # noqa: F841 class PrefixTCPStoreTest(TestCase, StoreTestBase): @@ -254,7 +254,7 @@ def test_common_errors(self): raise unittest.SkipTest("C10D is not built with NCCL process group," " skipping test") vars = { - "WORLD_SIZE": "2", + "WORLD_SIZE": "1", "RANK": "0", "MASTER_ADDR": "127.0.0.1", "MASTER_PORT": common.find_free_port(), @@ -287,9 +287,9 @@ def withouts(d, keys): with self.assertRaisesRegex(ValueError, 'WORLD_SIZE expected'): gen = c10d.rendezvous('env://') next(gen) - c10d.init_process_group(backend='nccl', world_size=2) + c10d.init_process_group(backend='nccl', world_size=1) self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 2) + self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() with Env(without(vars, 'RANK')): @@ -298,19 +298,19 @@ def withouts(d, keys): next(gen) c10d.init_process_group(backend='nccl', rank=0) self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 2) + self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])): - c10d.init_process_group(backend='nccl', rank=0, world_size=2) + c10d.init_process_group(backend='nccl', rank=0, world_size=1) self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 2) + self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() with Env(vars): c10d.init_process_group(backend='nccl') self.assertEqual(c10d.get_rank(), 0) - self.assertEqual(c10d.get_world_size(), 2) + self.assertEqual(c10d.get_world_size(), 1) c10d.destroy_process_group() with Env(without(vars, 'MASTER_ADDR')): @@ -324,9 +324,9 @@ def withouts(d, keys): next(gen) with Env(without(vars, 'WORLD_SIZE')): - gen = c10d.rendezvous('env://?world_size={}'.format(2)) + gen = c10d.rendezvous('env://?world_size={}'.format(1)) _, _, size = next(gen) - self.assertEqual(size, 2) + self.assertEqual(size, 1) with Env(without(vars, 'RANK')): gen = c10d.rendezvous('env://?rank={}'.format(0)) @@ -334,38 +334,28 @@ def withouts(d, keys): self.assertEqual(rank, 0) with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])): - gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 2)) + gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 1)) _, rank, size = next(gen) self.assertEqual(rank, 0) - self.assertEqual(size, 2) + self.assertEqual(size, 1) @retry_on_address_already_in_use_error def test_nominal(self): - os.environ['WORLD_SIZE'] = '2' + os.environ['WORLD_SIZE'] = '1' os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = str(common.find_free_port()) - # First rank + # Single rank os.environ['RANK'] = '0' gen0 = c10d.rendezvous('env://') store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) - self.assertEqual(2, size0) + self.assertEqual(1, size0) - # Second rank - os.environ['RANK'] = '1' - gen1 = c10d.rendezvous('env://') - store1, rank1, size1 = next(gen1) - self.assertEqual(1, rank1) - self.assertEqual(2, size1) - - # Set value on both stores store0.set("key0", "value0") - store1.set("key1", "value1") - # Cross check with get - self.assertEqual(b"value0", store1.get("key0")) - self.assertEqual(b"value1", store0.get("key1")) + # check with get + self.assertEqual(b"value0", store0.get("key0")) class RendezvousFileTest(TestCase): @@ -417,23 +407,17 @@ def test_common_errors(self): def test_nominal(self): addr = 'localhost' port = common.find_free_port() - url = 'tcp://%s:%d?world_size=%d' % (addr, port, 2) + url = 'tcp://%s:%d?world_size=%d' % (addr, port, 1) gen0 = c10d.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) - self.assertEqual(2, size0) - gen1 = c10d.rendezvous(url + "&rank=1") - store1, rank1, size1 = next(gen1) - self.assertEqual(1, rank1) - self.assertEqual(2, size1) + self.assertEqual(1, size0) - # Set value on both stores + # Set value on the single store store0.set("key0", "value0") - store1.set("key1", "value1") - # Cross check with get - self.assertEqual(b"value0", store1.get("key0")) - self.assertEqual(b"value1", store0.get("key1")) + # check with get + self.assertEqual(b"value0", store0.get("key0")) class MultiProcessTestCase(TestCase): diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 9099c565dfd26c..c4a94dd351ad98 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -138,7 +138,7 @@ They are used in specifying strategies for reduction collectives, e.g., .def(py::init()); shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store) - .def(py::init()); + .def(py::init()); shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store) .def(py::init()); diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index f7b37bd3368037..1df2563afd236c 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -92,7 +92,7 @@ def _error(msg): rank = int(query["rank"]) world_size = int(query["world_size"]) start_daemon = rank == 0 - store = TCPStore(result.hostname, result.port, start_daemon) + store = TCPStore(result.hostname, result.port, world_size, start_daemon) yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it @@ -140,7 +140,7 @@ def _env_error(var): # Now start the TCP store daemon on the rank 0 start_daemon = rank == 0 - store = TCPStore(master_addr, master_port, start_daemon) + store = TCPStore(master_addr, master_port, world_size, start_daemon) yield (store, rank, world_size) # If this configuration is invalidated, there is nothing we can do about it diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index 523ef3fc4b39c3..242b9a6c93166e 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -278,10 +278,14 @@ bool TCPStoreDaemon::checkKeys(const std::vector& keys) const { TCPStore::TCPStore( const std::string& masterAddr, PortType masterPort, + int numWorkers, bool isServer) : isServer_(isServer), tcpStoreAddr_(masterAddr), - tcpStorePort_(masterPort) { + tcpStorePort_(masterPort), + numWorkers_(numWorkers), + initKey_("init/"), + regularPrefix_("/") { if (isServer_) { // Opening up the listening socket std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort); @@ -291,6 +295,8 @@ TCPStore::TCPStore( } // Connect to the daemon storeSocket_ = tcputil::connect(tcpStoreAddr_, tcpStorePort_); + + waitForWorkers_(); } TCPStore::~TCPStore() { @@ -303,20 +309,56 @@ TCPStore::~TCPStore() { } } +void TCPStore::waitForWorkers_() { + addHelper_(initKey_, 1); + // Let server block until all workers have completed, this ensures that + // the server daemon thread is always running until the very end + if (isServer_) { + const auto start = std::chrono::steady_clock::now(); + while (true) { + std::vector value = getHelper_(initKey_); + auto buf = reinterpret_cast(value.data()); + auto len = value.size(); + int numWorkersCompleted = std::stoi(std::string(buf, len)); + if (numWorkersCompleted >= numWorkers_) { + break; + } + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + if (timeout_ != kNoTimeout && elapsed > timeout_) { + break; + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } +} + void TCPStore::set(const std::string& key, const std::vector& data) { + std::string regKey = regularPrefix_ + key; tcputil::sendValue(storeSocket_, QueryType::SET); - tcputil::sendString(storeSocket_, key, true); + tcputil::sendString(storeSocket_, regKey, true); tcputil::sendVector(storeSocket_, data); } std::vector TCPStore::get(const std::string& key) { - wait({key}); + std::string regKey = regularPrefix_ + key; + return getHelper_(regKey); +} + +std::vector TCPStore::getHelper_(const std::string& key) { + waitHelper_({key}, timeout_); tcputil::sendValue(storeSocket_, QueryType::GET); tcputil::sendString(storeSocket_, key); return tcputil::recvVector(storeSocket_); } int64_t TCPStore::add(const std::string& key, int64_t value) { + std::string regKey = regularPrefix_ + key; + return addHelper_(regKey, value); +} + +int64_t TCPStore::addHelper_(const std::string& key, int64_t value) { tcputil::sendValue(storeSocket_, QueryType::ADD); tcputil::sendString(storeSocket_, key, true); tcputil::sendValue(storeSocket_, value); @@ -328,7 +370,8 @@ bool TCPStore::check(const std::vector& keys) { SizeType nkeys = keys.size(); tcputil::sendBytes(storeSocket_, &nkeys, 1, (nkeys > 0)); for (size_t i = 0; i < nkeys; i++) { - tcputil::sendString(storeSocket_, keys[i], (i != (nkeys - 1))); + std::string regKey = regularPrefix_ + keys[i]; + tcputil::sendString(storeSocket_, regKey, (i != (nkeys - 1))); } auto checkResponse = tcputil::recvValue(storeSocket_); if (checkResponse == CheckResponseType::READY) { @@ -347,6 +390,17 @@ void TCPStore::wait(const std::vector& keys) { void TCPStore::wait( const std::vector& keys, const std::chrono::milliseconds& timeout) { + std::vector regKeys; + regKeys.resize(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) { + regKeys[i] = regularPrefix_ + keys[i]; + } + waitHelper_(regKeys, timeout); +} + +void TCPStore::waitHelper_( + const std::vector& keys, + const std::chrono::milliseconds& timeout) { // Set the socket timeout if there is a wait timeout if (timeout != kNoTimeout) { struct timeval timeoutTV = {.tv_sec = timeout.count() / 1000, diff --git a/torch/lib/c10d/TCPStore.hpp b/torch/lib/c10d/TCPStore.hpp index 891f263c16ca76..86b4d6854add5e 100644 --- a/torch/lib/c10d/TCPStore.hpp +++ b/torch/lib/c10d/TCPStore.hpp @@ -48,6 +48,7 @@ class TCPStore : public Store { explicit TCPStore( const std::string& masterAddr, PortType masterPort, + int numWorkers, bool isServer = false); virtual ~TCPStore(); @@ -67,6 +68,13 @@ class TCPStore : public Store { const std::chrono::milliseconds& timeout) override; protected: + int64_t addHelper_(const std::string& key, int64_t value); + std::vector getHelper_(const std::string& key); + void waitHelper_( + const std::vector& keys, + const std::chrono::milliseconds& timeout); + void waitForWorkers_(); + bool isServer_; int storeSocket_ = -1; int masterListenSocket_ = -1; @@ -74,6 +82,10 @@ class TCPStore : public Store { std::string tcpStoreAddr_; PortType tcpStorePort_; + int numWorkers_; + const std::string initKey_; + const std::string regularPrefix_; + // Only needs to be launched as the server std::unique_ptr tcpStoreDaemon_ = nullptr; }; diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 0808bf148aa130..a81fa65ffaab5c 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -8,8 +8,10 @@ #include void testHelper(const std::string& prefix = "") { + const auto numThreads = 16; + const auto numWorkers = numThreads + 1; // server store - c10d::TCPStore serverTCPStore("127.0.0.1", 29500, true); + c10d::TCPStore serverTCPStore("127.0.0.1", 29500, numWorkers, true); c10d::PrefixStore serverStore(prefix, serverTCPStore); // Basic set/get on the server store @@ -22,7 +24,6 @@ void testHelper(const std::string& prefix = "") { // Hammer on TCPStore std::vector threads; - const auto numThreads = 16; const auto numIterations = 1000; c10d::test::Semaphore sem1, sem2; @@ -31,7 +32,7 @@ void testHelper(const std::string& prefix = "") { std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { clientTCPStores.push_back(std::unique_ptr( - new c10d::TCPStore("127.0.0.1", 29500, false))); + new c10d::TCPStore("127.0.0.1", 29500, numWorkers, false))); clientStores.push_back(std::unique_ptr( new c10d::PrefixStore(prefix, *clientTCPStores[i]))); }