diff --git a/include/aws/io/channel_bootstrap.h b/include/aws/io/channel_bootstrap.h index 8e28579df..e65794756 100644 --- a/include/aws/io/channel_bootstrap.h +++ b/include/aws/io/channel_bootstrap.h @@ -178,7 +178,7 @@ struct aws_server_bootstrap { struct aws_socket_channel_bootstrap_options { struct aws_client_bootstrap *bootstrap; const char *host_name; - uint16_t port; + uint32_t port; const struct aws_socket_options *socket_options; const struct aws_tls_connection_options *tls_options; aws_client_bootstrap_on_channel_event_fn *creation_callback; @@ -208,7 +208,7 @@ struct aws_socket_channel_bootstrap_options { struct aws_server_socket_channel_bootstrap_options { struct aws_server_bootstrap *bootstrap; const char *host_name; - uint16_t port; + uint32_t port; const struct aws_socket_options *socket_options; const struct aws_tls_connection_options *tls_options; aws_server_bootstrap_on_accept_channel_setup_fn *incoming_callback; diff --git a/include/aws/io/socket.h b/include/aws/io/socket.h index de4fed356..a6223b05e 100644 --- a/include/aws/io/socket.h +++ b/include/aws/io/socket.h @@ -98,7 +98,7 @@ typedef void(aws_socket_on_readable_fn)(struct aws_socket *socket, int error_cod #endif struct aws_socket_endpoint { char address[AWS_ADDRESS_MAX_LEN]; - uint16_t port; + uint32_t port; }; struct aws_socket { @@ -302,6 +302,22 @@ AWS_IO_API int aws_socket_get_error(struct aws_socket *socket); */ AWS_IO_API bool aws_socket_is_open(struct aws_socket *socket); +/** + * Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if connecting to this port is illegal. + * For example, port must be in range 1-65535 to connect with IPv4. + * These port values would fail eventually in aws_socket_connect(), + * but you can use this function to validate earlier. + */ +AWS_IO_API int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain); + +/** + * Raises AWS_IO_SOCKET_INVALID_ADDRESS and logs an error if binding to this port is illegal. + * For example, port must in range 0-65535 to bind with IPv4. + * These port values would fail eventually in aws_socket_bind(), + * but you can use this function to validate earlier. + */ +AWS_IO_API int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain); + /** * Assigns a random address (UUID) for use with AWS_SOCKET_LOCAL (Unix Domain Sockets). * For use in internal tests only. diff --git a/source/channel_bootstrap.c b/source/channel_bootstrap.c index dfafe02fc..f5e364261 100644 --- a/source/channel_bootstrap.c +++ b/source/channel_bootstrap.c @@ -118,7 +118,7 @@ struct client_connection_args { aws_client_bootstrap_on_channel_event_fn *shutdown_callback; struct client_channel_data channel_data; struct aws_socket_options outgoing_options; - uint16_t outgoing_port; + uint32_t outgoing_port; struct aws_string *host_name; void *user_data; uint8_t addresses_count; @@ -764,14 +764,14 @@ int aws_client_bootstrap_new_socket_channel(struct aws_socket_channel_bootstrap_ } const char *host_name = options->host_name; - uint16_t port = options->port; + uint32_t port = options->port; AWS_LOGF_TRACE( AWS_LS_IO_CHANNEL_BOOTSTRAP, - "id=%p: attempting to initialize a new client channel to %s:%d", + "id=%p: attempting to initialize a new client channel to %s:%u", (void *)bootstrap, host_name, - (int)port); + port); aws_ref_count_init( &client_connection_args->ref_count, @@ -1363,10 +1363,10 @@ struct aws_socket *aws_server_bootstrap_new_socket_listener( AWS_LOGF_INFO( AWS_LS_IO_CHANNEL_BOOTSTRAP, "id=%p: attempting to initialize a new " - "server socket listener for %s:%d", + "server socket listener for %s:%u", (void *)bootstrap_options->bootstrap, bootstrap_options->host_name, - (int)bootstrap_options->port); + bootstrap_options->port); aws_ref_count_init( &server_connection_args->ref_count, diff --git a/source/posix/socket.c b/source/posix/socket.c index 99361106f..dcf0c9d55 100644 --- a/source/posix/socket.c +++ b/source/posix/socket.c @@ -340,18 +340,7 @@ static int s_update_local_endpoint(struct aws_socket *socket) { } else if (address.ss_family == AF_VSOCK) { struct sockaddr_vm *s = (struct sockaddr_vm *)&address; - /* VSOCK port is 32bit, but aws_socket_endpoint.port is only 16bit. - * Hopefully this isn't an issue, since users can only pass in 16bit values. - * But if it becomes an issue, we'll need to make aws_socket_endpoint more flexible */ - if (s->svm_port > UINT16_MAX) { - AWS_LOGF_ERROR( - AWS_LS_IO_SOCKET, - "id=%p fd=%d: aws_socket_endpoint can't deal with VSOCK port > UINT16_MAX", - (void *)socket, - socket->io_handle.data.fd); - return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); - } - tmp_endpoint.port = (uint16_t)s->svm_port; + tmp_endpoint.port = s->svm_port; snprintf(tmp_endpoint.address, sizeof(tmp_endpoint.address), "%" PRIu32, s->svm_cid); return AWS_OP_SUCCESS; @@ -642,18 +631,22 @@ int aws_socket_connect( return AWS_OP_ERR; } + if (aws_socket_validate_port_for_connect(remote_endpoint->port, socket->options.domain)) { + return AWS_OP_ERR; + } + struct socket_address address; AWS_ZERO_STRUCT(address); socklen_t sock_size = 0; int pton_err = 1; if (socket->options.domain == AWS_SOCKET_IPV4) { pton_err = inet_pton(AF_INET, remote_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); - address.sock_addr_types.addr_in.sin_port = htons(remote_endpoint->port); + address.sock_addr_types.addr_in.sin_port = htons((uint16_t)remote_endpoint->port); address.sock_addr_types.addr_in.sin_family = AF_INET; sock_size = sizeof(address.sock_addr_types.addr_in); } else if (socket->options.domain == AWS_SOCKET_IPV6) { pton_err = inet_pton(AF_INET6, remote_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); - address.sock_addr_types.addr_in6.sin6_port = htons(remote_endpoint->port); + address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port); address.sock_addr_types.addr_in6.sin6_family = AF_INET6; sock_size = sizeof(address.sock_addr_types.addr_in6); } else if (socket->options.domain == AWS_SOCKET_LOCAL) { @@ -664,7 +657,7 @@ int aws_socket_connect( } else if (socket->options.domain == AWS_SOCKET_VSOCK) { pton_err = parse_cid(remote_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; - address.sock_addr_types.vm_addr.svm_port = (unsigned int)remote_endpoint->port; + address.sock_addr_types.vm_addr.svm_port = remote_endpoint->port; sock_size = sizeof(address.sock_addr_types.vm_addr); #endif } else { @@ -676,21 +669,21 @@ int aws_socket_connect( int errno_value = errno; /* Always cache errno before potential side-effect */ AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to parse address %s:%d.", + "id=%p fd=%d: failed to parse address %s:%u.", (void *)socket, socket->io_handle.data.fd, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); return aws_raise_error(s_convert_pton_error(pton_err, errno_value)); } AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p fd=%d: connecting to endpoint %s:%d.", + "id=%p fd=%d: connecting to endpoint %s:%u.", (void *)socket, socket->io_handle.data.fd, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); socket->state = CONNECTING; socket->remote_endpoint = *remote_endpoint; @@ -806,13 +799,17 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint return AWS_OP_ERR; } + if (aws_socket_validate_port_for_bind(local_endpoint->port, socket->options.domain)) { + return AWS_OP_ERR; + } + AWS_LOGF_INFO( AWS_LS_IO_SOCKET, - "id=%p fd=%d: binding to %s:%d.", + "id=%p fd=%d: binding to %s:%u.", (void *)socket, socket->io_handle.data.fd, local_endpoint->address, - (int)local_endpoint->port); + local_endpoint->port); struct socket_address address; AWS_ZERO_STRUCT(address); @@ -820,12 +817,12 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint int pton_err = 1; if (socket->options.domain == AWS_SOCKET_IPV4) { pton_err = inet_pton(AF_INET, local_endpoint->address, &address.sock_addr_types.addr_in.sin_addr); - address.sock_addr_types.addr_in.sin_port = htons(local_endpoint->port); + address.sock_addr_types.addr_in.sin_port = htons((uint16_t)local_endpoint->port); address.sock_addr_types.addr_in.sin_family = AF_INET; sock_size = sizeof(address.sock_addr_types.addr_in); } else if (socket->options.domain == AWS_SOCKET_IPV6) { pton_err = inet_pton(AF_INET6, local_endpoint->address, &address.sock_addr_types.addr_in6.sin6_addr); - address.sock_addr_types.addr_in6.sin6_port = htons(local_endpoint->port); + address.sock_addr_types.addr_in6.sin6_port = htons((uint16_t)local_endpoint->port); address.sock_addr_types.addr_in6.sin6_family = AF_INET6; sock_size = sizeof(address.sock_addr_types.addr_in6); } else if (socket->options.domain == AWS_SOCKET_LOCAL) { @@ -836,7 +833,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint } else if (socket->options.domain == AWS_SOCKET_VSOCK) { pton_err = parse_cid(local_endpoint->address, &address.sock_addr_types.vm_addr.svm_cid); address.sock_addr_types.vm_addr.svm_family = AF_VSOCK; - address.sock_addr_types.vm_addr.svm_port = (unsigned int)local_endpoint->port; + address.sock_addr_types.vm_addr.svm_port = local_endpoint->port; sock_size = sizeof(address.sock_addr_types.vm_addr); #endif } else { @@ -848,11 +845,11 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint int errno_value = errno; /* Always cache errno before potential side-effect */ AWS_LOGF_ERROR( AWS_LS_IO_SOCKET, - "id=%p fd=%d: failed to parse address %s:%d.", + "id=%p fd=%d: failed to parse address %s:%u.", (void *)socket, socket->io_handle.data.fd, local_endpoint->address, - (int)local_endpoint->port); + local_endpoint->port); return aws_raise_error(s_convert_pton_error(pton_err, errno_value)); } @@ -882,7 +879,7 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p fd=%d: successfully bound to %s:%d", + "id=%p fd=%d: successfully bound to %s:%u", (void *)socket, socket->io_handle.data.fd, socket->local_endpoint.address, @@ -996,7 +993,7 @@ static void s_socket_accept_event( new_sock->local_endpoint = socket->local_endpoint; new_sock->state = CONNECTED_READ | CONNECTED_WRITE; - uint16_t port = 0; + uint32_t port = 0; /* get the info on the incoming socket's address */ if (in_addr.ss_family == AF_INET) { diff --git a/source/socket_shared.c b/source/socket_shared.c new file mode 100644 index 000000000..63c640b49 --- /dev/null +++ b/source/socket_shared.c @@ -0,0 +1,75 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include + +#include + +/* common validation for connect() and bind() */ +static int s_socket_validate_port_for_domain(uint32_t port, enum aws_socket_domain domain) { + switch (domain) { + case AWS_SOCKET_IPV4: + case AWS_SOCKET_IPV6: + if (port > UINT16_MAX) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "Invalid port=%u for %s. Cannot exceed 65535", + port, + domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6"); + return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); + } + break; + + case AWS_SOCKET_LOCAL: + /* port is ignored */ + break; + + case AWS_SOCKET_VSOCK: + /* any 32bit port is legal */ + break; + + default: + AWS_LOGF_ERROR(AWS_LS_IO_SOCKET, "Cannot validate port for unknown domain=%d", domain); + return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); + } + return AWS_OP_SUCCESS; +} + +int aws_socket_validate_port_for_connect(uint32_t port, enum aws_socket_domain domain) { + if (s_socket_validate_port_for_domain(port, domain)) { + return AWS_OP_ERR; + } + + /* additional validation */ + switch (domain) { + case AWS_SOCKET_IPV4: + case AWS_SOCKET_IPV6: + if (port == 0) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, + "Invalid port=%u for %s connections. Must use 1-65535", + port, + domain == AWS_SOCKET_IPV4 ? "IPv4" : "IPv6"); + return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); + } + break; + + case AWS_SOCKET_VSOCK: + if (port == (uint32_t)-1) { + AWS_LOGF_ERROR( + AWS_LS_IO_SOCKET, "Invalid port for VSOCK connections. Cannot use VMADDR_PORT_ANY (-1U)."); + return aws_raise_error(AWS_IO_SOCKET_INVALID_ADDRESS); + } + break; + + default: + /* no extra validation */ + break; + } + return AWS_OP_SUCCESS; +} + +int aws_socket_validate_port_for_bind(uint32_t port, enum aws_socket_domain domain) { + return s_socket_validate_port_for_domain(port, domain); +} diff --git a/source/windows/iocp/socket.c b/source/windows/iocp/socket.c index b3f899b13..6d879417b 100644 --- a/source/windows/iocp/socket.c +++ b/source/windows/iocp/socket.c @@ -449,6 +449,11 @@ int aws_socket_connect( return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); } } + + if (aws_socket_validate_port_for_connect(remote_endpoint->port, socket->options.domain)) { + return AWS_OP_ERR; + } + return socket_impl->vtable->connect(socket, remote_endpoint, event_loop, on_connection_result, user_data); } @@ -457,6 +462,11 @@ int aws_socket_bind(struct aws_socket *socket, const struct aws_socket_endpoint socket->state = ERRORED; return aws_raise_error(AWS_IO_SOCKET_ILLEGAL_OPERATION_FOR_STATE); } + + if (aws_socket_validate_port_for_bind(local_endpoint->port, socket->options.domain)) { + return AWS_OP_ERR; + } + struct iocp_socket *socket_impl = socket->impl; return socket_impl->vtable->bind(socket, local_endpoint); } @@ -709,11 +719,11 @@ static int s_ipv4_stream_connection_success(struct aws_socket *socket) { AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: local endpoint %s:%d", + "id=%p handle=%p: local endpoint %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, socket->local_endpoint.address, - (int)socket->local_endpoint.port); + socket->local_endpoint.port); setsockopt((SOCKET)socket->io_handle.data.handle, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0); socket->state = CONNECTED_WRITE | CONNECTED_READ; @@ -770,11 +780,11 @@ static int s_ipv6_stream_connection_success(struct aws_socket *socket) { AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: local endpoint %s:%d", + "id=%p handle=%p: local endpoint %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, socket->local_endpoint.address, - (int)socket->local_endpoint.port); + socket->local_endpoint.port); setsockopt((SOCKET)socket->io_handle.data.handle, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0); @@ -1050,23 +1060,23 @@ static int s_ipv4_stream_connect( int aws_err = s_convert_pton_error(err); /* call before logging or WSAError may get cleared */ AWS_LOGF_ERROR( AWS_LS_IO_SOCKET, - "id=%p handle=%p: failed to parse address %s:%d.", + "id=%p handle=%p: failed to parse address %s:%u.", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); return aws_raise_error(aws_err); } AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: connecting to endpoint %s:%d.", + "id=%p handle=%p: connecting to endpoint %s:%u.", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); - addr_in.sin_port = htons(remote_endpoint->port); + addr_in.sin_port = htons((uint16_t)remote_endpoint->port); addr_in.sin_family = AF_INET; /* stupid as hell, we have to bind first*/ @@ -1117,23 +1127,23 @@ static int s_ipv6_stream_connect( int aws_err = s_convert_pton_error(pton_err); /* call before logging or WSAError may get cleared */ AWS_LOGF_ERROR( AWS_LS_IO_SOCKET, - "id=%p handle=%p: failed to parse address %s:%d.", + "id=%p handle=%p: failed to parse address %s:%u.", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); return aws_raise_error(aws_err); } AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: connecting to endpoint %s:%d.", + "id=%p handle=%p: connecting to endpoint %s:%u.", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); - addr_in6.sin6_port = htons(remote_endpoint->port); + addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port); addr_in6.sin6_family = AF_INET6; return s_tcp_connect( @@ -1244,11 +1254,11 @@ static inline int s_dgram_connect( AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: connecting to to %s:%d", + "id=%p handle=%p: connecting to to %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port); + remote_endpoint->port); int reuse = 1; if (setsockopt((SOCKET)socket->io_handle.data.handle, SOL_SOCKET, SO_REUSEADDR, (char *)&reuse, sizeof(int))) { @@ -1269,11 +1279,11 @@ static inline int s_dgram_connect( int wsa_err = WSAGetLastError(); /* logging may reset error, so cache it */ AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: Failed to connect to %s:%d with error %d.", + "id=%p handle=%p: Failed to connect to %s:%u with error %d.", (void *)socket, (void *)socket->io_handle.data.handle, remote_endpoint->address, - (int)remote_endpoint->port, + remote_endpoint->port, wsa_err); aws_raise_error(s_determine_socket_error(wsa_err)); goto error; @@ -1285,11 +1295,11 @@ static inline int s_dgram_connect( AWS_LOGF_DEBUG( AWS_LS_IO_SOCKET, - "id=%p handle=%p: local endpoint %s:%d", + "id=%p handle=%p: local endpoint %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, socket->local_endpoint.address, - (int)socket->local_endpoint.port); + socket->local_endpoint.port); if (s_process_tcp_sock_options(socket)) { goto error; @@ -1334,7 +1344,7 @@ static int s_ipv4_dgram_connect( return aws_raise_error(aws_err); } - addr_in.sin_port = htons(remote_endpoint->port); + addr_in.sin_port = htons((uint16_t)remote_endpoint->port); addr_in.sin_family = AF_INET; return s_dgram_connect(socket, remote_endpoint, connect_loop, (struct sockaddr *)&addr_in, sizeof(addr_in)); @@ -1361,7 +1371,7 @@ static int s_ipv6_dgram_connect( return aws_raise_error(aws_err); } - addr_in6.sin6_port = htons(remote_endpoint->port); + addr_in6.sin6_port = htons((uint16_t)remote_endpoint->port); addr_in6.sin6_family = AF_INET6; return s_dgram_connect(socket, remote_endpoint, connect_loop, (struct sockaddr *)&addr_in6, sizeof(addr_in6)); @@ -1406,11 +1416,11 @@ static inline int s_tcp_bind(struct aws_socket *socket, struct sockaddr *sock_ad AWS_LOGF_INFO( AWS_LS_IO_SOCKET, - "id=%p handle=%p: binding to tcp %s:%d", + "id=%p handle=%p: binding to tcp %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, socket->local_endpoint.address, - (int)socket->local_endpoint.port); + socket->local_endpoint.port); socket->state = BOUND; return AWS_OP_SUCCESS; @@ -1431,7 +1441,7 @@ static int s_ipv4_stream_bind(struct aws_socket *socket, const struct aws_socket return aws_raise_error(aws_err); } - addr_in.sin_port = htons(local_endpoint->port); + addr_in.sin_port = htons((uint16_t)local_endpoint->port); addr_in.sin_family = AF_INET; return s_tcp_bind(socket, (struct sockaddr *)&addr_in, sizeof(addr_in)); @@ -1448,7 +1458,7 @@ static int s_ipv6_stream_bind(struct aws_socket *socket, const struct aws_socket return aws_raise_error(aws_err); } - addr_in6.sin6_port = htons(local_endpoint->port); + addr_in6.sin6_port = htons((uint16_t)local_endpoint->port); addr_in6.sin6_family = AF_INET6; return s_tcp_bind(socket, (struct sockaddr *)&addr_in6, sizeof(addr_in6)); @@ -1474,11 +1484,11 @@ static inline int s_udp_bind(struct aws_socket *socket, struct sockaddr *sock_ad AWS_LOGF_INFO( AWS_LS_IO_SOCKET, - "id=%p handle=%p: binding to udp %s:%p", + "id=%p handle=%p: binding to udp %s:%u", (void *)socket, (void *)socket->io_handle.data.handle, socket->local_endpoint.address, - (int)socket->local_endpoint.port); + socket->local_endpoint.port); socket->state = CONNECTED_READ; return AWS_OP_SUCCESS; @@ -1499,7 +1509,7 @@ static int s_ipv4_dgram_bind(struct aws_socket *socket, const struct aws_socket_ return aws_raise_error(aws_err); } - addr_in.sin_port = htons(local_endpoint->port); + addr_in.sin_port = htons((uint16_t)local_endpoint->port); addr_in.sin_family = AF_INET; return s_udp_bind(socket, (struct sockaddr *)&addr_in, sizeof(addr_in)); @@ -1516,7 +1526,7 @@ static int s_ipv6_dgram_bind(struct aws_socket *socket, const struct aws_socket_ return aws_raise_error(aws_err); } - addr_in6.sin6_port = htons(local_endpoint->port); + addr_in6.sin6_port = htons((uint16_t)local_endpoint->port); addr_in6.sin6_family = AF_INET6; return s_udp_bind(socket, (struct sockaddr *)&addr_in6, sizeof(addr_in6)); @@ -1888,7 +1898,7 @@ static void s_tcp_accept_event( do { socket_impl->incoming_socket->state = CONNECTED_WRITE | CONNECTED_READ; - uint16_t port = 0; + uint32_t port = 0; struct sockaddr_storage *in_addr = (struct sockaddr_storage *)socket_impl->accept_buffer; @@ -1917,11 +1927,11 @@ static void s_tcp_accept_event( socket_impl->incoming_socket->remote_endpoint.port = port; AWS_LOGF_INFO( AWS_LS_IO_SOCKET, - "id=%p handle=%p: incoming connection accepted from %s:%d.", + "id=%p handle=%p: incoming connection accepted from %s:%u.", (void *)socket, (void *)socket->io_handle.data.handle, socket_impl->incoming_socket->remote_endpoint.address, - (int)port); + port); u_long non_blocking = 1; ioctlsocket((SOCKET)socket_impl->incoming_socket->io_handle.data.handle, FIONBIO, &non_blocking); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 992e7456b..7f263f850 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -76,6 +76,7 @@ add_net_test_case(cleanup_before_connect_or_timeout_doesnt_explode) add_test_case(cleanup_in_accept_doesnt_explode) add_test_case(cleanup_in_write_cb_doesnt_explode) add_test_case(sock_write_cb_is_async) +add_test_case(socket_validate_port) if(WIN32) add_test_case(local_socket_pipe_connected_race) diff --git a/tests/socket_test.c b/tests/socket_test.c index c654742d2..b23ed084b 100644 --- a/tests/socket_test.c +++ b/tests/socket_test.c @@ -1752,3 +1752,56 @@ static int s_local_socket_pipe_connected_race(struct aws_allocator *allocator, v AWS_TEST_CASE(local_socket_pipe_connected_race, s_local_socket_pipe_connected_race) #endif + +static int s_test_socket_validate_port(struct aws_allocator *allocator, void *ctx) { + (void)allocator; + (void)ctx; + + /* IPv4 - 16bit port, only bind can use 0 */ + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_IPV4)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_IPV4)); + + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0, AWS_SOCKET_IPV4)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_IPV4)); + + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0xFFFFFFFF, AWS_SOCKET_IPV4)); + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(0xFFFFFFFF, AWS_SOCKET_IPV4)); + + /* IPv6 - 16bit port, only bind can use 0 */ + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_IPV6)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_IPV6)); + + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0, AWS_SOCKET_IPV6)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_IPV6)); + + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(0xFFFFFFFF, AWS_SOCKET_IPV6)); + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(0xFFFFFFFF, AWS_SOCKET_IPV6)); + + /* VSOCK - 32bit port, only bind can use VMADDR_PORT_ANY (-1U) */ + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_VSOCK)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_VSOCK)); + + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0, AWS_SOCKET_VSOCK)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_VSOCK)); + + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0x7FFFFFFF, AWS_SOCKET_VSOCK)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0x7FFFFFFF, AWS_SOCKET_VSOCK)); + + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect((uint32_t)-1, AWS_SOCKET_VSOCK)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind((uint32_t)-1, AWS_SOCKET_VSOCK)); + + /* LOCAL - ignores port */ + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(0, AWS_SOCKET_LOCAL)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(0, AWS_SOCKET_LOCAL)); + ASSERT_SUCCESS(aws_socket_validate_port_for_connect(80, AWS_SOCKET_LOCAL)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind(80, AWS_SOCKET_LOCAL)); + ASSERT_SUCCESS(aws_socket_validate_port_for_connect((uint32_t)-1, AWS_SOCKET_LOCAL)); + ASSERT_SUCCESS(aws_socket_validate_port_for_bind((uint32_t)-1, AWS_SOCKET_LOCAL)); + + /* invalid domain should fail */ + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_connect(80, (enum aws_socket_domain)(-1))); + ASSERT_ERROR(AWS_IO_SOCKET_INVALID_ADDRESS, aws_socket_validate_port_for_bind(80, (enum aws_socket_domain)(-1))); + + return 0; +} +AWS_TEST_CASE(socket_validate_port, s_test_socket_validate_port) diff --git a/tests/tls_handler_test.c b/tests/tls_handler_test.c index b44bc70fa..fb1b6a4d5 100644 --- a/tests/tls_handler_test.c +++ b/tests/tls_handler_test.c @@ -672,7 +672,7 @@ struct default_host_callback_data { static int s_verify_negotiation_fails_helper( struct aws_allocator *allocator, const struct aws_string *host_name, - uint16_t port, + uint32_t port, struct aws_tls_ctx_options *client_ctx_options) { struct aws_tls_ctx *client_ctx = aws_tls_client_ctx_new(allocator, client_ctx_options); @@ -755,7 +755,7 @@ static int s_verify_negotiation_fails_helper( static int s_verify_negotiation_fails( struct aws_allocator *allocator, const struct aws_string *host_name, - uint16_t port, + uint32_t port, void (*context_options_override_fn)(struct aws_tls_ctx_options *)) { aws_io_library_init(allocator); @@ -1019,7 +1019,7 @@ static int s_tls_client_channel_negotiation_error_socket_closed_fn(struct aws_al (void)ctx; const char *host_name = "aws-crt-test-stuff.s3.amazonaws.com"; - uint16_t port = 80; /* Note: intentionally wrong and not 443 */ + uint32_t port = 80; /* Note: intentionally wrong and not 443 */ aws_io_library_init(allocator); @@ -1084,7 +1084,7 @@ AWS_TEST_CASE( static int s_verify_good_host( struct aws_allocator *allocator, const struct aws_string *host_name, - uint16_t port, + uint32_t port, void (*override_tls_options_fn)(struct aws_tls_ctx_options *)) { aws_io_library_init(allocator);