diff --git a/folly/io/async/fdsock/AsyncFdSocket.cpp b/folly/io/async/fdsock/AsyncFdSocket.cpp index 0db1c1f6dac..7eaaab46f07 100644 --- a/folly/io/async/fdsock/AsyncFdSocket.cpp +++ b/folly/io/async/fdsock/AsyncFdSocket.cpp @@ -54,6 +54,23 @@ AsyncFdSocket::AsyncFdSocket( } #endif +AsyncFdSocket::AsyncFdSocket( + AsyncFdSocket::DoesNotMoveFdSocketState, AsyncSocket* sock) + : AsyncSocket(sock) +#if !defined(_WIN32) + , + readAncillaryDataCob_(this) { + setUpCallbacks(); +} +#else +{ +} +#endif + +AsyncFdSocket::AsyncFdSocket( + AsyncFdSocket::DoesNotMoveFdSocketState tag, AsyncSocket::UniquePtr sock) + : AsyncFdSocket(tag, sock.get()) {} + void AsyncFdSocket::writeChainWithFds( WriteCallback* callback, std::unique_ptr buf, @@ -130,6 +147,19 @@ void AsyncFdSocket::setUpCallbacks() noexcept { AsyncSocket::setReadAncillaryDataCB(&readAncillaryDataCob_); } +void AsyncFdSocket::swapFdReadStateWith(AsyncFdSocket* other) { + // We don't need these write-state assertions to correctly swap read + // state, but since the only use-case is `moveToPlaintext`, they help. + DCHECK_EQ(0, other->allocatedToSendFdsSeqNum_); + DCHECK_EQ(0, other->sentFdsSeqNum_); + DCHECK_EQ(0, other->sendMsgCob_.writeTagToFds_.size()); + + fdsQueue_.swap(other->fdsQueue_); + std::swap(receivedFdsSeqNum_, other->receivedFdsSeqNum_); + // Do NOT swap `readAncillaryDataCob_` since its internal members are not + // "state", but plumbing that does not change. +} + void AsyncFdSocket::releaseIOBuf( std::unique_ptr buf, ReleaseIOBufCallback* callback) { sendMsgCob_.destroyFdsForWriteTag(WriteRequestTag{buf.get()}); diff --git a/folly/io/async/fdsock/AsyncFdSocket.h b/folly/io/async/fdsock/AsyncFdSocket.h index c7a94e942bf..d3d84bb5873 100644 --- a/folly/io/async/fdsock/AsyncFdSocket.h +++ b/folly/io/async/fdsock/AsyncFdSocket.h @@ -73,6 +73,24 @@ class AsyncFdSocket : public AsyncSocket { NetworkSocket fd, const folly::SocketAddress* peerAddress = nullptr); + /** + * EXPERIMENTAL / TEMPORARY: These move-like constructors should not be + * used to go from one AsyncFdSocket to another because this will not + * correctly preserve read & write state. Full move is not implemented + * since its trickier, and was not yet needed -- see `swapFdReadStateWith`. + */ + struct DoesNotMoveFdSocketState {}; + + protected: + _FRIEND_TEST_FOR_ASYNC_FD_SOCKET( + AsyncFdSocketSequenceRoundtripTest, WithDataSize); + // Protected since it's easy to accidentally pass an `AsyncFdSocket` here, + // a scenario that's extremely easy to use incorrectly. + AsyncFdSocket(DoesNotMoveFdSocketState, AsyncSocket*); + + public: + AsyncFdSocket(DoesNotMoveFdSocketState, AsyncSocket::UniquePtr); + /** * `AsyncSocket::writeChain` analog that passes FDs as ancillary data over * the socket (see `man cmsg`). @@ -136,8 +154,19 @@ class AsyncFdSocket : public AsyncSocket { LOG(DFATAL) << "AsyncFdSocket::setReadAncillaryDataCB is forbidden"; } -// This uses no ancillary data callbacks on Windows, they wouldn't compile. +// This class has no ancillary data callbacks on Windows, they wouldn't compile #if !defined(_WIN32) + /** + * EXPERIMENTAL / TEMPORARY: This just does what is required for + * `moveToPlaintext` to support StopTLS. That use-case could later be + * covered by full move-construct or move-assign support, but both would + * be more complex to support. + * + * Swaps "read FDs" state (receive queue & sequence numbers) with `other`. + * DFATALs if `other` had any "write FDs" state. + */ + void swapFdReadStateWith(AsyncFdSocket* other); + protected: void releaseIOBuf( std::unique_ptr, ReleaseIOBufCallback*) override; diff --git a/folly/io/async/fdsock/test/AsyncFdSocketTest.cpp b/folly/io/async/fdsock/test/AsyncFdSocketTest.cpp index 7e229b503c7..c040a4e73b1 100644 --- a/folly/io/async/fdsock/test/AsyncFdSocketTest.cpp +++ b/folly/io/async/fdsock/test/AsyncFdSocketTest.cpp @@ -90,8 +90,9 @@ struct AsyncFdSocketTest : public testing::Test { }()} {} explicit AsyncFdSocketTest(std::array fds) - : sendSock_{&evb_, fds[0]}, recvSock_{&evb_, fds[1]} { - recvSock_.setReadCB(&rcb_); + : sendSock_{&evb_, fds[0]}, + recvSock_(std::make_unique(&evb_, fds[1])) { + recvSock_->setReadCB(&rcb_); } EventBase evb_; @@ -100,7 +101,7 @@ struct AsyncFdSocketTest : public testing::Test { AsyncFdSocket sendSock_; ReadCallback rcb_; // NB: `~AsyncSocket` calls `rcb.readEOF` - AsyncFdSocket recvSock_; + std::unique_ptr recvSock_; }; TEST_F(AsyncFdSocketTest, TestAddSeqNum) { @@ -171,7 +172,7 @@ TEST_P(AsyncFdSocketSimpleRoundtripTest, WithNumFds) { rcb_.verifyData(&data, sizeof(data)); rcb_.clearData(); - checkFdsMatch(sendFds, sendSeqNum, recvSock_.popNextReceivedFds()); + checkFdsMatch(sendFds, sendSeqNum, recvSock_->popNextReceivedFds()); } // Round-trip & verify various numbers of FDs with 1 byte of data. @@ -225,7 +226,7 @@ TEST_F(AsyncFdSocketTest, MultiPartSend) { // FDs are sent with the first send & received by the first receive evb_.loopOnce(); - checkFdsMatch(sendFds, sendSeqNum, recvSock_.popNextReceivedFds()); + checkFdsMatch(sendFds, sendSeqNum, recvSock_->popNextReceivedFds()); EXPECT_EQ(1, sendSock.numWrites_); // Receive the rest of the data. @@ -239,23 +240,23 @@ TEST_F(AsyncFdSocketTest, MultiPartSend) { // There are no more data or FDs evb_.loopOnce(EVLOOP_NONBLOCK); EXPECT_EQ(0, rcb_.dataRead()) << "Leftover reads"; - EXPECT_TRUE(recvSock_.popNextReceivedFds().empty()) << "Extra FDs"; + EXPECT_TRUE(recvSock_->popNextReceivedFds().empty()) << "Extra FDs"; } struct AsyncFdSocketSequenceRoundtripTest : public AsyncFdSocketTest, - public testing::WithParamInterface {}; + public testing::WithParamInterface> {}; TEST_P(AsyncFdSocketSequenceRoundtripTest, WithDataSize) { - size_t dataSize = GetParam(); + auto [swapSocket, dataSize] = GetParam(); // The default `ReadCallback` has special-snowflake buffer management // that's annoying for this test. Secondarily, this exercises the // "ReadVec" path. ReadvCallback rcb(128, 3); // Avoid `readEOF` use-after-stack-scope in `~AsyncSocket`. - SCOPE_EXIT { recvSock_.setReadCB(nullptr); }; - recvSock_.setReadCB(&rcb); + SCOPE_EXIT { recvSock_->setReadCB(nullptr); }; + recvSock_->setReadCB(&rcb); std::queue< std::tuple> @@ -288,6 +289,23 @@ TEST_P(AsyncFdSocketSequenceRoundtripTest, WithDataSize) { // The max expected steps is ~3k: 1234567 / (3 * 128) for (int i = 0; i < 10000 && !sentQueue.empty(); ++i) { evb_.loopOnce(EVLOOP_NONBLOCK); + // Validate that "move from AsyncSocket" and "swap read state" interrupt + // neither the reading of data nor of FDs. + if (swapSocket) { + AsyncFdSocket prevReadStateSock{nullptr}; + prevReadStateSock.swapFdReadStateWith(recvSock_.get()); + + // Test moving the non-FD parts of the socket, while reading. + struct EnableMakeUnique : public AsyncFdSocket { + EnableMakeUnique(AsyncSocket* sock) + : AsyncFdSocket(AsyncFdSocket::DoesNotMoveFdSocketState{}, sock) {} + }; + recvSock_ = std::make_unique(recvSock_.get()); + recvSock_->setReadCB(&rcb); + + // Test moving the FD read state. + recvSock_->swapFdReadStateWith(&prevReadStateSock); + } size_t dataRead = rcb.buf_->computeChainDataLength(); if (!dataRead) { continue; @@ -305,7 +323,7 @@ TEST_P(AsyncFdSocketSequenceRoundtripTest, WithDataSize) { // FDs, which would fail in `checkFdsMatch`. if (!sendFds.empty()) { const auto sendSeqNum = std::get<3>(sentQueue.front()); - checkFdsMatch(sendFds, sendSeqNum, recvSock_.popNextReceivedFds()); + checkFdsMatch(sendFds, sendSeqNum, recvSock_->popNextReceivedFds()); } } @@ -331,14 +349,22 @@ TEST_P(AsyncFdSocketSequenceRoundtripTest, WithDataSize) { EXPECT_TRUE(sentQueue.empty()) << "Stuck reading?"; evb_.loopOnce(EVLOOP_NONBLOCK); EXPECT_EQ(0, rcb.buf_->computeChainDataLength()) << "Leftover reads"; - EXPECT_TRUE(recvSock_.popNextReceivedFds().empty()) << "Extra FDs"; + EXPECT_TRUE(recvSock_->popNextReceivedFds().empty()) << "Extra FDs"; } // Vary the data size to (hopefully) get a variety of chunking behaviors. INSTANTIATE_TEST_SUITE_P( VaryDataSize, AsyncFdSocketSequenceRoundtripTest, - testing::Values(1, 12, 123, 1234, 12345, 123456, 1234567)); + testing::Combine( + testing::Values(false, true), + testing::Values(1, 12, 123, 1234, 12345, 123456, 1234567)), + [](const auto& info) { + return fmt::format( + "{}{}", + std::get<0>(info.param) ? "SwapSocket_" : "", + std::get<1>(info.param)); + }); #endif // !Windows