Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Serialisation to properly deal with trivially copyable types #93

Merged
merged 19 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/networking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ checks explicitly for an explicit type. Be careful about multiple declarations.
For this partial specialisation three static methods need to be defined.

.. codeblock:: c++
static inline std::vector<char> serialise(const T& in)
static inline std::vector<uint8_t> serialise(const T& in)

static inline T deserialise(const std::vector<char>& in)
static inline T deserialise(const std::vector<uint8_t>& in)

static inline uint64_t hash()

2 changes: 1 addition & 1 deletion src/dsl/word/Network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace dsl {
static inline std::tuple<std::shared_ptr<NetworkSource>, NetworkData<T>> get(
const threading::Reaction& /*reaction*/) {

auto* data = store::ThreadStore<std::vector<char>>::value;
auto* data = store::ThreadStore<std::vector<uint8_t>>::value;
auto* source = store::ThreadStore<NetworkSource>::value;

if (data && source) {
Expand Down
6 changes: 3 additions & 3 deletions src/dsl/word/UDP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace dsl {
/// @brief If the packet is valid
bool valid{false};
/// @brief The data that was received
std::vector<char> payload{};
std::vector<uint8_t> payload{};
/// @brief The local address that the packet was received on
util::network::sock_t local{};
/// @brief The remote address that the packet was received from
Expand Down Expand Up @@ -121,7 +121,7 @@ namespace dsl {
Target remote;

/// @brief The data to be sent in the packet
std::vector<char> payload{};
std::vector<uint8_t> payload{};

/**
* @brief Casts this packet to a boolean to check if it is valid
Expand Down Expand Up @@ -350,7 +350,7 @@ namespace dsl {
}

// Allocate max size for a UDP packet
std::vector<char> buffer(65535, 0);
std::vector<uint8_t> buffer(65535, 0);

// Make some variables to hold our message header information
std::array<char, 0x100> cmbuff = {0};
Expand Down
2 changes: 1 addition & 1 deletion src/dsl/word/emit/Network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace dsl {
/// The hash identifying the type of object
uint64_t hash{0};
/// The serialised data
std::vector<char> payload{};
std::vector<uint8_t> payload{};
/// If the message should be sent reliably
bool reliable{false};
};
Expand Down
2 changes: 1 addition & 1 deletion src/dsl/word/emit/UDP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ namespace dsl {
}

// Serialise to our payload
std::vector<char> payload = util::serialise::Serialise<DataType>::serialise(*data);
std::vector<uint8_t> payload = util::serialise::Serialise<DataType>::serialise(*data);

// Try to send our payload
if (::sendto(fd,
Expand Down
8 changes: 4 additions & 4 deletions src/extension/NetworkController.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ namespace extension {
network.set_packet_callback([this](const network::NUClearNetwork::NetworkTarget& remote,
const uint64_t& hash,
const bool& reliable,
std::vector<char>&& payload) {
std::vector<uint8_t>&& payload) {
// Construct our NetworkSource information
dsl::word::NetworkSource src{remote.name, remote.target, reliable};

// Move the payload in as we are stealing it
std::vector<char> p(std::move(payload));
std::vector<uint8_t> p(std::move(payload));

// Store in our thread local cache
dsl::store::ThreadStore<std::vector<char>>::value = &p;
dsl::store::ThreadStore<std::vector<uint8_t>>::value = &p;
dsl::store::ThreadStore<dsl::word::NetworkSource>::value = &src;

/* Mutex Scope */ {
Expand All @@ -76,7 +76,7 @@ namespace extension {
}

// Clear our cache
dsl::store::ThreadStore<std::vector<char>>::value = nullptr;
dsl::store::ThreadStore<std::vector<uint8_t>>::value = nullptr;
dsl::store::ThreadStore<dsl::word::NetworkSource>::value = nullptr;
});

Expand Down
24 changes: 12 additions & 12 deletions src/extension/network/NUClearNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ namespace extension {
*
* @return the data and who it was sent from
*/
std::pair<util::network::sock_t, std::vector<char>> read_socket(fd_t fd) {
std::pair<util::network::sock_t, std::vector<uint8_t>> read_socket(fd_t fd) {

// Allocate a vector that can hold a datagram
std::vector<char> payload(1500);
std::vector<uint8_t> payload(1500);
TrentHouliston marked this conversation as resolved.
Show resolved Hide resolved
iovec iov{};
iov.iov_base = payload.data();
iov.iov_len = static_cast<decltype(iov.iov_len)>(payload.size());
Expand Down Expand Up @@ -82,7 +82,7 @@ namespace extension {
}

void NUClearNetwork::set_packet_callback(
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<char>&&)> f) {
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<uint8_t>&&)> f) {
packet_callback = std::move(f);
}

Expand Down Expand Up @@ -547,7 +547,7 @@ namespace extension {
}
}

void NUClearNetwork::process_packet(const sock_t& address, std::vector<char>&& payload) {
void NUClearNetwork::process_packet(const sock_t& address, std::vector<uint8_t>&& payload) {

// First validate this is a NUClear network packet we can read (a version 2 NUClear packet)
if (payload.size() >= sizeof(PacketHeader) && payload[0] == '\xE2' && payload[1] == '\x98'
Expand Down Expand Up @@ -671,7 +671,7 @@ namespace extension {
if (it != remote->recent_packets.end() && packet.reliable) {

// Allocate room for the whole ack packet
std::vector<char> r(sizeof(ACKPacket) + (packet.packet_count / 8), 0);
std::vector<uint8_t> r(sizeof(ACKPacket) + (packet.packet_count / 8), 0);
ACKPacket& response = *reinterpret_cast<ACKPacket*>(r.data());
response = ACKPacket();
response.packet_id = packet.packet_id;
Expand Down Expand Up @@ -703,8 +703,8 @@ namespace extension {
if (packet.packet_count == 1) {

// Copy our data into a vector
std::vector<char> out(&packet.data,
&packet.data + payload.size() - sizeof(DataPacket) + 1);
std::vector<uint8_t> out(&packet.data,
&packet.data + payload.size() - sizeof(DataPacket) + 1);

// If this is a reliable packet, send an ack back
if (packet.reliable) {
Expand Down Expand Up @@ -751,7 +751,7 @@ namespace extension {

// A basic ack has room for 8 packets and we need 1 extra byte for each 8
// additional packets
std::vector<char> r(sizeof(NACKPacket) + (packet.packet_count / 8), 0);
std::vector<uint8_t> r(sizeof(NACKPacket) + (packet.packet_count / 8), 0);
NACKPacket& response = *reinterpret_cast<NACKPacket*>(r.data());
response = NACKPacket();
response.packet_id = packet.packet_id;
Expand Down Expand Up @@ -790,7 +790,7 @@ namespace extension {
if (packet.reliable) {
// A basic ack has room for 8 packets and we need 1 extra byte for each 8
// additional packets
std::vector<char> r(sizeof(ACKPacket) + (packet.packet_count / 8), 0);
std::vector<uint8_t> r(sizeof(ACKPacket) + (packet.packet_count / 8), 0);
ACKPacket& response = *reinterpret_cast<ACKPacket*>(r.data());
response = ACKPacket();
response.packet_id = packet.packet_id;
Expand Down Expand Up @@ -825,7 +825,7 @@ namespace extension {
}

// Read in our data
std::vector<char> out;
std::vector<uint8_t> out;
out.reserve(payload_size);
for (auto& p : assembler.second) {
const DataPacket& part = *reinterpret_cast<DataPacket*>(p.second.data());
Expand Down Expand Up @@ -997,7 +997,7 @@ namespace extension {
void NUClearNetwork::send_packet(const sock_t& target,
NUClear::extension::network::DataPacket header,
uint16_t packet_no,
const std::vector<char>& payload,
const std::vector<uint8_t>& payload,
const bool& /*reliable*/) {

// Our packet we are sending
Expand Down Expand Up @@ -1028,7 +1028,7 @@ namespace extension {


void NUClearNetwork::send(const uint64_t& hash,
const std::vector<char>& payload,
const std::vector<uint8_t>& payload,
const std::string& target,
bool reliable) {

Expand Down
16 changes: 8 additions & 8 deletions src/extension/network/NUClearNetwork.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace extension {
std::mutex assemblers_mutex;
/// Storage for fragmented packets while we build them
std::map<uint16_t,
std::pair<std::chrono::steady_clock::time_point, std::map<uint16_t, std::vector<char>>>>
std::pair<std::chrono::steady_clock::time_point, std::map<uint16_t, std::vector<uint8_t>>>>
assemblers{};

/// Struct storing the kalman filter for round trip time
Expand Down Expand Up @@ -130,15 +130,15 @@ namespace extension {
* @param target who we are sending to (blank means everyone)
* @param reliable if the delivery of the data should be ensured
*/
void send(const uint64_t& hash, const std::vector<char>& payload, const std::string& target, bool reliable);
void send(const uint64_t& hash, const std::vector<uint8_t>& payload, const std::string& target, bool reliable);

/**
* @brief Set the callback to use when a data packet is completed
*
* @param f the callback function
*/
void set_packet_callback(
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<char>&&)> f);
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<uint8_t>&&)> f);

/**
* @brief Set the callback to use when a node joins the network
Expand Down Expand Up @@ -231,7 +231,7 @@ namespace extension {
DataPacket header{};

/// The data to send
std::vector<char> payload{};
std::vector<uint8_t> payload{};
};

/**
Expand All @@ -255,7 +255,7 @@ namespace extension {
* @param address who the packet came from
* @param data the data that was sent in this packet
*/
void process_packet(const sock_t& address, std::vector<char>&& payload);
void process_packet(const sock_t& address, std::vector<uint8_t>&& payload);

/**
* @brief Send an announce packet to our announce address
Expand All @@ -279,7 +279,7 @@ namespace extension {
void send_packet(const sock_t& target,
DataPacket header,
uint16_t packet_no,
const std::vector<char>& payload,
const std::vector<uint8_t>& payload,
const bool& reliable);

/**
Expand Down Expand Up @@ -307,13 +307,13 @@ namespace extension {
uint16_t packet_data_mtu{1000};

// Our announce packet
std::vector<char> announce_packet{};
std::vector<uint8_t> announce_packet{};

/// An atomic source for packet IDs to make sure they are semi unique
std::atomic<uint16_t> packet_id_source{0};

/// The callback to execute when a data packet is completed
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<char>&&)>
std::function<void(const NetworkTarget&, const uint64_t&, const bool&, std::vector<uint8_t>&&)>
packet_callback;
/// The callback to execute when a node joins the network
std::function<void(const NetworkTarget&)> join_callback;
Expand Down
23 changes: 14 additions & 9 deletions src/util/serialise/Serialise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#ifndef NUCLEAR_UTIL_SERIALISE_SERIALISE_HPP
#define NUCLEAR_UTIL_SERIALISE_SERIALISE_HPP

#include <cstring>
#include <string>
#include <type_traits>

Expand Down Expand Up @@ -50,13 +51,13 @@ namespace util {
template <typename T>
struct Serialise<T, std::enable_if_t<std::is_trivially_copyable<T>::value, T>> {

static inline std::vector<char> serialise(const T& in) {
std::vector<char> out(sizeof(T));
static inline std::vector<uint8_t> serialise(const T& in) {
std::vector<uint8_t> out(sizeof(T));
std::memcpy(out.data(), &in, sizeof(T));
return out;
}

static inline T deserialise(const std::vector<char>& in) {
static inline T deserialise(const std::vector<uint8_t>& in) {
if (in.size() != sizeof(T)) {
throw std::length_error("Serialised data is not the correct size");
}
Expand All @@ -80,8 +81,8 @@ namespace util {

using V = std::remove_reference_t<iterator_value_type_t<T>>;

static inline std::vector<char> serialise(const T& in) {
std::vector<char> out;
static inline std::vector<uint8_t> serialise(const T& in) {
std::vector<uint8_t> out;
out.reserve(sizeof(V) * size_t(std::distance(std::begin(in), std::end(in))));

for (const V& item : in) {
Expand All @@ -92,7 +93,11 @@ namespace util {
return out;
}

static inline T deserialise(const std::vector<char>& in) {
static inline T deserialise(const std::vector<uint8_t>& in) {

if (in.size() % sizeof(V) != 0) {
throw std::length_error("Serialised data is not the correct size");
}

T out;

Expand All @@ -118,14 +123,14 @@ namespace util {
|| std::is_base_of<::google::protobuf::MessageLite, T>::value,
T>> {

static inline std::vector<char> serialise(const T& in) {
std::vector<char> output(in.ByteSize());
static inline std::vector<uint8_t> serialise(const T& in) {
std::vector<uint8_t> output(in.ByteSize());
in.SerializeToArray(output.data(), output.size());

return output;
}

static inline T deserialise(const std::vector<char>& in) {
static inline T deserialise(const std::vector<uint8_t>& in) {
// Make a buffer
T out;

Expand Down
Loading
Loading