Skip to content

Commit

Permalink
Merge pull request managarm#691 from no92/bind-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Dennisbonke authored Jul 9, 2024
2 parents 7c848ad + c39af62 commit 04ec361
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
13 changes: 9 additions & 4 deletions servers/netserver/src/ip/ip4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ bool Ip4Router::addRoute(Route r) {
return routes.emplace(std::move(r)).second;
}

std::optional<Route> Ip4Router::resolveRoute(uint32_t ip) {
std::optional<Route> Ip4Router::resolveRoute(uint32_t ip, std::shared_ptr<nic::Link> link) {
for (auto i = routes.begin(); i != routes.end(); i++) {
const auto &r = *i;
if (r.network.sameNet(ip)) {
if(link && r.link.lock()->index() != link->index())
continue;

if (r.link.expired()) {
i = routes.erase(i);
continue;
Expand Down Expand Up @@ -230,8 +233,8 @@ async::result<frg::expected<protocols::fs::Error, size_t>> Ip4Socket::sendmsg(vo
}

async::result<std::optional<Ip4TargetInfo>>
Ip4::targetByRemote(uint32_t remote) {
auto oroute = ip4Router().resolveRoute(remote);
Ip4::targetByRemote(uint32_t remote, std::shared_ptr<nic::Link> link) {
auto oroute = ip4Router().resolveRoute(remote, link);
if (!oroute) {
std::cout << "netserver: net unreachable" << std::endl;
co_return std::nullopt;
Expand Down Expand Up @@ -329,7 +332,9 @@ async::result<protocols::fs::Error> Ip4::sendFrame(Ip4TargetInfo ti,

void Ip4::feedPacket(nic::MacAddress, nic::MacAddress,
arch::dma_buffer owner, arch::dma_buffer_view frame, std::weak_ptr<nic::Link> link) {
Ip4Packet hdr;
Ip4Packet hdr{};
hdr.link = link;

if (!hdr.parse(std::move(owner), frame)) {
std::cout << "netserver: runt, or otherwise invalid, ip4 frame received"
<< std::endl;
Expand Down
5 changes: 3 additions & 2 deletions servers/netserver/src/ip/ip4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct Ip4Router {

// false if insertion fails
bool addRoute(Route r);
std::optional<Route> resolveRoute(uint32_t ip);
std::optional<Route> resolveRoute(uint32_t ip, std::shared_ptr<nic::Link> link = {});

inline const std::set<Route> &getRoutes() const {
return routes;
Expand Down Expand Up @@ -106,6 +106,7 @@ class Ip4Packet {
} header;
static_assert(sizeof(header) == 20, "bad header size");
arch::dma_buffer_view data;
std::weak_ptr<nic::Link> link;

inline arch::dma_buffer_view payload() const {
return data.subview(header.ihl * 4);
Expand Down Expand Up @@ -140,7 +141,7 @@ struct Ip4 {
void setLink(CidrAddress addr, std::weak_ptr<nic::Link> link);
std::optional<uint32_t> findLinkIp(uint32_t ipOnNet, nic::Link *link);

async::result<std::optional<Ip4TargetInfo>> targetByRemote(uint32_t);
async::result<std::optional<Ip4TargetInfo>> targetByRemote(uint32_t, std::shared_ptr<nic::Link> link = {});
async::result<protocols::fs::Error> sendFrame(Ip4TargetInfo,
void*, size_t,
uint16_t);
Expand Down
35 changes: 34 additions & 1 deletion servers/netserver/src/ip/tcp4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <arch/variable.hpp>
#include <protocols/fs/server.hpp>
#include <cstring>
#include <format>
#include <iomanip>
#include <random>
#include <fcntl.h>
Expand Down Expand Up @@ -493,6 +494,32 @@ struct Tcp4Socket {
co_return 0;
}

static async::result<frg::expected<protocols::fs::Error>> setSocketOption(void *object,
int layer, int number, std::vector<char> optbuf) {
auto self = static_cast<Tcp4Socket *>(object);

if(layer == SOL_SOCKET && number == SO_BINDTODEVICE) {
std::string ifname{optbuf.data()};

if(ifname.empty()) {
self->boundInterface_ = {};
} else {
auto nic = nic::Link::byName(ifname);

if(!nic)
co_return protocols::fs::Error::illegalArguments;

self->boundInterface_ = nic;
co_return {};
}
}

std::cout << std::format("netserver: unhandled TCP socket setsockopt layer {} number {}\n",
layer, number);

co_return protocols::fs::Error::invalidProtocolOption;
}

constexpr static protocols::fs::FileOperations ops {
.read = &read,
.write = &write,
Expand All @@ -507,6 +534,7 @@ struct Tcp4Socket {
.recvMsg = &recvMsg,
.sendMsg = &sendMsg,
.peername = &peername,
.setSocketOption = &setSocketOption,
};

bool bindAvailable(uint32_t ipAddress = INADDR_ANY) {
Expand Down Expand Up @@ -575,6 +603,8 @@ struct Tcp4Socket {
uint64_t outSeq_ = 0;
uint64_t hupSeq_ = 1;
async::recurring_event pollEvent_;

std::shared_ptr<nic::Link> boundInterface_ = {};
};

async::result<void> Tcp4Socket::flushOutPackets_() {
Expand All @@ -596,7 +626,7 @@ async::result<void> Tcp4Socket::flushOutPackets_() {
localFlushedSn_ = randomSn;

// Construct and transmit the initial SYN packet.
auto targetInfo = co_await ip4().targetByRemote(remoteEp_.ipAddress);
auto targetInfo = co_await ip4().targetByRemote(remoteEp_.ipAddress, boundInterface_);
if (!targetInfo) {
// TODO: Return an error to users.
std::cout << "netserver: Destination unreachable" << std::endl;
Expand Down Expand Up @@ -722,6 +752,9 @@ async::result<void> Tcp4Socket::flushOutPackets_() {
}

void Tcp4Socket::handleInPacket_(TcpPacket packet) {
if(boundInterface_ && boundInterface_->index() != packet.packet->link.lock()->index())
return;

if(connectState_ == ConnectState::sendSyn) {
if(localSettledSn_ == localFlushedSn_) {
std::cout << "netserver: Rejecting packet before SYN is sent [sendSyn]"
Expand Down

0 comments on commit 04ec361

Please sign in to comment.