Skip to content

Commit

Permalink
refactor: make Transport::ReceivedBytes just return success/fail
Browse files Browse the repository at this point in the history
  • Loading branch information
sipa committed Aug 24, 2023
1 parent bb4aab9 commit 8a3b6f3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
5 changes: 2 additions & 3 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,8 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
nRecvBytes += msg_bytes.size();
while (msg_bytes.size() > 0) {
// absorb network data
int handled = m_transport->ReceivedBytes(msg_bytes);
if (handled < 0) {
// Serious header problem, disconnect from the peer.
if (!m_transport->ReceivedBytes(msg_bytes)) {
// Serious transport problem, disconnect from the peer.
return false;
}

Expand Down
23 changes: 18 additions & 5 deletions src/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,22 @@ class Transport {
virtual bool ReceivedMessageComplete() const = 0;
/** Set the deserialization context version for objects returned by GetReceivedMessage. */
virtual void SetReceiveVersion(int version) = 0;
/** Feed wire bytes to the transport; chops off consumed bytes off front of msg_bytes. */
virtual int ReceivedBytes(Span<const uint8_t>& msg_bytes) = 0;
/** Retrieve a completed message from transport (only when ReceivedMessageComplete). */

/** Feed wire bytes to the transport.
*
* @return false if some bytes were invalid, in which case the transport can't be used anymore.
*
* Consumed bytes are chopped off the front of msg_bytes.
*/
virtual bool ReceivedBytes(Span<const uint8_t>& msg_bytes) = 0;

/** Retrieve a completed message from transport.
*
* This can only be called when ReceivedMessageComplete() is true.
*
* If reject_message=true is returned the message itself is invalid, but (other than false
* returned by ReceivedBytes) the transport is not in an inconsistent state.
*/
virtual CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) = 0;

// 2. Sending side functions, for converting messages into bytes to be sent over the wire.
Expand Down Expand Up @@ -387,7 +400,7 @@ class V1Transport final : public Transport
vRecv.SetVersion(nVersionIn);
}

int ReceivedBytes(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
bool ReceivedBytes(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
AssertLockNotHeld(m_recv_mutex);
LOCK(m_recv_mutex);
Expand All @@ -397,7 +410,7 @@ class V1Transport final : public Transport
} else {
msg_bytes = msg_bytes.subspan(ret);
}
return ret;
return ret >= 0;
}

CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
Expand Down
3 changes: 1 addition & 2 deletions src/test/fuzz/p2p_transport_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end());
Span<const uint8_t> msg_bytes{mutable_msg_bytes};
while (msg_bytes.size() > 0) {
const int handled = recv_transport.ReceivedBytes(msg_bytes);
if (handled < 0) {
if (!recv_transport.ReceivedBytes(msg_bytes)) {
break;
}
if (recv_transport.ReceivedMessageComplete()) {
Expand Down

0 comments on commit 8a3b6f3

Please sign in to comment.