From c539e0e2d26c9e0279ee208dbefb528f17159722 Mon Sep 17 00:00:00 2001 From: Pedro Falcato Date: Sat, 25 Jan 2025 04:11:39 +0000 Subject: [PATCH] tcp: Add sndbuf accounting Add per-socket send buffer limits. Also significantly improve segment sending and fix some latent TCP/IP bugs while we're at it. This patch doesn't yet introduce this functionality for UDP or UNIX sockets. Signed-off-by: Pedro Falcato --- kernel/drivers/net/e1000/e1000.cpp | 4 +- kernel/drivers/virtio/network/network.cpp | 4 +- kernel/include/onyx/atomic.h | 10 +- kernel/include/onyx/net/socket.h | 65 ++++- kernel/include/onyx/net/tcp.h | 3 +- kernel/include/onyx/packetbuf.h | 11 +- kernel/include/onyx/page_frag.h | 43 +++ kernel/kernel/mm/Makefile | 3 +- kernel/kernel/mm/page_frag.c | 64 +++++ kernel/kernel/net/ipv4/ipv4.cpp | 4 +- kernel/kernel/net/ipv6/ipv6.cpp | 5 +- kernel/kernel/net/packetbuf.cpp | 38 ++- kernel/kernel/net/socket.cpp | 9 +- kernel/kernel/net/tcp.cpp | 309 +++++++++++++++------- kernel/kernel/net/tcp_input.cpp | 6 +- 15 files changed, 459 insertions(+), 119 deletions(-) create mode 100644 kernel/include/onyx/page_frag.h create mode 100644 kernel/kernel/mm/page_frag.c diff --git a/kernel/drivers/net/e1000/e1000.cpp b/kernel/drivers/net/e1000/e1000.cpp index 09efc913..6a7dbfb3 100644 --- a/kernel/drivers/net/e1000/e1000.cpp +++ b/kernel/drivers/net/e1000/e1000.cpp @@ -290,7 +290,7 @@ struct page_frag_res /* TODO: Put this in an actual header */ -extern "C" struct page_frag_res page_frag_alloc(struct page_frag_alloc_info *inf, size_t size) +extern "C" struct page_frag_res page_frag_alloc2(struct page_frag_alloc_info *inf, size_t size) { assert(size <= PAGE_SIZE); @@ -352,7 +352,7 @@ int e1000_init_rx(struct e1000_device *dev) for (unsigned int i = 0; i < number_rx_desc; i++) { - struct page_frag_res res = page_frag_alloc(&alloc_info, rx_buffer_size); + struct page_frag_res res = page_frag_alloc2(&alloc_info, rx_buffer_size); /* How can this even happen? Keep this here though, as a sanity check */ if (!res.page) panic("OOM allocating rx buffers"); diff --git a/kernel/drivers/virtio/network/network.cpp b/kernel/drivers/virtio/network/network.cpp index 0b34a6ff..7aca8085 100644 --- a/kernel/drivers/virtio/network/network.cpp +++ b/kernel/drivers/virtio/network/network.cpp @@ -30,7 +30,7 @@ struct page_frag_res size_t off; }; -extern "C" struct page_frag_res page_frag_alloc(struct page_frag_alloc_info *inf, size_t size); +extern "C" struct page_frag_res page_frag_alloc2(struct page_frag_alloc_info *inf, size_t size); namespace virtio { @@ -149,7 +149,7 @@ bool network_vdev::setup_rx() for (unsigned int i = 0; i < qsize; i++) { - auto [page, off] = page_frag_alloc(&alloc_info, rx_buf_size); + auto [page, off] = page_frag_alloc2(&alloc_info, rx_buf_size); virtio_allocation_info info; diff --git a/kernel/include/onyx/atomic.h b/kernel/include/onyx/atomic.h index 80bd01ae..1ab30a29 100644 --- a/kernel/include/onyx/atomic.h +++ b/kernel/include/onyx/atomic.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024 Pedro Falcato + * Copyright (c) 2024 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -26,4 +26,12 @@ __atomic_compare_exchange_n(ptr, &__old, new, false, __ATOMIC_SEQ_CST, __ATOMIC_RELAXED); \ __old; \ }) + +#define cmpxchg_relaxed(ptr, old, new) \ + ({ \ + __auto_type __old = (old); \ + __atomic_compare_exchange_n(ptr, &__old, new, false, __ATOMIC_RELAXED, __ATOMIC_RELAXED); \ + __old; \ + }) + #endif diff --git a/kernel/include/onyx/net/socket.h b/kernel/include/onyx/net/socket.h index 8a489970..a47a7ac2 100644 --- a/kernel/include/onyx/net/socket.h +++ b/kernel/include/onyx/net/socket.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 - 2022 Pedro Falcato + * Copyright (c) 2018 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -56,6 +57,7 @@ struct socket_ops void (*close)(struct socket *); void (*handle_backlog)(struct socket *); short (*poll)(struct socket *, void *poll_file, short events); + void (*write_space)(struct socket *); }; struct socket : public refcountable @@ -79,11 +81,13 @@ struct socket : public refcountable bool broadcast_allowed : 1; bool proto_needs_work : 1 {0}; bool dead : 1 {0}; + bool sndbuf_locked : 1 {0}; struct list_head socket_backlog; unsigned int rx_max_buf; - unsigned int tx_max_buf; + unsigned int sk_sndbuf; + unsigned sk_send_queued; int backlog; unsigned int shutdown_state; @@ -92,17 +96,23 @@ struct socket : public refcountable const struct socket_ops *sock_ops; + /* Socket page frag info - used for allocating wmem */ + struct page_frag_info sock_pfi; + /* Define a default constructor here */ socket() : type{}, proto{}, domain{}, flags{}, sock_err{}, socket_lock{}, bound{}, connected{}, - reuse_addr{false}, rx_max_buf{DEFAULT_RX_MAX_BUF}, tx_max_buf{DEFAULT_TX_MAX_BUF}, + reuse_addr{false}, rx_max_buf{DEFAULT_RX_MAX_BUF}, sk_sndbuf{DEFAULT_TX_MAX_BUF}, shutdown_state{}, rcv_timeout{0}, snd_timeout{0}, sock_ops{} { INIT_LIST_HEAD(&socket_backlog); + pfi_init(&sock_pfi); + sk_send_queued = 0; } virtual ~socket() { + pfi_destroy(&sock_pfi); } short poll(void *poll_file, short events); @@ -326,6 +336,55 @@ int sock_default_getpeername(struct socket *sock, struct sockaddr *addr, socklen int sock_default_shutdown(struct socket *sock, int how); void sock_default_close(struct socket *sock); short sock_default_poll(struct socket *sock, void *poll_file, short events); + +static inline bool sock_may_write(struct socket *sock) +{ + return READ_ONCE(sock->sk_send_queued) < READ_ONCE(sock->sk_sndbuf); +} + +static inline int sock_write_space(struct socket *sock) +{ + return READ_ONCE(sock->sk_sndbuf) - READ_ONCE(sock->sk_send_queued); +} + +static inline bool sock_charge_snd_bytes(struct socket *sock, unsigned int bytes) +{ + unsigned int queued = READ_ONCE(sock->sk_send_queued), new_space, expected; + do + { + expected = queued; + new_space = queued + bytes; + if (new_space > sock->sk_sndbuf) + return false; + queued = cmpxchg_relaxed(&sock->sk_send_queued, expected, new_space); + } while (queued != expected); + return true; +} + +static inline bool sock_charge_pbf(struct socket *sock, struct packetbuf *pbf) +{ + return sock_charge_snd_bytes(sock, pbf->total_len); +} + +static inline void sock_discharge_snd_bytes(struct socket *sock, unsigned int bytes) +{ + unsigned int queued = READ_ONCE(sock->sk_send_queued), new_space, expected; + do + { + expected = queued; + new_space = queued - bytes; + WARN_ON(queued < new_space); + queued = cmpxchg_relaxed(&sock->sk_send_queued, expected, new_space); + } while (queued != expected); + + sock->sock_ops->write_space(sock); +} + +static inline void sock_discharge_pbf(struct socket *sock, struct packetbuf *pbf) +{ + sock_discharge_snd_bytes(sock, pbf->total_len); +} + __END_CDECLS #endif diff --git a/kernel/include/onyx/net/tcp.h b/kernel/include/onyx/net/tcp.h index b1c5153d..559849e4 100644 --- a/kernel/include/onyx/net/tcp.h +++ b/kernel/include/onyx/net/tcp.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -207,6 +207,7 @@ struct tcp_socket : public inet_socket unsigned int nr_sacks; int mss_for_ack; + struct list_head accept_queue; struct list_head conn_queue; int connqueue_len; }; diff --git a/kernel/include/onyx/packetbuf.h b/kernel/include/onyx/packetbuf.h index 7ddc64e3..c13b9a6b 100644 --- a/kernel/include/onyx/packetbuf.h +++ b/kernel/include/onyx/packetbuf.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -97,10 +97,12 @@ struct packetbuf uint16_t gso_size; uint8_t gso_flags; + u8 nr_vecs; unsigned int needs_csum : 1; unsigned int zero_copy : 1; int domain; + unsigned int total_len; /* The next bytes are always available for protocols. */ #define PACKETBUF_PROTO_SPACE 64 @@ -119,6 +121,9 @@ struct packetbuf struct tcp_packetbuf_info tpi; }; + struct socket *sock; + void (*dtor)(struct packetbuf *pbf); + #ifdef __cplusplus /** * @brief Construct a new default packetbuf object. @@ -130,6 +135,9 @@ struct packetbuf header_length{}, gso_size{}, gso_flags{}, needs_csum{0}, zero_copy{0}, domain{0} { route = {}; + sock = NULL; + dtor = NULL; + nr_vecs = 0; } /** @@ -439,6 +447,7 @@ static inline void pbf_put_ref(struct packetbuf *pbf) typedef unsigned int gfp_t; struct packetbuf *pbf_alloc(gfp_t gfp); +struct packetbuf *pbf_alloc_sk(gfp_t gfp, struct socket *sock, unsigned int len); __END_CDECLS diff --git a/kernel/include/onyx/page_frag.h b/kernel/include/onyx/page_frag.h new file mode 100644 index 00000000..2432b5f1 --- /dev/null +++ b/kernel/include/onyx/page_frag.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2025 Pedro Falcato + * This file is part of Onyx, and is released under the terms of the GPLv2 License + * check LICENSE at the root directory for more information + * + * SPDX-License-Identifier: GPL-2.0-only + */ +#ifndef _ONYX_PAGE_FRAG_H +#define _ONYX_PAGE_FRAG_H + +#include +#include + +__BEGIN_CDECLS + +struct page_frag_info +{ + struct page *page; + unsigned int offset; + unsigned int len; +}; + +struct page_frag +{ + struct page *page; + unsigned int offset; + unsigned int len; +}; + +static inline void pfi_init(struct page_frag_info *pfi) +{ + pfi->page = NULL; + pfi->len = pfi->offset = 0; +} + +int page_frag_alloc(struct page_frag_info *pfi, unsigned int len, gfp_t gfp, + struct page_frag *frag); + +void pfi_destroy(struct page_frag_info *pfi); + +__END_CDECLS + +#endif diff --git a/kernel/kernel/mm/Makefile b/kernel/kernel/mm/Makefile index aeb81080..5b324329 100644 --- a/kernel/kernel/mm/Makefile +++ b/kernel/kernel/mm/Makefile @@ -1,4 +1,5 @@ -mm-y:= bootmem.o page.o pagealloc.o vm_object.o vm.o vmalloc.o reclaim.o anon.o mincore.o page_lru.o swap.o rmap.o slab_cache_pool.o madvise.o +mm-y:= bootmem.o page.o pagealloc.o vm_object.o vm.o vmalloc.o reclaim.o anon.o \ + mincore.o page_lru.o swap.o rmap.o slab_cache_pool.o madvise.o page_frag.o mm-$(CONFIG_KUNIT)+= vm_tests.o mm-$(CONFIG_X86)+= memory.o mm-$(CONFIG_RISCV)+= memory.o diff --git a/kernel/kernel/mm/page_frag.c b/kernel/kernel/mm/page_frag.c new file mode 100644 index 00000000..4128d618 --- /dev/null +++ b/kernel/kernel/mm/page_frag.c @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 Pedro Falcato + * This file is part of Onyx, and is released under the terms of the GPLv2 License + * check LICENSE at the root directory for more information + * + * SPDX-License-Identifier: GPL-2.0-only + */ +#include + +#include + +static int page_frag_refill(struct page_frag_info *pfi, unsigned int len, gfp_t gfp) +{ + unsigned int order = pages2order(vm_size_to_pages(len)); + + if (WARN_ON_ONCE(order > 0)) + { + /* TODO: We're missing GFP_COMP support, and without it the refcounting gets all screwed + * up. So reject order > 0 allocations. */ + pr_warn("%s: Asked for %u bytes, which we can't deliver\n", __func__, len); + return -ENOMEM; + } + + if (pfi->page) + page_unref(pfi->page); + + pfi->page = alloc_pages(order, gfp); + if (!pfi->page) + return -ENOMEM; + pfi->offset = 0; + pfi->len = 1UL << (order + PAGE_SHIFT); + return 0; +} + +int page_frag_alloc(struct page_frag_info *pfi, unsigned int len, gfp_t gfp, struct page_frag *frag) +{ + /* Check if we don't have a page already, or if we dont have enough space for the frag */ + if (!pfi->page || pfi->len - pfi->offset < len) + { + if (page_frag_refill(pfi, len, gfp) < 0) + return -ENOMEM; + } + + page_ref(pfi->page); + frag->page = pfi->page; + frag->len = len; + frag->offset = pfi->offset; + pfi->offset += len; + + if (pfi->offset == len) + { + /* Release our ref if someone ate the whole thing. */ + page_unref(pfi->page); + pfi->page = NULL; + } + + return 0; +} + +void pfi_destroy(struct page_frag_info *pfi) +{ + if (pfi->page) + page_unref(pfi->page); +} diff --git a/kernel/kernel/net/ipv4/ipv4.cpp b/kernel/kernel/net/ipv4/ipv4.cpp index 2c8716b9..f4861eb3 100644 --- a/kernel/kernel/net/ipv4/ipv4.cpp +++ b/kernel/kernel/net/ipv4/ipv4.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016 - 2024 Pedro Falcato + * Copyright (c) 2016 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -431,7 +431,7 @@ int handle_packet(netif *nif, packetbuf *buf) buf->data += iphdr_len; /* Adjust tail to point at the end of the ipv4 packet */ - buf->tail = (unsigned char *) header + ntohs(header->total_len); + buf->tail = cul::min(buf->end, (unsigned char *) header + ntohs(header->total_len)); inet_route route; route.dst_addr.in4.s_addr = header->dest_ip; diff --git a/kernel/kernel/net/ipv6/ipv6.cpp b/kernel/kernel/net/ipv6/ipv6.cpp index 12eb1f07..a8784ef9 100644 --- a/kernel/kernel/net/ipv6/ipv6.cpp +++ b/kernel/kernel/net/ipv6/ipv6.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -433,7 +433,8 @@ int handle_packet(netif *nif, packetbuf *buf) buf->data += iphdr_len; /* Adjust tail to point at the end of the ipv4 packet */ - buf->tail = (unsigned char *) header + iphdr_len + ntohs(header->payload_length); + buf->tail = + cul::min(buf->end, (unsigned char *) header + iphdr_len + ntohs(header->payload_length)); inet_route route; route.dst_addr.in6 = header->dst_addr; diff --git a/kernel/kernel/net/packetbuf.cpp b/kernel/kernel/net/packetbuf.cpp index 5456cc14..cb1ffd1f 100644 --- a/kernel/kernel/net/packetbuf.cpp +++ b/kernel/kernel/net/packetbuf.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -81,6 +82,7 @@ bool pbf_allocate_space(struct packetbuf *pbf, size_t length) pbf->net_header = pbf->transport_header = nullptr; pbf->data = pbf->tail = (unsigned char *) pbf->buffer_start; pbf->end = (unsigned char *) pbf->buffer_start + PAGE_SIZE; + pbf->nr_vecs = nr_pages; return true; } @@ -139,6 +141,8 @@ void *packetbuf::put(unsigned int size) */ packetbuf::~packetbuf() { + if (dtor) + dtor(this); for (auto &v : page_vec) { if (v.page) @@ -177,6 +181,7 @@ packetbuf *packetbuf_clone(packetbuf *original) buf->domain = original->domain; buf->route = original->route; buf->tpi = original->tpi; + buf->nr_vecs = original->nr_vecs; return buf.release(); } @@ -406,6 +411,37 @@ struct packetbuf *pbf_alloc(gfp_t gfp) return pbf; } +struct packetbuf *pbf_alloc_sk(gfp_t gfp, struct socket *sock, unsigned int len) +{ + struct page_frag f; + struct packetbuf *pbf; + + pbf = pbf_alloc(gfp); + if (!pbf) + return NULL; + + len = ALIGN_TO(len, 4); + if (page_frag_alloc(&sock->sock_pfi, len, gfp, &f) < 0) + { + pbf_free(pbf); + return NULL; + } + + pbf->page_vec[0].page = f.page; + pbf->page_vec[0].page_off = f.offset; + pbf->page_vec[0].length = f.len; + + pbf->buffer_start = (char *) PAGE_TO_VIRT(f.page) + f.offset; + pbf->net_header = pbf->transport_header = NULL; + pbf->data = pbf->tail = (unsigned char *) pbf->buffer_start; + pbf->end = (unsigned char *) pbf->buffer_start + f.len; + pbf->sock = sock; + pbf->total_len = sizeof(struct packetbuf) + f.len; + pbf->nr_vecs = 1; + + return pbf; +} + #ifdef CONFIG_KUNIT static ref_guard alloc_pbf(unsigned int length) diff --git a/kernel/kernel/net/socket.cpp b/kernel/kernel/net/socket.cpp index 57fdccde..34f932d9 100644 --- a/kernel/kernel/net/socket.cpp +++ b/kernel/kernel/net/socket.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 - 2022 Pedro Falcato + * Copyright (c) 2018 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -1014,7 +1014,7 @@ int socket::getsockopt_socket_level(int optname, void *optval, socklen_t *optlen } case SO_SNDBUF: { - return put_option(tx_max_buf, optval, optlen); + return put_option(sk_sndbuf, optval, optlen); } case SO_REUSEADDR: { @@ -1042,7 +1042,7 @@ int socket::setsockopt_socket_level(int optname, const void *optval, socklen_t o if (ex.has_error()) return ex.error(); - rx_max_buf = ex.value(); + rx_max_buf = ex.value() * 2; return 0; } @@ -1052,7 +1052,8 @@ int socket::setsockopt_socket_level(int optname, const void *optval, socklen_t o if (ex.has_error()) return ex.error(); - tx_max_buf = ex.value(); + sk_sndbuf = ex.value() * 2; + sndbuf_locked = true; return 0; } diff --git a/kernel/kernel/net/tcp.cpp b/kernel/kernel/net/tcp.cpp index 813d8488..4007d8ac 100644 --- a/kernel/kernel/net/tcp.cpp +++ b/kernel/kernel/net/tcp.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -19,8 +19,20 @@ socket_table tcp_table; const inet_proto tcp_proto{"tcp", &tcp_table}; -u16 tcpv4_calculate_checksum(const tcp_header *header, u16 packet_length, uint32_t srcip, u32 dstip, - bool calc_data) +static inline inetsum_t tcp_data_csum(inetsum_t r, struct packetbuf *pbf) +{ + for (u8 i = 1; i < pbf->nr_vecs; i++) + { + struct page_iov *iov = &pbf->page_vec[i]; + u8 *ptr = ((u8 *) PAGE_TO_VIRT(iov->page)) + iov->page_off; + r = __ipsum_unfolded(ptr, iov->length, r); + } + + return r; +} + +u16 tcpv4_calculate_checksum(const tcp_header *header, u16 packet_length, struct packetbuf *pbf, + uint32_t srcip, u32 dstip, bool calc_data) { u32 proto = ((packet_length + IPPROTO_TCP) << 8); u16 buf[2]; @@ -31,13 +43,16 @@ u16 tcpv4_calculate_checksum(const tcp_header *header, u16 packet_length, uint32 r = __ipsum_unfolded(buf, sizeof(buf), r); if (calc_data) - r = __ipsum_unfolded(header, packet_length, r); + { + r = __ipsum_unfolded(header, pbf->tail - pbf->data, r); + r = tcp_data_csum(r, pbf); + } return ipsum_fold(r); } -u16 tcpv6_calculate_checksum(const tcp_header *header, u16 packet_length, const in6_addr &srcip, - const in6_addr &dstip, bool calc_data) +u16 tcpv6_calculate_checksum(const tcp_header *header, u16 packet_length, struct packetbuf *pbf, + const in6_addr &srcip, const in6_addr &dstip, bool calc_data) { u32 proto = htonl(IPPROTO_TCP); u32 pseudo_len = htonl(packet_length); @@ -49,7 +64,10 @@ u16 tcpv6_calculate_checksum(const tcp_header *header, u16 packet_length, const assert(header->checksum == 0); if (calc_data) - r = __ipsum_unfolded(header, packet_length, r); + { + r = __ipsum_unfolded(header, pbf->tail - pbf->data, r); + r = tcp_data_csum(r, pbf); + } return ipsum_fold(r); } @@ -69,14 +87,15 @@ u16 tcpv6_calculate_checksum(const tcp_header *header, u16 packet_length, const * @return uint16_t The internet checksum */ template -uint16_t tcp_calculate_checksum(const tcp_header *header, uint16_t len, const inet_route::addr &src, - const inet_route::addr &dest, bool do_rest_of_packet = true) +uint16_t tcp_calculate_checksum(const tcp_header *header, uint16_t len, struct packetbuf *pbf, + const inet_route::addr &src, const inet_route::addr &dest, + bool do_rest_of_packet = true) { uint16_t result = 0; if constexpr (domain == AF_INET6) - result = tcpv6_calculate_checksum(header, len, src.in6, dest.in6, do_rest_of_packet); + result = tcpv6_calculate_checksum(header, len, pbf, src.in6, dest.in6, do_rest_of_packet); else - result = tcpv4_calculate_checksum(header, len, src.in4.s_addr, dest.in4.s_addr, + result = tcpv4_calculate_checksum(header, len, pbf, src.in4.s_addr, dest.in4.s_addr, do_rest_of_packet); // Checksum offloading needs an unfolded checksum @@ -110,7 +129,10 @@ static void tcp_init_sock(struct tcp_socket *sock) sock->sack_needs_send = 0; sock->nr_sacks = 0; INIT_LIST_HEAD(&sock->conn_queue); + INIT_LIST_HEAD(&sock->accept_queue); sock->connqueue_len = 0; + /* Default the send window to 4MiB */ + sock->sk_sndbuf = 0x400000; } static __init void tcp_init() @@ -202,6 +224,12 @@ static int tcp_sendpbuf(struct tcp_socket *sock, struct packetbuf *pbf) header_length += tcp_push_options(sock, pbf); } +#if 0 + pr_warn("segment len %u\n", segment_len); + for (int i = 0; i < 3; i++) + pr_warn("vec[%d]: page %p off %u len %u\n", i, pbf->page_vec[i].page, + pbf->page_vec[i].page_off, pbf->page_vec[i].length); +#endif hdr = (struct tcp_header *) pbf_push_header(pbf, sizeof(struct tcp_header)); memset(hdr, 0, sizeof(struct tcp_header)); if (pbf->tpi.ack) @@ -230,7 +258,7 @@ static int tcp_sendpbuf(struct tcp_socket *sock, struct packetbuf *pbf) } hdr->checksum = call_based_on_inet2( - sock, tcp_calculate_checksum, hdr, static_cast(header_length + segment_len), + sock, tcp_calculate_checksum, hdr, static_cast(header_length + segment_len), pbf, sock->route_cache.src_addr, sock->route_cache.dst_addr, need_csum); iflow flow{sock->route_cache, IPPROTO_TCP, sock->effective_domain() == AF_INET6}; @@ -305,7 +333,7 @@ static void tcp_prepare_nondata_header(struct tcp_socket *sock, struct packetbuf } hdr->checksum = - call_based_on_inet2(sock, tcp_calculate_checksum, hdr, header_len, + call_based_on_inet2(sock, tcp_calculate_checksum, hdr, header_len, pbf, sock->route_cache.src_addr, sock->route_cache.dst_addr, need_csum); } @@ -631,7 +659,7 @@ short tcp_poll(struct socket *sock_, void *poll_file, short events) { if (events & POLLIN) { - if (!list_is_empty(&sock->conn_queue)) + if (!list_is_empty(&sock->accept_queue)) avail_events |= POLLIN; else poll_wait_helper(poll_file, &sock->rx_wq); @@ -651,7 +679,7 @@ short tcp_poll(struct socket *sock_, void *poll_file, short events) if (events & POLLOUT) { - if (!(sock->shutdown_state & SHUTDOWN_WR)) + if (!(sock->shutdown_state & SHUTDOWN_WR) && sock_may_write(sock)) avail_events |= POLLOUT; } @@ -688,103 +716,177 @@ static void tcp_prepare_segment(struct tcp_socket *sock, struct packetbuf *pbf) } } -static int tcp_write_alloc(struct tcp_socket *sock, const iovec *vec, size_t vec_len, size_t mss, - size_t skip_first); +static int tcp_write_alloc(struct tcp_socket *sock); -static int tcp_append_write(struct tcp_socket *sock, const struct iovec *vec, size_t vec_len, - size_t mss) +static bool tcp_attempt_merge(struct packetbuf *pbf, struct page_frag *pf) { - size_t read_in_vec = 0; - packetbuf *pbf = NULL; - unsigned int packet_len = 0; - - if (list_is_empty(&sock->output_queue)) - goto alloc_append; + struct page_iov *iov = &pbf->page_vec[pbf->nr_vecs - 1]; - pbf = list_last_entry(&sock->output_queue, struct packetbuf, list_node); - if (!vec_len) - return 0; - - while ((packet_len = pbf_length(pbf)) < mss) + /* Okay, we have the last page_iov. Check if we can merge it, if not, check if we can append the + * page frag. */ + if (likely(iov->page == pf->page && iov->page_off + iov->length == pf->offset)) { - /* OOOH, we've got some room, let's expand! */ - const uint8_t *ubuf = (uint8_t *) vec->iov_base + read_in_vec; - auto len = vec->iov_len - read_in_vec; - unsigned int to_expand = cul::clamp(len, mss - packet_len); - ssize_t st = pbf->expand_buffer(ubuf, to_expand); - - if (st < 0) - return -ENOBUFS; - - pbf->tpi.seq_len += to_expand; - read_in_vec += st; - if (read_in_vec == vec->iov_len) + iov->length += pf->len; + /* We already hold a ref, so drop the new one */ + page_unref(pf->page); + /* And adjust the pbf data area if required */ + if (pbf->nr_vecs == 1) { - vec++; - read_in_vec = 0; - vec_len--; + pbf->tail += pf->len; + pbf->end += pf->len; } - - /* Good, we're finished. */ - if (!vec_len) - return 0; + return true; } -alloc_append: - return tcp_write_alloc(sock, vec, vec_len, mss, read_in_vec); + /* TODO: We can't use the last page_iov for legacy reasons */ + if (unlikely(pbf->nr_vecs >= PBF_PAGE_IOVS - 1)) + return false; + pbf->nr_vecs++; + iov++; + iov->page = pf->page; + iov->page_off = pf->offset; + iov->length = pf->len; + return true; } -static int tcp_write_alloc(struct tcp_socket *sock, const iovec *vec, size_t vec_len, size_t mss, - size_t skip_first) +static u8 *ptr_from_frag(struct page_frag *pf) { - size_t added_from_vec = 0; - size_t vec_nr = 0; - while (vec_len) + return ((u8 *) PAGE_TO_VIRT(pf->page)) + pf->offset; +} + +static int tcp_append_to_segment(struct tcp_socket *tp, struct packetbuf *pbf, + struct iovec_iter *iter) +{ + struct iovec iov; + unsigned int len, to_add; + struct page_frag pf; + int err; + + int write_space = sock_write_space(tp); + if (write_space <= 0) + return -EWOULDBLOCK; + + len = pbf_length(pbf); + + while (len < tp->mss) { - struct packetbuf *pbf = pbf_alloc(GFP_KERNEL); - if (!pbf) + if (iter->empty()) + break; + if (write_space == 0) + break; + + iov = iter->curiovec(); + to_add = min((unsigned int) iov.iov_len, (unsigned int) write_space); + to_add = min(to_add, tp->mss - len); + to_add = min(to_add, (unsigned int) PAGE_SIZE); + + err = page_frag_alloc(&tp->sock_pfi, to_add, GFP_KERNEL, &pf); + if (err) return -ENOBUFS; - size_t iov_len = vec->iov_len; - if (vec_nr == 0) + /* Note: We cannot copy_from_iter because we don't yet know if this fragment will be valid + */ + if (copy_from_user(ptr_from_frag(&pf), iov.iov_base, to_add) < 0) { - // We might be creating a new packet from a partial iov that already filled - // some other packet in the list. - - iov_len -= skip_first; + page_unref(pf.page); + return -EFAULT; } - unsigned long max_payload = cul::clamp(iov_len - added_from_vec, mss); - unsigned long to_alloc = max_payload + PACKET_MAX_HEAD_LENGTH; - - auto ubuf = (const uint8_t *) vec->iov_base + added_from_vec; + if (WARN_ON(!sock_charge_snd_bytes(tp, to_add))) + { + /* This should not happen, since the send buf cannot suddenly shrink while we hold + * the socket lock. */ + page_unref(pf.page); + break; + } - if (!pbf_allocate_space(pbf, to_alloc)) + if (unlikely(!tcp_attempt_merge(pbf, &pf))) { - pbf_free(pbf); - return -ENOBUFS; + page_unref(pf.page); + sock_discharge_snd_bytes(tp, to_add); + break; } - pbf_reserve_headers(pbf, PACKET_MAX_HEAD_LENGTH); - auto st = pbf_expand_buffer(pbf, ubuf, max_payload); + pbf->total_len += pf.len; + len += pf.len; + write_space -= len; + pbf->tpi.seq_len += pf.len; + tp->snd_next += pf.len; + iter->advance(to_add); + } - tcp_prepare_segment(sock, pbf); - assert((size_t) st == max_payload); - added_from_vec += max_payload; - list_add_tail(&pbf->list_node, &sock->output_queue); + return iter->empty() ? 0 : -ENOSPC; +} - if (added_from_vec == iov_len) - { - added_from_vec = 0; - vec_len--; - vec++; - vec_nr++; - } +static int tcp_append_write(struct tcp_socket *sock, struct iovec_iter *iter, size_t mss, int flags) +{ + struct packetbuf *pbf = NULL; + int old_space; + int err; + + while (!iter->empty()) + { + if (!sock_may_write(sock)) + goto wait_for_space; + if (list_is_empty(&sock->output_queue)) + goto alloc_segment; + + pbf = list_last_entry(&sock->output_queue, struct packetbuf, list_node); + + err = tcp_append_to_segment(sock, pbf, iter); + if (err == -EWOULDBLOCK) + goto wait_for_space; + if (err != -ENOSPC) + return err; + + alloc_segment: + err = tcp_write_alloc(sock); + if (err == -EWOULDBLOCK) + goto wait_for_space; + if (err) + return err; + continue; + wait_for_space: + if (flags & MSG_DONTWAIT) + return -EWOULDBLOCK; + /* Try to output */ + tcp_output(sock); + old_space = sock_write_space(sock); + err = wait_for_event_socklocked_interruptible_2(&sock->rx_wq, + sock_write_space(sock) > old_space, sock); + if (err == -EINTR) + return err; } return 0; } +static void tcp_pbf_dtor(struct packetbuf *pbf) +{ + sock_discharge_pbf(pbf->sock, pbf); +} + +static int tcp_write_alloc(struct tcp_socket *sock) +{ + struct packetbuf *pbf = pbf_alloc_sk(GFP_KERNEL, sock, MAX_TCP_HEADER_LENGTH); + if (!pbf) + return -ENOBUFS; + + if (!sock_charge_pbf(sock, pbf)) + { + /* Failed to charge write space, stop. */ + pbf_free(pbf); + return -EAGAIN; + } + + pbf->dtor = tcp_pbf_dtor; + + pbf_reserve_headers(pbf, MAX_TCP_HEADER_LENGTH); + tcp_prepare_segment(sock, pbf); + list_add_tail(&pbf->list_node, &sock->output_queue); + return 0; +} + ssize_t tcp_sendmsg(struct socket *sock_, const msghdr *msg, int flags) { int err; @@ -812,13 +914,14 @@ ssize_t tcp_sendmsg(struct socket *sock_, const msghdr *msg, int flags) if (len < 0) return len; - err = tcp_append_write(sock, msg->msg_iov, msg->msg_iovlen, sock->mss); + iovec_iter iter{{msg->msg_iov, (size_t) msg->msg_iovlen}, (size_t) len, IOVEC_USER}; + err = tcp_append_write(sock, &iter, sock->mss, flags); if (err < 0) - return err; + return len - iter.bytes > 0 ?: err; err = tcp_output(sock); if (err < 0) - return err; + return len > 0 ?: err; return len; } @@ -1037,6 +1140,8 @@ void tcp_destroy_sock(struct tcp_socket *sock) bst_for_every_entry_delete(&sock->out_of_order_tree, pbf, struct packetbuf, bst_node) pbf_put_ref(pbf); + + WARN_ON(sock->sk_send_queued > 0); sock->unref(); } @@ -1113,15 +1218,17 @@ struct socket *tcp_accept(struct socket *sock_, int flags) int err; sock->socket_lock.lock(); - err = wait_for_event_socklocked_interruptible_2(&sock->rx_wq, !list_is_empty(&sock->conn_queue), - sock); + CHECK(sock->state == TCP_STATE_LISTEN); + err = wait_for_event_socklocked_interruptible_2(&sock->rx_wq, + !list_is_empty(&sock->accept_queue), sock); if (err) { sock->socket_lock.unlock(); return NULL; } - new_sock = list_first_entry(&sock->conn_queue, struct tcp_socket, conn_queue); + CHECK(!list_is_empty(&sock->accept_queue)); + new_sock = list_first_entry(&sock->accept_queue, struct tcp_socket, conn_queue); list_remove(&new_sock->conn_queue); sock->socket_lock.unlock(); return new_sock; @@ -1131,6 +1238,13 @@ static int tcp_getsockopt(struct socket *, int level, int optname, void *optval, static int tcp_setsockopt(struct socket *, int level, int optname, const void *optval, socklen_t optlen); +static void tcp_write_space(struct socket *sock) +{ + struct tcp_socket *tp = TCP_SOCK(sock); + /* This looks... overeager to wake up? */ + wait_queue_wake_all(&tp->rx_wq); +} + const struct socket_ops tcp_ops = { .destroy = cpp_destroy, .listen = tcp_listen, @@ -1147,13 +1261,14 @@ const struct socket_ops tcp_ops = { .close = tcp_close, .handle_backlog = tcp_handle_backlog, .poll = tcp_poll, + .write_space = tcp_write_space, }; struct socket *tcp_create_socket(int type) { struct tcp_socket *sock = (struct tcp_socket *) kmem_cache_alloc(tcp_cache, GFP_ATOMIC); - /* Most init is done on the ctor's side, we do tcp-specific init here... What's really important - * is that the core socket is TYPESAFE_BY_RCU. */ + /* Most init is done on the ctor's side, we do tcp-specific init here... What's really + * important is that the core socket is TYPESAFE_BY_RCU. */ if (sock) { new (sock) tcp_socket; @@ -1240,9 +1355,9 @@ int tcp_send_synack(struct tcp_connreq *conn) (TCP_FLAG_SYN | TCP_FLAG_ACK)); hdr->checksum = conn->tc_domain == AF_INET - ? tcpv4_calculate_checksum(hdr, header_len, route->dst_addr.in4.s_addr, + ? tcpv4_calculate_checksum(hdr, header_len, pbf, route->dst_addr.in4.s_addr, route->src_addr.in4.s_addr, true) - : tcpv6_calculate_checksum(hdr, header_len, route->dst_addr.in6, + : tcpv6_calculate_checksum(hdr, header_len, pbf, route->dst_addr.in6, route->src_addr.in6, true); iflow flow{*route, IPPROTO_TCP, conn->tc_domain == AF_INET6}; if (conn->tc_domain == AF_INET) @@ -1310,11 +1425,11 @@ void __tcp_send_rst(struct packetbuf *pbf, u32 seq, u32 ack_nr, int ack) hdr->checksum = pbf->domain == AF_INET - ? tcpv4_calculate_checksum(hdr, sizeof(struct tcp_header), + ? tcpv4_calculate_checksum(hdr, sizeof(struct tcp_header), rst_pbf, other_route->dst_addr.in4.s_addr, other_route->src_addr.in4.s_addr, true) - : tcpv6_calculate_checksum(hdr, sizeof(struct tcp_header), other_route->dst_addr.in6, - other_route->src_addr.in6, true); + : tcpv6_calculate_checksum(hdr, sizeof(struct tcp_header), rst_pbf, + other_route->dst_addr.in6, other_route->src_addr.in6, true); iflow flow{ex.value(), IPPROTO_TCP, pbf->domain == AF_INET6}; if (pbf->domain == AF_INET) ip::v4::send_packet(flow, rst_pbf); diff --git a/kernel/kernel/net/tcp_input.cpp b/kernel/kernel/net/tcp_input.cpp index a1318757..08835cca 100644 --- a/kernel/kernel/net/tcp_input.cpp +++ b/kernel/kernel/net/tcp_input.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 - 2024 Pedro Falcato + * Copyright (c) 2020 - 2025 Pedro Falcato * This file is part of Onyx, and is released under the terms of the GPLv2 License * check LICENSE at the root directory for more information * @@ -795,6 +795,7 @@ static int tcp_input_conn(struct tcp_connreq *conn, struct packetbuf *pbf) sock->dest_addr = conn->tc_dst; sock->src_addr = conn->tc_src; + pr_warn("src %pI4 dst %pI4\n", &sock->src_addr.in4, &sock->dest_addr.in4); sock->ipv4_on_inet6 = on_ipv4_mode; sock->route_cache = cul::move(conn->tc_route); sock->route_cache_valid = 1; @@ -869,11 +870,12 @@ static int tcp_input_conn(struct tcp_connreq *conn, struct packetbuf *pbf) } else { + conn->tc_sock = NULL; list_remove(&conn->tc_list_node); kfree_rcu(conn, tc_rcu_head); /* We can double up this conn_queue as a list node, because sock->conn_queue will never be * in a LISTEN state */ - list_add_tail(&sock->conn_queue, &parent->conn_queue); + list_add_tail(&sock->conn_queue, &parent->accept_queue); wait_queue_wake_all(&parent->rx_wq); }